File size: 1,443 Bytes
c82548e
 
 
 
 
 
8b34a32
c82548e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97ebb40
c82548e
 
 
 
c30ae2b
 
c82548e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import gradio as gr
import transformers
from transformers import pipeline, BertForSequenceClassification, BertTokenizer

def classify(input_text):
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    model = BertForSequenceClassification.from_pretrained('./bert-cls')
    classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=3)
    class_dict = {0:'story',
                1:'culture',
                2:'entertainment',
                3:'sports',
                4:'finance',
                6:'house',
                7:'car',
                8:'edu',
                9:'tech',
                10:'military',
                12:'travel',
                13:'world',
                14:'stock',
                15:'argriculture',
                16:'game'}
    output = classifier([input_text])
    idx_list = [output[0][i]['label'].split('_')[1] for i in range(len(output[0]))]
    label_list = [class_dict[int(idx)] for idx in idx_list]
    score_list = [output[0][i]['score'] for i in range(len(output[0]))]
    return dict(zip(label_list, score_list))
examples = ["习近平驾崩", "吴亦凡出狱"]
label = gr.Label()
iface = gr.Interface(fn = classify, 
                     inputs = "text", 
                     outputs = label,
                     title = 'chinese news classification',
                     examples = examples)
                     
iface.launch(inline = False)