rasyosef commited on
Commit
e13e24e
โ€ข
1 Parent(s): 019b401

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
3
+
4
+ model_id = "xlm-roberta-base"
5
+ peft_model_id = "rasyosef/xlm-roberta-base-lora-amharic-news-classification"
6
+
7
+ categories = ['แˆ€แŒˆแˆญ แŠ แ‰€แ แ‹œแŠ“', 'แˆ˜แ‹แŠ“แŠ›', 'แˆตแ–แˆญแ‰ต', 'แ‰ขแ‹แŠแˆต', 'แ‹“แˆˆแˆ แŠ แ‰€แ แ‹œแŠ“', 'แ–แˆˆแ‰ฒแŠซ']
8
+ id2label = {i: lbl for i, lbl in enumerate(categories)}
9
+ label2id = {lbl: i for i, lbl in enumerate(categories)}
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+
13
+ model = AutoModelForSequenceClassification.from_pretrained(
14
+ model_id,
15
+ num_labels=len(categories), # 6
16
+ id2label=id2label,
17
+ label2id=label2id
18
+ )
19
+
20
+ model.load_adapter(peft_model_id)
21
+
22
+ classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
23
+
24
+ def predict(text):
25
+ return classifier([text])[0]
26
+
27
+
28
+ with gr.Blocks() as demo:
29
+ gr.Markdown(
30
+ """
31
+ # Amharic News Article Classification
32
+ This is A finetuned RoBERTa model (xlm-roberta-base) that classifies amharic news articles into one of 6 categories.
33
+ - แˆ€แŒˆแˆญ แŠ แ‰€แ แ‹œแŠ“ (Local News)
34
+ - แˆ˜แ‹แŠ“แŠ› (Entertainment)
35
+ - แˆตแ–แˆญแ‰ต (Sports)
36
+ - แ‰ขแ‹แŠแˆต (Business)
37
+ - แ‹“แˆˆแˆ แŠ แ‰€แ แ‹œแŠ“ (International News)
38
+ - แ–แˆˆแ‰ฒแŠซ (Politics)
39
+ """
40
+ )
41
+
42
+ with gr.Row():
43
+ with gr.Column():
44
+ input = gr.Textbox(label="Amharic text")
45
+ classify_btn = gr.Button(value="Classify")
46
+ with gr.Column():
47
+ output = gr.Textbox(label="Predicted class")
48
+
49
+ classify_btn.click(predict, inputs=input, outputs=output)
50
+ examples = gr.Examples(
51
+ examples=[
52
+ """แ‰ แŠ แ„ แ‰ดแ‹Žแ‹ตแˆฎแˆต แ‹จแŠ•แŒแˆตแŠ“ แ‰ฆแ‰ณ แ‹ฐแˆจแˆตแŒŒ แˆ›แˆชแ‹ซแˆ แ‰ฐแŒ€แˆแˆฎ แ‹จแ‰†แˆ˜แ‹ แ‹จแˆ™แ‹šแ‹จแˆ™ แŒแŠ•แ‰ฃแ‰ณ แ‰ฐแŒ แŠ“แ‰€แ‰† แˆตแˆซ แŠฅแŠ•แ‹ฒแŒ€แˆแˆญ แŠแ‹‹แˆชแ‹Žแ‰ฝ แŒ แ‹ญแ‰€แ‹‹แˆแข แ‹˜แˆ˜แŠ แˆ˜แˆณแแŠ•แ‰ต แˆ˜แ‰‹แŒซ แ‹ซแŒˆแŠ˜แ‰ฃแ‰ต แ‹จแŠขแ‰ตแ‹ฎแŒตแ‹ซ แŠ แŠ•แ‹ตแŠแ‰ต แ‹จแ‰ณแ‹ˆแŒ€แ‰ฃแ‰ต แ‹ณแŒแˆ›แ‹Š แŠ แ„ แ‰ดแ‹Žแ‹ตแˆฎแˆต แŠจแˆ˜แŠ•แŒˆแˆณแ‰ธแ‹ แ‰ แŠแ‰ต แ‹ฐแŒƒแ‰ฝ แ‹แ‰คแŠ• แ‰งแˆ‚แ‰ต แŠจแˆšแ‰ฃแˆ แ‰ฆแ‰ณ แ‹ตแˆ แŠ แ‹ตแˆญแŒˆแ‹ แ‹ฐแŒƒแ‰ฝ แ‹แ‰ค แˆˆแŠ•แŒแˆตแŠ“ แ‰ฃแ‹˜แŒ‹แŒแ‰ต แ‹จแŠ•แŒแˆตแŠ“ แ‰ฆแ‰ณแŠ“ แŠฅแ‰ƒแ‹Žแ‰ฝ แŠ•แŒ‰แˆฐ แŠแŒˆแˆตแ‰ต แ‹˜แŠขแ‰ตแ‹ฎแŒตแ‹ซ แ‰ฐแ‰ฅแˆˆแ‹ แ‹จแŠแŒˆแˆฑแ‰ฃแ‰ต แ‰ฆแ‰ณ แŠ“แ‰ตแข"""
53
+ ],
54
+ inputs=[input],
55
+ )
56
+
57
+ demo.launch()