Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import BertTokenizer, BertModel | |
| import gradio as gr | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| id2label = {0: "负向", 1: "中性", 2: "正向"} | |
| class BertCNNClassifier(nn.Module): | |
| def __init__(self, bert_model, num_classes=3, dropout=0.3): | |
| super().__init__() | |
| self.bert = bert_model | |
| self.conv1 = nn.Conv2d(1, 100, (3, bert_model.config.hidden_size)) | |
| self.conv2 = nn.Conv2d(1, 100, (4, bert_model.config.hidden_size)) | |
| self.conv3 = nn.Conv2d(1, 100, (5, bert_model.config.hidden_size)) | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc = nn.Linear(300, num_classes) | |
| def forward(self, input_ids, attention_mask, token_type_ids=None): | |
| with torch.no_grad(): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) | |
| x = outputs.last_hidden_state.unsqueeze(1) | |
| x1 = torch.relu(self.conv1(x)).squeeze(3) | |
| x2 = torch.relu(self.conv2(x)).squeeze(3) | |
| x3 = torch.relu(self.conv3(x)).squeeze(3) | |
| x1 = torch.max_pool1d(x1, x1.size(2)).squeeze(2) | |
| x2 = torch.max_pool1d(x2, x2.size(2)).squeeze(2) | |
| x3 = torch.max_pool1d(x3, x3.size(2)).squeeze(2) | |
| x = torch.cat((x1, x2, x3), dim=1) | |
| x = self.dropout(x) | |
| return self.fc(x) | |
| model_name = "hfl/chinese-macbert-base" | |
| tokenizer = BertTokenizer.from_pretrained("bert_cnn_tokenizer") | |
| bert_model = BertModel.from_pretrained(model_name) | |
| model = BertCNNClassifier(bert_model).to(device) | |
| model.load_state_dict(torch.load("bert_cnn_sentiment.pth", map_location=device)) | |
| model.eval() | |
| def predict(text): | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) | |
| pred = torch.argmax(outputs, dim=1).item() | |
| return id2label[pred] | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=3, placeholder="请输入朋友圈文案..."), | |
| outputs="text", | |
| title="朋友圈情绪识别", | |
| description="输入一段朋友圈内容,判断情绪:负向 / 中性 / 正向" | |
| ) | |
| interface.launch() | |