momo commited on
Commit
98186a7
โ€ข
1 Parent(s): 91863a0
Files changed (1) hide show
  1. app.py +52 -28
app.py CHANGED
@@ -2,52 +2,76 @@
2
  python interactive.py
3
  """
4
  import torch
5
- from transformers import AutoTokenizer, BertForSequenceClassification
6
  from transformers import TextClassificationPipeline
7
  import gradio as gr
8
 
9
- model_name = 'momo/KcBERT-base_Hate_speech_Privacy_Detection'
10
-
11
- model_name_list = [
12
- 'momo/KcELECTRA-base_Hate_speech_Privacy_Detection',
13
- "momo/KcBERT-base_Hate_speech_Privacy_Detection",
14
- ]
15
-
16
- model = BertForSequenceClassification.from_pretrained(
17
- model_name,
18
  num_labels= 15,
19
  problem_type="multi_label_classification"
20
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
23
 
24
- unsmile_labels = ["์—ฌ์„ฑ/๊ฐ€์กฑ","๋‚จ์„ฑ","์„ฑ์†Œ์ˆ˜์ž","์ธ์ข…/๊ตญ์ ","์—ฐ๋ น","์ง€์—ญ","์ข…๊ต","๊ธฐํƒ€ ํ˜์˜ค","์•…ํ”Œ/์š•์„ค","clean", 'name', 'number', 'address', 'bank', 'person']
25
- num_labels = len(unsmile_labels)
26
 
27
- model.config.id2label = {i: label for i, label in zip(range(num_labels), unsmile_labels)}
28
- model.config.label2id = {label: i for i, label in zip(range(num_labels), unsmile_labels)}
29
 
30
- pipe = TextClassificationPipeline(
31
  model = model,
32
  tokenizer = tokenizer,
33
  return_all_scores=True,
34
  function_to_apply='sigmoid'
35
  )
36
-
37
- def dectection(input):
38
- for result in pipe(input)[0]:
39
- return result
40
 
41
- #Create a gradio app with a button that calls predict()
42
- app = gr.Interface(
43
- fn=dectection,
44
- inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], outputs=['label'],
45
- title="ํ•œ๊ตญ์–ด ํ˜์˜คํ‘œํ˜„, ๊ฐœ์ธ์ •๋ณด ํŒ๋ณ„๊ธฐ (Korean Hate Speech and Privacy Detection)",
46
- description="Korean Hate Speech and Privacy Detection."
47
- )
48
- app.launch(share=True)
49
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
 
 
2
  python interactive.py
3
  """
4
  import torch
5
+ from transformers import AutoTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoConfig
6
  from transformers import TextClassificationPipeline
7
  import gradio as gr
8
 
9
+ # global var
10
+ MODEL_NAME = 'momo/KcBERT-base_Hate_speech_Privacy_Detection'
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ model = AutoModelForSequenceClassification.from_pretrained(
13
+ MODEL_NAME,
 
 
 
 
14
  num_labels= 15,
15
  problem_type="multi_label_classification"
16
  )
17
+ config = AutoConfig.from_pretrained(MODEL_NAME)
18
+
19
+ MODEL_BUF = {
20
+ "name": MODEL_NAME,
21
+ "tokenizer": tokenizer,
22
+ "model": model,
23
+ "config": config
24
+ }
25
+
26
+ def change_model_name(name):
27
+ MODEL_BUF["name"] = name
28
+ MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name)
29
+ MODEL_BUF["model"] = AutoModelForSequenceClassification.from_pretrained(name)
30
+ MODEL_BUF["config"] = AutoConfig.from_pretrained(name)
31
+
32
 
33
+ def predict(model_name, text):
34
+ if model_name != MODEL_BUF["name"]:
35
+ change_model_name(model_name)
36
+
37
+ tokenizer = MODEL_BUF["tokenizer"]
38
+ model = MODEL_BUF["model"]
39
+ config = MODEL_BUF["config"]
40
 
41
+ unsmile_labels = ["์—ฌ์„ฑ/๊ฐ€์กฑ","๋‚จ์„ฑ","์„ฑ์†Œ์ˆ˜์ž","์ธ์ข…/๊ตญ์ ","์—ฐ๋ น","์ง€์—ญ","์ข…๊ต","๊ธฐํƒ€ ํ˜์˜ค","์•…ํ”Œ/์š•์„ค","clean", 'name', 'number', 'address', 'bank', 'person']
42
+ num_labels = len(unsmile_labels)
43
 
44
+ model.config.id2label = {i: label for i, label in zip(range(num_labels), unsmile_labels)}
45
+ model.config.label2id = {label: i for i, label in zip(range(num_labels), unsmile_labels)}
46
 
47
+ pipe = TextClassificationPipeline(
48
  model = model,
49
  tokenizer = tokenizer,
50
  return_all_scores=True,
51
  function_to_apply='sigmoid'
52
  )
 
 
 
 
53
 
54
+ for result in pipe(text)[0]:
 
 
 
 
 
 
 
55
 
56
+ return result
57
 
58
+ if __name__ == '__main__':
59
+ text = '์ฟ๋”ด๊ฑธ ํ™๋ณฟ๊ธ€ ์ฟ๋ž‰๊ณญ ์Œ‘์ ฉ๋‚„๊ณ  ์•‰์•Ÿ์žˆ๋ƒฉ'
60
+
61
+ model_name_list = [
62
+ 'momo/KcELECTRA-base_Hate_speech_Privacy_Detection',
63
+ "momo/KcBERT-base_Hate_speech_Privacy_Detection",
64
+ ]
65
+
66
+ #Create a gradio app with a button that calls predict()
67
+ app = gr.Interface(
68
+ fn=predict,
69
+ inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], outputs=['label', 'plot'],
70
+ examples = [[MODEL_BUF["name"], text], [MODEL_BUF["name"], "4=๐Ÿฆ€ 4โ‰ ๐Ÿฆ€"]],
71
+ title="ํ•œ๊ตญ์–ด ํ˜์˜คํ‘œํ˜„, ๊ฐœ์ธ์ •๋ณด ํŒ๋ณ„๊ธฐ (Korean Hate Speech and Privacy Detection)",
72
+ description="Korean Hate Speech and Privacy Detection."
73
+ )
74
+ app.launch(inline=False)
75
 
76
 
77