import gradio as gr import re import torch from transformers import pipeline import os AUTH_TOKEN = os.environ["AUTH_TOKEN"] cate_classifier = pipeline('text-classification', model="djsull/kobigbird-cate-class-finder", use_auth_token=AUTH_TOKEN, return_all_scores=True, function_to_apply='softmax', ) def predict(text): query = text cleanr = re.compile('<.*?>') query = re.sub(cleanr, '', query) query = ' '.join(re.sub('[^가-힣a-zA-Z0-9 ]', ' ', query).split()) result = cate_classifier(text)[0] ress = {} ch = 0 chch = 0 for i in range(len(result)): if result[i]['score'] >= ch: ch = result[i]['score'] chch = i text_tmp = result[chch]["label"] ress[text_tmp] = int(result[chch]["score"] * 10000) / 100 return ress gr.Interface( predict, inputs=gr.inputs.Textbox(label="Type anything"), outputs=gr.outputs.Textbox(label="labels"), title="Single-label Category classification", ).launch()