data-mining / app.py
ND0210's picture
Update app.py
89e8599 verified
raw
history blame contribute delete
No virus
1.71 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
genres = ['Kinh tế', 'Giáo dục', 'Xe', 'Sức khoẻ', 'Công nghệ - Game']
tokenizer = AutoTokenizer.from_pretrained("mob2711/phoBERT_finetune_news_classification")
model = AutoModelForSequenceClassification.from_pretrained("mob2711/phoBERT_finetune_news_classification")
def tokenize(text):
encoded_text = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
return encoded_text
def predict_proba(text_data):
encoded_data = tokenize(text_data)
with torch.no_grad():
outputs = model(**encoded_data)
logits = outputs.logits
probabilities = F.softmax(logits, dim=-1)[0]
label_probs = {genres[id]: prob for id, prob in enumerate(probabilities)}
return label_probs
# Interface
input_text = gr.Textbox(label="Enter the title")
output_text = gr.Label(label="Predicted Probabilities")
demo = gr.Interface(
fn=predict_proba,
inputs=input_text,
outputs=output_text,
title="Newspaper Title Classifier",
examples=["Chủ tịch HĐQT Trường quốc tế AISVN đề xuất hỗ trợ 125 tỉ đồng",
"Chuyên gia tài chính Nguyễn Trí Hiếu bị 'hack' gần 500 triệu đồng, ngân hàng im lặng suốt 3 tháng?",
"Kon Tum chi hỗ trợ 28.000 liều vắc xin tiêm phòng bệnh dại cho chó, mèo",
"Microsoft hợp tác OpenAI phát triển siêu máy tính AI giá hơn 100 tỉ USD",
"Triệu hồi 170.000 xe điện Hyundai, Kia bị lỗi mất điện",
]
)
demo.launch()