File size: 1,085 Bytes
4864553
5f7492a
 
 
 
 
8236645
5f7492a
8236645
 
5f7492a
 
 
4864553
 
5f7492a
4864553
 
5f7492a
8236645
5f7492a
4864553
8236645
 
 
 
 
 
 
 
5f7492a
4864553
8236645
 
 
5f7492a
4864553
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

model = AutoModelForSeq2SeqLM.from_pretrained("Jayyydyyy/m2m100_418m_tokipona")
tokenizer = AutoTokenizer.from_pretrained("facebook/m2m100_418M")
device = "cuda:0" if torch.cuda.is_available() else "cpu"

LANG_CODES = {
	"English": "en",
	"toki pona": "tl"
}

def translate(text, src_lang, tgt_lang, candidates:int):
	src = LANG_CODES.get(src_lang)
	tgt = LANG_CODES.get(tgt_lang)

	tokenizer.src_lang = src
	tokenizer.tgt_lang = tgt

	ins = tokenizer(text, return_tensors="pt").to(device)

	gen_args = {
		"return_dict_in_generate": True,
		"output_scores": True,
		"output_hidden_states": True,
		"length_penalty": 0.0,  # don"t encourage longer or shorter output
		"num_return_sequences": candidates,
		"num_beams": candidates,
		"forced_bos_token_id": tokenizer.lang_code_to_id[tgt]
	}

	outs = model.generate(**{**ins, **gen_args})
	return outs
	# output = tokenizer.batch_decode(outs.sequences, skip_special_tokens=True)
	# return "\n".join(output)

print(translate("Hello!", "English", "toki pona", 1))