Create app.py
Browse filesAdd initial app.py content for sentiment classifier
app.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# --- 1. 全局模型加载 (---
|
| 7 |
+
MODEL_NAME = "uer/roberta-base-finetuned-jd-binary-chinese"
|
| 8 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
+
|
| 10 |
+
# 加载分词器和模型
|
| 11 |
+
try:
|
| 12 |
+
print(f"正在加载模型: {MODEL_NAME} 到设备: {DEVICE}")
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 14 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
|
| 15 |
+
model.to(DEVICE)
|
| 16 |
+
model.eval() # 将模型设置为评估模式
|
| 17 |
+
print("模型加载成功。")
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"模型加载失败,请检查安装和网络连接: {e}")
|
| 20 |
+
# 在实际应用中,您可能需要优雅地退出或提供一个备用方案
|
| 21 |
+
|
| 22 |
+
# 定义类别映射
|
| 23 |
+
LABEL_MAP = {0: "消极 (Negative)", 1: "积极 (Positive)"}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# --- 2. Gradio 核心预测函数 ---
|
| 27 |
+
def classify_text(text):
|
| 28 |
+
"""
|
| 29 |
+
接收用户输入的文本,返回分类结果和置信度。
|
| 30 |
+
"""
|
| 31 |
+
if not text:
|
| 32 |
+
return "请输入需要分类的文本。", None, None, None
|
| 33 |
+
|
| 34 |
+
# 分词和预处理
|
| 35 |
+
inputs = tokenizer(text,
|
| 36 |
+
padding=True,
|
| 37 |
+
truncation=True,
|
| 38 |
+
max_length=128,
|
| 39 |
+
return_tensors='pt').to(DEVICE)
|
| 40 |
+
|
| 41 |
+
# 模型预测
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
outputs = model(**inputs)
|
| 44 |
+
|
| 45 |
+
# 处理输出结果
|
| 46 |
+
# logits -> softmax 转换为概率
|
| 47 |
+
predictions = torch.softmax(outputs.logits, dim=1)[0] # 取出第一个(也是唯一的)输入结果
|
| 48 |
+
|
| 49 |
+
# 预测的类别
|
| 50 |
+
predicted_class_id = torch.argmax(predictions).item()
|
| 51 |
+
predicted_label = LABEL_MAP[predicted_class_id]
|
| 52 |
+
|
| 53 |
+
# 预测分数
|
| 54 |
+
score_negative = predictions[0].item()
|
| 55 |
+
score_positive = predictions[1].item()
|
| 56 |
+
|
| 57 |
+
# 格式化输出文本
|
| 58 |
+
result_text = f"预测类别:**{predicted_label}**"
|
| 59 |
+
|
| 60 |
+
# 格式化置信度字典,用于 Gradio 的 Label 组件
|
| 61 |
+
confidence_dict = {
|
| 62 |
+
"消极 (Negative)": score_negative,
|
| 63 |
+
"积极 (Positive)": score_positive
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
return result_text, confidence_dict
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# --- 3. Gradio 接口配置和启动 ---
|
| 70 |
+
|
| 71 |
+
# 定义演示界面的标题和描述
|
| 72 |
+
title = "Hugging Face 中文情感分析演示"
|
| 73 |
+
description = "使用 uer/roberta-base-finetuned-jd-binary-chinese 模型对输入的中文文本进行积极/消极情感二分类。"
|
| 74 |
+
examples = [
|
| 75 |
+
["这家餐厅的菜味道太棒了,服务员也很热情。"],
|
| 76 |
+
["我等了两个小时,包裹还没送到,非常生气。"],
|
| 77 |
+
["我对这款产品不满意,但也不算太差。"]
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
# 创建 Gradio 接口
|
| 81 |
+
iface = gr.Interface(
|
| 82 |
+
fn=classify_text, # 绑定上面定义的预测函数
|
| 83 |
+
inputs=gr.Textbox(lines=5, label="输入您的中文文本"), # 输入组件
|
| 84 |
+
outputs=[
|
| 85 |
+
gr.Markdown(label="分类结果"), # 显示最终的预测结果文本
|
| 86 |
+
gr.Label(label="置信度分数", num_top_classes=2) # 显示两个类别的概率
|
| 87 |
+
],
|
| 88 |
+
title=title,
|
| 89 |
+
description=description,
|
| 90 |
+
examples=examples
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# 启动 Web 服务
|
| 94 |
+
iface.launch()
|