|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from configuration_indictrans import IndicTransConfig |
|
from modeling_indictrans import IndicTransForConditionalGeneration |
|
|
|
|
|
def remove_ignore_keys_(state_dict): |
|
ignore_keys = [ |
|
"encoder.version", |
|
"decoder.version", |
|
"model.encoder.version", |
|
"model.decoder.version", |
|
"_float_tensor", |
|
"encoder.embed_positions._float_tensor", |
|
"decoder.embed_positions._float_tensor", |
|
] |
|
for k in ignore_keys: |
|
state_dict.pop(k, None) |
|
|
|
|
|
def make_linear_from_emb(emb): |
|
vocab_size, emb_size = emb.shape |
|
lin_layer = nn.Linear(vocab_size, emb_size, bias=False) |
|
lin_layer.weight.data = emb.data |
|
return lin_layer |
|
|
|
|
|
def convert_fairseq_IT2_checkpoint_from_disk(checkpoint_path): |
|
model = torch.load(checkpoint_path, map_location="cpu") |
|
args = model["args"] or model["cfg"]["model"] |
|
state_dict = model["model"] |
|
remove_ignore_keys_(state_dict) |
|
encoder_vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] |
|
decoder_vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0] |
|
|
|
config = IndicTransConfig( |
|
encoder_vocab_size=encoder_vocab_size, |
|
decoder_vocab_size=decoder_vocab_size, |
|
max_source_positions=args.max_source_positions, |
|
max_target_positions=args.max_target_positions, |
|
encoder_layers=args.encoder_layers, |
|
decoder_layers=args.decoder_layers, |
|
layernorm_embedding=args.layernorm_embedding, |
|
encoder_normalize_before=args.encoder_normalize_before, |
|
decoder_normalize_before=args.decoder_normalize_before, |
|
encoder_attention_heads=args.encoder_attention_heads, |
|
decoder_attention_heads=args.decoder_attention_heads, |
|
encoder_ffn_dim=args.encoder_ffn_embed_dim, |
|
decoder_ffn_dim=args.decoder_ffn_embed_dim, |
|
encoder_embed_dim=args.encoder_embed_dim, |
|
decoder_embed_dim=args.decoder_embed_dim, |
|
encoder_layerdrop=args.encoder_layerdrop, |
|
decoder_layerdrop=args.decoder_layerdrop, |
|
dropout=args.dropout, |
|
attention_dropout=args.attention_dropout, |
|
activation_dropout=args.activation_dropout, |
|
activation_function=args.activation_fn, |
|
share_decoder_input_output_embed=args.share_decoder_input_output_embed, |
|
scale_embedding=not args.no_scale_embedding, |
|
) |
|
|
|
model = IndicTransForConditionalGeneration(config) |
|
model.model.load_state_dict(state_dict, strict=False) |
|
if not args.share_decoder_input_output_embed: |
|
model.lm_head = make_linear_from_emb( |
|
state_dict["decoder.output_projection.weight"] |
|
) |
|
print(model) |
|
return model |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--fairseq_path", |
|
default="indic-en/model/checkpoint_best.pt", |
|
type=str, |
|
help="path to a model.pt on local filesystem.", |
|
) |
|
parser.add_argument( |
|
"--pytorch_dump_folder_path", |
|
default="indic-en/hf_model", |
|
type=str, |
|
help="Path to the output PyTorch model.", |
|
) |
|
|
|
args = parser.parse_args() |
|
model = convert_fairseq_IT2_checkpoint_from_disk(args.fairseq_path) |
|
model.save_pretrained(args.pytorch_dump_folder_path) |
|
|