SanctiMolyDemo1 / app.py
alex6095's picture
Update app.py
5ef0c66
import torch
import re
import streamlit as st
import pandas as pd
from transformers import PreTrainedTokenizerFast, DistilBertForSequenceClassification, BartForConditionalGeneration
from tokenization_kobert import KoBertTokenizer
from tokenizers import SentencePieceBPETokenizer
@st.cache(allow_output_mutation=True)
def get_topic():
model = DistilBertForSequenceClassification.from_pretrained(
'alex6095/SanctiMolyTopic', problem_type="multi_label_classification", num_labels=9)
model.eval()
tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
return model, tokenizer
@st.cache(allow_output_mutation=True)
def get_date():
model = BartForConditionalGeneration.from_pretrained('alex6095/SanctiMoly-Bart')
model.eval()
tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-summarization')
return model, tokenizer
class RegexSubstitution(object):
"""Regex substitution class for transform"""
def __init__(self, regex, sub=''):
if isinstance(regex, re.Pattern):
self.regex = regex
else:
self.regex = re.compile(regex)
self.sub = sub
def __call__(self, target):
if isinstance(target, list):
return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
else:
return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
default_text = '''์งˆ๋ณ‘๊ด€๋ฆฌ์ฒญ์€ 23์ผ ์ง€๋ฐฉ์ž์น˜๋‹จ์ฒด๊ฐ€ ๋ณด๊ฑด๋‹น๊ตญ๊ณผ ํ˜‘์˜ ์—†์ด ๋‹จ๋…์œผ๋กœ ์ธํ”Œ๋ฃจ์—”์ž(๋…๊ฐ) ๋ฐฑ์‹  ์ ‘์ข… ์ค‘๋‹จ์„ ๊ฒฐ์ •ํ•ด์„œ๋Š” ์•ˆ ๋œ๋‹ค๋Š” ์ž…์žฅ์„ ๋ฐํ˜”๋‹ค.
์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  ์ฐธ๊ณ ์ž๋ฃŒ๋ฅผ ๋ฐฐํฌํ•˜๊ณ  โ€œํ–ฅํ›„ ์ „์ฒด ๊ตญ๊ฐ€ ์˜ˆ๋ฐฉ์ ‘์ข…์‚ฌ์—…์ด ์ฐจ์งˆ ์—†์ด ์ง„ํ–‰๋˜๋„๋ก ์ง€์ž์ฒด๊ฐ€ ์ž์ฒด์ ์œผ๋กœ ์ ‘์ข… ์œ ๋ณด ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•˜์ง€ ์•Š๋„๋ก ์•ˆ๋‚ด๋ฅผ ํ–ˆ๋‹คโ€๊ณ  ์„ค๋ช…ํ–ˆ๋‹ค.
๋…๊ฐ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•œ ํ›„ ๊ณ ๋ น์ธต์„ ์ค‘์‹ฌ์œผ๋กœ ์ „๊ตญ์—์„œ ์‚ฌ๋ง์ž๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์„œ์šธ ์˜๋“ฑํฌ๊ตฌ๋ณด๊ฑด์†Œ๋Š” ์ „๋‚ , ๊ฒฝ๋ถ ํฌํ•ญ์‹œ๋Š” ์ด๋‚  ๊ด€๋‚ด ์˜๋ฃŒ๊ธฐ๊ด€์— ์ ‘์ข…์„ ๋ณด๋ฅ˜ํ•ด๋‹ฌ๋ผ๋Š” ๊ณต๋ฌธ์„ ๋‚ด๋ ค๋ณด๋ƒˆ๋‹ค. ์ด๋Š” ์˜ˆ๋ฐฉ์ ‘์ข…๊ณผ ์‚ฌ๋ง ๊ฐ„ ์ง์ ‘์  ์—ฐ๊ด€์„ฑ์ด ๋‚ฎ์•„ ์ ‘์ข…์„ ์ค‘๋‹จํ•  ์ƒํ™ฉ์€ ์•„๋‹ˆ๋ผ๋Š” ์งˆ๋ณ‘์ฒญ์˜ ํŒ๋‹จ๊ณผ๋Š” ๋‹ค๋ฅธ ๊ฒƒ์ด๋‹ค.
์งˆ๋ณ‘์ฒญ์€ ์ง€๋‚œ 21์ผ ์ „๋ฌธ๊ฐ€ ๋“ฑ์ด ์ฐธ์—ฌํ•œ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜โ€™์˜ ๋ถ„์„ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋…๊ฐ ์˜ˆ๋ฐฉ์ ‘์ข… ์‚ฌ์—…์„ ์ผ์ •๋Œ€๋กœ ์ง„ํ–‰ํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค. ํŠนํžˆ ๊ณ ๋ น ์–ด๋ฅด์‹ ๊ณผ ์–ด๋ฆฐ์ด, ์ž„์‹ ๋ถ€ ๋“ฑ ๋…๊ฐ ๊ณ ์œ„ํ—˜๊ตฐ์€ ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•˜์ง€ ์•Š์•˜์„ ๋•Œ ํ•ฉ๋ณ‘์ฆ ํ”ผํ•ด๊ฐ€ ํด ์ˆ˜ ์žˆ๋‹ค๋ฉด์„œ ์ ‘์ข…์„ ๋…๋ คํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ๋ฐœํ‘œ ์ดํ›„์—๋„ ์‚ฌ๋ง ๋ณด๊ณ ๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜ ํšŒ์˜โ€™์™€ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ์ „๋ฌธ์œ„์›ํšŒโ€™๋ฅผ ๊ฐœ์ตœํ•ด ๋…๊ฐ๋ฐฑ์‹ ๊ณผ ์‚ฌ๋ง ๊ฐ„ ๊ด€๋ จ์„ฑ, ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ์—ฌ๋ถ€ ๋“ฑ์— ๋Œ€ํ•ด ๋‹ค์‹œ ๊ฒฐ๋ก  ๋‚ด๋ฆฌ๊ธฐ๋กœ ํ–ˆ๋‹ค. ํšŒ์˜ ๊ฒฐ๊ณผ๋Š” ์ด๋‚  ์˜คํ›„ 7์‹œ ๋„˜์–ด ๋ฐœํ‘œ๋  ์˜ˆ์ •์ด๋‹ค.
'''
topics_raw = ['IT/๊ณผํ•™', '๊ฒฝ์ œ', '๋ฌธํ™”', '๋ฏธ์šฉ/๊ฑด๊ฐ•', '์‚ฌํšŒ', '์ƒํ™œ', '์Šคํฌ์ธ ', '์—ฐ์˜ˆ', '์ •์น˜']
topic_model, topic_tokenizer = get_topic()
date_model, date_tokenizer = get_date()
st.sidebar.header('Menu')
name = st.sidebar.selectbox('Model', ['Topic Classification', 'Date Prediction'])
if name == 'Topic Classification':
title = 'News Topic Classification'
model, tokenizer = topic_model, topic_tokenizer
elif name == 'Date Prediction':
title = 'News Date prediction'
model, tokenizer = date_model, date_tokenizer
st.title(title)
text = st.text_area("Input news :", value=default_text)
st.markdown("## Original News Data")
st.write(text)
if name == 'Topic Classification':
st.markdown("## Predict Topic")
col1, col2 = st.columns(2)
if text:
with st.spinner('processing..'):
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
encoded_dict = tokenizer(
text=text,
add_special_tokens=True,
max_length=512,
truncation=True,
return_tensors='pt',
return_length=True
)
input_ids = encoded_dict['input_ids']
input_ids_len = encoded_dict['length'].unsqueeze(0)
attn_mask = torch.arange(input_ids.size(1))
attn_mask = attn_mask[None, :] < input_ids_len[:, None]
outputs = model(input_ids=input_ids, attention_mask=attn_mask)
_, preds = torch.max(outputs.logits, 1)
col1.write(topics_raw[preds.squeeze(0)])
softmax = torch.nn.Softmax(dim=1)
prob = softmax(outputs.logits).squeeze(0).detach()
chart_data = pd.DataFrame({
'Topic': topics_raw,
'Probability': prob
})
chart_data = chart_data.set_index('Topic')
col2.bar_chart(chart_data)
elif name == 'Date Prediction':
st.markdown("## Predict 3 possible Date")
if text:
with st.spinner('processing..'):
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
raw_input_ids = tokenizer.encode(text)
input_ids = [tokenizer.bos_token_id] + \
raw_input_ids + [tokenizer.eos_token_id]
outputs = model.generate(torch.tensor([input_ids]),
early_stopping=True,
do_sample=True, #์ƒ˜ํ”Œ๋ง ์ „๋žต ์‚ฌ์šฉ
max_length=50, # ์ตœ๋Œ€ ๋””์ฝ”๋”ฉ ๊ธธ์ด๋Š” 50
top_k=50, # ํ™•๋ฅ  ์ˆœ์œ„๊ฐ€ 50์œ„ ๋ฐ–์ธ ํ† ํฐ์€ ์ƒ˜ํ”Œ๋ง์—์„œ ์ œ์™ธ
top_p=0.95, # ๋ˆ„์  ํ™•๋ฅ ์ด 95%์ธ ํ›„๋ณด์ง‘ํ•ฉ์—์„œ๋งŒ ์ƒ์„ฑ
num_return_sequences=3 #3๊ฐœ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋””์ฝ”๋”ฉํ•ด๋‚ธ๋‹ค
)
pred_print = []
for output in outputs:
pred_print.append(tokenizer.decode(output.squeeze().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True))
st.write(", ".join(pred_print))