File size: 1,097 Bytes
50e4884
 
 
6227ecf
50e4884
 
7ffcb26
50e4884
6227ecf
 
 
 
 
 
50e4884
 
 
 
 
 
6227ecf
 
72d67dc
 
 
6227ecf
72d67dc
 
 
ad29555
 
72d67dc
 
50e4884
 
 
 
 
7a78cc3
50e4884
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr
import re
import torch
from transformers import pipeline
import os


AUTH_TOKEN = os.environ["AUTH_TOKEN"]
cate_classifier = pipeline('text-classification',
                    model="djsull/kobigbird-cate-class-finder",
                    use_auth_token=AUTH_TOKEN,
                    return_all_scores=True,
                    function_to_apply='softmax',
                    )

def predict(text):
    query = text
    cleanr = re.compile('<.*?>')
    query = re.sub(cleanr, '', query)
    query = ' '.join(re.sub('[^가-힣a-zA-Z0-9 ]', ' ', query).split())
    result = cate_classifier(text)[0]

    ress = {}
    ch = 0
    chch = 0
    for i in range(len(result)):
        if result[i]['score'] >= ch:
            ch = result[i]['score']
            chch = i
    text_tmp = result[chch]["label"]
    ress[text_tmp] = int(result[chch]["score"] * 10000) / 100
    
    return ress

gr.Interface(
    predict,
    inputs=gr.inputs.Textbox(label="Type anything"),
    outputs=gr.outputs.Textbox(label="labels"),
    title="Single-label Category classification",
).launch()