""" python interactive.py """ import torch from transformers import AutoTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoConfig from transformers import TextClassificationPipeline import gradio as gr # global var MODEL_NAME = 'momo/KcBERT-base_Hate_speech_Privacy_Detection' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, num_labels= 15, problem_type="multi_label_classification" ) MODEL_BUF = { "name": MODEL_NAME, "tokenizer": tokenizer, "model": model, } def change_model_name(name): MODEL_BUF["name"] = name MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name) MODEL_BUF["model"] = AutoModelForSequenceClassification.from_pretrained(name) def predict(model_name, text): if model_name != MODEL_BUF["name"]: change_model_name(model_name) tokenizer = MODEL_BUF["tokenizer"] model = MODEL_BUF["model"] unsmile_labels = ["여성/가족","남성","성소수자","인종/국적","연령","지역","종교","기타 혐오","악플/욕설", "clean", '이름', '전화번호', '주소', '계좌번호', '주민번호'] num_labels = len(unsmile_labels) model.config.id2label = {i: label for i, label in zip(range(num_labels), unsmile_labels)} model.config.label2id = {label: i for i, label in zip(range(num_labels), unsmile_labels)} pipe = TextClassificationPipeline( model = model, tokenizer = tokenizer, return_all_scores=True, function_to_apply='sigmoid' ) return pipe(text)[0] if __name__ == '__main__': exam1 = '경기도 성남시 수정구 태평3동은 우리 동네야!' exam2 = '내 핸드폰 번호는 010-3930-8237 이야!' exam3 = '아 젠장 너무 짜증난다' model_name_list = [ 'momo/KcELECTRA-base_Hate_speech_Privacy_Detection', "momo/KcBERT-base_Hate_speech_Privacy_Detection", ] #Create a gradio app with a button that calls predict() app = gr.Interface( fn=predict, inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], outputs='text', examples = [ [MODEL_BUF["name"], exam1], [MODEL_BUF["name"], exam2], [MODEL_BUF["name"], exam3] ], title="한국어 혐오표현, 개인정보 판별기 (Korean Hate Speech and Privacy Detection)", description="Korean Hate Speech and Privacy Detection. \t 15개 label Detection: 여성/가족, 남성, 성소수자, 인종/국적, 연령, 지역, 종교, 기타 혐오, 악플/욕설, clean, name, number, address, bank, person" ) app.launch()