ranranlove commited on
Commit
5b67a04
·
verified ·
1 Parent(s): ca188ff

Create app.py

Browse files

Add initial app.py content for sentiment classifier

Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ import torch
4
+ import os
5
+
6
+ # --- 1. 全局模型加载 (---
7
+ MODEL_NAME = "uer/roberta-base-finetuned-jd-binary-chinese"
8
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # 加载分词器和模型
11
+ try:
12
+ print(f"正在加载模型: {MODEL_NAME} 到设备: {DEVICE}")
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
15
+ model.to(DEVICE)
16
+ model.eval() # 将模型设置为评估模式
17
+ print("模型加载成功。")
18
+ except Exception as e:
19
+ print(f"模型加载失败,请检查安装和网络连接: {e}")
20
+ # 在实际应用中,您可能需要优雅地退出或提供一个备用方案
21
+
22
+ # 定义类别映射
23
+ LABEL_MAP = {0: "消极 (Negative)", 1: "积极 (Positive)"}
24
+
25
+
26
+ # --- 2. Gradio 核心预测函数 ---
27
+ def classify_text(text):
28
+ """
29
+ 接收用户输入的文本,返回分类结果和置信度。
30
+ """
31
+ if not text:
32
+ return "请输入需要分类的文本。", None, None, None
33
+
34
+ # 分词和预处理
35
+ inputs = tokenizer(text,
36
+ padding=True,
37
+ truncation=True,
38
+ max_length=128,
39
+ return_tensors='pt').to(DEVICE)
40
+
41
+ # 模型预测
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+
45
+ # 处理输出结果
46
+ # logits -> softmax 转换为概率
47
+ predictions = torch.softmax(outputs.logits, dim=1)[0] # 取出第一个(也是唯一的)输入结果
48
+
49
+ # 预测的类别
50
+ predicted_class_id = torch.argmax(predictions).item()
51
+ predicted_label = LABEL_MAP[predicted_class_id]
52
+
53
+ # 预测分数
54
+ score_negative = predictions[0].item()
55
+ score_positive = predictions[1].item()
56
+
57
+ # 格式化输出文本
58
+ result_text = f"预测类别:**{predicted_label}**"
59
+
60
+ # 格式化置信度字典,用于 Gradio 的 Label 组件
61
+ confidence_dict = {
62
+ "消极 (Negative)": score_negative,
63
+ "积极 (Positive)": score_positive
64
+ }
65
+
66
+ return result_text, confidence_dict
67
+
68
+
69
+ # --- 3. Gradio 接口配置和启动 ---
70
+
71
+ # 定义演示界面的标题和描述
72
+ title = "Hugging Face 中文情感分析演示"
73
+ description = "使用 uer/roberta-base-finetuned-jd-binary-chinese 模型对输入的中文文本进行积极/消极情感二分类。"
74
+ examples = [
75
+ ["这家餐厅的菜味道太棒了,服务员也很热情。"],
76
+ ["我等了两个小时,包裹还没送到,非常生气。"],
77
+ ["我对这款产品不满意,但也不算太差。"]
78
+ ]
79
+
80
+ # 创建 Gradio 接口
81
+ iface = gr.Interface(
82
+ fn=classify_text, # 绑定上面定义的预测函数
83
+ inputs=gr.Textbox(lines=5, label="输入您的中文文本"), # 输入组件
84
+ outputs=[
85
+ gr.Markdown(label="分类结果"), # 显示最终的预测结果文本
86
+ gr.Label(label="置信度分数", num_top_classes=2) # 显示两个类别的概率
87
+ ],
88
+ title=title,
89
+ description=description,
90
+ examples=examples
91
+ )
92
+
93
+ # 启动 Web 服务
94
+ iface.launch()