gorkemozkaya commited on
Commit
a7233a3
1 Parent(s): cccae94

add translation task

Browse files
Files changed (1) hide show
  1. main.py +31 -6
main.py CHANGED
@@ -1,17 +1,42 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- counter = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def predict(inp):
6
- global counter
7
- counter += 1
8
- return str(counter)
9
 
10
  def run():
11
  demo = gr.Interface(
12
  fn=predict,
13
- inputs=gr.inputs.Textbox(label="Input Text"),
14
- outputs=gr.outputs.Textbox(label="Output Text"),
15
  )
16
 
17
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ import sys
3
+ import tensorflow_text as tftxt
4
+ import tensorflow as tf
5
+ from official.core import exp_factory, task_factory
6
+ from official.nlp.configs import wmt_transformer_experiments as wmt_te
7
 
8
+ tokenizer= tftxt.SentencepieceTokenizer(
9
+ model=tf.io.gfile.GFile("/home/user/app/pretrained_v2/sentencepiece_en_tr.model", "rb").read(),
10
+ add_eos=True)
11
+
12
+ sys.path = ['/home/user/app/nmt-en-tr/datasets', '/home/user/app/nmt-en-tr/models'] + sys.path
13
+ # sys.path = ['/root/.local/lib/python3.10/site-packages', '/root/.local/bin'] + sys.path
14
+
15
+ task_config = exp_factory.get_exp_config('transformer_tr_en_blended/base').task
16
+ task_config.sentencepiece_model_path = 'pretrained_v2/sentencepiece_en_tr.model'
17
+
18
+ translation_task = task_factory.get_task(task_config)
19
+ model_en_tr = translation_task.build_model()
20
+ # model_tr_en = translation_task.build_model() # we can use the same task
21
+
22
+ def translate(input_text, model):
23
+ tokenized = tokenizer.tokenize(input_text)
24
+ translated = model({'inputs' : tf.reshape(tokenized, [1, -1])})
25
+ return tokenizer.detokenize(translated['outputs']).numpy()[0].decode('utf-8')
26
+
27
+ ignore = translate("test", model_en_tr)
28
+ # ignore = translate("test", model_tr_en)
29
+
30
+ model_en_tr.load_weights("pretrained_v2/en_tr/en_tr")
31
 
32
  def predict(inp):
33
+ return translate(inp, model_en_tr)
 
 
34
 
35
  def run():
36
  demo = gr.Interface(
37
  fn=predict,
38
+ inputs=gr.inputs.Textbox(label="English"),
39
+ outputs=gr.outputs.Textbox(label="Turkish"),
40
  )
41
 
42
  demo.launch(server_name="0.0.0.0", server_port=7860)