|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Convert OpenAI GPT checkpoint.""" |
|
|
|
|
|
import argparse |
|
|
|
import torch |
|
|
|
from transformers import GPT2Config, GPT2Model, load_tf_weights_in_gpt2 |
|
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME |
|
from transformers.utils import logging |
|
|
|
|
|
logging.set_verbosity_info() |
|
|
|
|
|
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): |
|
|
|
if gpt2_config_file == "": |
|
config = GPT2Config() |
|
else: |
|
config = GPT2Config.from_json_file(gpt2_config_file) |
|
model = GPT2Model(config) |
|
|
|
|
|
load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) |
|
|
|
|
|
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME |
|
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME |
|
print(f"Save PyTorch model to {pytorch_weights_dump_path}") |
|
torch.save(model.state_dict(), pytorch_weights_dump_path) |
|
print(f"Save configuration file to {pytorch_config_dump_path}") |
|
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: |
|
f.write(config.to_json_string()) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." |
|
) |
|
parser.add_argument( |
|
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." |
|
) |
|
parser.add_argument( |
|
"--gpt2_config_file", |
|
default="", |
|
type=str, |
|
help="An optional config json file corresponding to the pre-trained OpenAI model. \n" |
|
"This specifies the model architecture.", |
|
) |
|
args = parser.parse_args() |
|
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path) |
|
|