gorkemozkaya commited on
Commit
54177cb
1 Parent(s): e35cceb

reversed translation direction

Browse files
Files changed (1) hide show
  1. main.py +8 -8
main.py CHANGED
@@ -17,27 +17,27 @@ task_config = exp_factory.get_exp_config('transformer_tr_en_blended/base').task
17
  task_config.sentencepiece_model_path = '/code/pretrained_v2/sentencepiece_en_tr.model'
18
 
19
  translation_task = task_factory.get_task(task_config)
20
- model_en_tr = translation_task.build_model()
21
- # model_tr_en = translation_task.build_model() # we can use the same task
22
 
23
  def translate(input_text, model):
24
  tokenized = tokenizer.tokenize(input_text)
25
  translated = model({'inputs' : tf.reshape(tokenized, [1, -1])})
26
  return tokenizer.detokenize(translated['outputs']).numpy()[0].decode('utf-8')
27
 
28
- ignore = translate("test", model_en_tr)
29
- # ignore = translate("test", model_tr_en)
30
 
31
- model_en_tr.load_weights("/code/pretrained_v2/en_tr/en_tr")
32
 
33
  def predict(inp):
34
- return translate(inp, model_en_tr)
35
 
36
  def run():
37
  demo = gr.Interface(
38
  fn=predict,
39
- inputs=gr.inputs.Textbox(label="English"),
40
- outputs=gr.outputs.Textbox(label="Turkish"),
41
  )
42
 
43
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
17
  task_config.sentencepiece_model_path = '/code/pretrained_v2/sentencepiece_en_tr.model'
18
 
19
  translation_task = task_factory.get_task(task_config)
20
+ # model_en_tr = translation_task.build_model()
21
+ model_tr_en = translation_task.build_model() # we can use the same task
22
 
23
  def translate(input_text, model):
24
  tokenized = tokenizer.tokenize(input_text)
25
  translated = model({'inputs' : tf.reshape(tokenized, [1, -1])})
26
  return tokenizer.detokenize(translated['outputs']).numpy()[0].decode('utf-8')
27
 
28
+ # ignore = translate("test", model_en_tr)
29
+ ignore = translate("test", model_tr_en)
30
 
31
+ model_tr_en.load_weights("/code/pretrained_v2/tr_en/tr_en")
32
 
33
  def predict(inp):
34
+ return translate(inp, model_tr_en)
35
 
36
  def run():
37
  demo = gr.Interface(
38
  fn=predict,
39
+ inputs=gr.inputs.Textbox(label="Turkish"),
40
+ outputs=gr.outputs.Textbox(label="English"),
41
  )
42
 
43
  demo.launch(server_name="0.0.0.0", server_port=7860)