zswwsz's picture
Update app.py
c0ec2aa
raw
history blame contribute delete
No virus
2.1 kB
import gradio as gr
from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
import re
import torch
tokenizer = BertTokenizer.from_pretrained(r'./model')
model = BertForSequenceClassification.from_pretrained(r'./model', num_labels = 6)
labels = ['伤感', '快乐', '励志', '宣泄', '平静', '感人']
def preprocess(temp):
temp = re.sub(u"\n\n", "\n", temp)
temp = re.sub(u"(^\n)|(\n$)", "", temp)
temp = re.sub(u",", ",", temp)
temp = re.sub(u"\?", "?", temp)
temp = re.sub(u"!", "!", temp)
temp = re.sub(u"\.", "。", temp)
temp = re.sub('[^\u4e00-\u9fa5,。?!\n]+', '', temp)
temp = re.sub(u"\n", ",", temp)
for _ in range(int(len(temp) / 2)):
temp = re.sub(u",,|!!|??|。。", ",", temp)
temp = re.sub(u",!|!,", "!", temp)
temp = re.sub(u",?|?,", "?", temp)
temp = re.sub(u",。|。,", "。", temp)
temp = temp.strip(',')
return temp
def classify_text(inp):
# print(inp)
inp = preprocess(inp)
# print(inp)
inp = tokenizer(inp, padding=True, max_length=512, truncation=True, return_tensors="pt")
# print(inp)
with torch.no_grad():
logits = model(**inp).logits
# print(logits)
# print(logits.shape)
# logits = logits.argsort().squeeze(0)
logits = torch.nn.Softmax(dim=1)(logits).squeeze(0)
# print(logits)
return {labels[i]: float(logits[i].item()) for i in range(len(labels))}
# return {logits.argmax().item() : labels[logits.argmax().item()]}
input = '明天又是好日子\n千金的光阴不能等\n明天又是好日子\n赶上了盛世咱享太平\n今天是个好日子\n心想的事儿都能成\n明天又是好日子\n千金的光阴不能等\n今天明天都是好日子\n赶上了盛世咱享太平'
gr.Interface(
classify_text,
# gr.inputs.Image(),
gr.inputs.Textbox(lines=5,default=""),
outputs = 'label'
# inputs='image',
# outputs='label',
# examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]],
).launch(debug=True)