semishawn commited on
Commit
8236645
1 Parent(s): 4864553

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -19
app.py CHANGED
@@ -4,38 +4,34 @@ import torch
4
  model = AutoModelForSeq2SeqLM.from_pretrained("Jayyydyyy/m2m100_418m_tokipona")
5
  tokenizer = AutoTokenizer.from_pretrained("facebook/m2m100_418M")
6
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
7
  LANG_CODES = {
8
- "English":"en",
9
- "toki pona":"tl"
10
  }
11
 
12
  def translate(text, src_lang, tgt_lang, candidates:int):
13
- """
14
- Translate the text from source lang to target lang
15
- """
16
-
17
  src = LANG_CODES.get(src_lang)
18
  tgt = LANG_CODES.get(tgt_lang)
19
 
20
  tokenizer.src_lang = src
21
  tokenizer.tgt_lang = tgt
22
 
23
- ins = tokenizer(text, return_tensors='pt').to(device)
24
 
25
  gen_args = {
26
- 'return_dict_in_generate': True,
27
- 'output_scores': True,
28
- 'output_hidden_states': True,
29
- 'length_penalty': 0.0, # don't encourage longer or shorter output,
30
- 'num_return_sequences': candidates,
31
- 'num_beams':candidates,
32
- 'forced_bos_token_id': tokenizer.lang_code_to_id[tgt]
33
- }
34
-
35
 
36
  outs = model.generate(**{**ins, **gen_args})
37
- output = tokenizer.batch_decode(outs.sequences, skip_special_tokens=True)
38
-
39
- return '\n'.join(output)
40
 
41
  print(translate("Hello!", "English", "toki pona", 1))
 
4
  model = AutoModelForSeq2SeqLM.from_pretrained("Jayyydyyy/m2m100_418m_tokipona")
5
  tokenizer = AutoTokenizer.from_pretrained("facebook/m2m100_418M")
6
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
7
+
8
  LANG_CODES = {
9
+ "English": "en",
10
+ "toki pona": "tl"
11
  }
12
 
13
  def translate(text, src_lang, tgt_lang, candidates:int):
 
 
 
 
14
  src = LANG_CODES.get(src_lang)
15
  tgt = LANG_CODES.get(tgt_lang)
16
 
17
  tokenizer.src_lang = src
18
  tokenizer.tgt_lang = tgt
19
 
20
+ ins = tokenizer(text, return_tensors="pt").to(device)
21
 
22
  gen_args = {
23
+ "return_dict_in_generate": True,
24
+ "output_scores": True,
25
+ "output_hidden_states": True,
26
+ "length_penalty": 0.0, # don"t encourage longer or shorter output
27
+ "num_return_sequences": candidates,
28
+ "num_beams": candidates,
29
+ "forced_bos_token_id": tokenizer.lang_code_to_id[tgt]
30
+ }
 
31
 
32
  outs = model.generate(**{**ins, **gen_args})
33
+ return outs
34
+ # output = tokenizer.batch_decode(outs.sequences, skip_special_tokens=True)
35
+ # return "\n".join(output)
36
 
37
  print(translate("Hello!", "English", "toki pona", 1))