Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from transformers import BertModel | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW | |
import json | |
import gradio as gr | |
eg_text = ' 酒店的地理位置实在不错,所以从大堂开始就令人惊艳。城景房不但在房间可以看到上海的美景' | |
model_name = 'bert-base-chinese' | |
max_len = 128 | |
n_class = 2 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer_cn = AutoTokenizer.from_pretrained(model_name) | |
voc_size = len(tokenizer_cn.vocab) | |
name_list = ['Negative review', 'Positive review'] | |
class bertBlock(nn.Module): | |
def __init__(self,): | |
super().__init__() | |
self.model_block = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=100).to(device) | |
def forward(self, text_b): | |
x = self.model_block(text_b) | |
return x.logits | |
class textCNNblock(nn.Module): | |
def __init__(self,): | |
super().__init__() | |
emb_dim = 100 | |
# n_class = 4 | |
kernels=[3,4,5] | |
kernel_number=[150,150,150] | |
self.embd = nn.Embedding(voc_size, emb_dim) | |
self.convs = nn.ModuleList([nn.Conv1d(max_len, number, size,padding=size) for (size,number) in zip(kernels,kernel_number)]) | |
self.dropout=nn.Dropout(0.1) | |
self.out = nn.Linear(sum(kernel_number), 100) | |
def forward(self, x): | |
x = self.embd(x) | |
x = [F.relu(conv(x)) for conv in self.convs] | |
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] | |
x = torch.cat(x, 1) | |
x = self.dropout(x) | |
x = self.out(x) | |
return x | |
class LSTMModelblock(nn.Module): | |
def __init__(self,): | |
super().__init__() | |
emb_dim = 100 | |
# n_class = 2 | |
self.embd = nn.Embedding(voc_size, emb_dim) | |
self.lstm = nn.LSTM(emb_dim,50) | |
self.out = nn.Linear(6400, 100) | |
self.flatten = nn.Flatten() | |
def forward(self, x): | |
x = self.embd(x) | |
x,_ = self.lstm(x) | |
x = self.flatten(x) | |
x = self.out(x) | |
return x | |
class BERT_CNN_LSTM(nn.Module): | |
def __init__(self, ): | |
super(BERT_CNN_LSTM, self).__init__() | |
self.bert = bert_block | |
self.lstm = lstm_model_block | |
self.cnn = text_cnn_block | |
self.fc1 = nn.Linear(300, 100) | |
self.fc2 = nn.Linear(100, n_class) | |
self.dropout1 = nn.Dropout(0.2) | |
self.dropout2 = nn.Dropout(0.2) | |
self.att = nn.TransformerEncoderLayer(d_model=100, nhead=2) | |
self.flatten = nn.Flatten() | |
def forward(self, input_ids): | |
bert_out = self.bert(input_ids) | |
lstm_out = self.lstm(input_ids) | |
cnn_out = self.cnn(input_ids) | |
x = torch.stack((bert_out,lstm_out, cnn_out), dim=1) | |
x = self.att(x) | |
x = self.flatten(x) | |
x = self.fc1(x) | |
x = self.dropout1(x) | |
x = self.fc2(x) | |
return x | |
bert_block = bertBlock().to(device) | |
text_cnn_block = textCNNblock().to(device) | |
lstm_model_block = LSTMModelblock().to(device) | |
# 创建模型 | |
model_big_load = BERT_CNN_LSTM() | |
model_big_load.to(device) | |
model_big_load.load_state_dict(torch.load("model_big.pth",map_location=torch.device('cpu'))) | |
model_big_load.eval() | |
def predict(one_text ): | |
one_result = tokenizer_cn(one_text,padding='max_length', max_length=max_len, truncation=True, return_tensors="pt") | |
# print(one_result) | |
one_ids = one_result.input_ids[0] | |
one_ids = one_ids.unsqueeze(0).to(device) | |
# 使用模型进行预测 | |
with torch.no_grad(): | |
output = model_big_load(one_ids) | |
# print(output) | |
# 计算预测概率 | |
pred_score = nn.functional.softmax(output[0], dim=0) | |
pred_score = torch.max(pred_score).cpu().numpy() | |
# 获取预测结果 | |
pred_index = torch.argmax(output, dim=1).item() | |
pred_label = name_list[pred_index] | |
print(f"predict class name : {pred_label} \npredict score : {pred_score}") | |
print(pred_index) | |
# 转为json字符串格式 | |
result_dict = {'pred_score':str(pred_score),'pred_index':str(pred_index),'pred_label':pred_label } | |
result_json = json.dumps(result_dict) | |
return result_json | |
demo = gr.Interface(fn=predict, | |
inputs="text", | |
outputs="text", | |
examples=['酒店的地理位置实在不错,所以从大堂开始就令人惊艳。城景房不但在房间可以看到上海的美景','住了一次,感觉很差。灯光太暗,房间比较旧!' ], | |
) | |
# demo.launch(debug=True) | |
demo.launch() |