davanstrien HF staff commited on
Commit
e72a9c0
1 Parent(s): 5e1003d

Refactor translation function to support paragraph translation

Browse files
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -34,26 +34,36 @@ def load_tokenizer(src_lang, tgt_lang):
34
  def translate(text: str, src_lang: str, tgt_lang: str):
35
  tokenizer = load_tokenizer(src_lang, tgt_lang)
36
 
37
- sentences = nltk.sent_tokenize(text)
38
- translated_sentences = []
39
-
40
- for sentence in sentences:
41
- input_tokens = (
42
- tokenizer(sentence, return_tensors="pt").input_ids[0].cpu().numpy().tolist()
43
- )
44
- translated_chunk = model.generate(
45
- input_ids=torch.tensor([input_tokens]).to(device),
46
- forced_bos_token_id=tokenizer.lang_code_to_id[code_mapping[tgt_lang]],
47
- max_length=len(input_tokens) + 50,
48
- num_return_sequences=1,
49
- )
50
- translated_chunk = tokenizer.decode(
51
- translated_chunk[0], skip_special_tokens=True
52
- )
53
- translated_sentences.append(translated_chunk)
54
-
55
- translated_text = " ".join(translated_sentences)
56
- return translated_text
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  description = """
 
34
  def translate(text: str, src_lang: str, tgt_lang: str):
35
  tokenizer = load_tokenizer(src_lang, tgt_lang)
36
 
37
+ paragraphs = text.split("\n")
38
+ translated_paragraphs = []
39
+
40
+ for paragraph in paragraphs:
41
+ sentences = nltk.sent_tokenize(paragraph)
42
+ translated_sentences = []
43
+
44
+ for sentence in sentences:
45
+ input_tokens = (
46
+ tokenizer(sentence, return_tensors="pt")
47
+ .input_ids[0]
48
+ .cpu()
49
+ .numpy()
50
+ .tolist()
51
+ )
52
+ translated_chunk = model.generate(
53
+ input_ids=torch.tensor([input_tokens]).to(device),
54
+ forced_bos_token_id=tokenizer.lang_code_to_id[code_mapping[tgt_lang]],
55
+ max_length=len(input_tokens) + 50,
56
+ num_return_sequences=1,
57
+ )
58
+ translated_chunk = tokenizer.decode(
59
+ translated_chunk[0], skip_special_tokens=True
60
+ )
61
+ translated_sentences.append(translated_chunk)
62
+
63
+ translated_paragraph = " ".join(translated_sentences)
64
+ translated_paragraphs.append(translated_paragraph)
65
+
66
+ return "\n".join(translated_paragraphs)
67
 
68
 
69
  description = """