Grounded-Segment-Anything / transformers_4_35_0 /convert_pytorch_checkpoint_to_tf2.py
liuyizhang
add transformers_4_35_0
1ce5e18
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert pytorch checkpoints to TensorFlow"""
import argparse
import os
from . import (
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig,
BartConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
FlaubertConfig,
GPT2Config,
LayoutLMConfig,
LxmertConfig,
OpenAIGPTConfig,
RobertaConfig,
T5Config,
TFAlbertForPreTraining,
TFBartForConditionalGeneration,
TFBartForSequenceClassification,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFCamembertForMaskedLM,
TFCTRLLMHeadModel,
TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering,
TFDPRContextEncoder,
TFDPRQuestionEncoder,
TFDPRReader,
TFElectraForPreTraining,
TFFlaubertWithLMHeadModel,
TFGPT2LMHeadModel,
TFLayoutLMForMaskedLM,
TFLxmertForPreTraining,
TFLxmertVisualFeatureEncoder,
TFOpenAIGPTLMHeadModel,
TFRobertaForCausalLM,
TFRobertaForMaskedLM,
TFRobertaForSequenceClassification,
TFT5ForConditionalGeneration,
TFTransfoXLLMHeadModel,
TFWav2Vec2Model,
TFXLMRobertaForMaskedLM,
TFXLMWithLMHeadModel,
TFXLNetLMHeadModel,
TransfoXLConfig,
Wav2Vec2Config,
Wav2Vec2Model,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
is_torch_available,
load_pytorch_checkpoint_in_tf2_model,
)
from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
if is_torch_available():
import numpy as np
import torch
from . import (
AlbertForPreTraining,
BartForConditionalGeneration,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
CamembertForMaskedLM,
CTRLLMHeadModel,
DistilBertForMaskedLM,
DistilBertForQuestionAnswering,
DPRContextEncoder,
DPRQuestionEncoder,
DPRReader,
ElectraForPreTraining,
FlaubertWithLMHeadModel,
GPT2LMHeadModel,
LayoutLMForMaskedLM,
LxmertForPreTraining,
LxmertVisualFeatureEncoder,
OpenAIGPTLMHeadModel,
RobertaForMaskedLM,
RobertaForSequenceClassification,
T5ForConditionalGeneration,
TransfoXLLMHeadModel,
XLMRobertaForMaskedLM,
XLMWithLMHeadModel,
XLNetLMHeadModel,
)
logging.set_verbosity_info()
MODEL_CLASSES = {
"bart": (
BartConfig,
TFBartForConditionalGeneration,
TFBartForSequenceClassification,
BartForConditionalGeneration,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"bert": (
BertConfig,
TFBertForPreTraining,
BertForPreTraining,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert-large-uncased-whole-word-masking-finetuned-squad": (
BertConfig,
TFBertForQuestionAnswering,
BertForQuestionAnswering,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert-large-cased-whole-word-masking-finetuned-squad": (
BertConfig,
TFBertForQuestionAnswering,
BertForQuestionAnswering,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert-base-cased-finetuned-mrpc": (
BertConfig,
TFBertForSequenceClassification,
BertForSequenceClassification,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"dpr": (
DPRConfig,
TFDPRQuestionEncoder,
TFDPRContextEncoder,
TFDPRReader,
DPRQuestionEncoder,
DPRContextEncoder,
DPRReader,
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"gpt2": (
GPT2Config,
TFGPT2LMHeadModel,
GPT2LMHeadModel,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlnet": (
XLNetConfig,
TFXLNetLMHeadModel,
XLNetLMHeadModel,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlm": (
XLMConfig,
TFXLMWithLMHeadModel,
XLMWithLMHeadModel,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlm-roberta": (
XLMRobertaConfig,
TFXLMRobertaForMaskedLM,
XLMRobertaForMaskedLM,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"transfo-xl": (
TransfoXLConfig,
TFTransfoXLLMHeadModel,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"openai-gpt": (
OpenAIGPTConfig,
TFOpenAIGPTLMHeadModel,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"roberta": (
RobertaConfig,
TFRobertaForCausalLM,
TFRobertaForMaskedLM,
RobertaForMaskedLM,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"layoutlm": (
LayoutLMConfig,
TFLayoutLMForMaskedLM,
LayoutLMForMaskedLM,
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"roberta-large-mnli": (
RobertaConfig,
TFRobertaForSequenceClassification,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"camembert": (
CamembertConfig,
TFCamembertForMaskedLM,
CamembertForMaskedLM,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"flaubert": (
FlaubertConfig,
TFFlaubertWithLMHeadModel,
FlaubertWithLMHeadModel,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert": (
DistilBertConfig,
TFDistilBertForMaskedLM,
DistilBertForMaskedLM,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert-base-distilled-squad": (
DistilBertConfig,
TFDistilBertForQuestionAnswering,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"lxmert": (
LxmertConfig,
TFLxmertForPreTraining,
LxmertForPreTraining,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"lxmert-visual-feature-encoder": (
LxmertConfig,
TFLxmertVisualFeatureEncoder,
LxmertVisualFeatureEncoder,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"ctrl": (
CTRLConfig,
TFCTRLLMHeadModel,
CTRLLMHeadModel,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"albert": (
AlbertConfig,
TFAlbertForPreTraining,
AlbertForPreTraining,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"t5": (
T5Config,
TFT5ForConditionalGeneration,
T5ForConditionalGeneration,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"electra": (
ElectraConfig,
TFElectraForPreTraining,
ElectraForPreTraining,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"wav2vec2": (
Wav2Vec2Config,
TFWav2Vec2Model,
Wav2Vec2Model,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
}
def convert_pt_checkpoint_to_tf(
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
):
if model_type not in MODEL_CLASSES:
raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.")
config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
# Initialise TF model
if config_file in aws_config_map:
config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
config = config_class.from_json_file(config_file)
config.output_hidden_states = True
config.output_attentions = True
print(f"Building TensorFlow model from configuration: {config}")
tf_model = model_class(config)
# Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_config_map.keys():
pytorch_checkpoint_path = cached_file(
pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
)
# Load PyTorch checkpoint in tf2 model:
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)
with torch.no_grad():
pto = pt_model(**pt_model.dummy_inputs)
np_pt = pto[0].numpy()
np_tf = tfo[0].numpy()
diff = np.amax(np.abs(np_pt - np_tf))
print(f"Max absolute difference between models outputs {diff}")
assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}"
# Save pytorch-model
print(f"Save TensorFlow model to {tf_dump_path}")
tf_model.save_weights(tf_dump_path, save_format="h5")
def convert_all_pt_checkpoints_to_tf(
args_model_type,
tf_dump_path,
model_shortcut_names_or_path=None,
config_shortcut_names_or_path=None,
compare_with_pt_model=False,
use_cached_models=False,
remove_cached_files=False,
only_convert_finetuned_models=False,
):
if args_model_type is None:
model_types = list(MODEL_CLASSES.keys())
else:
model_types = [args_model_type]
for j, model_type in enumerate(model_types, start=1):
print("=" * 100)
print(f" Converting model type {j}/{len(model_types)}: {model_type}")
print("=" * 100)
if model_type not in MODEL_CLASSES:
raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.")
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
if model_shortcut_names_or_path is None:
model_shortcut_names_or_path = list(aws_model_maps.keys())
if config_shortcut_names_or_path is None:
config_shortcut_names_or_path = model_shortcut_names_or_path
for i, (model_shortcut_name, config_shortcut_name) in enumerate(
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
):
print("-" * 100)
if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
if not only_convert_finetuned_models:
print(f" Skipping finetuned checkpoint {model_shortcut_name}")
continue
model_type = model_shortcut_name
elif only_convert_finetuned_models:
print(f" Skipping not finetuned checkpoint {model_shortcut_name}")
continue
print(
f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}"
)
print("-" * 100)
if config_shortcut_name in aws_config_map:
config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
else:
config_file = config_shortcut_name
if model_shortcut_name in aws_model_maps:
model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
else:
model_file = model_shortcut_name
if os.path.isfile(model_shortcut_name):
model_shortcut_name = "converted_model"
convert_pt_checkpoint_to_tf(
model_type=model_type,
pytorch_checkpoint_path=model_file,
config_file=config_file,
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
compare_with_pt_model=compare_with_pt_model,
)
if remove_cached_files:
os.remove(config_file)
os.remove(model_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
)
parser.add_argument(
"--model_type",
default=None,
type=str,
help=(
f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and "
"convert all the models from AWS."
),
)
parser.add_argument(
"--pytorch_checkpoint_path",
default=None,
type=str,
help=(
"Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
"If not given, will download and convert all the checkpoints from AWS."
),
)
parser.add_argument(
"--config_file",
default=None,
type=str,
help=(
"The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture. If not given and "
"--pytorch_checkpoint_path is not given or is a shortcut name "
"use the configuration associated to the shortcut name on the AWS"
),
)
parser.add_argument(
"--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
)
parser.add_argument(
"--use_cached_models",
action="store_true",
help="Use cached models if possible instead of updating to latest checkpoint versions.",
)
parser.add_argument(
"--remove_cached_files",
action="store_true",
help="Remove pytorch models after conversion (save memory when converting in batches).",
)
parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
args = parser.parse_args()
# if args.pytorch_checkpoint_path is not None:
# convert_pt_checkpoint_to_tf(args.model_type.lower(),
# args.pytorch_checkpoint_path,
# args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
# args.tf_dump_path,
# compare_with_pt_model=args.compare_with_pt_model,
# use_cached_models=args.use_cached_models)
# else:
convert_all_pt_checkpoints_to_tf(
args.model_type.lower() if args.model_type is not None else None,
args.tf_dump_path,
model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
if args.pytorch_checkpoint_path is not None
else None,
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models,
remove_cached_files=args.remove_cached_files,
only_convert_finetuned_models=args.only_convert_finetuned_models,
)