text_classify / app.py
azhongai666666's picture
Update app.py
3417af0 verified
from transformers import AutoTokenizer
import torch
from torch import nn
import torch.nn.functional as F
import json
import gradio as gr
eg_text = '科研报告文档.pdf'
name = 'bert-base-chinese'
max_len = 50
n_class = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer_cn = AutoTokenizer.from_pretrained(name)
voc_size = len(tokenizer_cn.vocab)
# name_list = ['参考文献', '招聘信息', '提交材料', '计算机课程']
name_list = ['aaaa', 'bbbb', 'cccc', 'dddd']
class TransformerModel(nn.Module):
def __init__(self,):
super().__init__()
emb_dim = 100
self.embd = nn.Embedding(voc_size, emb_dim)
encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=2)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)
self.out = nn.Linear(5000, n_class)
self.flatten = nn.Flatten()
def forward(self, x):
x = self.embd(x)
x= self.transformer_encoder(x)
x = self.flatten(x)
x = self.out(x)
return x
transformer_model_load = TransformerModel().to(device)
transformer_model_load.load_state_dict(torch.load("transformer_model.pth"))
transformer_model_load.eval()
def predict(one_text ):
one_result = tokenizer_cn(one_text,padding='max_length', max_length=max_len, truncation=True, return_tensors="pt")
# print(one_result)
one_ids = one_result.input_ids[0]
one_ids = one_ids.unsqueeze(0).to(device)
# 使用模型进行预测
with torch.no_grad():
output = transformer_model_load(one_ids)
# print(output)
# 计算预测概率
pred_score = nn.functional.softmax(output[0], dim=0)
pred_score = torch.max(pred_score).cpu().numpy()
# 获取预测结果
pred_index = torch.argmax(output, dim=1).item()
pred_label = name_list[pred_index]
print(f"predict class name : {pred_label} \npredict score : {pred_score}")
print(pred_index)
# 转为json字符串格式
result_dict = {'pred_score':str(pred_score),'pred_index':str(pred_index),'pred_label':pred_label }
result_json = json.dumps(result_dict)
return result_json
demo = gr.Interface(fn=predict,
inputs="text",
outputs="text",
examples=['科研报告文档.pdf'],
)
# demo.launch(debug=True)
demo.launch()