guymorlan commited on
Commit
df677ba
1 Parent(s): 2e03cd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -7
app.py CHANGED
@@ -10,7 +10,8 @@ dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egypt
10
  # translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
11
  translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True)
12
  tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect")
13
- translator_ar2en = pipeline(task="translation", model="guymorlan/Shami2English")
 
14
  transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
15
 
16
  speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
@@ -28,7 +29,7 @@ def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT
28
  return colors_hex
29
 
30
 
31
- def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4):
32
  alignment = []
33
  for i, tok in enumerate(outputs.cross_attentions[2][0][7]):
34
  alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
@@ -93,7 +94,7 @@ def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, thresh
93
 
94
  srchtml = []
95
  for i, token in enumerate(encoder_input_ids[0]):
96
- if i == 0:
97
  continue
98
  if f"trg_{i}" in colordict:
99
  label = f"trg_{i}"
@@ -158,13 +159,42 @@ def translate_english(input_text, include):
158
 
159
  return palhtml, pal_out, sy_out, lb_out, eg_out
160
 
161
- def translate_arabic(input_text):
162
  if not input_text:
163
  return ""
164
 
165
- result = translator_ar2en([input_text])
166
- return result[0]["translation_text"]
 
 
 
 
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def get_audio(input_text):
170
  audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav")
@@ -244,6 +274,7 @@ with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default")
244
  input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True)
245
  pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]);
246
  include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg])
 
247
  with gr.Tab('Ar > En'):
248
  with gr.Row():
249
  with gr.Column():
@@ -252,8 +283,12 @@ with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default")
252
  btn = gr.Button("Translate", label="Translate")
253
  gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).")
254
  with gr.Column():
255
- eng = gr.Textbox(label="English", lines=1, elem_id="liter")
 
 
 
256
  btn.click(translate_arabic,inputs=input_text, outputs=[eng])
 
257
  with gr.Tab("Transliterate"):
258
  with gr.Row():
259
  with gr.Column():
 
10
  # translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
11
  translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True)
12
  tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect")
13
+ translator_ar2en = MarianMTModel.from_pretrained("guymorlan/Shami2English", output_attentions=True)
14
+ tokenizer_ar2en = AutoTokenizer.from_pretrained("guymorlan/Shami2English")
15
  transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
16
 
17
  speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
 
29
  return colors_hex
30
 
31
 
32
+ def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4, skip_first_src=True):
33
  alignment = []
34
  for i, tok in enumerate(outputs.cross_attentions[2][0][7]):
35
  alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
 
94
 
95
  srchtml = []
96
  for i, token in enumerate(encoder_input_ids[0]):
97
+ if skip_first_src and i == 0:
98
  continue
99
  if f"trg_{i}" in colordict:
100
  label = f"trg_{i}"
 
159
 
160
  return palhtml, pal_out, sy_out, lb_out, eg_out
161
 
162
+ def translate_arabic(input_text, include=["Colorize"]):
163
  if not input_text:
164
  return ""
165
 
166
+ input_tokens = tokenizer_ar2en(input_text, return_tensors="pt").input_ids
167
+ # print(input_tokens)
168
+ outputs = translator_ar2en.generate(input_tokens)
169
+ # print(outputs)
170
+
171
+ encoder_input_ids = input_tokens[0].unsqueeze(0)
172
+ decoder_input_ids = outputs[0].unsqueeze(0)
173
 
174
+ decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True)
175
+ # print(decoded)
176
+
177
+ print(include)
178
+ if "Colorize" in include:
179
+ html_outputs = translator_ar2en(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)
180
+
181
+ # set dynamic threshold
182
+ # print(input_tokens, input_tokens.shape)
183
+ if input_tokens.shape[1] < 20:
184
+ threshold = 0.1
185
+ elif input_tokens.shape[1] < 30:
186
+ threshold = 0.01
187
+ else:
188
+ threshold = 0.05
189
+
190
+ print("threshold", threshold)
191
+
192
+ srchtml, tgthtml = align_words(html_outputs, tokenizer_ar2en, encoder_input_ids, decoder_input_ids, threshold, skip_first_src=False)
193
+ enhtml = f"<div style='direction: rtl'>{srchtml}</div><br><br><div>{tgthtml}</div>"
194
+ else:
195
+ enhtml = f"<div style='font-size: 30px;'>{decoded[0]}</div>"
196
+
197
+ return enhtml
198
 
199
  def get_audio(input_text):
200
  audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav")
 
274
  input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True)
275
  pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]);
276
  include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg])
277
+
278
  with gr.Tab('Ar > En'):
279
  with gr.Row():
280
  with gr.Column():
 
283
  btn = gr.Button("Translate", label="Translate")
284
  gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).")
285
  with gr.Column():
286
+ with gr.Box(label = "English"):
287
+ gr.Markdown("English")
288
+ with gr.Box():
289
+ eng = gr.HTML("<br>", label="English", elem_id="main")
290
  btn.click(translate_arabic,inputs=input_text, outputs=[eng])
291
+
292
  with gr.Tab("Transliterate"):
293
  with gr.Row():
294
  with gr.Column():