UL2-nemo-conversion / convert_finnish_ul2_model.py
Faton Rekathati
UL2 conversion instructions
0fd282e
import os
import torch
from convert_nemo_ul2_checkpoint import convert_nemo_to_hf
from transformers import T5ForConditionalGeneration, AutoTokenizer
#### Step 1: Convert the original HF model which was converted to NEMO back to HF weights
nemo_weights = torch.load("ul2-base-nl36-finnish/nemo_state_dict.pt")
hf_weights = convert_nemo_to_hf(nemo_weights)
#### Step 2: Load original HF model and save its config/tokenizer in local folder
hf_model = T5ForConditionalGeneration.from_pretrained("Finnish-NLP/ul2-base-nl36-finnish")
tokenizer = AutoTokenizer.from_pretrained("Finnish-NLP/ul2-base-nl36-finnish")
# Save tokenizer in ul2-base-nl36-finnish
tokenizer.save_pretrained("ul2-base-nl36-finnish/hf_t5_ul2")
# Save config in ul2-base-nl36-finnish
hf_model.config.save_pretrained("ul2-base-nl36-finnish/hf_t5_ul2")
#### Step 3: Save our converted weights to the local folder
# Save converted model weights in ul2-base-nl36-finnish
torch.save(hf_weights, os.path.join("ul2-base-nl36-finnish/hf_t5_ul2", "pytorch_model.bin"))
#### Step4: Load the converted model from local folder and check whether weights are the same
converted_model = T5ForConditionalGeneration.from_pretrained("ul2-base-nl36-finnish/hf_t5_ul2")
equal = []
for key in hf_model.state_dict().keys():
print(key)
print(torch.allclose(hf_model.state_dict()[key], converted_model.state_dict()[key]))
equal.append(torch.allclose(hf_model.state_dict()[key], converted_model.state_dict()[key]))
print(f"All weights are equal: {all(equal)}")