snoop2head's picture
upload post processing
849819f
raw
history blame contribute delete
No virus
4.89 kB
# -*- coding: utf-8 -*-
import json
import pandas as pd
import numpy as np
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForTokenClassification
st.set_page_config(
page_title="NER ๊ธฐ๋ฐ˜ ๋ฏผ๊ฐ์ •๋ณด ์‹๋ณ„", layout="wide", initial_sidebar_state="expanded"
)
@st.cache
def load_model(model_name):
model = AutoModelForTokenClassification.from_pretrained(model_name)
return model
st.title("๐Ÿ”’ NER ๊ธฐ๋ฐ˜ ๋ฏผ๊ฐ์ •๋ณด ์‹๋ณ„๊ธฐ")
st.write("๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์‹œ๊ณ , CTRL+Enter(CMD+Enter)๋ฅผ ๋ˆ„๋ฅด์„ธ์š” ๐Ÿค—")
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-base")
model = load_model("QuoQA-NLP/konec-privacy")
model.eval()
default_value = "์˜์ง„๋‹˜, ๋‹น๋‡จ ๊ฒ€์‚ฌํ•œ ๊ฑฐ ๊ฒฐ๊ณผ ๋‚˜์˜ค์…จ์–ด์š”."
src_text = st.text_area(
"๊ฒ€์‚ฌํ•˜๊ณ  ์‹ถ์€ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
default_value,
height=300,
max_chars=150,
)
def yield_df(default_value):
tokenized = tokenizer.encode(default_value)
print(tokenized)
output = model(input_ids=torch.tensor([tokenized]))
logits = output.logits
print(logits.size())
# get prediction for each tokens for 17 classes
pred = logits.argmax(-1).squeeze().numpy()
print(pred)
class_map = {
"B-ADD": 0,
"I-ADD": 1,
"B-DN": 2,
"I-DN": 3,
"B-DT": 4,
"I-DT": 5,
"B-LC": 6,
"I-LC": 7,
"B-OG": 8,
"I-OG": 9,
"B-PS": 10,
"I-PS": 11,
"B-QT": 12,
"I-QT": 13,
"B-RL": 14,
"I-RL": 15,
"O": 16
}
class_map_inverted = {v: k for k, v in class_map.items()}
# decode prediction
class_decoded = [class_map_inverted[p] for p in pred]
print(class_decoded)
label_map = {
"ADD": "์ฃผ์†Œ ์ •๋ณด",
"DN": "์งˆํ™˜ ์ •๋ณด",
"DT": "๋‚ ์งœ ์ •๋ณด",
"LC": "์žฅ์†Œ ์ •๋ณด",
"OG": "๊ธฐ๊ด€ ์ •๋ณด",
"PS": "์ธ๋ช…/๋ณ„๋ช… ์ •๋ณด",
"QT": "์ˆ˜๋Ÿ‰ ์ •๋ณด",
"RL": "๊ด€๊ณ„ ์ •๋ณด",
"O": "๋น„๋ฏผ๊ฐ ์ •๋ณด"
}
# pair tokens with prediction
tokenized_text = tokenizer.convert_ids_to_tokens(tokenized)
list_result = []
for token, pred in zip(tokenized_text, class_decoded):
splitted_pred = pred.split("-")
pred_class = splitted_pred[-1]
label = label_map[pred_class]
# print with 10 characters with spaces divided with |
result = {"ํ˜•ํƒœ์†Œ":token, "์˜ˆ์ƒ ๋ผ๋ฒจ":label}
list_result.append(result)
df = pd.DataFrame(list_result)
# remove first and last row
df = df.iloc[1:-1]
return df
def convert_df(df:pd.DataFrame):
return df.to_csv(index=False).encode('utf-8')
def convert_json(df:pd.DataFrame):
result = df.to_json(orient="index")
parsed = json.loads(result)
json_string = json.dumps(parsed)
#st.json(json_string, expanded=True)
return json_string
filtering_map = {
"์ฃผ์†Œ ์ •๋ณด": "[์ฃผ์†Œ]",
"์งˆํ™˜ ์ •๋ณด": "[์งˆํ™˜]",
"๋‚ ์งœ ์ •๋ณด": "[๋‚ ์งœ]",
"์žฅ์†Œ ์ •๋ณด": "[์žฅ์†Œ]",
"๊ธฐ๊ด€ ์ •๋ณด": "[๊ธฐ๊ด€]",
"์ธ๋ช…/๋ณ„๋ช… ์ •๋ณด": "[์ด๋ฆ„]",
"์ˆ˜๋Ÿ‰ ์ •๋ณด": "[์ˆ˜๋Ÿ‰]",
"๊ด€๊ณ„ ์ •๋ณด": "[๊ด€๊ณ„]",
"๋น„๋ฏผ๊ฐ ์ •๋ณด": "[๋น„๋ฏผ๊ฐ]"
}
if src_text == "":
st.warning("Please **enter text** for translation")
else:
df_result = yield_df(src_text)
st.markdown("### ํ•„ํ„ฐ๋ง ๋œ ๋ฌธ์žฅ")
display_result = ""
for index, row in df_result.iterrows():
token_info = row["ํ˜•ํƒœ์†Œ"]
label_info = row["์˜ˆ์ƒ ๋ผ๋ฒจ"]
if label_info != "๋น„๋ฏผ๊ฐ ์ •๋ณด":
token_info = filtering_map[label_info]
if "##" in token_info:
token_info = token_info.replace("##", "")
else:
token_info = " " + token_info
display_result += token_info
st.write(display_result)
st.markdown("### ๋ถ„๋ฅ˜๋œ ๋‹จ์–ด๋“ค")
st.header("")
cs, c1, c2, c3, cLast = st.columns([0.75, 1.5, 1.5, 1.5, 0.75])
st.table(df_result)
with c1:
#csvbutton = download_button(results, "results.csv", "๐Ÿ“ฅ Download .csv")
csvbutton = st.download_button(label="๐Ÿ“ฅ csv๋กœ ๋‹ค์šด๋กœ๋“œ", data=convert_df(df_result), file_name= "results.csv", mime='text/csv', key='csv')
with c2:
#textbutton = download_button(results, "results.txt", "๐Ÿ“ฅ Download .txt")
textbutton = st.download_button(label="๐Ÿ“ฅ txt๋กœ ๋‹ค์šด๋กœ๋“œ", data=convert_df(df_result), file_name= "results.text", mime='text/plain', key='text')
with c3:
#jsonbutton = download_button(results, "results.json", "๐Ÿ“ฅ Download .json")
jsonbutton = st.download_button(label="๐Ÿ“ฅ json์œผ๋กœ ๋‹ค์šด๋กœ๋“œ", data=convert_json(df_result), file_name= "results.json", mime='application/json', key='json')
with st.expander("(์ฃผ) ์ฟผ์นด์—์ด์•„์ด ๋ฐ๋ชจ ์‚ฌ์‚ฌ ๊ด€๋ จ", expanded=True):
st.write(
"""
ํ•ด๋‹น ๋ฐ๋ชจ๋Š” 2022๋…„๋„ ๊ณผํ•™๊ธฐ์ˆ ์ •๋ณดํ†ต์‹ ๋ถ€์˜ ์žฌ์›์œผ๋กœ ์ •๋ณดํ†ต์‹ ์‚ฐ์—…์ง„ํฅ์›์˜ ์ง€์›์„ ๋ฐ›์•„ ์ˆ˜ํ–‰๋œ ์—ฐ๊ตฌ์ž„
(๊ณผ์ œ๋ฒˆํ˜ธ: A1504-22-1005)
"""
)