Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2018 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Convert pytorch checkpoints to TensorFlow""" | |
import argparse | |
import os | |
from . import ( | |
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
BART_PRETRAINED_MODEL_ARCHIVE_LIST, | |
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, | |
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, | |
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST, | |
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, | |
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
T5_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
AlbertConfig, | |
BartConfig, | |
BertConfig, | |
CamembertConfig, | |
CTRLConfig, | |
DistilBertConfig, | |
DPRConfig, | |
ElectraConfig, | |
FlaubertConfig, | |
GPT2Config, | |
LayoutLMConfig, | |
LxmertConfig, | |
OpenAIGPTConfig, | |
RobertaConfig, | |
T5Config, | |
TFAlbertForPreTraining, | |
TFBartForConditionalGeneration, | |
TFBartForSequenceClassification, | |
TFBertForPreTraining, | |
TFBertForQuestionAnswering, | |
TFBertForSequenceClassification, | |
TFCamembertForMaskedLM, | |
TFCTRLLMHeadModel, | |
TFDistilBertForMaskedLM, | |
TFDistilBertForQuestionAnswering, | |
TFDPRContextEncoder, | |
TFDPRQuestionEncoder, | |
TFDPRReader, | |
TFElectraForPreTraining, | |
TFFlaubertWithLMHeadModel, | |
TFGPT2LMHeadModel, | |
TFLayoutLMForMaskedLM, | |
TFLxmertForPreTraining, | |
TFLxmertVisualFeatureEncoder, | |
TFOpenAIGPTLMHeadModel, | |
TFRobertaForCausalLM, | |
TFRobertaForMaskedLM, | |
TFRobertaForSequenceClassification, | |
TFT5ForConditionalGeneration, | |
TFTransfoXLLMHeadModel, | |
TFWav2Vec2Model, | |
TFXLMRobertaForMaskedLM, | |
TFXLMWithLMHeadModel, | |
TFXLNetLMHeadModel, | |
TransfoXLConfig, | |
Wav2Vec2Config, | |
Wav2Vec2Model, | |
XLMConfig, | |
XLMRobertaConfig, | |
XLNetConfig, | |
is_torch_available, | |
load_pytorch_checkpoint_in_tf2_model, | |
) | |
from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging | |
if is_torch_available(): | |
import numpy as np | |
import torch | |
from . import ( | |
AlbertForPreTraining, | |
BartForConditionalGeneration, | |
BertForPreTraining, | |
BertForQuestionAnswering, | |
BertForSequenceClassification, | |
CamembertForMaskedLM, | |
CTRLLMHeadModel, | |
DistilBertForMaskedLM, | |
DistilBertForQuestionAnswering, | |
DPRContextEncoder, | |
DPRQuestionEncoder, | |
DPRReader, | |
ElectraForPreTraining, | |
FlaubertWithLMHeadModel, | |
GPT2LMHeadModel, | |
LayoutLMForMaskedLM, | |
LxmertForPreTraining, | |
LxmertVisualFeatureEncoder, | |
OpenAIGPTLMHeadModel, | |
RobertaForMaskedLM, | |
RobertaForSequenceClassification, | |
T5ForConditionalGeneration, | |
TransfoXLLMHeadModel, | |
XLMRobertaForMaskedLM, | |
XLMWithLMHeadModel, | |
XLNetLMHeadModel, | |
) | |
logging.set_verbosity_info() | |
MODEL_CLASSES = { | |
"bart": ( | |
BartConfig, | |
TFBartForConditionalGeneration, | |
TFBartForSequenceClassification, | |
BartForConditionalGeneration, | |
BART_PRETRAINED_MODEL_ARCHIVE_LIST, | |
), | |
"bert": ( | |
BertConfig, | |
TFBertForPreTraining, | |
BertForPreTraining, | |
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"bert-large-uncased-whole-word-masking-finetuned-squad": ( | |
BertConfig, | |
TFBertForQuestionAnswering, | |
BertForQuestionAnswering, | |
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"bert-large-cased-whole-word-masking-finetuned-squad": ( | |
BertConfig, | |
TFBertForQuestionAnswering, | |
BertForQuestionAnswering, | |
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"bert-base-cased-finetuned-mrpc": ( | |
BertConfig, | |
TFBertForSequenceClassification, | |
BertForSequenceClassification, | |
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"dpr": ( | |
DPRConfig, | |
TFDPRQuestionEncoder, | |
TFDPRContextEncoder, | |
TFDPRReader, | |
DPRQuestionEncoder, | |
DPRContextEncoder, | |
DPRReader, | |
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, | |
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, | |
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST, | |
), | |
"gpt2": ( | |
GPT2Config, | |
TFGPT2LMHeadModel, | |
GPT2LMHeadModel, | |
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"xlnet": ( | |
XLNetConfig, | |
TFXLNetLMHeadModel, | |
XLNetLMHeadModel, | |
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"xlm": ( | |
XLMConfig, | |
TFXLMWithLMHeadModel, | |
XLMWithLMHeadModel, | |
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"xlm-roberta": ( | |
XLMRobertaConfig, | |
TFXLMRobertaForMaskedLM, | |
XLMRobertaForMaskedLM, | |
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"transfo-xl": ( | |
TransfoXLConfig, | |
TFTransfoXLLMHeadModel, | |
TransfoXLLMHeadModel, | |
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"openai-gpt": ( | |
OpenAIGPTConfig, | |
TFOpenAIGPTLMHeadModel, | |
OpenAIGPTLMHeadModel, | |
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"roberta": ( | |
RobertaConfig, | |
TFRobertaForCausalLM, | |
TFRobertaForMaskedLM, | |
RobertaForMaskedLM, | |
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"layoutlm": ( | |
LayoutLMConfig, | |
TFLayoutLMForMaskedLM, | |
LayoutLMForMaskedLM, | |
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, | |
), | |
"roberta-large-mnli": ( | |
RobertaConfig, | |
TFRobertaForSequenceClassification, | |
RobertaForSequenceClassification, | |
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"camembert": ( | |
CamembertConfig, | |
TFCamembertForMaskedLM, | |
CamembertForMaskedLM, | |
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"flaubert": ( | |
FlaubertConfig, | |
TFFlaubertWithLMHeadModel, | |
FlaubertWithLMHeadModel, | |
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"distilbert": ( | |
DistilBertConfig, | |
TFDistilBertForMaskedLM, | |
DistilBertForMaskedLM, | |
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"distilbert-base-distilled-squad": ( | |
DistilBertConfig, | |
TFDistilBertForQuestionAnswering, | |
DistilBertForQuestionAnswering, | |
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"lxmert": ( | |
LxmertConfig, | |
TFLxmertForPreTraining, | |
LxmertForPreTraining, | |
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"lxmert-visual-feature-encoder": ( | |
LxmertConfig, | |
TFLxmertVisualFeatureEncoder, | |
LxmertVisualFeatureEncoder, | |
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"ctrl": ( | |
CTRLConfig, | |
TFCTRLLMHeadModel, | |
CTRLLMHeadModel, | |
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"albert": ( | |
AlbertConfig, | |
TFAlbertForPreTraining, | |
AlbertForPreTraining, | |
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"t5": ( | |
T5Config, | |
TFT5ForConditionalGeneration, | |
T5ForConditionalGeneration, | |
T5_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"electra": ( | |
ElectraConfig, | |
TFElectraForPreTraining, | |
ElectraForPreTraining, | |
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
"wav2vec2": ( | |
Wav2Vec2Config, | |
TFWav2Vec2Model, | |
Wav2Vec2Model, | |
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
), | |
} | |
def convert_pt_checkpoint_to_tf( | |
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True | |
): | |
if model_type not in MODEL_CLASSES: | |
raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.") | |
config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type] | |
# Initialise TF model | |
if config_file in aws_config_map: | |
config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models) | |
config = config_class.from_json_file(config_file) | |
config.output_hidden_states = True | |
config.output_attentions = True | |
print(f"Building TensorFlow model from configuration: {config}") | |
tf_model = model_class(config) | |
# Load weights from tf checkpoint | |
if pytorch_checkpoint_path in aws_config_map.keys(): | |
pytorch_checkpoint_path = cached_file( | |
pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models | |
) | |
# Load PyTorch checkpoint in tf2 model: | |
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path) | |
if compare_with_pt_model: | |
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network | |
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu") | |
pt_model = pt_model_class.from_pretrained( | |
pretrained_model_name_or_path=None, config=config, state_dict=state_dict | |
) | |
with torch.no_grad(): | |
pto = pt_model(**pt_model.dummy_inputs) | |
np_pt = pto[0].numpy() | |
np_tf = tfo[0].numpy() | |
diff = np.amax(np.abs(np_pt - np_tf)) | |
print(f"Max absolute difference between models outputs {diff}") | |
assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}" | |
# Save pytorch-model | |
print(f"Save TensorFlow model to {tf_dump_path}") | |
tf_model.save_weights(tf_dump_path, save_format="h5") | |
def convert_all_pt_checkpoints_to_tf( | |
args_model_type, | |
tf_dump_path, | |
model_shortcut_names_or_path=None, | |
config_shortcut_names_or_path=None, | |
compare_with_pt_model=False, | |
use_cached_models=False, | |
remove_cached_files=False, | |
only_convert_finetuned_models=False, | |
): | |
if args_model_type is None: | |
model_types = list(MODEL_CLASSES.keys()) | |
else: | |
model_types = [args_model_type] | |
for j, model_type in enumerate(model_types, start=1): | |
print("=" * 100) | |
print(f" Converting model type {j}/{len(model_types)}: {model_type}") | |
print("=" * 100) | |
if model_type not in MODEL_CLASSES: | |
raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.") | |
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] | |
if model_shortcut_names_or_path is None: | |
model_shortcut_names_or_path = list(aws_model_maps.keys()) | |
if config_shortcut_names_or_path is None: | |
config_shortcut_names_or_path = model_shortcut_names_or_path | |
for i, (model_shortcut_name, config_shortcut_name) in enumerate( | |
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1 | |
): | |
print("-" * 100) | |
if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name: | |
if not only_convert_finetuned_models: | |
print(f" Skipping finetuned checkpoint {model_shortcut_name}") | |
continue | |
model_type = model_shortcut_name | |
elif only_convert_finetuned_models: | |
print(f" Skipping not finetuned checkpoint {model_shortcut_name}") | |
continue | |
print( | |
f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}" | |
) | |
print("-" * 100) | |
if config_shortcut_name in aws_config_map: | |
config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models) | |
else: | |
config_file = config_shortcut_name | |
if model_shortcut_name in aws_model_maps: | |
model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models) | |
else: | |
model_file = model_shortcut_name | |
if os.path.isfile(model_shortcut_name): | |
model_shortcut_name = "converted_model" | |
convert_pt_checkpoint_to_tf( | |
model_type=model_type, | |
pytorch_checkpoint_path=model_file, | |
config_file=config_file, | |
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"), | |
compare_with_pt_model=compare_with_pt_model, | |
) | |
if remove_cached_files: | |
os.remove(config_file) | |
os.remove(model_file) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Required parameters | |
parser.add_argument( | |
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file." | |
) | |
parser.add_argument( | |
"--model_type", | |
default=None, | |
type=str, | |
help=( | |
f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and " | |
"convert all the models from AWS." | |
), | |
) | |
parser.add_argument( | |
"--pytorch_checkpoint_path", | |
default=None, | |
type=str, | |
help=( | |
"Path to the PyTorch checkpoint path or shortcut name to download from AWS. " | |
"If not given, will download and convert all the checkpoints from AWS." | |
), | |
) | |
parser.add_argument( | |
"--config_file", | |
default=None, | |
type=str, | |
help=( | |
"The config json file corresponding to the pre-trained model. \n" | |
"This specifies the model architecture. If not given and " | |
"--pytorch_checkpoint_path is not given or is a shortcut name " | |
"use the configuration associated to the shortcut name on the AWS" | |
), | |
) | |
parser.add_argument( | |
"--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions." | |
) | |
parser.add_argument( | |
"--use_cached_models", | |
action="store_true", | |
help="Use cached models if possible instead of updating to latest checkpoint versions.", | |
) | |
parser.add_argument( | |
"--remove_cached_files", | |
action="store_true", | |
help="Remove pytorch models after conversion (save memory when converting in batches).", | |
) | |
parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.") | |
args = parser.parse_args() | |
# if args.pytorch_checkpoint_path is not None: | |
# convert_pt_checkpoint_to_tf(args.model_type.lower(), | |
# args.pytorch_checkpoint_path, | |
# args.config_file if args.config_file is not None else args.pytorch_checkpoint_path, | |
# args.tf_dump_path, | |
# compare_with_pt_model=args.compare_with_pt_model, | |
# use_cached_models=args.use_cached_models) | |
# else: | |
convert_all_pt_checkpoints_to_tf( | |
args.model_type.lower() if args.model_type is not None else None, | |
args.tf_dump_path, | |
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] | |
if args.pytorch_checkpoint_path is not None | |
else None, | |
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None, | |
compare_with_pt_model=args.compare_with_pt_model, | |
use_cached_models=args.use_cached_models, | |
remove_cached_files=args.remove_cached_files, | |
only_convert_finetuned_models=args.only_convert_finetuned_models, | |
) | |