Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
import re | |
HTML_WRAPPER = """<div dir="rtl" style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>""" | |
# Replace this with above latest checkpoint | |
model_checkpoint = "ArefSadeghian/arabert-finetuned-caner" | |
token_classifier = pipeline( | |
"token-classification", model=model_checkpoint, aggregation_strategy="simple" | |
) | |
import re | |
import unicodedata | |
diacritics = { | |
'\u064B': None, # FATHATAN | |
'\u064C': None, # DAMMATAN | |
'\u064D': None, # KASRATAN | |
'\u064E': None, # FATHA | |
'\u064F': None, # DAMMA | |
'\u0650': None, # KASRA | |
'\u0651': None, # SHADDA | |
'\u0652': None, # SUKUN | |
} | |
def remove_diacritics(text): | |
normalized_text = unicodedata.normalize('NFKD', text) | |
return normalized_text.translate(dict.fromkeys(map(ord, diacritics))) | |
def remove_punctuation(text): | |
return re.sub(r'[^\w\s]', '', text) | |
def preprocess_arabic_text(text): | |
# Remove diacritics | |
text = remove_diacritics(text) | |
# Remove punctuation | |
text = remove_punctuation(text) | |
# Normalize whitespace | |
text = re.sub(r'\s+', ' ', text) | |
# Convert to lowercase | |
text = text.lower() | |
return text | |
# Define a function to highlight different labels in the text | |
def highlight_text(text, entities): | |
entity_colors = {"Allah": "#ffe5cc", "Book": "#b3daff", "Clan": "#faedcb", "Crime": "#ffb3d9", | |
"Date": "#cce6ff", "Day": "#cce6ff", "Hell": "#d9d9d9", "Loc": "#d9b3ff", | |
"Meas": "#e6ccff", "Mon": "#ffd6cc", "Month": "#ffd6cc", "NatOb": "#ffe0b3", | |
"Number": "#ffe0cc", "Org": "#c1ffb3", "Para": "#f2f2f2", "Pers": "#b3ffb3", | |
"Prophet": "#e6ccff", "Rlig": "#ffff80", "Sect": "#b3d9ff", "Time": "#ffb3ba"} | |
highlighted = [] | |
i = 0 | |
for entity in entities: | |
highlighted.extend(text[i:int(entity['start'])].split()) | |
entity_group = entity['entity_group'] | |
score = entity['score'] | |
marked_text = f'<mark class="{entity_group}" style="background-color: {entity_colors[entity_group]}">{entity["word"]}<sub>{entity_group}</sub><sup>{score:.2f}</sup></mark>' | |
highlighted.append(marked_text) | |
i = int(entity['end']) + 1 | |
highlighted.extend(text[i:].split()) | |
return HTML_WRAPPER.format(' '.join(highlighted)) | |
# Create the Gradio interface | |
def predict_ner(text): | |
try: | |
text = preprocess_arabic_text(text) | |
entities = token_classifier(text) | |
highlighted_text = highlight_text(text, entities) | |
return highlighted_text + '\n\n' + str(entities) | |
except Exception as e: | |
print(e) | |
return str(e) | |
iface = gr.Interface( | |
fn=predict_ner, | |
inputs=gr.inputs.Textbox(label="Enter Hadith in Arabic"), | |
outputs=gr.outputs.HTML(label="Predicted Labels"), | |
title="Hadith Analysis" | |
) | |
# Launch the interface | |
iface.launch() |