# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import shutil import warnings from argparse import ArgumentParser, Namespace from pathlib import Path from typing import List from ..utils import logging from . import BaseTransformersCLICommand try: from cookiecutter.main import cookiecutter _has_cookiecutter = True except ImportError: _has_cookiecutter = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name def add_new_model_command_factory(args: Namespace): return AddNewModelCommand(args.testing, args.testing_file, path=args.path) class AddNewModelCommand(BaseTransformersCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): add_new_model_parser = parser.add_parser("add-new-model") add_new_model_parser.add_argument("--testing", action="store_true", help="If in testing mode.") add_new_model_parser.add_argument("--testing_file", type=str, help="Configuration file on which to run.") add_new_model_parser.add_argument( "--path", type=str, help="Path to cookiecutter. Should only be used for testing purposes." ) add_new_model_parser.set_defaults(func=add_new_model_command_factory) def __init__(self, testing: bool, testing_file: str, path=None, *args): self._testing = testing self._testing_file = testing_file self._path = path def run(self): warnings.warn( "The command `transformers-cli add-new-model` is deprecated and will be removed in v5 of Transformers. " "It is not actively maintained anymore, so might give a result that won't pass all tests and quality " "checks, you should use `transformers-cli add-new-model-like` instead." ) if not _has_cookiecutter: raise ImportError( "Model creation dependencies are required to use the `add_new_model` command. Install them by running " "the following at the root of your `transformers` clone:\n\n\t$ pip install -e .[modelcreation]\n" ) # Ensure that there is no other `cookiecutter-template-xxx` directory in the current working directory directories = [directory for directory in os.listdir() if "cookiecutter-template-" == directory[:22]] if len(directories) > 0: raise ValueError( "Several directories starting with `cookiecutter-template-` in current working directory. " "Please clean your directory by removing all folders starting with `cookiecutter-template-` or " "change your working directory." ) path_to_transformer_root = ( Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent ) path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model" # Execute cookiecutter if not self._testing: cookiecutter(str(path_to_cookiecutter)) else: with open(self._testing_file, "r") as configuration_file: testing_configuration = json.load(configuration_file) cookiecutter( str(path_to_cookiecutter if self._path is None else self._path), no_input=True, extra_context=testing_configuration, ) directory = [directory for directory in os.listdir() if "cookiecutter-template-" in directory[:22]][0] # Retrieve configuration with open(directory + "/configuration.json", "r") as configuration_file: configuration = json.load(configuration_file) lowercase_model_name = configuration["lowercase_modelname"] generate_tensorflow_pytorch_and_flax = configuration["generate_tensorflow_pytorch_and_flax"] os.remove(f"{directory}/configuration.json") output_pytorch = "PyTorch" in generate_tensorflow_pytorch_and_flax output_tensorflow = "TensorFlow" in generate_tensorflow_pytorch_and_flax output_flax = "Flax" in generate_tensorflow_pytorch_and_flax model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}" os.makedirs(model_dir, exist_ok=True) os.makedirs(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}", exist_ok=True) # Tests require submodules as they have parent imports with open(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/__init__.py", "w"): pass shutil.move( f"{directory}/__init__.py", f"{model_dir}/__init__.py", ) shutil.move( f"{directory}/configuration_{lowercase_model_name}.py", f"{model_dir}/configuration_{lowercase_model_name}.py", ) def remove_copy_lines(path): with open(path, "r") as f: lines = f.readlines() with open(path, "w") as f: for line in lines: if "# Copied from transformers." not in line: f.write(line) if output_pytorch: if not self._testing: remove_copy_lines(f"{directory}/modeling_{lowercase_model_name}.py") shutil.move( f"{directory}/modeling_{lowercase_model_name}.py", f"{model_dir}/modeling_{lowercase_model_name}.py", ) shutil.move( f"{directory}/test_modeling_{lowercase_model_name}.py", f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_{lowercase_model_name}.py", ) else: os.remove(f"{directory}/modeling_{lowercase_model_name}.py") os.remove(f"{directory}/test_modeling_{lowercase_model_name}.py") if output_tensorflow: if not self._testing: remove_copy_lines(f"{directory}/modeling_tf_{lowercase_model_name}.py") shutil.move( f"{directory}/modeling_tf_{lowercase_model_name}.py", f"{model_dir}/modeling_tf_{lowercase_model_name}.py", ) shutil.move( f"{directory}/test_modeling_tf_{lowercase_model_name}.py", f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_tf_{lowercase_model_name}.py", ) else: os.remove(f"{directory}/modeling_tf_{lowercase_model_name}.py") os.remove(f"{directory}/test_modeling_tf_{lowercase_model_name}.py") if output_flax: if not self._testing: remove_copy_lines(f"{directory}/modeling_flax_{lowercase_model_name}.py") shutil.move( f"{directory}/modeling_flax_{lowercase_model_name}.py", f"{model_dir}/modeling_flax_{lowercase_model_name}.py", ) shutil.move( f"{directory}/test_modeling_flax_{lowercase_model_name}.py", f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_flax_{lowercase_model_name}.py", ) else: os.remove(f"{directory}/modeling_flax_{lowercase_model_name}.py") os.remove(f"{directory}/test_modeling_flax_{lowercase_model_name}.py") shutil.move( f"{directory}/{lowercase_model_name}.md", f"{path_to_transformer_root}/docs/source/en/model_doc/{lowercase_model_name}.md", ) shutil.move( f"{directory}/tokenization_{lowercase_model_name}.py", f"{model_dir}/tokenization_{lowercase_model_name}.py", ) shutil.move( f"{directory}/tokenization_fast_{lowercase_model_name}.py", f"{model_dir}/tokenization_{lowercase_model_name}_fast.py", ) from os import fdopen, remove from shutil import copymode, move from tempfile import mkstemp def replace(original_file: str, line_to_copy_below: str, lines_to_copy: List[str]): # Create temp file fh, abs_path = mkstemp() line_found = False with fdopen(fh, "w") as new_file: with open(original_file) as old_file: for line in old_file: new_file.write(line) if line_to_copy_below in line: line_found = True for line_to_copy in lines_to_copy: new_file.write(line_to_copy) if not line_found: raise ValueError(f"Line {line_to_copy_below} was not found in file.") # Copy the file permissions from the old file to the new file copymode(original_file, abs_path) # Remove original file remove(original_file) # Move new file move(abs_path, original_file) def skip_units(line): return ( ("generating PyTorch" in line and not output_pytorch) or ("generating TensorFlow" in line and not output_tensorflow) or ("generating Flax" in line and not output_flax) ) def replace_in_files(path_to_datafile): with open(path_to_datafile) as datafile: lines_to_copy = [] skip_file = False skip_snippet = False for line in datafile: if "# To replace in: " in line and "##" not in line: file_to_replace_in = line.split('"')[1] skip_file = skip_units(line) elif "# Below: " in line and "##" not in line: line_to_copy_below = line.split('"')[1] skip_snippet = skip_units(line) elif "# End." in line and "##" not in line: if not skip_file and not skip_snippet: replace(file_to_replace_in, line_to_copy_below, lines_to_copy) lines_to_copy = [] elif "# Replace with" in line and "##" not in line: lines_to_copy = [] elif "##" not in line: lines_to_copy.append(line) remove(path_to_datafile) replace_in_files(f"{directory}/to_replace_{lowercase_model_name}.py") os.rmdir(directory)