Spaces:
Running
on
Zero
Running
on
Zero
cointegrated
commited on
Commit
•
2a62da0
1
Parent(s):
d0ffdbf
add punctuation normalization and load the tokenizer only once
Browse files- app.py +14 -7
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
from flores import code_mapping
|
5 |
import platform
|
@@ -28,12 +29,11 @@ def load_model():
|
|
28 |
model = load_model()
|
29 |
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
)
|
35 |
-
return tokenizer
|
36 |
|
|
|
37 |
|
38 |
# cache function
|
39 |
@lru_cache(maxsize=100)
|
@@ -44,10 +44,17 @@ def translate(text: str, src_lang: str, tgt_lang: str):
|
|
44 |
raise gr.Error("The target language is empty! Please choose it in the dropdown list.")
|
45 |
return _translate(text, src_lang, tgt_lang)
|
46 |
|
|
|
47 |
# Only assign GPU if cache not used
|
48 |
@spaces.GPU
|
49 |
def _translate(text: str, src_lang: str, tgt_lang: str):
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
paragraphs = text.split("\n")
|
53 |
translated_paragraphs = []
|
@@ -66,7 +73,7 @@ def _translate(text: str, src_lang: str, tgt_lang: str):
|
|
66 |
)
|
67 |
translated_chunk = model.generate(
|
68 |
input_ids=torch.tensor([input_tokens]).to(device),
|
69 |
-
forced_bos_token_id=tokenizer.convert_tokens_to_ids(
|
70 |
max_length=len(input_tokens) + 50,
|
71 |
num_return_sequences=1,
|
72 |
num_beams=5,
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
+
from sacremoses import MosesPunctNormalizer
|
4 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
from flores import code_mapping
|
6 |
import platform
|
|
|
29 |
model = load_model()
|
30 |
|
31 |
|
32 |
+
# Loading the tokenizer once, because re-loading it takes about 1.5 seconds each time
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
34 |
+
|
|
|
|
|
35 |
|
36 |
+
punct_normalizer = MosesPunctNormalizer(lang="en")
|
37 |
|
38 |
# cache function
|
39 |
@lru_cache(maxsize=100)
|
|
|
44 |
raise gr.Error("The target language is empty! Please choose it in the dropdown list.")
|
45 |
return _translate(text, src_lang, tgt_lang)
|
46 |
|
47 |
+
|
48 |
# Only assign GPU if cache not used
|
49 |
@spaces.GPU
|
50 |
def _translate(text: str, src_lang: str, tgt_lang: str):
|
51 |
+
src_code = code_mapping[src_lang]
|
52 |
+
tgt_code = code_mapping[tgt_lang]
|
53 |
+
tokenizer.src_lang = src_code
|
54 |
+
tokenizer.tgt_lang = tgt_code
|
55 |
+
|
56 |
+
# normalizing the punctuation first
|
57 |
+
text = punct_normalizer.normalize(text)
|
58 |
|
59 |
paragraphs = text.split("\n")
|
60 |
translated_paragraphs = []
|
|
|
73 |
)
|
74 |
translated_chunk = model.generate(
|
75 |
input_ids=torch.tensor([input_tokens]).to(device),
|
76 |
+
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
|
77 |
max_length=len(input_tokens) + 50,
|
78 |
num_return_sequences=1,
|
79 |
num_beams=5,
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ transformers
|
|
3 |
torch
|
4 |
gradio==4.32.2
|
5 |
spaces
|
6 |
-
nltk
|
|
|
|
3 |
torch
|
4 |
gradio==4.32.2
|
5 |
spaces
|
6 |
+
nltk
|
7 |
+
sacremoses
|