guymorlan's picture
Update app.py
fff6c5b
import gradio as gr
from transformers import pipeline, MarianMTModel, AutoTokenizer
import os
import azure.cognitiveservices.speech as speechsdk
import matplotlib.pyplot as plt
import numpy as np
dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egyptian": "E"}
# translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True)
tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect")
translator_ar2en = MarianMTModel.from_pretrained("guymorlan/Shami2English", output_attentions=True)
tokenizer_ar2en = AutoTokenizer.from_pretrained("guymorlan/Shami2English")
transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT
# Generate a colormap with a specified number of colors
cmap = plt.cm.get_cmap(palette, num_colors)
# Get the RGB values of the colors in the colormap
colors_rgb = cmap(np.arange(num_colors))
# Convert the RGB values to hexadecimal color codes
colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb]
return colors_hex
def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4, skip_first_src=True):
alignment = []
for i, tok in enumerate(outputs.cross_attentions[2][0][7]):
alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
merged = []
for i in alignment:
token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0]
if token not in tokenizer.convert_tokens_to_ids(["</s>", "<pad>", "<unk>"]):
if merged:
tomerge = False
# check overlap with previous entry
for x in i[1]:
if x in merged[-1][1]:# or tokenizer.convert_ids_to_tokens([encoder_input_ids[0][x]])[0][0] != "โ–":
tomerge = True
break
# if first character is not a "โ–"
if token[0] != "โ–":
tomerge = True
if tomerge:
merged[-1][0] += i[0]
merged[-1][1] += i[1]
else:
merged.append(i)
else:
merged.append(i)
colordict = {}
ncolors = 0
for i in merged:
src_tok = [f"src_{x}" for x in i[0]]
trg_tok = [f"trg_{x}" for x in i[1]]
all_tok = src_tok + trg_tok
# see if any tokens in entry already have associated color
newcolor = None
for t in all_tok:
if t in colordict:
newcolor = colordict[t]
break
if not newcolor:
newcolor = ncolors
ncolors += 1
for t in all_tok:
if t not in colordict:
colordict[t] = newcolor
colors = generate_diverging_colors(ncolors, palette="Set2")
id_to_color = {i: c for i, c in enumerate(colors)}
for k, v in colordict.items():
colordict[k] = id_to_color[v]
tgthtml = []
for i, token in enumerate(decoder_input_ids[0]):
if f"src_{i}" in colordict:
label = f"src_{i}"
tgthtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
else:
tgthtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
tgthtml = "".join(tgthtml)
tgthtml = tgthtml.replace("โ–", " ")
tgthtml = f"<span style='font-size: 30px'>{tgthtml}</span>"
srchtml = []
for i, token in enumerate(encoder_input_ids[0]):
if skip_first_src and i == 0:
continue
if f"trg_{i}" in colordict:
label = f"trg_{i}"
srchtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
else:
srchtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>")
srchtml = "".join(srchtml)
srchtml = srchtml.replace("โ–", " ")
srchtml = f"<span style='font-size: 30px'>{srchtml}</span>"
return srchtml, tgthtml
def translate_english(input_text, include):
if not input_text:
return "", "", "", "", ""
inputs = [f"{val} {input_text}" for val in dialects.values()]
sy, lb, eg = "SYR" in include, "LEB" in include, "EGY" in include
# remove 2nd element if sy is false
if not eg:
inputs.pop()
if not lb:
inputs.pop()
if not sy:
inputs.pop()
input_tokens = tokenizer_en2ar(inputs, return_tensors="pt").input_ids
# print(input_tokens)
outputs = translator_en2ar.generate(input_tokens)
# print(outputs)
encoder_input_ids = input_tokens[0].unsqueeze(0)
decoder_input_ids = outputs[0].unsqueeze(0)
decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True)
# print(decoded)
pal_out = decoded[0]
sy_out = decoded[1] if sy else ""
lb_out = decoded[1 + sy] if lb else ""
eg_out = decoded[1 + sy + lb] if eg else ""
if "Colorize" in include:
html_outputs = translator_en2ar(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)
# set dynamic threshold
# print(input_tokens, input_tokens.shape)
if input_tokens.shape[1] < 10:
threshold = 0.4
elif input_tokens.shape[1] < 20:
threshold = 0.10
else:
threshold = 0.05
print("threshold", threshold)
srchtml, tgthtml = align_words(html_outputs, tokenizer_en2ar, encoder_input_ids, decoder_input_ids, threshold)
palhtml = f"{srchtml}<br><br><div style='direction: rtl'>{tgthtml}</div>"
else:
palhtml = f"<div style='font-size: 30px; direction: rtl'>{pal_out}</div>"
return palhtml, pal_out, sy_out, lb_out, eg_out
def translate_arabic(input_text, include=["Colorize"]):
if not input_text:
return ""
input_tokens = tokenizer_ar2en(input_text, return_tensors="pt").input_ids
# print(input_tokens)
outputs = translator_ar2en.generate(input_tokens)
# print(outputs)
encoder_input_ids = input_tokens[0].unsqueeze(0)
decoder_input_ids = outputs[0].unsqueeze(0)
decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True)
# print(decoded)
print(include)
if "Colorize" in include:
html_outputs = translator_ar2en(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)
# set dynamic threshold
# print(input_tokens, input_tokens.shape)
if input_tokens.shape[1] < 20:
threshold = 0.1
elif input_tokens.shape[1] < 30:
threshold = 0.01
else:
threshold = 0.05
print("threshold", threshold)
srchtml, tgthtml = align_words(html_outputs, tokenizer_ar2en, encoder_input_ids, decoder_input_ids, threshold, skip_first_src=False)
enhtml = f"<div style='direction: rtl'>{srchtml}</div><br><br><div>{tgthtml}</div>"
else:
enhtml = f"<div style='font-size: 30px;'>{decoded[0]}</div>"
return enhtml
def get_audio(input_text):
audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav")
speech_config.speech_synthesis_voice_name='ar-SY-AmanyNeural'
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config)
speech_synthesis_result = speech_synthesizer.speak_text_async(input_text).get()
return f"{input_text}.wav"
def get_transliteration(input_text, include=["Translit."]):
if "Translit." not in include:
return ""
result = transliterator([input_text])
return result[0]["translation_text"]
bla = """
"""
css = """
#liter textarea, #trans textarea { font-size: 25px;}
#trans textarea { direction: rtl; }
#check { border-style: none !important; }
:root {--button-secondary-background-focus: #2563eb !important;
--button-secondary-background-base: #2563eb !important;
--button-secondary-background-hover: linear-gradient(to bottom right, #0692e8, #5859c2);
--button-secondary-text-color-base: white !important;
--button-secondary-text-color-hover: white !important;
--button-secondary-background-focus: rgb(51 122 216 / 70%) !important;
--button-secondary-text-color-focus: white !important}
.dark {--button-secondary-background-base: #2563eb !important;
--button-secondary-background-focus: rgb(51 122 216 / 70%) !important;
--button-secondary-background-hover: linear-gradient(to bottom right, #0692e8, #5859c2)}
.feather-music { stroke: #2563eb; }
"""
def toggle_visibility(include):
outs = [gr.Textbox.update(visible=True)] * 4
if "Translit." not in include:
outs[0] = gr.Textbox.update(visible=False)
if "SYR" not in include:
outs[1] = gr.Textbox.update(visible=False)
if "LEB" not in include:
outs[2] = gr.Textbox.update(visible=False)
if "EGY" not in include:
outs[3] = gr.Textbox.update(visible=False)
return outs
with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default") as demo:
gr.HTML("<h2><span style='color: #2563eb'>Levantine Arabic</span> Translator</h2>")
with gr.Tab('En > Ar'):
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input", placeholder="Enter English text", lines=2)
gr.Examples(["I wanted to go to the store yesterday, but it rained", "How are you feeling today?"], input_text)
btn = gr.Button("Translate", label="Translate")
with gr.Row():
include = gr.CheckboxGroup(["Translit.", "SYR", "LEB", "EGY", "Colorize"],
label="Disable features to speed up translation",
value=["Translit.", "EGY", "Colorize"])
gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il). Pronunciation model is specifically tailored to urban Palestinian Arabic. Text-to-speech uses Microsoft Azure's API and may provide different result from the transliterated pronunciation.")
with gr.Column():
with gr.Box(label = "Palestinian"):
gr.Markdown("Palestinian")
with gr.Box():
pal_html = gr.HTML("<br>", visible=True, label="Palestinian", elem_id="main")
pal = gr.Textbox(lines=1, label="Palestinian", elem_id="trans", visible=False)
pal_translit = gr.Textbox(lines=1, label="Palestinian Pronunciation (Urban)", elem_id="liter")
sy = gr.Textbox(lines=1, label="Syrian", elem_id="trans", visible=False)
lb = gr.Textbox(lines=1, label="Lebanese", elem_id="trans", visible=False)
eg = gr.Textbox(lines=1, label="Egyptian", elem_id="trans")
# with gr.Row():
audio = gr.Audio(label="Audio - Palestinian", interactive=False)
audio_button = gr.Button("Get Audio", label="Click Here to Get Audio")
audio_button.click(get_audio, inputs=[pal], outputs=[audio])
btn.click(translate_english,inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg], api_name="en2ar", _js="function jump(x, y){document.getElementById('main').scrollIntoView(); return [x, y];}")
input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True)
pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]);
include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg])
with gr.Tab('Ar > En'):
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1, elem_id="trans")
gr.Examples(["ุฎู„ูŠู†ุง ู†ุฏูˆุฑ ุนู„ู‰ ู…ุทุนู… ุชุงู†ูŠ", "ู‚ุฏูŠุด ุญู‚ ุงู„ุจู†ุฏูˆุฑุฉุŸ"], input_text)
btn = gr.Button("Translate", label="Translate")
gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).")
with gr.Column():
with gr.Box(label = "English"):
gr.Markdown("English")
with gr.Box():
eng = gr.HTML("<br>", label="English", elem_id="main")
btn.click(translate_arabic,inputs=input_text, outputs=[eng], api_name = "ar2en")
with gr.Tab("Transliterate"):
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1)
gr.Examples(["ุฎู„ูŠู†ุง ู†ุฏูˆุฑ ุนู„ู‰ ู…ุทุนู… ุชุงู†ูŠ", "ู‚ุฏูŠุด ุญู‚ ุงู„ุจู†ุฏูˆุฑุฉุŸ"], input_text)
btn = gr.Button("Transliterate", label="Transliterate")
gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il)")
with gr.Column():
translit = gr.Textbox(label="Transliteration", lines=1, elem_id="liter")
btn.click(get_transliteration, inputs=input_text, outputs=[translit])
demo.launch()