xiaogantongxue's picture
Upload 4 files
129d2da verified
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import gradio as gr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
id2label = {0: "负向", 1: "中性", 2: "正向"}
class BertCNNClassifier(nn.Module):
def __init__(self, bert_model, num_classes=3, dropout=0.3):
super().__init__()
self.bert = bert_model
self.conv1 = nn.Conv2d(1, 100, (3, bert_model.config.hidden_size))
self.conv2 = nn.Conv2d(1, 100, (4, bert_model.config.hidden_size))
self.conv3 = nn.Conv2d(1, 100, (5, bert_model.config.hidden_size))
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(300, num_classes)
def forward(self, input_ids, attention_mask, token_type_ids=None):
with torch.no_grad():
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
x = outputs.last_hidden_state.unsqueeze(1)
x1 = torch.relu(self.conv1(x)).squeeze(3)
x2 = torch.relu(self.conv2(x)).squeeze(3)
x3 = torch.relu(self.conv3(x)).squeeze(3)
x1 = torch.max_pool1d(x1, x1.size(2)).squeeze(2)
x2 = torch.max_pool1d(x2, x2.size(2)).squeeze(2)
x3 = torch.max_pool1d(x3, x3.size(2)).squeeze(2)
x = torch.cat((x1, x2, x3), dim=1)
x = self.dropout(x)
return self.fc(x)
model_name = "hfl/chinese-macbert-base"
tokenizer = BertTokenizer.from_pretrained("bert_cnn_tokenizer")
bert_model = BertModel.from_pretrained(model_name)
model = BertCNNClassifier(bert_model).to(device)
model.load_state_dict(torch.load("bert_cnn_sentiment.pth", map_location=device))
model.eval()
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
with torch.no_grad():
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
pred = torch.argmax(outputs, dim=1).item()
return id2label[pred]
interface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="请输入朋友圈文案..."),
outputs="text",
title="朋友圈情绪识别",
description="输入一段朋友圈内容,判断情绪:负向 / 中性 / 正向"
)
interface.launch()