import gradio as gr from transformers import BertTokenizer, BertForSequenceClassification import numpy as np import re import torch tokenizer = BertTokenizer.from_pretrained(r'./model') model = BertForSequenceClassification.from_pretrained(r'./model', num_labels = 6) labels = ['伤感', '快乐', '励志', '宣泄', '平静', '感人'] def preprocess(temp): temp = re.sub(u"\n\n", "\n", temp) temp = re.sub(u"(^\n)|(\n$)", "", temp) temp = re.sub(u",", ",", temp) temp = re.sub(u"\?", "?", temp) temp = re.sub(u"!", "!", temp) temp = re.sub(u"\.", "。", temp) temp = re.sub('[^\u4e00-\u9fa5,。?!\n]+', '', temp) temp = re.sub(u"\n", ",", temp) for _ in range(int(len(temp) / 2)): temp = re.sub(u",,|!!|??|。。", ",", temp) temp = re.sub(u",!|!,", "!", temp) temp = re.sub(u",?|?,", "?", temp) temp = re.sub(u",。|。,", "。", temp) temp = temp.strip(',') return temp def classify_text(inp): # print(inp) inp = preprocess(inp) # print(inp) inp = tokenizer(inp, padding=True, max_length=512, truncation=True, return_tensors="pt") # print(inp) with torch.no_grad(): logits = model(**inp).logits # print(logits) # print(logits.shape) # logits = logits.argsort().squeeze(0) logits = torch.nn.Softmax(dim=1)(logits).squeeze(0) # print(logits) return {labels[i]: float(logits[i].item()) for i in range(len(labels))} # return {logits.argmax().item() : labels[logits.argmax().item()]} input = '明天又是好日子\n千金的光阴不能等\n明天又是好日子\n赶上了盛世咱享太平\n今天是个好日子\n心想的事儿都能成\n明天又是好日子\n千金的光阴不能等\n今天明天都是好日子\n赶上了盛世咱享太平' gr.Interface( classify_text, # gr.inputs.Image(), gr.inputs.Textbox(lines=5,default=input), outputs = 'label' # inputs='image', # outputs='label', # examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]], ).launch(debug=True)