# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import os from argparse import ArgumentParser, Namespace from importlib import import_module import huggingface_hub import numpy as np from packaging import version from .. import ( FEATURE_EXTRACTOR_MAPPING, IMAGE_PROCESSOR_MAPPING, PROCESSOR_MAPPING, TOKENIZER_MAPPING, AutoConfig, AutoFeatureExtractor, AutoImageProcessor, AutoProcessor, AutoTokenizer, is_datasets_available, is_tf_available, is_torch_available, ) from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging from . import BaseTransformersCLICommand if is_tf_available(): import tensorflow as tf tf.config.experimental.enable_tensor_float_32_execution(False) if is_torch_available(): import torch if is_datasets_available(): from datasets import load_dataset MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors def convert_command_factory(args: Namespace): """ Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint. Returns: ServeCommand """ return PTtoTFCommand( args.model_name, args.local_dir, args.max_error, args.new_weights, args.no_pr, args.push, args.extra_commit_description, args.override_model_class, ) class PTtoTFCommand(BaseTransformersCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): """ Register this command to argparse so it's available for the transformer-cli Args: parser: Root parser to register command-specific arguments """ train_parser = parser.add_parser( "pt-to-tf", help=( "CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint." " Can also be used to validate existing weights without opening PRs, with --no-pr." ), ) train_parser.add_argument( "--model-name", type=str, required=True, help="The model name, including owner/organization, as seen on the hub.", ) train_parser.add_argument( "--local-dir", type=str, default="", help="Optional local directory of the model repository. Defaults to /tmp/{model_name}", ) train_parser.add_argument( "--max-error", type=float, default=MAX_ERROR, help=( f"Maximum error tolerance. Defaults to {MAX_ERROR}. This flag should be avoided, use at your own risk." ), ) train_parser.add_argument( "--new-weights", action="store_true", help="Optional flag to create new TensorFlow weights, even if they already exist.", ) train_parser.add_argument( "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights." ) train_parser.add_argument( "--push", action="store_true", help="Optional flag to push the weights directly to `main` (requires permissions)", ) train_parser.add_argument( "--extra-commit-description", type=str, default="", help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).", ) train_parser.add_argument( "--override-model-class", type=str, default=None, help="If you think you know better than the auto-detector, you can specify the model class here. " "Can be either an AutoModel class or a specific model class like BertForSequenceClassification.", ) train_parser.set_defaults(func=convert_command_factory) @staticmethod def find_pt_tf_differences(pt_outputs, tf_outputs): """ Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences. """ # 1. All output attributes must be the same pt_out_attrs = set(pt_outputs.keys()) tf_out_attrs = set(tf_outputs.keys()) if pt_out_attrs != tf_out_attrs: raise ValueError( f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, TensorFlow:" f" {tf_out_attrs})" ) # 2. For each output attribute, computes the difference def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""): # If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in # recursivelly, keeping the name of the attribute. if isinstance(pt_out, torch.Tensor): tensor_difference = np.max(np.abs(pt_out.numpy() - tf_out.numpy())) differences[attr_name] = tensor_difference else: root_name = attr_name for i, pt_item in enumerate(pt_out): # If it is a named attribute, we keep the name. Otherwise, just its index. if isinstance(pt_item, str): branch_name = root_name + pt_item tf_item = tf_out[pt_item] pt_item = pt_out[pt_item] else: branch_name = root_name + f"[{i}]" tf_item = tf_out[i] differences = _find_pt_tf_differences(pt_item, tf_item, differences, branch_name) return differences return _find_pt_tf_differences(pt_outputs, tf_outputs, {}) def __init__( self, model_name: str, local_dir: str, max_error: float, new_weights: bool, no_pr: bool, push: bool, extra_commit_description: str, override_model_class: str, *args, ): self._logger = logging.get_logger("transformers-cli/pt_to_tf") self._model_name = model_name self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) self._max_error = max_error self._new_weights = new_weights self._no_pr = no_pr self._push = push self._extra_commit_description = extra_commit_description self._override_model_class = override_model_class def get_inputs(self, pt_model, tf_dummy_inputs, config): """ Returns the right inputs for the model, based on its signature. """ def _get_audio_input(): ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") speech_samples = ds.sort("id").select(range(2))[:2]["audio"] raw_samples = [x["array"] for x in speech_samples] return raw_samples model_config_class = type(pt_model.config) if model_config_class in PROCESSOR_MAPPING: processor = AutoProcessor.from_pretrained(self._local_dir) if model_config_class in TOKENIZER_MAPPING and processor.tokenizer.pad_token is None: processor.tokenizer.pad_token = processor.tokenizer.eos_token elif model_config_class in IMAGE_PROCESSOR_MAPPING: processor = AutoImageProcessor.from_pretrained(self._local_dir) elif model_config_class in FEATURE_EXTRACTOR_MAPPING: processor = AutoFeatureExtractor.from_pretrained(self._local_dir) elif model_config_class in TOKENIZER_MAPPING: processor = AutoTokenizer.from_pretrained(self._local_dir) if processor.pad_token is None: processor.pad_token = processor.eos_token else: raise ValueError(f"Unknown data processing type (model config type: {model_config_class})") model_forward_signature = set(inspect.signature(pt_model.forward).parameters.keys()) processor_inputs = {} if "input_ids" in model_forward_signature: processor_inputs.update( { "text": ["Hi there!", "I am a batch with more than one row and different input lengths."], "padding": True, "truncation": True, } ) if "pixel_values" in model_forward_signature: sample_images = load_dataset("cifar10", "plain_text", split="test")[:2]["img"] processor_inputs.update({"images": sample_images}) if "input_features" in model_forward_signature: feature_extractor_signature = inspect.signature(processor.feature_extractor).parameters # Pad to the largest input length by default but take feature extractor default # padding value if it exists e.g. "max_length" and is not False or None if "padding" in feature_extractor_signature: default_strategy = feature_extractor_signature["padding"].default if default_strategy is not False and default_strategy is not None: padding_strategy = default_strategy else: padding_strategy = True else: padding_strategy = True processor_inputs.update({"audio": _get_audio_input(), "padding": padding_strategy}) if "input_values" in model_forward_signature: # Wav2Vec2 audio input processor_inputs.update({"audio": _get_audio_input(), "padding": True}) pt_input = processor(**processor_inputs, return_tensors="pt") tf_input = processor(**processor_inputs, return_tensors="tf") # Extra input requirements, in addition to the input modality if ( config.is_encoder_decoder or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder")) or "decoder_input_ids" in tf_dummy_inputs ): decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0) pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)}) tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)}) return pt_input, tf_input def run(self): # hub version 0.9.0 introduced the possibility of programmatically opening PRs with normal write tokens. if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"): raise ImportError( "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub" " installation." ) else: from huggingface_hub import Repository, create_commit from huggingface_hub._commit_api import CommitOperationAdd # Fetch remote data repo = Repository(local_dir=self._local_dir, clone_from=self._model_name) # Load config and get the appropriate architecture -- the latter is needed to convert the head's weights config = AutoConfig.from_pretrained(self._local_dir) architectures = config.architectures if self._override_model_class is not None: if self._override_model_class.startswith("TF"): architectures = [self._override_model_class[2:]] else: architectures = [self._override_model_class] try: pt_class = getattr(import_module("transformers"), architectures[0]) except AttributeError: raise ValueError(f"Model class {self._override_model_class} not found in transformers.") try: tf_class = getattr(import_module("transformers"), "TF" + architectures[0]) except AttributeError: raise ValueError(f"TF model class TF{self._override_model_class} not found in transformers.") elif architectures is None: # No architecture defined -- use auto classes pt_class = getattr(import_module("transformers"), "AutoModel") tf_class = getattr(import_module("transformers"), "TFAutoModel") self._logger.warning("No detected architecture, using AutoModel/TFAutoModel") else: # Architecture defined -- use it if len(architectures) > 1: raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})") self._logger.warning(f"Detected architecture: {architectures[0]}") pt_class = getattr(import_module("transformers"), architectures[0]) try: tf_class = getattr(import_module("transformers"), "TF" + architectures[0]) except AttributeError: raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.") # Check the TF dummy inputs to see what keys we need in the forward pass tf_from_pt_model = tf_class.from_config(config) tf_dummy_inputs = tf_from_pt_model.dummy_inputs del tf_from_pt_model # Try to keep only one model in memory at a time # Load the model and get some basic inputs pt_model = pt_class.from_pretrained(self._local_dir) pt_model.eval() pt_input, tf_input = self.get_inputs(pt_model, tf_dummy_inputs, config) with torch.no_grad(): pt_outputs = pt_model(**pt_input, output_hidden_states=True) del pt_model # will no longer be used, and may have a large memory footprint tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True, training=False) # Confirms that cross loading PT weights into TF worked. crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs) output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k} hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k} if len(output_differences) == 0 and architectures is not None: raise ValueError( f"Something went wrong -- the config file has architectures ({architectures}), but no model head" " output was found. All outputs start with 'hidden'" ) max_crossload_output_diff = max(output_differences.values()) if output_differences else 0.0 max_crossload_hidden_diff = max(hidden_differences.values()) if max_crossload_output_diff > self._max_error or max_crossload_hidden_diff > self._max_error: raise ValueError( "The cross-loaded TensorFlow model has different outputs, something went wrong!\n" + f"\nList of maximum output differences above the threshold ({self._max_error}):\n" + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > self._max_error]) + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_error}):\n" + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_error]) ) # Save the weights in a TF format (if needed) and confirms that the results are still good tf_weights_path = os.path.join(self._local_dir, TF2_WEIGHTS_NAME) tf_weights_index_path = os.path.join(self._local_dir, TF2_WEIGHTS_INDEX_NAME) if (not os.path.exists(tf_weights_path) and not os.path.exists(tf_weights_index_path)) or self._new_weights: tf_from_pt_model.save_pretrained(self._local_dir) del tf_from_pt_model # will no longer be used, and may have a large memory footprint tf_model = tf_class.from_pretrained(self._local_dir) tf_outputs = tf_model(**tf_input, output_hidden_states=True) conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs) output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k} hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k} if len(output_differences) == 0 and architectures is not None: raise ValueError( f"Something went wrong -- the config file has architectures ({architectures}), but no model head" " output was found. All outputs start with 'hidden'" ) max_conversion_output_diff = max(output_differences.values()) if output_differences else 0.0 max_conversion_hidden_diff = max(hidden_differences.values()) if max_conversion_output_diff > self._max_error or max_conversion_hidden_diff > self._max_error: raise ValueError( "The converted TensorFlow model has different outputs, something went wrong!\n" + f"\nList of maximum output differences above the threshold ({self._max_error}):\n" + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > self._max_error]) + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_error}):\n" + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_error]) ) commit_message = "Update TF weights" if self._new_weights else "Add TF weights" if self._push: repo.git_add(auto_lfs_track=True) repo.git_commit(commit_message) repo.git_push(blocking=True) # this prints a progress bar with the upload self._logger.warning(f"TF weights pushed into {self._model_name}") elif not self._no_pr: self._logger.warning("Uploading the weights into a new PR...") commit_descrition = ( "Model converted by the [`transformers`' `pt_to_tf`" " CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py). " "All converted model outputs and hidden layers were validated against its PyTorch counterpart.\n\n" f"Maximum crossload output difference={max_crossload_output_diff:.3e}; " f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n" f"Maximum conversion output difference={max_conversion_output_diff:.3e}; " f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n" ) if self._max_error > MAX_ERROR: commit_descrition += ( f"\n\nCAUTION: The maximum admissible error was manually increased to {self._max_error}!" ) if self._extra_commit_description: commit_descrition += "\n\n" + self._extra_commit_description # sharded model -> adds all related files (index and .h5 shards) if os.path.exists(tf_weights_index_path): operations = [ CommitOperationAdd(path_in_repo=TF2_WEIGHTS_INDEX_NAME, path_or_fileobj=tf_weights_index_path) ] for shard_path in tf.io.gfile.glob(self._local_dir + "/tf_model-*.h5"): operations += [ CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path) ] else: operations = [CommitOperationAdd(path_in_repo=TF2_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)] hub_pr_url = create_commit( repo_id=self._model_name, operations=operations, commit_message=commit_message, commit_description=commit_descrition, repo_type="model", create_pr=True, ).pr_url self._logger.warning(f"PR open in {hub_pr_url}")