File size: 2,979 Bytes
693378c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from transformers import pipeline
import re

HTML_WRAPPER = """<div dir="rtl" style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""

# Replace this with above latest checkpoint
model_checkpoint = "ArefSadeghian/arabert-finetuned-caner"
token_classifier = pipeline(
    "token-classification", model=model_checkpoint, aggregation_strategy="simple"
)

import re
import unicodedata

diacritics = {
    '\u064B': None,   # FATHATAN
    '\u064C': None,   # DAMMATAN
    '\u064D': None,   # KASRATAN
    '\u064E': None,   # FATHA
    '\u064F': None,   # DAMMA
    '\u0650': None,   # KASRA
    '\u0651': None,   # SHADDA
    '\u0652': None,   # SUKUN
}

def remove_diacritics(text):
    normalized_text = unicodedata.normalize('NFKD', text)
    return normalized_text.translate(dict.fromkeys(map(ord, diacritics)))

def remove_punctuation(text):
    return re.sub(r'[^\w\s]', '', text)

def preprocess_arabic_text(text):
    # Remove diacritics
    text = remove_diacritics(text)
    
    # Remove punctuation
    text = remove_punctuation(text)
    
    # Normalize whitespace
    text = re.sub(r'\s+', ' ', text)
    
    # Convert to lowercase
    text = text.lower()

    return text


# Define a function to highlight different labels in the text
def highlight_text(text, entities):
    entity_colors = {"Allah": "#ffe5cc", "Book": "#b3daff", "Clan": "#faedcb", "Crime": "#ffb3d9",
                     "Date": "#cce6ff", "Day": "#cce6ff", "Hell": "#d9d9d9", "Loc": "#d9b3ff",
                     "Meas": "#e6ccff", "Mon": "#ffd6cc", "Month": "#ffd6cc", "NatOb": "#ffe0b3",
                     "Number": "#ffe0cc", "Org": "#c1ffb3", "Para": "#f2f2f2", "Pers": "#b3ffb3",
                     "Prophet": "#e6ccff", "Rlig": "#ffff80", "Sect": "#b3d9ff", "Time": "#ffb3ba"}
    highlighted = []
    i = 0
    for entity in entities:
        highlighted.extend(text[i:int(entity['start'])].split())
        entity_group = entity['entity_group']
        score = entity['score']
        marked_text = f'<mark class="{entity_group}" style="background-color: {entity_colors[entity_group]}">{entity["word"]}<sub>{entity_group}</sub><sup>{score:.2f}</sup></mark>'
        highlighted.append(marked_text)
        i = int(entity['end']) + 1
    highlighted.extend(text[i:].split())
    return HTML_WRAPPER.format(' '.join(highlighted))


# Create the Gradio interface
def predict_ner(text):
    try:
        text = preprocess_arabic_text(text)
        entities = token_classifier(text)
        highlighted_text = highlight_text(text, entities)
        return highlighted_text + '\n\n' + str(entities)
    except Exception as e:
        print(e)
        return str(e)


iface = gr.Interface(
    fn=predict_ner, 
    inputs=gr.inputs.Textbox(label="Enter Hadith in Arabic"), 
    outputs=gr.outputs.HTML(label="Predicted Labels"),
    title="Hadith Analysis"
)

# Launch the interface
iface.launch()