File size: 2,351 Bytes
cb13f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from transformers import AutoTokenizer
import torch 
from torch import nn
import torch.nn.functional as F
import json
import gradio as gr


eg_text = '科研报告文档.pdf'
name = 'bert-base-chinese'

max_len = 50
n_class = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer_cn = AutoTokenizer.from_pretrained(name)
voc_size = len(tokenizer_cn.vocab)
# name_list = ['参考文献', '招聘信息', '提交材料', '计算机课程']
name_list = ['aaaa', 'bbbb', 'cccc', 'dddd']

class TransformerModel(nn.Module):
    def __init__(self,):
        super().__init__()
        emb_dim = 100

        self.embd = nn.Embedding(voc_size, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=2)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)
        self.out = nn.Linear(5000, n_class)
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.embd(x)

        x= self.transformer_encoder(x)
        x = self.flatten(x)
        x = self.out(x)
        return x

transformer_model_load = TransformerModel().to(device)

transformer_model_load.load_state_dict(torch.load("transformer_model.pth"))
transformer_model_load.eval()



def predict(one_text ):
  one_result = tokenizer_cn(one_text,padding='max_length', max_length=max_len, truncation=True, return_tensors="pt")
  # print(one_result)
  one_ids = one_result.input_ids[0]
  one_ids = one_ids.unsqueeze(0).to(device)

  # 使用模型进行预测
  with torch.no_grad():
      output = transformer_model_load(one_ids)
  # print(output)
  # 计算预测概率
  pred_score = nn.functional.softmax(output[0], dim=0)
  pred_score = torch.max(pred_score).cpu().numpy()

  # 获取预测结果
  pred_index = torch.argmax(output, dim=1).item()
  pred_label = name_list[pred_index]

  print(f"predict class name : {pred_label} \npredict score : {pred_score}")
  print(pred_index)
  # 转为json字符串格式
  result_dict = {'pred_score':str(pred_score),'pred_index':str(pred_index),'pred_label':pred_label }
  result_json = json.dumps(result_dict)

  return result_json


demo = gr.Interface(fn=predict, 
             inputs="text",

             outputs="text",
             examples=['科研报告文档.pdf'],
             )
             
# demo.launch(debug=True)
demo.launch()