momo's picture
add app
ec82774
raw
history blame
2.7 kB
"""
python interactive.py
"""
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TextClassificationPipeline
import gradio as gr
# global var
MODEL_NAME = 'momo/KcELECTRA-base_Hate_speech_Privacy_Detection'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels= 15,
problem_type="multi_label_classification"
)
MODEL_BUF = {
"name": MODEL_NAME,
"tokenizer": tokenizer,
"model": model,
}
def change_model_name(name):
MODEL_BUF["name"] = name
MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name)
MODEL_BUF["model"] = AutoModelForSequenceClassification.from_pretrained(name)
def predict(model_name, text):
if model_name != MODEL_BUF["name"]:
change_model_name(model_name)
tokenizer = MODEL_BUF["tokenizer"]
model = MODEL_BUF["model"]
unsmile_labels = ["์—ฌ์„ฑ/๊ฐ€์กฑ","๋‚จ์„ฑ","์„ฑ์†Œ์ˆ˜์ž","์ธ์ข…/๊ตญ์ ","์—ฐ๋ น","์ง€์—ญ","์ข…๊ต","๊ธฐํƒ€ ํ˜์˜ค","์•…ํ”Œ/์š•์„ค", "clean", '์ด๋ฆ„', '์ „ํ™”๋ฒˆํ˜ธ', '์ฃผ์†Œ', '๊ณ„์ขŒ๋ฒˆํ˜ธ', '์ฃผ๋ฏผ๋ฒˆํ˜ธ']
num_labels = len(unsmile_labels)
model.config.id2label = {i: label for i, label in zip(range(num_labels), unsmile_labels)}
model.config.label2id = {label: i for i, label in zip(range(num_labels), unsmile_labels)}
pipe = TextClassificationPipeline(
model = model,
tokenizer = tokenizer,
return_all_scores=True,
function_to_apply='sigmoid'
)
for result in pipe(text):
print(result)
return result
if __name__ == '__main__':
exam1 = '๊ฒฝ๊ธฐ๋„ ์„ฑ๋‚จ์‹œ ์ˆ˜์ •๊ตฌ ํƒœํ‰3๋™์€ ์šฐ๋ฆฌ ๋™๋„ค์•ผ!'
exam2 = '๋‚ด ํ•ธ๋“œํฐ ๋ฒˆํ˜ธ๋Š” 010-3930-8237 ์ด์•ผ!'
exam3 = '์•„ ์  ์žฅ ๋„ˆ๋ฌด ์งœ์ฆ๋‚œ๋‹ค'
model_name_list = [
'momo/KcELECTRA-base_Hate_speech_Privacy_Detection',
"momo/KcBERT-base_Hate_speech_Privacy_Detection",
]
#Create a gradio app with a button that calls predict()
app = gr.Interface(
fn=predict,
inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'],
outputs='text',
examples = [
[MODEL_BUF["name"], exam1],
[MODEL_BUF["name"], exam2],
[MODEL_BUF["name"], exam3]
],
title="ํ•œ๊ตญ์–ด ํ˜์˜คํ‘œํ˜„, ๊ฐœ์ธ์ •๋ณด ํŒ๋ณ„๊ธฐ (Korean Hate Speech and Privacy Detection)",
description="Korean Hate Speech and Privacy Detection. \t 15๊ฐœ label Detection: ์—ฌ์„ฑ/๊ฐ€์กฑ, ๋‚จ์„ฑ, ์„ฑ์†Œ์ˆ˜์ž, ์ธ์ข…/๊ตญ์ , ์—ฐ๋ น, ์ง€์—ญ, ์ข…๊ต, ๊ธฐํƒ€ ํ˜์˜ค, ์•…ํ”Œ/์š•์„ค, clean, name, number, address, bank, person"
)
app.launch()