F22-NoorHedayat / app.py
Montazer's picture
Duplicate from ArefSadeghian/arabert-finetuned-caner
693378c
import gradio as gr
from transformers import pipeline
import re
HTML_WRAPPER = """<div dir="rtl" style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
# Replace this with above latest checkpoint
model_checkpoint = "ArefSadeghian/arabert-finetuned-caner"
token_classifier = pipeline(
"token-classification", model=model_checkpoint, aggregation_strategy="simple"
)
import re
import unicodedata
diacritics = {
'\u064B': None, # FATHATAN
'\u064C': None, # DAMMATAN
'\u064D': None, # KASRATAN
'\u064E': None, # FATHA
'\u064F': None, # DAMMA
'\u0650': None, # KASRA
'\u0651': None, # SHADDA
'\u0652': None, # SUKUN
}
def remove_diacritics(text):
normalized_text = unicodedata.normalize('NFKD', text)
return normalized_text.translate(dict.fromkeys(map(ord, diacritics)))
def remove_punctuation(text):
return re.sub(r'[^\w\s]', '', text)
def preprocess_arabic_text(text):
# Remove diacritics
text = remove_diacritics(text)
# Remove punctuation
text = remove_punctuation(text)
# Normalize whitespace
text = re.sub(r'\s+', ' ', text)
# Convert to lowercase
text = text.lower()
return text
# Define a function to highlight different labels in the text
def highlight_text(text, entities):
entity_colors = {"Allah": "#ffe5cc", "Book": "#b3daff", "Clan": "#faedcb", "Crime": "#ffb3d9",
"Date": "#cce6ff", "Day": "#cce6ff", "Hell": "#d9d9d9", "Loc": "#d9b3ff",
"Meas": "#e6ccff", "Mon": "#ffd6cc", "Month": "#ffd6cc", "NatOb": "#ffe0b3",
"Number": "#ffe0cc", "Org": "#c1ffb3", "Para": "#f2f2f2", "Pers": "#b3ffb3",
"Prophet": "#e6ccff", "Rlig": "#ffff80", "Sect": "#b3d9ff", "Time": "#ffb3ba"}
highlighted = []
i = 0
for entity in entities:
highlighted.extend(text[i:int(entity['start'])].split())
entity_group = entity['entity_group']
score = entity['score']
marked_text = f'<mark class="{entity_group}" style="background-color: {entity_colors[entity_group]}">{entity["word"]}<sub>{entity_group}</sub><sup>{score:.2f}</sup></mark>'
highlighted.append(marked_text)
i = int(entity['end']) + 1
highlighted.extend(text[i:].split())
return HTML_WRAPPER.format(' '.join(highlighted))
# Create the Gradio interface
def predict_ner(text):
try:
text = preprocess_arabic_text(text)
entities = token_classifier(text)
highlighted_text = highlight_text(text, entities)
return highlighted_text + '\n\n' + str(entities)
except Exception as e:
print(e)
return str(e)
iface = gr.Interface(
fn=predict_ner,
inputs=gr.inputs.Textbox(label="Enter Hadith in Arabic"),
outputs=gr.outputs.HTML(label="Predicted Labels"),
title="Hadith Analysis"
)
# Launch the interface
iface.launch()