# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A script to add and/or update the attribute `pipeline_model_mapping` in model test files. This script will be (mostly) used in the following 2 situations: - run within a (scheduled) CI job to: - check if model test files in the library have updated `pipeline_model_mapping`, - and/or update test files and (possibly) open a GitHub pull request automatically - being run by a `transformers` member to quickly check and update some particular test file(s) This script is **NOT** intended to be run (manually) by community contributors. """ import argparse import glob import inspect import os import re import unittest from get_test_info import get_test_classes from tests.test_pipeline_mixin import pipeline_test_mapping PIPELINE_TEST_MAPPING = {} for task, _ in pipeline_test_mapping.items(): PIPELINE_TEST_MAPPING[task] = {"pt": None, "tf": None} # DO **NOT** add item to this set (unless the reason is approved) TEST_FILE_TO_IGNORE = { "tests/models/esm/test_modeling_esmfold.py", # The pipeline test mapping is added to `test_modeling_esm.py` } def get_framework(test_class): """Infer the framework from the test class `test_class`.""" if "ModelTesterMixin" in [x.__name__ for x in test_class.__bases__]: return "pt" elif "TFModelTesterMixin" in [x.__name__ for x in test_class.__bases__]: return "tf" elif "FlaxModelTesterMixin" in [x.__name__ for x in test_class.__bases__]: return "flax" else: return None def get_mapping_for_task(task, framework): """Get mappings defined in `XXXPipelineTests` for the task `task`.""" # Use the cached results if PIPELINE_TEST_MAPPING[task].get(framework, None) is not None: return PIPELINE_TEST_MAPPING[task][framework] pipeline_test_class = pipeline_test_mapping[task]["test"] mapping = None if framework == "pt": mapping = getattr(pipeline_test_class, "model_mapping", None) elif framework == "tf": mapping = getattr(pipeline_test_class, "tf_model_mapping", None) if mapping is not None: mapping = dict(mapping.items()) # cache the results PIPELINE_TEST_MAPPING[task][framework] = mapping return mapping def get_model_for_pipeline_test(test_class, task): """Get the model architecture(s) related to the test class `test_class` for a pipeline `task`.""" framework = get_framework(test_class) if framework is None: return None mapping = get_mapping_for_task(task, framework) if mapping is None: return None config_classes = list({model_class.config_class for model_class in test_class.all_model_classes}) if len(config_classes) != 1: raise ValueError("There should be exactly one configuration class from `test_class.all_model_classes`.") # This could be a list/tuple of model classes, but it's rare. model_class = mapping.get(config_classes[0], None) if isinstance(model_class, (tuple, list)): model_class = sorted(model_class, key=lambda x: x.__name__) return model_class def get_pipeline_model_mapping(test_class): """Get `pipeline_model_mapping` for `test_class`.""" mapping = [(task, get_model_for_pipeline_test(test_class, task)) for task in pipeline_test_mapping] mapping = sorted([(task, model) for task, model in mapping if model is not None], key=lambda x: x[0]) return dict(mapping) def get_pipeline_model_mapping_string(test_class): """Get `pipeline_model_mapping` for `test_class` as a string (to be added to the test file). This will be a 1-line string. After this is added to a test file, `make style` will format it beautifully. """ framework = get_framework(test_class) if framework == "pt": framework = "torch" default_value = "{}" mapping = get_pipeline_model_mapping(test_class) if len(mapping) == 0: return "" texts = [] for task, model_classes in mapping.items(): if isinstance(model_classes, (tuple, list)): # A list/tuple of model classes value = "(" + ", ".join([x.__name__ for x in model_classes]) + ")" else: # A single model class value = model_classes.__name__ texts.append(f'"{task}": {value}') text = "{" + ", ".join(texts) + "}" text = f"pipeline_model_mapping = {text} if is_{framework}_available() else {default_value}" return text def is_valid_test_class(test_class): """Restrict to `XXXModelTesterMixin` and should be a subclass of `unittest.TestCase`.""" base_class_names = {"ModelTesterMixin", "TFModelTesterMixin", "FlaxModelTesterMixin"} if not issubclass(test_class, unittest.TestCase): return False return len(base_class_names.intersection([x.__name__ for x in test_class.__bases__])) > 0 def find_test_class(test_file): """Find a test class in `test_file` to which we will add `pipeline_model_mapping`.""" test_classes = [x for x in get_test_classes(test_file) if is_valid_test_class(x)] target_test_class = None for test_class in test_classes: # If a test class has defined `pipeline_model_mapping`, let's take it if getattr(test_class, "pipeline_model_mapping", None) is not None: target_test_class = test_class break # Take the test class with the shortest name (just a heuristic) if target_test_class is None and len(test_classes) > 0: target_test_class = sorted(test_classes, key=lambda x: (len(x.__name__), x.__name__))[0] return target_test_class def find_block_ending(lines, start_idx, indent_level): end_idx = start_idx for idx, line in enumerate(lines[start_idx:]): indent = len(line) - len(line.lstrip()) if idx == 0 or indent > indent_level or (indent == indent_level and line.strip() == ")"): end_idx = start_idx + idx elif idx > 0 and indent <= indent_level: # Outside the definition block of `pipeline_model_mapping` break return end_idx def add_pipeline_model_mapping(test_class, overwrite=False): """Add `pipeline_model_mapping` to `test_class`.""" if getattr(test_class, "pipeline_model_mapping", None) is not None: if not overwrite: return "", -1 line_to_add = get_pipeline_model_mapping_string(test_class) if len(line_to_add) == 0: return "", -1 line_to_add = line_to_add + "\n" # The code defined the class `test_class` class_lines, class_start_line_no = inspect.getsourcelines(test_class) # `inspect` gives the code for an object, including decorator(s) if any. # We (only) need the exact line of the class definition. for idx, line in enumerate(class_lines): if line.lstrip().startswith("class "): class_lines = class_lines[idx:] class_start_line_no += idx break class_end_line_no = class_start_line_no + len(class_lines) - 1 # The index in `class_lines` that starts the definition of `all_model_classes`, `all_generative_model_classes` or # `pipeline_model_mapping`. This assumes they are defined in such order, and we take the start index of the last # block that appears in a `test_class`. start_idx = None # The indent level of the line at `class_lines[start_idx]` (if defined) indent_level = 0 # To record if `pipeline_model_mapping` is found in `test_class`. def_line = None for idx, line in enumerate(class_lines): if line.strip().startswith("all_model_classes = "): indent_level = len(line) - len(line.lstrip()) start_idx = idx elif line.strip().startswith("all_generative_model_classes = "): indent_level = len(line) - len(line.lstrip()) start_idx = idx elif line.strip().startswith("pipeline_model_mapping = "): indent_level = len(line) - len(line.lstrip()) start_idx = idx def_line = line break if start_idx is None: return "", -1 # Find the ending index (inclusive) of the above found block. end_idx = find_block_ending(class_lines, start_idx, indent_level) # Extract `is_xxx_available()` from existing blocks: some models require specific libraries like `timm` and use # `is_timm_available()` instead of `is_torch_available()`. # Keep leading and trailing whitespaces r = re.compile(r"\s(is_\S+?_available\(\))\s") for line in class_lines[start_idx : end_idx + 1]: backend_condition = r.search(line) if backend_condition is not None: # replace the leading and trailing whitespaces to the space character " ". target = " " + backend_condition[0][1:-1] + " " line_to_add = r.sub(target, line_to_add) break if def_line is None: # `pipeline_model_mapping` is not defined. The target index is set to the ending index (inclusive) of # `all_model_classes` or `all_generative_model_classes`. target_idx = end_idx else: # `pipeline_model_mapping` is defined. The target index is set to be one **BEFORE** its start index. target_idx = start_idx - 1 # mark the lines of the currently existing `pipeline_model_mapping` to be removed. for idx in range(start_idx, end_idx + 1): # These lines are going to be removed before writing to the test file. class_lines[idx] = None # noqa # Make sure the test class is a subclass of `PipelineTesterMixin`. parent_classes = [x.__name__ for x in test_class.__bases__] if "PipelineTesterMixin" not in parent_classes: # Put `PipelineTesterMixin` just before `unittest.TestCase` _parent_classes = [x for x in parent_classes if x != "TestCase"] + ["PipelineTesterMixin"] if "TestCase" in parent_classes: # Here we **assume** the original string is always with `unittest.TestCase`. _parent_classes.append("unittest.TestCase") parent_classes = ", ".join(_parent_classes) for idx, line in enumerate(class_lines): # Find the ending of the declaration of `test_class` if line.strip().endswith("):"): # mark the lines of the declaration of `test_class` to be removed for _idx in range(idx + 1): class_lines[_idx] = None # noqa break # Add the new, one-line, class declaration for `test_class` class_lines[0] = f"class {test_class.__name__}({parent_classes}):\n" # Add indentation line_to_add = " " * indent_level + line_to_add # Insert `pipeline_model_mapping` to `class_lines`. # (The line at `target_idx` should be kept by definition!) class_lines = class_lines[: target_idx + 1] + [line_to_add] + class_lines[target_idx + 1 :] # Remove the lines that are marked to be removed class_lines = [x for x in class_lines if x is not None] # Move from test class to module (in order to write to the test file) module_lines = inspect.getsourcelines(inspect.getmodule(test_class))[0] # Be careful with the 1-off between line numbers and array indices module_lines = module_lines[: class_start_line_no - 1] + class_lines + module_lines[class_end_line_no:] code = "".join(module_lines) moddule_file = inspect.getsourcefile(test_class) with open(moddule_file, "w", encoding="UTF-8", newline="\n") as fp: fp.write(code) return line_to_add def add_pipeline_model_mapping_to_test_file(test_file, overwrite=False): """Add `pipeline_model_mapping` to `test_file`.""" test_class = find_test_class(test_file) if test_class: add_pipeline_model_mapping(test_class, overwrite=overwrite) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--test_file", type=str, help="A path to the test file, starting with the repository's `tests` directory." ) parser.add_argument( "--all", action="store_true", help="If to check and modify all test files.", ) parser.add_argument( "--overwrite", action="store_true", help="If to overwrite a test class if it has already defined `pipeline_model_mapping`.", ) args = parser.parse_args() if not args.all and not args.test_file: raise ValueError("Please specify either `test_file` or pass `--all` to check/modify all test files.") elif args.all and args.test_file: raise ValueError("Only one of `--test_file` and `--all` could be specified.") test_files = [] if args.test_file: test_files = [args.test_file] else: pattern = os.path.join("tests", "models", "**", "test_modeling_*.py") for test_file in glob.glob(pattern): # `Flax` is not concerned at this moment if not test_file.startswith("test_modeling_flax_"): test_files.append(test_file) for test_file in test_files: if test_file in TEST_FILE_TO_IGNORE: print(f"[SKIPPED] {test_file} is skipped as it is in `TEST_FILE_TO_IGNORE` in the file {__file__}.") continue add_pipeline_model_mapping_to_test_file(test_file, overwrite=args.overwrite)