# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import os import sys # This is required to make the module import works (when the python process is running from the root of the repo) sys.path.append(".") r""" The argument `test_file` in this file refers to a model test file. This should be a string of the from `tests/models/*/test_modeling_*.py`. """ def get_module_path(test_file): """Return the module path of a model test file.""" components = test_file.split(os.path.sep) if components[0:2] != ["tests", "models"]: raise ValueError( "`test_file` should start with `tests/models/` (with `/` being the OS specific path separator). Got " f"{test_file} instead." ) test_fn = components[-1] if not test_fn.endswith("py"): raise ValueError(f"`test_file` should be a python file. Got {test_fn} instead.") if not test_fn.startswith("test_modeling_"): raise ValueError( f"`test_file` should point to a file name of the form `test_modeling_*.py`. Got {test_fn} instead." ) components = components[:-1] + [test_fn.replace(".py", "")] test_module_path = ".".join(components) return test_module_path def get_test_module(test_file): """Get the module of a model test file.""" test_module_path = get_module_path(test_file) test_module = importlib.import_module(test_module_path) return test_module def get_tester_classes(test_file): """Get all classes in a model test file whose names ends with `ModelTester`.""" tester_classes = [] test_module = get_test_module(test_file) for attr in dir(test_module): if attr.endswith("ModelTester"): tester_classes.append(getattr(test_module, attr)) # sort with class names return sorted(tester_classes, key=lambda x: x.__name__) def get_test_classes(test_file): """Get all [test] classes in a model test file with attribute `all_model_classes` that are non-empty. These are usually the (model) test classes containing the (non-slow) tests to run and are subclasses of one of the classes `ModelTesterMixin`, `TFModelTesterMixin` or `FlaxModelTesterMixin`, as well as a subclass of `unittest.TestCase`. Exceptions include `RagTestMixin` (and its subclasses). """ test_classes = [] test_module = get_test_module(test_file) for attr in dir(test_module): attr_value = getattr(test_module, attr) # (TF/Flax)ModelTesterMixin is also an attribute in specific model test module. Let's exclude them by checking # `all_model_classes` is not empty (which also excludes other special classes). model_classes = getattr(attr_value, "all_model_classes", []) if len(model_classes) > 0: test_classes.append(attr_value) # sort with class names return sorted(test_classes, key=lambda x: x.__name__) def get_model_classes(test_file): """Get all model classes that appear in `all_model_classes` attributes in a model test file.""" test_classes = get_test_classes(test_file) model_classes = set() for test_class in test_classes: model_classes.update(test_class.all_model_classes) # sort with class names return sorted(model_classes, key=lambda x: x.__name__) def get_model_tester_from_test_class(test_class): """Get the model tester class of a model test class.""" test = test_class() if hasattr(test, "setUp"): test.setUp() model_tester = None if hasattr(test, "model_tester"): # `(TF/Flax)ModelTesterMixin` has this attribute default to `None`. Let's skip this case. if test.model_tester is not None: model_tester = test.model_tester.__class__ return model_tester def get_test_classes_for_model(test_file, model_class): """Get all [test] classes in `test_file` that have `model_class` in their `all_model_classes`.""" test_classes = get_test_classes(test_file) target_test_classes = [] for test_class in test_classes: if model_class in test_class.all_model_classes: target_test_classes.append(test_class) # sort with class names return sorted(target_test_classes, key=lambda x: x.__name__) def get_tester_classes_for_model(test_file, model_class): """Get all model tester classes in `test_file` that are associated to `model_class`.""" test_classes = get_test_classes_for_model(test_file, model_class) tester_classes = [] for test_class in test_classes: tester_class = get_model_tester_from_test_class(test_class) if tester_class is not None: tester_classes.append(tester_class) # sort with class names return sorted(tester_classes, key=lambda x: x.__name__) def get_test_to_tester_mapping(test_file): """Get a mapping from [test] classes to model tester classes in `test_file`. This uses `get_test_classes` which may return classes that are NOT subclasses of `unittest.TestCase`. """ test_classes = get_test_classes(test_file) test_tester_mapping = {test_class: get_model_tester_from_test_class(test_class) for test_class in test_classes} return test_tester_mapping def get_model_to_test_mapping(test_file): """Get a mapping from model classes to test classes in `test_file`.""" model_classes = get_model_classes(test_file) model_test_mapping = { model_class: get_test_classes_for_model(test_file, model_class) for model_class in model_classes } return model_test_mapping def get_model_to_tester_mapping(test_file): """Get a mapping from model classes to model tester classes in `test_file`.""" model_classes = get_model_classes(test_file) model_to_tester_mapping = { model_class: get_tester_classes_for_model(test_file, model_class) for model_class in model_classes } return model_to_tester_mapping def to_json(o): """Make the information succinct and easy to read. Avoid the full class representation like `` when displaying the results. Instead, we use class name (`BertForMaskedLM`) for the readability. """ if isinstance(o, str): return o elif isinstance(o, type): return o.__name__ elif isinstance(o, (list, tuple)): return [to_json(x) for x in o] elif isinstance(o, dict): return {to_json(k): to_json(v) for k, v in o.items()} else: return o