Spaces:
Runtime error
Runtime error
Commit
•
a7233a3
1
Parent(s):
cccae94
add translation task
Browse files
main.py
CHANGED
@@ -1,17 +1,42 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
def predict(inp):
|
6 |
-
|
7 |
-
counter += 1
|
8 |
-
return str(counter)
|
9 |
|
10 |
def run():
|
11 |
demo = gr.Interface(
|
12 |
fn=predict,
|
13 |
-
inputs=gr.inputs.Textbox(label="
|
14 |
-
outputs=gr.outputs.Textbox(label="
|
15 |
)
|
16 |
|
17 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
1 |
import gradio as gr
|
2 |
+
import sys
|
3 |
+
import tensorflow_text as tftxt
|
4 |
+
import tensorflow as tf
|
5 |
+
from official.core import exp_factory, task_factory
|
6 |
+
from official.nlp.configs import wmt_transformer_experiments as wmt_te
|
7 |
|
8 |
+
tokenizer= tftxt.SentencepieceTokenizer(
|
9 |
+
model=tf.io.gfile.GFile("/home/user/app/pretrained_v2/sentencepiece_en_tr.model", "rb").read(),
|
10 |
+
add_eos=True)
|
11 |
+
|
12 |
+
sys.path = ['/home/user/app/nmt-en-tr/datasets', '/home/user/app/nmt-en-tr/models'] + sys.path
|
13 |
+
# sys.path = ['/root/.local/lib/python3.10/site-packages', '/root/.local/bin'] + sys.path
|
14 |
+
|
15 |
+
task_config = exp_factory.get_exp_config('transformer_tr_en_blended/base').task
|
16 |
+
task_config.sentencepiece_model_path = 'pretrained_v2/sentencepiece_en_tr.model'
|
17 |
+
|
18 |
+
translation_task = task_factory.get_task(task_config)
|
19 |
+
model_en_tr = translation_task.build_model()
|
20 |
+
# model_tr_en = translation_task.build_model() # we can use the same task
|
21 |
+
|
22 |
+
def translate(input_text, model):
|
23 |
+
tokenized = tokenizer.tokenize(input_text)
|
24 |
+
translated = model({'inputs' : tf.reshape(tokenized, [1, -1])})
|
25 |
+
return tokenizer.detokenize(translated['outputs']).numpy()[0].decode('utf-8')
|
26 |
+
|
27 |
+
ignore = translate("test", model_en_tr)
|
28 |
+
# ignore = translate("test", model_tr_en)
|
29 |
+
|
30 |
+
model_en_tr.load_weights("pretrained_v2/en_tr/en_tr")
|
31 |
|
32 |
def predict(inp):
|
33 |
+
return translate(inp, model_en_tr)
|
|
|
|
|
34 |
|
35 |
def run():
|
36 |
demo = gr.Interface(
|
37 |
fn=predict,
|
38 |
+
inputs=gr.inputs.Textbox(label="English"),
|
39 |
+
outputs=gr.outputs.Textbox(label="Turkish"),
|
40 |
)
|
41 |
|
42 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|