| | |
| | import sys |
| | from collections.abc import Iterable |
| | from runpy import run_path |
| | from shlex import split |
| | from typing import Any, Callable, Dict, List, Optional, Union |
| | from unittest.mock import patch |
| |
|
| | from torch.nn import GroupNorm, LayerNorm |
| | from torch.testing import assert_allclose as _assert_allclose |
| |
|
| | from mmengine.utils import digit_version |
| | from mmengine.utils.dl_utils import TORCH_VERSION |
| | from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm |
| |
|
| |
|
| | def assert_allclose( |
| | actual: Any, |
| | expected: Any, |
| | rtol: Optional[float] = None, |
| | atol: Optional[float] = None, |
| | equal_nan: bool = True, |
| | msg: Optional[Union[str, Callable]] = '', |
| | ) -> None: |
| | """Asserts that ``actual`` and ``expected`` are close. A wrapper function |
| | of ``torch.testing.assert_allclose``. |
| | |
| | Args: |
| | actual (Any): Actual input. |
| | expected (Any): Expected input. |
| | rtol (Optional[float]): Relative tolerance. If specified ``atol`` must |
| | also be specified. If omitted, default values based on the |
| | :attr:`~torch.Tensor.dtype` are selected with the below table. |
| | atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` |
| | must also be specified. If omitted, default values based on the |
| | :attr:`~torch.Tensor.dtype` are selected with the below table. |
| | equal_nan (bool): If ``True``, two ``NaN`` values will be considered |
| | equal. |
| | msg (Optional[Union[str, Callable]]): Optional error message to use if |
| | the values of corresponding tensors mismatch. Unused when PyTorch |
| | < 1.6. |
| | """ |
| | if 'parrots' not in TORCH_VERSION and \ |
| | digit_version(TORCH_VERSION) >= digit_version('1.6'): |
| | _assert_allclose( |
| | actual, |
| | expected, |
| | rtol=rtol, |
| | atol=atol, |
| | equal_nan=equal_nan, |
| | msg=msg) |
| | else: |
| | |
| | |
| | _assert_allclose( |
| | actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) |
| |
|
| |
|
| | def check_python_script(cmd): |
| | """Run the python cmd script with `__main__`. The difference between |
| | `os.system` is that, this function exectues code in the current process, so |
| | that it can be tracked by coverage tools. Currently it supports two forms: |
| | |
| | - ./tests/data/scripts/hello.py zz |
| | - python tests/data/scripts/hello.py zz |
| | """ |
| | args = split(cmd) |
| | if args[0] == 'python': |
| | args = args[1:] |
| | with patch.object(sys, 'argv', args): |
| | run_path(args[0], run_name='__main__') |
| |
|
| |
|
| | def _any(judge_result): |
| | """Since built-in ``any`` works only when the element of iterable is not |
| | iterable, implement the function.""" |
| | if not isinstance(judge_result, Iterable): |
| | return judge_result |
| |
|
| | try: |
| | for element in judge_result: |
| | if _any(element): |
| | return True |
| | except TypeError: |
| | |
| | if judge_result: |
| | return True |
| | return False |
| |
|
| |
|
| | def assert_dict_contains_subset(dict_obj: Dict[Any, Any], |
| | expected_subset: Dict[Any, Any]) -> bool: |
| | """Check if the dict_obj contains the expected_subset. |
| | |
| | Args: |
| | dict_obj (Dict[Any, Any]): Dict object to be checked. |
| | expected_subset (Dict[Any, Any]): Subset expected to be contained in |
| | dict_obj. |
| | |
| | Returns: |
| | bool: Whether the dict_obj contains the expected_subset. |
| | """ |
| |
|
| | for key, value in expected_subset.items(): |
| | if key not in dict_obj.keys() or _any(dict_obj[key] != value): |
| | return False |
| | return True |
| |
|
| |
|
| | def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool: |
| | """Check if attribute of class object is correct. |
| | |
| | Args: |
| | obj (object): Class object to be checked. |
| | expected_attrs (Dict[str, Any]): Dict of the expected attrs. |
| | |
| | Returns: |
| | bool: Whether the attribute of class object is correct. |
| | """ |
| | for attr, value in expected_attrs.items(): |
| | if not hasattr(obj, attr) or _any(getattr(obj, attr) != value): |
| | return False |
| | return True |
| |
|
| |
|
| | def assert_dict_has_keys(obj: Dict[str, Any], |
| | expected_keys: List[str]) -> bool: |
| | """Check if the obj has all the expected_keys. |
| | |
| | Args: |
| | obj (Dict[str, Any]): Object to be checked. |
| | expected_keys (List[str]): Keys expected to contained in the keys of |
| | the obj. |
| | |
| | Returns: |
| | bool: Whether the obj has the expected keys. |
| | """ |
| | return set(expected_keys).issubset(set(obj.keys())) |
| |
|
| |
|
| | def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool: |
| | """Check if target_keys is equal to result_keys. |
| | |
| | Args: |
| | result_keys (List[str]): Result keys to be checked. |
| | target_keys (List[str]): Target keys to be checked. |
| | |
| | Returns: |
| | bool: Whether target_keys is equal to result_keys. |
| | """ |
| | return set(result_keys) == set(target_keys) |
| |
|
| |
|
| | def assert_is_norm_layer(module) -> bool: |
| | """Check if the module is a norm layer. |
| | |
| | Args: |
| | module (nn.Module): The module to be checked. |
| | |
| | Returns: |
| | bool: Whether the module is a norm layer. |
| | """ |
| |
|
| | norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) |
| | return isinstance(module, norm_layer_candidates) |
| |
|
| |
|
| | def assert_params_all_zeros(module) -> bool: |
| | """Check if the parameters of the module is all zeros. |
| | |
| | Args: |
| | module (nn.Module): The module to be checked. |
| | |
| | Returns: |
| | bool: Whether the parameters of the module is all zeros. |
| | """ |
| | weight_data = module.weight.data |
| | is_weight_zero = weight_data.allclose( |
| | weight_data.new_zeros(weight_data.size())) |
| |
|
| | if hasattr(module, 'bias') and module.bias is not None: |
| | bias_data = module.bias.data |
| | is_bias_zero = bias_data.allclose( |
| | bias_data.new_zeros(bias_data.size())) |
| | else: |
| | is_bias_zero = True |
| |
|
| | return is_weight_zero and is_bias_zero |
| |
|