potsawee commited on
Commit
6acc418
·
1 Parent(s): 3bfc23d

use sentence split for translation

Browse files
Files changed (2) hide show
  1. app.py +28 -13
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import random
 
3
  import torch
4
  from transformers import MT5Tokenizer, MT5ForConditionalGeneration
5
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -11,32 +12,46 @@ translator.eval()
11
  summarizer.eval()
12
  translator.to(device)
13
  summarizer.to(device)
14
-
15
 
16
  def generate_output(
17
  task,
18
  text,
19
  ):
20
- inputs = tokenizer(
21
- [text],
22
- padding="longest",
23
- max_length=1024,
24
- truncation=True,
25
- return_tensors="pt",
26
- ).to(device)
27
  if task == 'Translation':
28
- outputs = translator.generate(
29
- **inputs,
30
- max_new_tokens=256,
31
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  elif task == 'Summarization':
 
 
 
 
 
 
 
33
  outputs = summarizer.generate(
34
  **inputs,
35
  max_new_tokens=256,
36
  )
 
37
  else:
38
  raise ValueError("task undefined!")
39
- gen_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
  return gen_text
41
 
42
  TASKS = ["Translation", "Summarization"]
 
1
  import gradio as gr
2
  import random
3
+ import spacy
4
  import torch
5
  from transformers import MT5Tokenizer, MT5ForConditionalGeneration
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
12
  summarizer.eval()
13
  translator.to(device)
14
  summarizer.to(device)
15
+ nlp = spacy.load("en_core_web_sm")
16
 
17
  def generate_output(
18
  task,
19
  text,
20
  ):
 
 
 
 
 
 
 
21
  if task == 'Translation':
22
+ sentences = [sent.text.strip() for sent in nlp(text).sents] # List[spacy.tokens.span.Span]
23
+ gen_texts = []
24
+ for sentence in sentences:
25
+ inputs = tokenizer(
26
+ [sentence],
27
+ padding="longest",
28
+ max_length=1024,
29
+ truncation=True,
30
+ return_tensors="pt",
31
+ ).to(device)
32
+ outputs = translator.generate(
33
+ **inputs,
34
+ max_new_tokens=256,
35
+ )
36
+ gen_text_ = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
+ gen_texts.append(gen_text_)
38
+ return " ".join(gen_texts)
39
+
40
  elif task == 'Summarization':
41
+ inputs = tokenizer(
42
+ [text],
43
+ padding="longest",
44
+ max_length=1024,
45
+ truncation=True,
46
+ return_tensors="pt",
47
+ ).to(device)
48
  outputs = summarizer.generate(
49
  **inputs,
50
  max_new_tokens=256,
51
  )
52
+ gen_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
  else:
54
  raise ValueError("task undefined!")
 
55
  return gen_text
56
 
57
  TASKS = ["Translation", "Summarization"]
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch>=1.10
2
  transformers>=4.11.3
3
  sentencepiece
 
 
1
  torch>=1.10
2
  transformers>=4.11.3
3
  sentencepiece
4
+ spacy