rasyosef's picture
Update app.py
37c8501
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
model_id = "xlm-roberta-base"
peft_model_id = "rasyosef/xlm-roberta-base-lora-amharic-news-classification"
categories = ['แˆ€แŒˆแˆญ แŠ แ‰€แ แ‹œแŠ“', 'แˆ˜แ‹แŠ“แŠ›', 'แˆตแ–แˆญแ‰ต', 'แ‰ขแ‹แŠแˆต', 'แ‹“แˆˆแˆ แŠ แ‰€แ แ‹œแŠ“', 'แ–แˆˆแ‰ฒแŠซ']
id2label = {i: lbl for i, lbl in enumerate(categories)}
label2id = {lbl: i for i, lbl in enumerate(categories)}
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
num_labels=len(categories), # 6
id2label=id2label,
label2id=label2id
)
model.load_adapter(peft_model_id)
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
def predict(text):
return classifier([text])[0]
with gr.Blocks() as demo:
gr.Markdown(
"""
# Amharic News Article Classification
This RoBERTa model (xlm-roberta-base) was finetuned using Low-Rank Adaptation (LoRA) that classifies amharic news articles into one of the following 6 categories.
- แˆ€แŒˆแˆญ แŠ แ‰€แ แ‹œแŠ“ (Local News)
- แˆ˜แ‹แŠ“แŠ› (Entertainment)
- แˆตแ–แˆญแ‰ต (Sports)
- แ‰ขแ‹แŠแˆต (Business)
- แ‹“แˆˆแˆ แŠ แ‰€แ แ‹œแŠ“ (International News)
- แ–แˆˆแ‰ฒแŠซ (Politics)
"""
)
with gr.Row():
with gr.Column():
input = gr.Textbox(label="Amharic text", placeholder="Enter text here", lines=3)
classify_btn = gr.Button(value="Classify")
with gr.Column():
output = gr.Textbox(label="Predicted class")
classify_btn.click(predict, inputs=input, outputs=output)
examples = gr.Examples(
examples=[
"แŠขแ‰ตแ‹ฎแŒตแ‹ซ แ•แˆชแˆแ‹จแˆญ แˆŠแŒ 6แŠ› แˆณแˆแŠ•แ‰ต แ‹จแŠฅแˆแ‹ต แŒจแ‹‹แ‰ณแ‹Žแ‰ฝ แ‰…แ‹ตแˆ˜ แ‹ณแˆฐแˆณ",
"แ‰ แŠ แ„ แ‰ดแ‹Žแ‹ตแˆฎแˆต แ‹จแŠ•แŒแˆตแŠ“ แ‰ฆแ‰ณ แ‹ฐแˆจแˆตแŒŒ แˆ›แˆชแ‹ซแˆ แ‰ฐแŒ€แˆแˆฎ แ‹จแ‰†แˆ˜แ‹ แ‹จแˆ™แ‹šแ‹จแˆ™ แŒแŠ•แ‰ฃแ‰ณ แ‰ฐแŒ แŠ“แ‰€แ‰† แˆตแˆซ แŠฅแŠ•แ‹ฒแŒ€แˆแˆญ แŠแ‹‹แˆชแ‹Žแ‰ฝ แŒ แ‹ญแ‰€แ‹‹แˆแข"
],
inputs=[input],
)
demo.launch()