import torch from transformers import AutoModel, AutoTokenizer, FlaxAutoModel from datasets import load_dataset from wechsel import WECHSEL, load_embeddings source_tokenizer = AutoTokenizer.from_pretrained("roberta-large") model = AutoModel.from_pretrained("roberta-large") target_tokenizer = AutoTokenizer.from_pretrained("./") wechsel = WECHSEL( load_embeddings("en"), load_embeddings("fi"), bilingual_dictionary="finnish" ) target_embeddings, info = wechsel.apply( source_tokenizer, target_tokenizer, model.get_input_embeddings().weight.detach().numpy(), ) model.get_input_embeddings().weight.data = torch.from_numpy(target_embeddings) model.save_pretrained("./") flax_model = FlaxAutoModel.from_pretrained("./", from_pt=True) flax_model.save_pretrained("./")