|
from transformers import AutoTokenizer |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import json |
|
import gradio as gr |
|
|
|
|
|
eg_text = '科研报告文档.pdf' |
|
name = 'bert-base-chinese' |
|
|
|
max_len = 50 |
|
n_class = 4 |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
tokenizer_cn = AutoTokenizer.from_pretrained(name) |
|
voc_size = len(tokenizer_cn.vocab) |
|
|
|
name_list = ['aaaa', 'bbbb', 'cccc', 'dddd'] |
|
|
|
class TransformerModel(nn.Module): |
|
def __init__(self,): |
|
super().__init__() |
|
emb_dim = 100 |
|
|
|
self.embd = nn.Embedding(voc_size, emb_dim) |
|
encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=2) |
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3) |
|
self.out = nn.Linear(5000, n_class) |
|
self.flatten = nn.Flatten() |
|
|
|
def forward(self, x): |
|
x = self.embd(x) |
|
|
|
x= self.transformer_encoder(x) |
|
x = self.flatten(x) |
|
x = self.out(x) |
|
return x |
|
|
|
transformer_model_load = TransformerModel().to(device) |
|
|
|
transformer_model_load.load_state_dict(torch.load("transformer_model.pth")) |
|
transformer_model_load.eval() |
|
|
|
|
|
|
|
def predict(one_text ): |
|
one_result = tokenizer_cn(one_text,padding='max_length', max_length=max_len, truncation=True, return_tensors="pt") |
|
|
|
one_ids = one_result.input_ids[0] |
|
one_ids = one_ids.unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = transformer_model_load(one_ids) |
|
|
|
|
|
pred_score = nn.functional.softmax(output[0], dim=0) |
|
pred_score = torch.max(pred_score).cpu().numpy() |
|
|
|
|
|
pred_index = torch.argmax(output, dim=1).item() |
|
pred_label = name_list[pred_index] |
|
|
|
print(f"predict class name : {pred_label} \npredict score : {pred_score}") |
|
print(pred_index) |
|
|
|
result_dict = {'pred_score':str(pred_score),'pred_index':str(pred_index),'pred_label':pred_label } |
|
result_json = json.dumps(result_dict) |
|
|
|
return result_json |
|
|
|
|
|
demo = gr.Interface(fn=predict, |
|
inputs="text", |
|
|
|
outputs="text", |
|
examples=['科研报告文档.pdf'], |
|
) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|