import torch from transformers import AutoModelForMaskedLM, AutoTokenizer, FlaxAutoModelForMaskedLM from datasets import load_dataset from wechsel import WECHSEL, load_embeddings source_tokenizer = AutoTokenizer.from_pretrained("roberta-large") model = AutoModelForMaskedLM.from_pretrained("roberta-large") target_tokenizer = AutoTokenizer.from_pretrained("./") wechsel = WECHSEL( load_embeddings("en"), load_embeddings("fi"), bilingual_dictionary="finnish" ) target_embeddings, info = wechsel.apply( source_tokenizer, target_tokenizer, model.get_input_embeddings().weight.detach().numpy(), ) model.get_input_embeddings().weight.data = torch.from_numpy(target_embeddings).to(torch.float32) model.save_pretrained("./") # flax_model = FlaxAutoModelForMaskedLM.from_pretrained("./", from_pt=True) # flax_model.save_pretrained("./")