zswwsz's picture
Update app.py
c0ec2aa
import gradio as gr
from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
import re
import torch
tokenizer = BertTokenizer.from_pretrained(r'./model')
model = BertForSequenceClassification.from_pretrained(r'./model', num_labels = 6)
labels = ['伤感', '快乐', '励志', '宣泄', '平静', '感人']
def preprocess(temp):
temp = re.sub(u"\n\n", "\n", temp)
temp = re.sub(u"(^\n)|(\n$)", "", temp)
temp = re.sub(u",", ",", temp)
temp = re.sub(u"\?", "?", temp)
temp = re.sub(u"!", "!", temp)
temp = re.sub(u"\.", "。", temp)
temp = re.sub('[^\u4e00-\u9fa5,。?!\n]+', '', temp)
temp = re.sub(u"\n", ",", temp)
for _ in range(int(len(temp) / 2)):
temp = re.sub(u",,|!!|??|。。", ",", temp)
temp = re.sub(u",!|!,", "!", temp)
temp = re.sub(u",?|?,", "?", temp)
temp = re.sub(u",。|。,", "。", temp)
temp = temp.strip(',')
return temp
def classify_text(inp):
# print(inp)
inp = preprocess(inp)
# print(inp)
inp = tokenizer(inp, padding=True, max_length=512, truncation=True, return_tensors="pt")
# print(inp)
with torch.no_grad():
logits = model(**inp).logits
# print(logits)
# print(logits.shape)
# logits = logits.argsort().squeeze(0)
logits = torch.nn.Softmax(dim=1)(logits).squeeze(0)
# print(logits)
return {labels[i]: float(logits[i].item()) for i in range(len(labels))}
# return {logits.argmax().item() : labels[logits.argmax().item()]}
input = '明天又是好日子\n千金的光阴不能等\n明天又是好日子\n赶上了盛世咱享太平\n今天是个好日子\n心想的事儿都能成\n明天又是好日子\n千金的光阴不能等\n今天明天都是好日子\n赶上了盛世咱享太平'
gr.Interface(
classify_text,
# gr.inputs.Image(),
gr.inputs.Textbox(lines=5,default=""),
outputs = 'label'
# inputs='image',
# outputs='label',
# examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]],
).launch(debug=True)