from torch import nn from transformers import CanineModel, CanineForTokenClassification, CaninePreTrainedModel, CanineTokenizer from transformers.modeling_outputs import TokenClassifierOutput import gradio as gr arabic_to_hebrew = { # regular letters "ا": "א", "أ": "א", "إ": "א", "ء": "א", "ئ": "א", "ؤ": "א", "آ": "אא", "ى": "א", "ب": "ב", "ت": "ת", "ث": "ת'", "ج": "ג'", "ح": "ח", "خ": "ח'", "د": "ד", "ذ": "ד'", "ر": "ר", "ز": "ז", "س": "ס", "ش": "ש", "ص": "צ", "ض": "צ'", "ط": "ט", "ظ": "ט'", "ع": "ע", "غ": "ע'", "ف": "פ", "ق": "ק", "ك": "כ", "ل": "ל", "م": "מ", "ن": "נ", "ه": "ה", "و": "ו", "ي": "י", "ة": "ה", # special characters "،": ",", "َ": "ַ", "ُ": "ֻ", "ِ": "ִ", } final_letters = { "ن": "ן", "م": "ם", "ص": "ץ", "ض": "ץ'", "ف": "ף", } def to_taatik(arabic): taatik = [] for index, letter in enumerate(arabic): if ( (index == len(arabic) - 1 or arabic[index + 1] in {" ", ".", "،"}) and letter in final_letters ): taatik.append(final_letters[letter]) elif letter not in arabic_to_hebrew: taatik.append(letter) else: taatik.append(arabic_to_hebrew[letter]) return taatik class TaatikModel(CaninePreTrainedModel): # based on CaninePreTrainedModel # slightly modified for multilabel classification def __init__(self, config, num_labels=7): # Note: one label for each nikud type, plus one for the deletion flag super().__init__(config) config.num_labels = num_labels self.num_labels = config.num_labels self.canine = CanineModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() self.criterion = nn.BCEWithLogitsLoss() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, ): outputs = self.canine( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: # print(logits) # print("-----------") # print(labels) loss = self.criterion(logits, labels) return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) # tokenizer = CanineTokenizer.from_pretrained("google/canine-c") # model = TashkeelModel.from_pretrained("google/canine-c") tokenizer = CanineTokenizer.from_pretrained("google/canine-s") # model = TaatikModel.from_pretrained("google/canine-s") # model = TaatikModel.from_pretrained("./checkpoint-19034/") model = TaatikModel.from_pretrained("guymorlan/Arabic2Taatik") def convert_nikkud_to_harakat(nikkud): labels = [] if "SHADDA" in nikkud: labels.append("SHADDA") if "TSERE" in nikkud: labels.append("KASRA") if "HOLAM" in nikkud: labels.append("DAMMA") if "PATACH" in nikkud: labels.append("FATHA") if "SHVA" in nikkud: labels.append("SUKUN") if "KUBUTZ" in nikkud: labels.append("DAMMA") if "HIRIQ" in nikkud: labels.append("KASRA") return labels def convert_binary_to_labels(binary_labels): labels = [] if binary_labels[0] == 1: labels.append("SHADDA") if binary_labels[1] == 1: labels.append("TSERE") if binary_labels[2] == 1: labels.append("HOLAM") if binary_labels[3] == 1: labels.append("PATACH") if binary_labels[4] == 1: labels.append("SHVA") if binary_labels[5] == 1: labels.append("KUBUTZ") if binary_labels[6] == 1: labels.append("HIRIQ") return labels def convert_label_names_to_chars(label): if label == "SHADDA": return "ّ" if label == "TSERE": return "ֵ" if label == "HOLAM": return "ֹ" if label == "PATACH": return "ַ" if label == "SHVA": return "ְ" if label == "KUBUTZ": return "ֻ" if label == "HIRIQ": return "ִ" # for these, return arabic harakat if label == "DAMMA": return "ُ" if label == "KASRA": return "ِ" if label == "FATHA": return "َ" if label == "SUKUN": return "ْ" return "" def predict(input, prefix = "P "): print(input) input_tok = tokenizer(prefix+input, return_tensors="pt") print(input_tok) outputs = model(**input_tok) print(outputs) labels = outputs.logits.sigmoid().round().int() labels = labels.tolist()[0][3:-1] print(labels) labels_hebrew = [convert_binary_to_labels(x) for x in labels] labels_arabic = [convert_nikkud_to_harakat(x) for x in labels_hebrew] print(f"labels_hebrew: {labels_hebrew}") print(f"labels_arabic: {labels_arabic}") hebrew = [[x] for x in to_taatik(input)] print(hebrew) arabic = [[x] for x in input] print(arabic) print(f"len hebrew: {len(hebrew)}") print(f"len arabic: {len(arabic)}") print(f"len labels_hebrew: {len(labels_hebrew)}") print(f"len labels_arabic: {len(labels_arabic)}") print(f"labels: {labels}") print(f"labels_hebrew: {labels_hebrew}") print(f"labels_arabic: {labels_arabic}") for i in range(len(hebrew)): hebrew[i].extend([convert_label_names_to_chars(x) for x in labels_hebrew[i]]) arabic[i].extend([convert_label_names_to_chars(x) for x in labels_arabic[i]]) hebrew = ["".join(x) for x in hebrew] arabic = ["".join(x) for x in arabic] # loop over hebrew, if there is a ' in the second position move it to last position for i in range(len(hebrew)): if len(hebrew[i]) > 1 and hebrew[i][1] == "'": hebrew[i] = hebrew[i][0] + hebrew[i][2:] + hebrew[i][1] hebrew = "".join(hebrew) arabic = "".join(arabic) return f"

{hebrew}

{arabic}

" font = "Arial Unicode MS, Tahoma, sans-serif" return f"

{hebrew}

{arabic}

" return f"

{hebrew}

{arabic}

" # return f"

{hebrew}

{arabic}

" font_url = "" with gr.Blocks(theme=gr.themes.Soft(), title="Ammiya Diacritizer") as demo: gr.HTML("

Colloquial Arabic

Diacritizer and Hebrew Transliterator" + font_url) with gr.Row(): with gr.Column(): input = gr.Textbox(label="Input", placeholder="Enter Arabic text", lines=1) gr.Examples(["بديش اروح معك"], input) btn = gr.Button(label="Analyze") with gr.Column(): with gr.Box(): html = gr.HTML() btn.click(predict, inputs=[input], outputs=[html]) input.submit(predict, inputs = [input], outputs=[html]) demo.load() demo.launch()