aha-multi-label / app.py
djsull's picture
Update app.py
7ffcb26
import gradio as gr
import re
import torch
from transformers import pipeline
import os
AUTH_TOKEN = os.environ["AUTH_TOKEN"]
cate_classifier = pipeline('text-classification',
model="djsull/kobigbird-cate-class-finder",
use_auth_token=AUTH_TOKEN,
return_all_scores=True,
function_to_apply='softmax',
)
def predict(text):
query = text
cleanr = re.compile('<.*?>')
query = re.sub(cleanr, '', query)
query = ' '.join(re.sub('[^가-힣a-zA-Z0-9 ]', ' ', query).split())
result = cate_classifier(text)[0]
ress = {}
ch = 0
chch = 0
for i in range(len(result)):
if result[i]['score'] >= ch:
ch = result[i]['score']
chch = i
text_tmp = result[chch]["label"]
ress[text_tmp] = int(result[chch]["score"] * 10000) / 100
return ress
gr.Interface(
predict,
inputs=gr.inputs.Textbox(label="Type anything"),
outputs=gr.outputs.Textbox(label="labels"),
title="Single-label Category classification",
).launch()