Skip to content

Commit

Permalink
add a option to disable the plugin (#25)
Browse files Browse the repository at this point in the history
This is especially useful when debugging potential incompatibilites with the vanilla collection
  • Loading branch information
pmeier authored May 25, 2021
1 parent cd6e565 commit b551426
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 1 deletion.
14 changes: 13 additions & 1 deletion pytest_pytorch/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TORCH_AVAILABLE = False

warnings.warn(
"Disabling the plugin 'pytest-pytorch', because 'torch' could not be imported."
"Disabling the `pytest-pytorch` plugin, because 'torch' could not be imported."
)


Expand Down Expand Up @@ -87,10 +87,22 @@ def collect(self):
yield from super().collect()


def pytest_addoption(parser, pluginmanager):
parser.addoption(
"--disable-pytest-pytorch",
action="store_true",
help="Disable the `pytest-pytorch` plugin",
)
return None


def pytest_pycollect_makeitem(collector, name, obj):
if not TORCH_AVAILABLE:
return None

if collector.config.getoption("disable_pytest_pytorch"):
return None

try:
if not issubclass(obj, TestCaseTemplate) or obj is TestCaseTemplate:
return None
Expand Down
15 changes: 15 additions & 0 deletions tests/assets/test_disabled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import TestCase


class TestFoo(TestCase):
def test_bar(self, device):
pass


instantiate_device_type_tests(TestFoo, globals(), only_for="cpu")


class TestSpam(TestCase):
def test_ham(self):
pass
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ def collect_tests(testdir):
def collect_tests_(file: str, cmds: str):
testdir.copy_example(file)
result = testdir.runpytest("--quiet", "--collect-only", *cmds)

if result.outlines[-1].startswith("no tests collected"):
return set()

assert result.ret == pytest.ExitCode.OK

collection = set()
Expand Down
7 changes: 7 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pytest


@pytest.mark.parametrize("option", ["--disable-pytest-pytorch"])
def test_disable_pytest_pytorch(testdir, option):
result = testdir.runpytest("--help")
assert option in "\n".join(result.outlines)
42 changes: 42 additions & 0 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,45 @@ def test_op_infos(collect_tests, file, cmds, selection):
def test_nested_names(collect_tests, file, cmds, selection):
collection = collect_tests(file, cmds)
assert collection == selection


@make_parametrization(
Config(
selection=(
"::TestFooCPU::test_bar_cpu",
"::TestSpam::test_ham",
),
),
Config(
new_cmds="::TestFoo",
selection=(),
),
Config(
new_cmds="::TestFoo::test_bar",
selection=(),
),
Config(
new_cmds="::TestFooCPU",
legacy_cmds=("-k", "TestFoo"),
selection=("::TestFooCPU::test_bar_cpu",),
),
Config(
new_cmds="::TestFooCPU::test_bar_cpu",
legacy_cmds=("-k", "TestFoo and test_bar"),
selection=("::TestFooCPU::test_bar_cpu",),
),
Config(
new_cmds="::TestSpam",
legacy_cmds=("-k", "TestSpam"),
selection=("::TestSpam::test_ham",),
),
Config(
new_cmds="::TestSpam::test_ham",
legacy_cmds=("-k", "TestSpam and test_ham"),
selection=("::TestSpam::test_ham",),
),
file="test_disabled.py",
)
def test_disabled(collect_tests, file, cmds, selection):
collection = collect_tests(file, ("--disable-pytest-pytorch", *cmds))
assert collection == selection

0 comments on commit b551426

Please sign in to comment.