Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import json | |
import gradio as gr | |
eg_text = '科研报告文档.pdf' | |
name = 'bert-base-chinese' | |
max_len = 50 | |
n_class = 4 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer_cn = AutoTokenizer.from_pretrained(name) | |
voc_size = len(tokenizer_cn.vocab) | |
# name_list = ['参考文献', '招聘信息', '提交材料', '计算机课程'] | |
name_list = ['aaaa', 'bbbb', 'cccc', 'dddd'] | |
class TransformerModel(nn.Module): | |
def __init__(self,): | |
super().__init__() | |
emb_dim = 100 | |
self.embd = nn.Embedding(voc_size, emb_dim) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=2) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3) | |
self.out = nn.Linear(5000, n_class) | |
self.flatten = nn.Flatten() | |
def forward(self, x): | |
x = self.embd(x) | |
x= self.transformer_encoder(x) | |
x = self.flatten(x) | |
x = self.out(x) | |
return x | |
transformer_model_load = TransformerModel().to(device) | |
transformer_model_load.load_state_dict(torch.load("transformer_model.pth")) | |
transformer_model_load.eval() | |
def predict(one_text ): | |
one_result = tokenizer_cn(one_text,padding='max_length', max_length=max_len, truncation=True, return_tensors="pt") | |
# print(one_result) | |
one_ids = one_result.input_ids[0] | |
one_ids = one_ids.unsqueeze(0).to(device) | |
# 使用模型进行预测 | |
with torch.no_grad(): | |
output = transformer_model_load(one_ids) | |
# print(output) | |
# 计算预测概率 | |
pred_score = nn.functional.softmax(output[0], dim=0) | |
pred_score = torch.max(pred_score).cpu().numpy() | |
# 获取预测结果 | |
pred_index = torch.argmax(output, dim=1).item() | |
pred_label = name_list[pred_index] | |
print(f"predict class name : {pred_label} \npredict score : {pred_score}") | |
print(pred_index) | |
# 转为json字符串格式 | |
result_dict = {'pred_score':str(pred_score),'pred_index':str(pred_index),'pred_label':pred_label } | |
result_json = json.dumps(result_dict) | |
return result_json | |
demo = gr.Interface(fn=predict, | |
inputs="text", | |
outputs="text", | |
examples=['科研报告文档.pdf'], | |
) | |
# demo.launch(debug=True) | |
demo.launch() | |