import gradio as gr from transformers import pipeline, MarianMTModel, AutoTokenizer import os import azure.cognitiveservices.speech as speechsdk import matplotlib.pyplot as plt import numpy as np dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egyptian": "E"} # translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect") translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True) tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect") translator_ar2en = MarianMTModel.from_pretrained("guymorlan/Shami2English", output_attentions=True) tokenizer_ar2en = AutoTokenizer.from_pretrained("guymorlan/Shami2English") transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator") speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION')) def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT # Generate a colormap with a specified number of colors cmap = plt.cm.get_cmap(palette, num_colors) # Get the RGB values of the colors in the colormap colors_rgb = cmap(np.arange(num_colors)) # Convert the RGB values to hexadecimal color codes colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb] return colors_hex def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4, skip_first_src=True): alignment = [] for i, tok in enumerate(outputs.cross_attentions[2][0][7]): alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()]) merged = [] for i in alignment: token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0] if token not in tokenizer.convert_tokens_to_ids(["", "", ""]): if merged: tomerge = False # check overlap with previous entry for x in i[1]: if x in merged[-1][1]:# or tokenizer.convert_ids_to_tokens([encoder_input_ids[0][x]])[0][0] != "▁": tomerge = True break # if first character is not a "▁" if token[0] != "▁": tomerge = True if tomerge: merged[-1][0] += i[0] merged[-1][1] += i[1] else: merged.append(i) else: merged.append(i) colordict = {} ncolors = 0 for i in merged: src_tok = [f"src_{x}" for x in i[0]] trg_tok = [f"trg_{x}" for x in i[1]] all_tok = src_tok + trg_tok # see if any tokens in entry already have associated color newcolor = None for t in all_tok: if t in colordict: newcolor = colordict[t] break if not newcolor: newcolor = ncolors ncolors += 1 for t in all_tok: if t not in colordict: colordict[t] = newcolor colors = generate_diverging_colors(ncolors, palette="Set2") id_to_color = {i: c for i, c in enumerate(colors)} for k, v in colordict.items(): colordict[k] = id_to_color[v] tgthtml = [] for i, token in enumerate(decoder_input_ids[0]): if f"src_{i}" in colordict: label = f"src_{i}" tgthtml.append(f"{tokenizer.convert_ids_to_tokens([token])[0]}") else: tgthtml.append(f"{tokenizer.convert_ids_to_tokens([token])[0]}") tgthtml = "".join(tgthtml) tgthtml = tgthtml.replace("▁", " ") tgthtml = f"{tgthtml}" srchtml = [] for i, token in enumerate(encoder_input_ids[0]): if skip_first_src and i == 0: continue if f"trg_{i}" in colordict: label = f"trg_{i}" srchtml.append(f"{tokenizer.convert_ids_to_tokens([token])[0]}") else: srchtml.append(f"{tokenizer.convert_ids_to_tokens([token])[0]}") srchtml = "".join(srchtml) srchtml = srchtml.replace("▁", " ") srchtml = f"{srchtml}" return srchtml, tgthtml def translate_english(input_text, include): if not input_text: return "", "", "", "", "" inputs = [f"{val} {input_text}" for val in dialects.values()] sy, lb, eg = "SYR" in include, "LEB" in include, "EGY" in include # remove 2nd element if sy is false if not eg: inputs.pop() if not lb: inputs.pop() if not sy: inputs.pop() input_tokens = tokenizer_en2ar(inputs, return_tensors="pt").input_ids # print(input_tokens) outputs = translator_en2ar.generate(input_tokens) # print(outputs) encoder_input_ids = input_tokens[0].unsqueeze(0) decoder_input_ids = outputs[0].unsqueeze(0) decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True) # print(decoded) pal_out = decoded[0] sy_out = decoded[1] if sy else "" lb_out = decoded[1 + sy] if lb else "" eg_out = decoded[1 + sy + lb] if eg else "" if "Colorize" in include: html_outputs = translator_en2ar(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids) # set dynamic threshold # print(input_tokens, input_tokens.shape) if input_tokens.shape[1] < 10: threshold = 0.4 elif input_tokens.shape[1] < 20: threshold = 0.10 else: threshold = 0.05 print("threshold", threshold) srchtml, tgthtml = align_words(html_outputs, tokenizer_en2ar, encoder_input_ids, decoder_input_ids, threshold) palhtml = f"{srchtml}

{tgthtml}
" else: palhtml = f"
{pal_out}
" return palhtml, pal_out, sy_out, lb_out, eg_out def translate_arabic(input_text, include=["Colorize"]): if not input_text: return "" input_tokens = tokenizer_ar2en(input_text, return_tensors="pt").input_ids # print(input_tokens) outputs = translator_ar2en.generate(input_tokens) # print(outputs) encoder_input_ids = input_tokens[0].unsqueeze(0) decoder_input_ids = outputs[0].unsqueeze(0) decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True) # print(decoded) print(include) if "Colorize" in include: html_outputs = translator_ar2en(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids) # set dynamic threshold # print(input_tokens, input_tokens.shape) if input_tokens.shape[1] < 20: threshold = 0.1 elif input_tokens.shape[1] < 30: threshold = 0.01 else: threshold = 0.05 print("threshold", threshold) srchtml, tgthtml = align_words(html_outputs, tokenizer_ar2en, encoder_input_ids, decoder_input_ids, threshold, skip_first_src=False) enhtml = f"
{srchtml}


{tgthtml}
" else: enhtml = f"
{decoded[0]}
" return enhtml def get_audio(input_text): audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav") speech_config.speech_synthesis_voice_name='ar-SY-AmanyNeural' speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config) speech_synthesis_result = speech_synthesizer.speak_text_async(input_text).get() return f"{input_text}.wav" def get_transliteration(input_text, include=["Translit."]): if "Translit." not in include: return "" result = transliterator([input_text]) return result[0]["translation_text"] bla = """ """ css = """ #liter textarea, #trans textarea { font-size: 25px;} #trans textarea { direction: rtl; } #check { border-style: none !important; } :root {--button-secondary-background-focus: #2563eb !important; --button-secondary-background-base: #2563eb !important; --button-secondary-background-hover: linear-gradient(to bottom right, #0692e8, #5859c2); --button-secondary-text-color-base: white !important; --button-secondary-text-color-hover: white !important; --button-secondary-background-focus: rgb(51 122 216 / 70%) !important; --button-secondary-text-color-focus: white !important} .dark {--button-secondary-background-base: #2563eb !important; --button-secondary-background-focus: rgb(51 122 216 / 70%) !important; --button-secondary-background-hover: linear-gradient(to bottom right, #0692e8, #5859c2)} .feather-music { stroke: #2563eb; } """ def toggle_visibility(include): outs = [gr.Textbox.update(visible=True)] * 4 if "Translit." not in include: outs[0] = gr.Textbox.update(visible=False) if "SYR" not in include: outs[1] = gr.Textbox.update(visible=False) if "LEB" not in include: outs[2] = gr.Textbox.update(visible=False) if "EGY" not in include: outs[3] = gr.Textbox.update(visible=False) return outs with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default") as demo: gr.HTML("

Levantine Arabic Translator

") with gr.Tab('En > Ar'): with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Input", placeholder="Enter English text", lines=2) gr.Examples(["I wanted to go to the store yesterday, but it rained", "How are you feeling today?"], input_text) btn = gr.Button("Translate", label="Translate") with gr.Row(): include = gr.CheckboxGroup(["Translit.", "SYR", "LEB", "EGY", "Colorize"], label="Disable features to speed up translation", value=["Translit.", "EGY", "Colorize"]) gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il). Pronunciation model is specifically tailored to urban Palestinian Arabic. Text-to-speech uses Microsoft Azure's API and may provide different result from the transliterated pronunciation.") with gr.Column(): with gr.Box(label = "Palestinian"): gr.Markdown("Palestinian") with gr.Box(): pal_html = gr.HTML("
", visible=True, label="Palestinian", elem_id="main") pal = gr.Textbox(lines=1, label="Palestinian", elem_id="trans", visible=False) pal_translit = gr.Textbox(lines=1, label="Palestinian Pronunciation (Urban)", elem_id="liter") sy = gr.Textbox(lines=1, label="Syrian", elem_id="trans", visible=False) lb = gr.Textbox(lines=1, label="Lebanese", elem_id="trans", visible=False) eg = gr.Textbox(lines=1, label="Egyptian", elem_id="trans") # with gr.Row(): audio = gr.Audio(label="Audio - Palestinian", interactive=False) audio_button = gr.Button("Get Audio", label="Click Here to Get Audio") audio_button.click(get_audio, inputs=[pal], outputs=[audio]) btn.click(translate_english,inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg], api_name="en2ar", _js="function jump(x, y){document.getElementById('main').scrollIntoView(); return [x, y];}") input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True) pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]); include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg]) with gr.Tab('Ar > En'): with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1, elem_id="trans") gr.Examples(["خلينا ندور على مطعم تاني", "قديش حق البندورة؟"], input_text) btn = gr.Button("Translate", label="Translate") gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).") with gr.Column(): with gr.Box(label = "English"): gr.Markdown("English") with gr.Box(): eng = gr.HTML("
", label="English", elem_id="main") btn.click(translate_arabic,inputs=input_text, outputs=[eng], api_name = "ar2en") with gr.Tab("Transliterate"): with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1) gr.Examples(["خلينا ندور على مطعم تاني", "قديش حق البندورة؟"], input_text) btn = gr.Button("Transliterate", label="Transliterate") gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il)") with gr.Column(): translit = gr.Textbox(label="Transliteration", lines=1, elem_id="liter") btn.click(get_transliteration, inputs=input_text, outputs=[translit]) demo.launch()