import torch | |
from transformers import AutoModelForMaskedLM, AutoTokenizer, FlaxAutoModelForMaskedLM | |
from datasets import load_dataset | |
from wechsel import WECHSEL, load_embeddings | |
source_tokenizer = AutoTokenizer.from_pretrained("roberta-large") | |
model = AutoModelForMaskedLM.from_pretrained("roberta-large") | |
target_tokenizer = AutoTokenizer.from_pretrained("./") | |
wechsel = WECHSEL( | |
load_embeddings("en"), | |
load_embeddings("fi"), | |
bilingual_dictionary="finnish" | |
) | |
target_embeddings, info = wechsel.apply( | |
source_tokenizer, | |
target_tokenizer, | |
model.get_input_embeddings().weight.detach().numpy(), | |
) | |
model.get_input_embeddings().weight.data = torch.from_numpy(target_embeddings).to(torch.float32) | |
model.save_pretrained("./") | |
# flax_model = FlaxAutoModelForMaskedLM.from_pretrained("./", from_pt=True) | |
# flax_model.save_pretrained("./") | |