File size: 2,168 Bytes
e13e24e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc285d8
e13e24e
 
 
 
 
 
 
 
 
 
 
dc285d8
e13e24e
 
 
 
 
 
 
37c8501
 
e13e24e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline

model_id = "xlm-roberta-base"
peft_model_id = "rasyosef/xlm-roberta-base-lora-amharic-news-classification"

categories = ['แˆ€แŒˆแˆญ แŠ แ‰€แ แ‹œแŠ“', 'แˆ˜แ‹แŠ“แŠ›', 'แˆตแ–แˆญแ‰ต', 'แ‰ขแ‹แŠแˆต', 'แ‹“แˆˆแˆ แŠ แ‰€แ แ‹œแŠ“', 'แ–แˆˆแ‰ฒแŠซ']
id2label = {i: lbl for i, lbl in enumerate(categories)}
label2id = {lbl: i for i, lbl in enumerate(categories)}

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels=len(categories), # 6
    id2label=id2label,
    label2id=label2id
)

model.load_adapter(peft_model_id)

classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)

def predict(text):
    return classifier([text])[0]


with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Amharic News Article Classification 
        This RoBERTa model (xlm-roberta-base) was finetuned using Low-Rank Adaptation (LoRA) that classifies amharic news articles into one of the following 6 categories.
        - แˆ€แŒˆแˆญ แŠ แ‰€แ แ‹œแŠ“ (Local News)
        - แˆ˜แ‹แŠ“แŠ› (Entertainment)
        - แˆตแ–แˆญแ‰ต (Sports)
        - แ‰ขแ‹แŠแˆต (Business)
        - แ‹“แˆˆแˆ แŠ แ‰€แ แ‹œแŠ“ (International News)
        - แ–แˆˆแ‰ฒแŠซ (Politics)
        """
    )

    with gr.Row():
        with gr.Column():
            input = gr.Textbox(label="Amharic text", placeholder="Enter text here", lines=3)
            classify_btn = gr.Button(value="Classify")
        with gr.Column():
            output = gr.Textbox(label="Predicted class")

    classify_btn.click(predict, inputs=input, outputs=output)
    examples = gr.Examples(
        examples=[
            "แŠขแ‰ตแ‹ฎแŒตแ‹ซ แ•แˆชแˆแ‹จแˆญ แˆŠแŒ 6แŠ› แˆณแˆแŠ•แ‰ต แ‹จแŠฅแˆแ‹ต แŒจแ‹‹แ‰ณแ‹Žแ‰ฝ แ‰…แ‹ตแˆ˜ แ‹ณแˆฐแˆณ",
            "แ‰ แŠ แ„ แ‰ดแ‹Žแ‹ตแˆฎแˆต แ‹จแŠ•แŒแˆตแŠ“ แ‰ฆแ‰ณ แ‹ฐแˆจแˆตแŒŒ แˆ›แˆชแ‹ซแˆ แ‰ฐแŒ€แˆแˆฎ แ‹จแ‰†แˆ˜แ‹ แ‹จแˆ™แ‹šแ‹จแˆ™ แŒแŠ•แ‰ฃแ‰ณ แ‰ฐแŒ แŠ“แ‰€แ‰† แˆตแˆซ แŠฅแŠ•แ‹ฒแŒ€แˆแˆญ แŠแ‹‹แˆชแ‹Žแ‰ฝ แŒ แ‹ญแ‰€แ‹‹แˆแข"
        ],
        inputs=[input],
    )

demo.launch()