File size: 2,098 Bytes
e6d5ea5
 
 
 
 
 
c876259
4a3c3e1
d2540d8
63d92be
 
d2540d8
 
 
c876259
029f31f
c876259
d4f0a84
d2540d8
 
 
 
 
 
 
 
 
 
 
 
cdd28fa
d2540d8
cdd28fa
0d3b4b7
cdd28fa
d2540d8
ba3c131
cdd28fa
 
96496f0
 
cdd28fa
96496f0
 
d2540d8
d4f0a84
 
d2540d8
 
 
c0ec2aa
d2540d8
 
 
 
9d33a47
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
import re
import torch

tokenizer = BertTokenizer.from_pretrained(r'./model')
model = BertForSequenceClassification.from_pretrained(r'./model', num_labels = 6)

labels = ['伤感', '快乐', '励志', '宣泄', '平静', '感人']

def preprocess(temp):
    temp = re.sub(u"\n\n", "\n", temp)
    temp = re.sub(u"(^\n)|(\n$)", "", temp)
    temp = re.sub(u",", ",", temp)
    temp = re.sub(u"\?", "?", temp)
    temp = re.sub(u"!", "!", temp)
    temp = re.sub(u"\.", "。", temp)
    temp = re.sub('[^\u4e00-\u9fa5,。?!\n]+', '', temp)
    temp = re.sub(u"\n", ",", temp)
    for _ in range(int(len(temp) / 2)):
        temp = re.sub(u",,|!!|??|。。", ",", temp)
        temp = re.sub(u",!|!,", "!", temp)
        temp = re.sub(u",?|?,", "?", temp)
        temp = re.sub(u",。|。,", "。", temp)
    temp = temp.strip(',')

    return temp

def classify_text(inp):
    # print(inp)
    inp = preprocess(inp)
    # print(inp)
    inp = tokenizer(inp, padding=True, max_length=512, truncation=True, return_tensors="pt")
    # print(inp)
    with torch.no_grad():
        logits = model(**inp).logits
    # print(logits)
    # print(logits.shape)
    # logits = logits.argsort().squeeze(0)
    logits = torch.nn.Softmax(dim=1)(logits).squeeze(0)
    # print(logits)
    return {labels[i]: float(logits[i].item()) for i in range(len(labels))}
    # return {logits.argmax().item() : labels[logits.argmax().item()]}

input = '明天又是好日子\n千金的光阴不能等\n明天又是好日子\n赶上了盛世咱享太平\n今天是个好日子\n心想的事儿都能成\n明天又是好日子\n千金的光阴不能等\n今天明天都是好日子\n赶上了盛世咱享太平'

gr.Interface(
    classify_text,
    # gr.inputs.Image(),
    gr.inputs.Textbox(lines=5,default=""),
    outputs = 'label'
    # inputs='image',
    # outputs='label',
    # examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]],
).launch(debug=True)