snoop2head commited on
Commit
498ff0a
โ€ข
1 Parent(s): 4c2695f

update demo(before postprocessing

Browse files
Files changed (2) hide show
  1. app.py +142 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import json
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ import streamlit as st
7
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
8
+
9
+ st.set_page_config(
10
+ page_title="NER ๊ธฐ๋ฐ˜ ๋ฏผ๊ฐ์ •๋ณด ์‹๋ณ„", layout="wide", initial_sidebar_state="expanded"
11
+ )
12
+
13
+ @st.cache
14
+ def load_model(model_name):
15
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
16
+ return model
17
+
18
+
19
+ st.title("๐Ÿ”’ NER ๊ธฐ๋ฐ˜ ๋ฏผ๊ฐ์ •๋ณด ์‹๋ณ„๊ธฐ")
20
+ st.write("๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์‹œ๊ณ , CTRL+Enter(CMD+Enter)๋ฅผ ๋ˆ„๋ฅด์„ธ์š” ๐Ÿค—")
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained("klue/roberta-base")
23
+ model = load_model("QuoQA-NLP/konec-privacy")
24
+
25
+ model.eval()
26
+
27
+
28
+ default_value = "์„ฑ์šฑ๋‹˜, ๋‹น๋‡จ ๊ฒ€์‚ฌํ•œ ๊ฑฐ ๊ฒฐ๊ณผ ๋‚˜์˜ค์…จ์–ด์š”."
29
+
30
+ src_text = st.text_area(
31
+ "๊ฒ€์‚ฌํ•˜๊ณ  ์‹ถ์€ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
32
+ default_value,
33
+ height=300,
34
+ max_chars=150,
35
+ )
36
+
37
+
38
+ def yield_df(default_value):
39
+ tokenized = tokenizer.encode(default_value)
40
+ print(tokenized)
41
+
42
+ output = model(input_ids=torch.tensor([tokenized]))
43
+ logits = output.logits
44
+ print(logits.size())
45
+
46
+ # get prediction for each tokens for 17 classes
47
+ pred = logits.argmax(-1).squeeze().numpy()
48
+ print(pred)
49
+
50
+ class_map = {
51
+ "B-ADD": 0,
52
+ "I-ADD": 1,
53
+ "B-DN": 2,
54
+ "I-DN": 3,
55
+ "B-DT": 4,
56
+ "I-DT": 5,
57
+ "B-LC": 6,
58
+ "I-LC": 7,
59
+ "B-OG": 8,
60
+ "I-OG": 9,
61
+ "B-PS": 10,
62
+ "I-PS": 11,
63
+ "B-QT": 12,
64
+ "I-QT": 13,
65
+ "B-RL": 14,
66
+ "I-RL": 15,
67
+ "O": 16
68
+ }
69
+
70
+ class_map_inverted = {v: k for k, v in class_map.items()}
71
+
72
+ # decode prediction
73
+ class_decoded = [class_map_inverted[p] for p in pred]
74
+ print(class_decoded)
75
+
76
+ label_map = {
77
+ "ADD": 0,
78
+ "DN": "์งˆํ™˜ ์ •๋ณด",
79
+ "DT": "๋‚ ์งœ ์ •๋ณด",
80
+ "LC": "์ฃผ์†Œ ์ •๋ณด(์ง€์—ญ, ์ด๋ฉ”์ผ ์ฃผ์†Œ ๋“ฑ)",
81
+ "OG": "๊ธฐ๊ด€ ์ •๋ณด",
82
+ "PS": "์ธ๋ช…/๋ณ„๋ช… ์ •๋ณด",
83
+ "QT": "์ˆ˜๋Ÿ‰ ์ •๋ณด",
84
+ "RL": "๊ด€๊ณ„ ์ •๋ณด",
85
+ "O": "๋น„๋ฏผ๊ฐ ์ •๋ณด"
86
+ }
87
+
88
+ # pair tokens with prediction
89
+ tokenized_text = tokenizer.convert_ids_to_tokens(tokenized)
90
+ list_result = []
91
+ for token, pred in zip(tokenized_text, class_decoded):
92
+ splitted_pred = pred.split("-")
93
+ pred_class = splitted_pred[-1]
94
+ label = label_map[pred_class]
95
+ # print with 10 characters with spaces divided with |
96
+ result = {"ํ˜•ํƒœ์†Œ":token, "์˜ˆ์ƒ ๋ผ๋ฒจ":label}
97
+ list_result.append(result)
98
+
99
+ df = pd.DataFrame(list_result)
100
+ # remove first and last row
101
+ df = df.iloc[1:-1]
102
+ st.table(df)
103
+ return df
104
+
105
+ def convert_df(df:pd.DataFrame):
106
+ return df.to_csv(index=False).encode('utf-8')
107
+
108
+ def convert_json(df:pd.DataFrame):
109
+ result = df.to_json(orient="index")
110
+ parsed = json.loads(result)
111
+ json_string = json.dumps(parsed)
112
+ #st.json(json_string, expanded=True)
113
+ return json_string
114
+
115
+
116
+ if src_text == "":
117
+ st.warning("Please **enter text** for translation")
118
+ else:
119
+ st.markdown("### ๋ถ„๋ฅ˜๋œ ๋‹จ์–ด๋“ค")
120
+ st.header("")
121
+ cs, c1, c2, c3, cLast = st.columns([0.75, 1.5, 1.5, 1.5, 0.75])
122
+
123
+ df_result = yield_df(src_text)
124
+
125
+ with c1:
126
+ #csvbutton = download_button(results, "results.csv", "๐Ÿ“ฅ Download .csv")
127
+ csvbutton = st.download_button(label="๐Ÿ“ฅ csv๋กœ ๋‹ค์šด๋กœ๋“œ", data=convert_df(df_result), file_name= "results.csv", mime='text/csv', key='csv')
128
+ with c2:
129
+ #textbutton = download_button(results, "results.txt", "๐Ÿ“ฅ Download .txt")
130
+ textbutton = st.download_button(label="๐Ÿ“ฅ txt๋กœ ๋‹ค์šด๋กœ๋“œ", data=convert_df(df_result), file_name= "results.text", mime='text/plain', key='text')
131
+ with c3:
132
+ #jsonbutton = download_button(results, "results.json", "๐Ÿ“ฅ Download .json")
133
+ jsonbutton = st.download_button(label="๐Ÿ“ฅ json์œผ๋กœ ๋‹ค์šด๋กœ๋“œ", data=convert_json(df_result), file_name= "results.json", mime='application/json', key='json')
134
+
135
+ with st.expander("(์ฃผ) ์ฟผ์นด์—์ด์•„์ด ๋ฐ๋ชจ ์‚ฌ์‚ฌ ๊ด€๋ จ", expanded=True):
136
+
137
+ st.write(
138
+ """
139
+ ํ•ด๋‹น ๋ฐ๋ชจ๋Š” 2022๋…„๋„ ๊ณผํ•™๊ธฐ์ˆ ์ •๋ณดํ†ต์‹ ๋ถ€์˜ ์žฌ์›์œผ๋กœ ์ •๋ณดํ†ต์‹ ์‚ฐ์—…์ง„ํฅ์›์˜ ์ง€์›์„ ๋ฐ›์•„ ์ˆ˜ํ–‰๋œ ์—ฐ๊ตฌ์ž„
140
+ (๊ณผ์ œ๋ฒˆํ˜ธ: A1504-22-1005)
141
+ """
142
+ )
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ streamlit
3
+ torch
4
+ pandas
5
+ numpy