File size: 2,391 Bytes
e13e24e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline

model_id = "xlm-roberta-base"
peft_model_id = "rasyosef/xlm-roberta-base-lora-amharic-news-classification"

categories = ['แˆ€แŒˆแˆญ แŠ แ‰€แ แ‹œแŠ“', 'แˆ˜แ‹แŠ“แŠ›', 'แˆตแ–แˆญแ‰ต', 'แ‰ขแ‹แŠแˆต', 'แ‹“แˆˆแˆ แŠ แ‰€แ แ‹œแŠ“', 'แ–แˆˆแ‰ฒแŠซ']
id2label = {i: lbl for i, lbl in enumerate(categories)}
label2id = {lbl: i for i, lbl in enumerate(categories)}

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels=len(categories), # 6
    id2label=id2label,
    label2id=label2id
)

model.load_adapter(peft_model_id)

classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)

def predict(text):
    return classifier([text])[0]


with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Amharic News Article Classification 
        This is A finetuned RoBERTa model (xlm-roberta-base) that classifies amharic news articles into one of 6 categories.
        - แˆ€แŒˆแˆญ แŠ แ‰€แ แ‹œแŠ“ (Local News)
        - แˆ˜แ‹แŠ“แŠ› (Entertainment)
        - แˆตแ–แˆญแ‰ต (Sports)
        - แ‰ขแ‹แŠแˆต (Business)
        - แ‹“แˆˆแˆ แŠ แ‰€แ แ‹œแŠ“ (International News)
        - แ–แˆˆแ‰ฒแŠซ (Politics)
        """
    )

    with gr.Row():
        with gr.Column():
            input = gr.Textbox(label="Amharic text")
            classify_btn = gr.Button(value="Classify")
        with gr.Column():
            output = gr.Textbox(label="Predicted class")

    classify_btn.click(predict, inputs=input, outputs=output)
    examples = gr.Examples(
        examples=[
            """แ‰ แŠ แ„ แ‰ดแ‹Žแ‹ตแˆฎแˆต แ‹จแŠ•แŒแˆตแŠ“ แ‰ฆแ‰ณ แ‹ฐแˆจแˆตแŒŒ แˆ›แˆชแ‹ซแˆ แ‰ฐแŒ€แˆแˆฎ แ‹จแ‰†แˆ˜แ‹ แ‹จแˆ™แ‹šแ‹จแˆ™ แŒแŠ•แ‰ฃแ‰ณ แ‰ฐแŒ แŠ“แ‰€แ‰† แˆตแˆซ แŠฅแŠ•แ‹ฒแŒ€แˆแˆญ แŠแ‹‹แˆชแ‹Žแ‰ฝ แŒ แ‹ญแ‰€แ‹‹แˆแข แ‹˜แˆ˜แŠ แˆ˜แˆณแแŠ•แ‰ต แˆ˜แ‰‹แŒซ แ‹ซแŒˆแŠ˜แ‰ฃแ‰ต แ‹จแŠขแ‰ตแ‹ฎแŒตแ‹ซ แŠ แŠ•แ‹ตแŠแ‰ต แ‹จแ‰ณแ‹ˆแŒ€แ‰ฃแ‰ต แ‹ณแŒแˆ›แ‹Š แŠ แ„ แ‰ดแ‹Žแ‹ตแˆฎแˆต แŠจแˆ˜แŠ•แŒˆแˆณแ‰ธแ‹ แ‰ แŠแ‰ต แ‹ฐแŒƒแ‰ฝ แ‹แ‰คแŠ• แ‰งแˆ‚แ‰ต แŠจแˆšแ‰ฃแˆ แ‰ฆแ‰ณ แ‹ตแˆ แŠ แ‹ตแˆญแŒˆแ‹ แ‹ฐแŒƒแ‰ฝ แ‹แ‰ค แˆˆแŠ•แŒแˆตแŠ“ แ‰ฃแ‹˜แŒ‹แŒแ‰ต แ‹จแŠ•แŒแˆตแŠ“ แ‰ฆแ‰ณแŠ“ แŠฅแ‰ƒแ‹Žแ‰ฝ แŠ•แŒ‰แˆฐ แŠแŒˆแˆตแ‰ต แ‹˜แŠขแ‰ตแ‹ฎแŒตแ‹ซ แ‰ฐแ‰ฅแˆˆแ‹ แ‹จแŠแŒˆแˆฑแ‰ฃแ‰ต แ‰ฆแ‰ณ แŠ“แ‰ตแข"""
        ],
        inputs=[input],
    )

demo.launch()