|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import os | 
					
						
						|  | from typing import List, Optional | 
					
						
						|  |  | 
					
						
						|  | import pytest | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | WORLD_SIZE_OPTIONS = (1, 2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pytest_plugins = [ | 
					
						
						|  |  | 
					
						
						|  | 'tests.fixtures.fixtures', | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _get_world_size(item: pytest.Item): | 
					
						
						|  | """Returns the world_size of a test, defaults to 1.""" | 
					
						
						|  | _default = pytest.mark.world_size(1).mark | 
					
						
						|  | return item.get_closest_marker('world_size', default=_default).args[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _get_option( | 
					
						
						|  | config: pytest.Config, | 
					
						
						|  | name: str, | 
					
						
						|  | default: Optional[str] = None, | 
					
						
						|  | ) -> str: | 
					
						
						|  | val = config.getoption(name) | 
					
						
						|  | if val is not None: | 
					
						
						|  | assert isinstance(val, str) | 
					
						
						|  | return val | 
					
						
						|  | val = config.getini(name) | 
					
						
						|  | if val == []: | 
					
						
						|  | val = None | 
					
						
						|  | if val is None: | 
					
						
						|  | if default is None: | 
					
						
						|  | pytest.fail(f'Config option {name} is not specified but is required',) | 
					
						
						|  | val = default | 
					
						
						|  | assert isinstance(val, str) | 
					
						
						|  | return val | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _add_option( | 
					
						
						|  | parser: pytest.Parser, | 
					
						
						|  | name: str, | 
					
						
						|  | help: str, | 
					
						
						|  | choices: Optional[list[str]] = None, | 
					
						
						|  | ): | 
					
						
						|  | parser.addoption( | 
					
						
						|  | f'--{name}', | 
					
						
						|  | default=None, | 
					
						
						|  | type=str, | 
					
						
						|  | choices=choices, | 
					
						
						|  | help=help, | 
					
						
						|  | ) | 
					
						
						|  | parser.addini( | 
					
						
						|  | name=name, | 
					
						
						|  | help=help, | 
					
						
						|  | type='string', | 
					
						
						|  | default=None, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def pytest_collection_modifyitems( | 
					
						
						|  | config: pytest.Config, | 
					
						
						|  | items: List[pytest.Item], | 
					
						
						|  | ) -> None: | 
					
						
						|  | """Filter tests by world_size (for multi-GPU tests)""" | 
					
						
						|  | world_size = int(os.environ.get('WORLD_SIZE', '1')) | 
					
						
						|  | print(f'world_size={world_size}') | 
					
						
						|  |  | 
					
						
						|  | conditions = [ | 
					
						
						|  | lambda item: _get_world_size(item) == world_size, | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | remaining = [] | 
					
						
						|  | deselected = [] | 
					
						
						|  | for item in items: | 
					
						
						|  | if all(condition(item) for condition in conditions): | 
					
						
						|  | remaining.append(item) | 
					
						
						|  | else: | 
					
						
						|  | deselected.append(item) | 
					
						
						|  |  | 
					
						
						|  | if deselected: | 
					
						
						|  | config.hook.pytest_deselected(items=deselected) | 
					
						
						|  | items[:] = remaining | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def pytest_addoption(parser: pytest.Parser) -> None: | 
					
						
						|  | _add_option( | 
					
						
						|  | parser, | 
					
						
						|  | 'seed', | 
					
						
						|  | help="""\ | 
					
						
						|  | Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked | 
					
						
						|  | before each test.""", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def pytest_sessionfinish(session: pytest.Session, exitstatus: int): | 
					
						
						|  | if exitstatus == 5: | 
					
						
						|  | session.exitstatus = 0 | 
					
						
						|  |  |