Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import numpy as np | |
import re | |
import torch | |
tokenizer = BertTokenizer.from_pretrained(r'./model') | |
model = BertForSequenceClassification.from_pretrained(r'./model', num_labels = 6) | |
labels = ['伤感', '快乐', '励志', '宣泄', '平静', '感人'] | |
def preprocess(temp): | |
temp = re.sub(u"\n\n", "\n", temp) | |
temp = re.sub(u"(^\n)|(\n$)", "", temp) | |
temp = re.sub(u",", ",", temp) | |
temp = re.sub(u"?", "?", temp) | |
temp = re.sub(u"!", "!", temp) | |
temp = re.sub(u".", "。", temp) | |
temp = re.sub('[^\u4e00-\u9fa5,。?!\n]+', '', temp) | |
temp = re.sub(u"\n", ",", temp) | |
for _ in range(int(len(temp) / 2)): | |
temp = re.sub(u",,|!!|??|。。", ",", temp) | |
temp = re.sub(u",!|!,", "!", temp) | |
temp = re.sub(u",?|?,", "?", temp) | |
temp = re.sub(u",。|。,", "。", temp) | |
temp = temp.strip(',') | |
return temp | |
def classify_text(inp): | |
inp = preprocess(inp) | |
print(inp) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
print(logits) | |
logits = torch.nn.Softmax(dim=0)(logits) | |
print(logits) | |
return {labels[i]: float(logits[i].item()) for i in range(len(labels))} | |
gr.Interface( | |
classify_text, | |
# gr.inputs.Image(), | |
gr.inputs.Textbox(lines=5,default=""), | |
outputs = 'label' | |
# inputs='image', | |
# outputs='label', | |
# examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]], | |
).launch(debug=True) |