guymorlan commited on
Commit
c1d1323
โ€ข
1 Parent(s): 4f1aa68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -30
app.py CHANGED
@@ -1,19 +1,113 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  import os
4
  import azure.cognitiveservices.speech as speechsdk
 
 
5
 
6
  dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egyptian": "E"}
7
 
8
- translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
 
 
9
  translator_ar2en = pipeline(task="translation", model="guymorlan/Shami2English")
10
  transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
11
 
12
  speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def translate_english(input_text, include):
15
  if not input_text:
16
- return "", "", "", ""
17
 
18
  inputs = [f"{val} {input_text}" for val in dialects.values()]
19
 
@@ -26,14 +120,43 @@ def translate_english(input_text, include):
26
  if not sy:
27
  inputs.pop()
28
 
29
- result = translator_en2ar(inputs)
30
-
31
- pal_out = result[0]["translation_text"]
32
- sy_out = result[1]["translation_text"] if sy else ""
33
- lb_out = result[1 + sy]["translation_text"] if lb else ""
34
- eg_out = result[1 + sy + lb]["translation_text"] if eg else ""
 
35
 
36
- return pal_out, sy_out, lb_out, eg_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def translate_arabic(input_text):
39
  if not input_text:
@@ -50,64 +173,101 @@ def get_audio(input_text):
50
  speech_synthesis_result = speech_synthesizer.speak_text_async(input_text).get()
51
  return f"{input_text}.wav"
52
 
53
- def get_transliteration(input_text, include=["Transliteration"]):
54
- if "Transliteration" not in include:
55
  return ""
56
  result = transliterator([input_text])
57
  return result[0]["translation_text"]
58
 
59
 
 
 
60
  css = """
61
  #liter textarea, #trans textarea { font-size: 25px;}
62
- #trans textarea { direction: rtl; };
 
 
 
 
 
 
 
 
 
 
 
 
63
  """
64
 
65
- with gr.Blocks(title = "English to Levantine Arabic", css=css, theme="default") as demo:
66
- gr.Markdown("# Levantine Arabic Translator")
67
- with gr.Tab('En -> Ar'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with gr.Row():
69
  with gr.Column():
70
- input_text = gr.Textbox(label="Input", placeholder="Enter English text", lines=1)
71
- gr.Examples(["I wanted to go to the store yesterday, but it rained", "How are you feeling today?", "Let's drink coffee"], input_text)
72
  btn = gr.Button("Translate", label="Translate")
73
  with gr.Row():
74
- include = gr.CheckboxGroup(["Transliteration", "Syrian", "Lebanese", "Egyptian"],
75
  label="Disable features to speed up translation",
76
- value=["Transliteration", "Syrian", "Lebanese", "Egyptian"])
77
  gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il). Pronunciation model is specifically tailored to urban Palestinian Arabic. Text-to-speech uses Microsoft Azure's API and may provide different result from the transliterated pronunciation.")
78
 
79
  with gr.Column():
80
- pal = gr.Textbox(lines=1, label="Palestinian", elem_id="trans")
81
- pal_translit = gr.Textbox(lines=1, label="Palestinian Pronunciation", elem_id="liter")
82
- sy = gr.Textbox(lines=1, label="Syrian", elem_id="trans")
83
- lb = gr.Textbox(lines=1, label="Lebanese", elem_id="trans")
 
 
 
 
84
  eg = gr.Textbox(lines=1, label="Egyptian", elem_id="trans")
85
  with gr.Row():
86
  audio = gr.Audio(label="Audio - Palestinian", interactive=False)
87
  audio_button = gr.Button("Get Audio", label="Click Here to Get Audio")
88
  audio_button.click(get_audio, inputs=[pal], outputs=[audio])
89
- btn.click(translate_english,inputs=[input_text, include], outputs=[pal, sy, lb, eg])
90
- input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal, sy, lb, eg])
91
- pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit])
92
- with gr.Tab('Ar -> En'):
 
 
93
  with gr.Row():
94
  with gr.Column():
95
  input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1, elem_id="trans")
96
- gr.Examples(["ุฎู„ูŠู†ุง ู†ุฏูˆุฑ ุนู„ู‰ ู…ุทุนู… ุชุงู†ูŠ", "ูƒุงู† ุจุฏูŠ ุงูˆูƒู„ ุงุดูŠ ู‚ุจู„ ู…ุง ู†ุฑูˆุญ"], input_text)
97
  btn = gr.Button("Translate", label="Translate")
98
  gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).")
99
  with gr.Column():
100
  eng = gr.Textbox(label="English", lines=1, elem_id="liter")
101
  btn.click(translate_arabic,inputs=input_text, outputs=[eng])
 
102
  with gr.Tab("Transliterate"):
103
  with gr.Row():
104
  with gr.Column():
105
  input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1)
106
- gr.Examples(["ุฎู„ูŠู†ุง ู†ุฏูˆุฑ ุนู„ู‰ ู…ุทุนู… ุชุงู†ูŠ", "ูƒุงู† ุจุฏูŠ ุงูˆูƒู„ ุงุดูŠ ู‚ุจู„ ู…ุง ู†ุฑูˆุญ"], input_text)
107
  btn = gr.Button("Transliterate", label="Transliterate")
108
  gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il)")
109
  with gr.Column():
110
  translit = gr.Textbox(label="Transliteration", lines=1, elem_id="liter")
111
  btn.click(get_transliteration, inputs=input_text, outputs=[translit])
112
 
 
113
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, MarianMTModel, AutoTokenizer
3
  import os
4
  import azure.cognitiveservices.speech as speechsdk
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
 
8
  dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egyptian": "E"}
9
 
10
+ # translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
11
+ translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True)
12
+ tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect")
13
  translator_ar2en = pipeline(task="translation", model="guymorlan/Shami2English")
14
  transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
15
 
16
  speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
17
 
18
+ def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT
19
+ # Generate a colormap with a specified number of colors
20
+ cmap = plt.cm.get_cmap(palette, num_colors)
21
+
22
+ # Get the RGB values of the colors in the colormap
23
+ colors_rgb = cmap(np.arange(num_colors))
24
+
25
+ # Convert the RGB values to hexadecimal color codes
26
+ colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb]
27
+
28
+ return colors_hex
29
+
30
+
31
+ def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4):
32
+ alignment = []
33
+ for i, tok in enumerate(outputs.cross_attentions[2][0][7]):
34
+ alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
35
+
36
+ merged = []
37
+ for i in alignment:
38
+ token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0]
39
+ if token not in tokenizer.convert_tokens_to_ids(["</s>", "<pad>", "<unk>"]):
40
+ if merged:
41
+ tomerge = False
42
+ # check overlap with previous entry
43
+ for x in i[1]:
44
+ if x in merged[-1][1]:# or tokenizer.convert_ids_to_tokens([encoder_input_ids[0][x]])[0][0] != "โ–":
45
+ tomerge = True
46
+ break
47
+ # if first character is not a "โ–"
48
+ if token[0] != "โ–":
49
+ tomerge = True
50
+ if tomerge:
51
+ merged[-1][0] += i[0]
52
+ merged[-1][1] += i[1]
53
+ else:
54
+ merged.append(i)
55
+ else:
56
+ merged.append(i)
57
+
58
+ colordict = {}
59
+ ncolors = 0
60
+ for i in merged:
61
+ src_tok = [f"src_{x}" for x in i[0]]
62
+ trg_tok = [f"trg_{x}" for x in i[1]]
63
+ all_tok = src_tok + trg_tok
64
+ # see if any tokens in entry already have associated color
65
+ newcolor = None
66
+ for t in all_tok:
67
+ if t in colordict:
68
+ newcolor = colordict[t]
69
+ break
70
+ if not newcolor:
71
+ newcolor = ncolors
72
+ ncolors += 1
73
+ for t in all_tok:
74
+ if t not in colordict:
75
+ colordict[t] = newcolor
76
+
77
+ colors = generate_diverging_colors(ncolors, palette="Set2")
78
+ id_to_color = {i: c for i, c in enumerate(colors)}
79
+ for k, v in colordict.items():
80
+ colordict[k] = id_to_color[v]
81
+
82
+
83
+ tgthtml = []
84
+ for i, token in enumerate(decoder_input_ids[0]):
85
+ if f"src_{i}" in colordict:
86
+ label = f"src_{i}"
87
+ tgthtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
88
+ else:
89
+ tgthtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
90
+ tgthtml = "".join(tgthtml)
91
+ tgthtml = tgthtml.replace("โ–", " ")
92
+ tgthtml = f"<span style='font-size: 30px'>{tgthtml}</span>"
93
+
94
+ srchtml = []
95
+ for i, token in enumerate(encoder_input_ids[0]):
96
+ if i == 0:
97
+ continue
98
+ if f"trg_{i}" in colordict:
99
+ label = f"trg_{i}"
100
+ srchtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
101
+ else:
102
+ srchtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
103
+ srchtml = "".join(srchtml)
104
+ srchtml = srchtml.replace("โ–", " ")
105
+ srchtml = f"<span style='font-size: 30px'>{srchtml}</span>"
106
+ return srchtml, tgthtml
107
+
108
  def translate_english(input_text, include):
109
  if not input_text:
110
+ return "", "", "", "", ""
111
 
112
  inputs = [f"{val} {input_text}" for val in dialects.values()]
113
 
 
120
  if not sy:
121
  inputs.pop()
122
 
123
+ input_tokens = tokenizer_en2ar(inputs, return_tensors="pt").input_ids
124
+ # print(input_tokens)
125
+ outputs = translator_en2ar.generate(input_tokens)
126
+ # print(outputs)
127
+
128
+ encoder_input_ids = input_tokens[0].unsqueeze(0)
129
+ decoder_input_ids = outputs[0].unsqueeze(0)
130
 
131
+
132
+ decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True)
133
+ # print(decoded)
134
+ pal_out = decoded[0]
135
+ sy_out = decoded[1] if sy else ""
136
+ lb_out = decoded[1 + sy] if lb else ""
137
+ eg_out = decoded[1 + sy + lb] if eg else ""
138
+
139
+ if "Colorize" in include:
140
+ html_outputs = translator_en2ar(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)
141
+
142
+ # set dynamic threshold
143
+ # print(input_tokens, input_tokens.shape)
144
+ if input_tokens.shape[1] < 10:
145
+ threshold = 0.4
146
+ elif input_tokens.shape[1] < 20:
147
+ threshold = 0.10
148
+ else:
149
+ threshold = 0.05
150
+
151
+ print("threshold", threshold)
152
+
153
+ srchtml, tgthtml = align_words(html_outputs, tokenizer_en2ar, encoder_input_ids, decoder_input_ids, threshold)
154
+ palhtml = f"{srchtml}<br><br><div style='direction: rtl'>{tgthtml}</div>"
155
+ else:
156
+ palhtml = f"<div style='font-size: 30px; direction: rtl'>{pal_out}</div>"
157
+
158
+
159
+ return palhtml, pal_out, sy_out, lb_out, eg_out
160
 
161
  def translate_arabic(input_text):
162
  if not input_text:
 
173
  speech_synthesis_result = speech_synthesizer.speak_text_async(input_text).get()
174
  return f"{input_text}.wav"
175
 
176
+ def get_transliteration(input_text, include=["Translit."]):
177
+ if "Translit." not in include:
178
  return ""
179
  result = transliterator([input_text])
180
  return result[0]["translation_text"]
181
 
182
 
183
+ bla = """
184
+ """
185
  css = """
186
  #liter textarea, #trans textarea { font-size: 25px;}
187
+ #trans textarea { direction: rtl; }
188
+ #check { border-style: none !important; }
189
+ :root {--button-secondary-background-focus: #2563eb !important;
190
+ --button-secondary-background-base: #2563eb !important;
191
+ --button-secondary-background-hover: linear-gradient(to bottom right, #0692e8, #5859c2);
192
+ --button-secondary-text-color-base: white !important;
193
+ --button-secondary-text-color-hover: white !important;
194
+ --button-secondary-background-focus: rgb(51 122 216 / 70%) !important;
195
+ --button-secondary-text-color-focus: white !important}
196
+ .dark {--button-secondary-background-base: #2563eb !important;
197
+ --button-secondary-background-focus: rgb(51 122 216 / 70%) !important;
198
+ --button-secondary-background-hover: linear-gradient(to bottom right, #0692e8, #5859c2)}
199
+ .feather-music { stroke: #2563eb; }
200
  """
201
 
202
+ def toggle_visibility(include):
203
+ outs = [gr.Textbox.update(visible=True)] * 4
204
+ if "Translit." not in include:
205
+ outs[0] = gr.Textbox.update(visible=False)
206
+ if "Syrian" not in include:
207
+ outs[1] = gr.Textbox.update(visible=False)
208
+ if "Lebanese" not in include:
209
+ outs[2] = gr.Textbox.update(visible=False)
210
+ if "Egyptian" not in include:
211
+ outs[3] = gr.Textbox.update(visible=False)
212
+
213
+ return outs
214
+
215
+ with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default") as demo:
216
+
217
+ gr.HTML("<h2><span style='color: #2563eb; font-size: 18px'>Levantine Arabic</span> Translator</h2>")
218
+
219
+ with gr.Tab('En > Ar'):
220
  with gr.Row():
221
  with gr.Column():
222
+ input_text = gr.Textbox(label="Input", placeholder="Enter English text", lines=2)
223
+ gr.Examples(["I wanted to go to the store yesterday, but it rained", "How are you feeling today?"], input_text)
224
  btn = gr.Button("Translate", label="Translate")
225
  with gr.Row():
226
+ include = gr.CheckboxGroup(["Translit.", "SYR", "LEB", "EGY", "Colorize"],
227
  label="Disable features to speed up translation",
228
+ value=["Translit.", "Egyptian", "Colorize"])
229
  gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il). Pronunciation model is specifically tailored to urban Palestinian Arabic. Text-to-speech uses Microsoft Azure's API and may provide different result from the transliterated pronunciation.")
230
 
231
  with gr.Column():
232
+ with gr.Box(label = "Palestinian"):
233
+ gr.Markdown("Palestinian")
234
+ with gr.Box():
235
+ pal_html = gr.HTML("<br>", visible=True, label="Palestinian", elem_id="main")
236
+ pal = gr.Textbox(lines=1, label="Palestinian", elem_id="trans", visible=False)
237
+ pal_translit = gr.Textbox(lines=1, label="Palestinian Pronunciation (Urban)", elem_id="liter")
238
+ sy = gr.Textbox(lines=1, label="Syrian", elem_id="trans", visible=False)
239
+ lb = gr.Textbox(lines=1, label="Lebanese", elem_id="trans", visible=False)
240
  eg = gr.Textbox(lines=1, label="Egyptian", elem_id="trans")
241
  with gr.Row():
242
  audio = gr.Audio(label="Audio - Palestinian", interactive=False)
243
  audio_button = gr.Button("Get Audio", label="Click Here to Get Audio")
244
  audio_button.click(get_audio, inputs=[pal], outputs=[audio])
245
+ btn.click(translate_english,inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg], _js="function jump(x, y){document.getElementById('main').scrollIntoView(); return [x, y];}")
246
+ input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True)
247
+ pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]);
248
+ include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg])
249
+
250
+ with gr.Tab('Ar > En'):
251
  with gr.Row():
252
  with gr.Column():
253
  input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1, elem_id="trans")
254
+ gr.Examples(["ุฎู„ูŠู†ุง ู†ุฏูˆุฑ ุนู„ู‰ ู…ุทุนู… ุชุงู†ูŠ", "ู‚ุฏูŠุด ุญู‚ ุงู„ุจู†ุฏูˆุฑุฉุŸ"], input_text)
255
  btn = gr.Button("Translate", label="Translate")
256
  gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).")
257
  with gr.Column():
258
  eng = gr.Textbox(label="English", lines=1, elem_id="liter")
259
  btn.click(translate_arabic,inputs=input_text, outputs=[eng])
260
+
261
  with gr.Tab("Transliterate"):
262
  with gr.Row():
263
  with gr.Column():
264
  input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1)
265
+ gr.Examples(["ุฎู„ูŠู†ุง ู†ุฏูˆุฑ ุนู„ู‰ ู…ุทุนู… ุชุงู†ูŠ", "ู‚ุฏูŠุด ุญู‚ ุงู„ุจู†ุฏูˆุฑุฉุŸ"], input_text)
266
  btn = gr.Button("Transliterate", label="Transliterate")
267
  gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il)")
268
  with gr.Column():
269
  translit = gr.Textbox(label="Transliteration", lines=1, elem_id="liter")
270
  btn.click(get_transliteration, inputs=input_text, outputs=[translit])
271
 
272
+
273
  demo.launch()