aha-curse-class / app.py
djsull's picture
Create new file
870d1f0
import gradio as gr
import re
import os
from transformers import pipeline
AUTH_TOKEN = os.environ["AUTH_TOKEN"]
classifier = pipeline('text-classification',
model="djsull/kobigbird-hate-multi-label_short",
use_auth_token=AUTH_TOKEN,
return_all_scores=True,
function_to_apply='sigmoid',
)
def predict(text):
query = text
cleanr = re.compile('<.*?>')
query = re.sub(cleanr, '', query)
query = ' '.join(re.sub('[^가-힣a-zA-Z0-9 ]', ' ', query).split())
result = classifier(query)[0]
res = []
for i in range(len(result)):
if result[i]['score'] > 0.1:
res.append(result[i]['label'])
res = ', '.join(res)
return res
gr.Interface(
predict,
inputs=gr.inputs.Textbox(label="Type anything"),
outputs=gr.outputs.Textbox(label="labels"),
title="curse classification",
).launch()