alex6095's picture
Update app.py
620a618
raw
history blame contribute delete
No virus
4.43 kB
import torch
import torch.nn as nn
import re
import streamlit as st
from transformers import DistilBertModel
from tokenization_kobert import KoBertTokenizer
class SanctiMoly(nn.Module):
""" Holy Moly News BERT """
def __init__(self, bert_model, freeze_bert = True):
super(SanctiMoly, self).__init__()
self.encoder = bert_model
# FC-BN-Tanh
self.linear = nn.Sequential(nn.Linear(768, 1024),
nn.BatchNorm1d(1024),
nn.Tanh(),
nn.Dropout(),
nn.Linear(1024, 768),
nn.BatchNorm1d(768),
nn.Tanh(),
nn.Dropout(),
nn.Linear(768, 120)
)
# self.softmax = nn.LogSoftmax(dim=-1)
if freeze_bert == True:
for param in self.encoder.parameters():
param.requires_grad = False
else:
for param in self.encoder.parameters():
param.requires_grad = True
def forward(self, input_ids, input_length):
# calculate attention mask
attn_mask = torch.arange(input_ids.size(1))
attn_mask = attn_mask[None, :] < input_length[:, None]
enc_o = self.encoder(input_ids, attn_mask)
output = self.linear(enc_o.last_hidden_state[:, 0, :])
# print(output.shape)
return output
@st.cache(allow_output_mutation=True)
def get_model():
bert_model = DistilBertModel.from_pretrained('alex6095/SanctiMolyOH_Cpu')
tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
model = SanctiMoly(bert_model, freeze_bert=False)
device = torch.device('cpu')
checkpoint = torch.load("./model.pt", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model, tokenizer
model, tokenizer = get_model()
class RegexSubstitution(object):
"""Regex substitution class for transform"""
def __init__(self, regex, sub=''):
if isinstance(regex, re.Pattern):
self.regex = regex
else:
self.regex = re.compile(regex)
self.sub = sub
def __call__(self, target):
if isinstance(target, list):
return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
else:
return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
def i2ym(fl):
return (str(fl // 12 + 2009), str(fl % 12 + 1))
default_text = '''ํ—Œ๋ฒ•์žฌํŒ์†Œ๊ฐ€ ๋ฐ•๊ทผํ˜œ ๋Œ€ํ†ต๋ น์˜ ํŒŒ๋ฉด์„ ๋งŒ์žฅ์ผ์น˜๋กœ ๊ฒฐ์ •ํ–ˆ๋‹ค. ํ˜„์ง ๋Œ€ํ†ต๋ น ํƒ„ํ•ต์ด ์ธ์šฉ๋œ ๊ฒƒ์€ ํ—Œ์ • ์‚ฌ์ƒ ์ตœ์ดˆ๋‹ค. ๋ฐ• ์ „ ๋Œ€ํ†ต๋ น์— ๋Œ€ํ•œ ํŒŒ๋ฉด์ด ๊ฒฐ์ •๋˜๋ฉด์„œ ํ—Œ๋ฒ•๊ณผ ๊ณต์ง์„ ๊ฑฐ๋ฒ•์— ๋”ฐ๋ผ ์•ž์œผ๋กœ 60์ผ ์ด๋‚ด์— ์ฐจ๊ธฐ ๋Œ€ํ†ต๋ น ์„ ๊ฑฐ๊ฐ€ ์น˜๋Ÿฌ์ง„๋‹ค.
์ด์ •๋ฏธ ํ—Œ์žฌ์†Œ์žฅ ๊ถŒํ•œ๋Œ€ํ–‰(์žฌํŒ๊ด€)์€ 10์ผ ์˜ค์ „ 11์‹œ 23๋ถ„ ์„œ์šธ ์ข…๋กœ๊ตฌ ํ—Œ๋ฒ•์žฌํŒ์†Œ ๋Œ€์‹ฌํŒ์ •์—์„œ โ€œํ”ผ์ฒญ๊ตฌ์ธ ๋Œ€ํ†ต๋ น ๋ฐ•๊ทผํ˜œ๋ฅผ ํŒŒ๋ฉดํ•œ๋‹คโ€๊ณ  ์ฃผ๋ฌธ์„ ์„ ๊ณ ํ–ˆ๋‹ค. ๊ทธ ์ˆœ๊ฐ„ ๋Œ€์‹ฌํŒ์ • ๊ณณ๊ณณ์—์„œ ๋ฌด๊ฒ๊ณ  ๋‚˜์งํ•œ ํƒ„์„ฑ์ด ํ„ฐ์ ธ ๋‚˜์™”๋‹ค. ์ด๋‚  ๋Œ€์‹ฌํŒ์ •์—์„  ๋ฐ•๊ทผํ˜œ ์ „ ๋Œ€ํ†ต๋ น ์ธก๊ณผ ๊ตญํšŒ์†Œ์ถ”์œ„์› ์ธก ๊ด€๊ณ„์ž๋“ค๊ณผ ์ทจ์žฌ์ง„ 80๋ช…, ์˜จ๋ผ์ธ ์ ‘์ˆ˜๋ฅผ ํ†ตํ•ด 795๋Œ€ 1์˜ ๊ฒฝ์Ÿ๋ฅ ์„ ๋šซ๊ณ  ์„ ์ •๋œ ์ผ๋ฐ˜๋ฐฉ์ฒญ๊ฐ 24๋ช…์ด ์ˆจ์„ ์ฃฝ์ด๊ณ  ์žˆ์—ˆ๋‹ค.
'''
st.title("Date prediction")
text = st.text_area("Input news :", value=default_text)
st.markdown("## Original News Data")
st.write(text)
st.markdown("## Predict Top 3 Date")
if text:
with st.spinner('processing..'):
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
encoded_dict = tokenizer(
text=[text],
add_special_tokens=True,
max_length=512,
truncation=True,
return_tensors='pt',
return_length=True
)
input_ids = encoded_dict['input_ids']
input_ids_len = encoded_dict['length']
pred = model(input_ids, input_ids_len)
_, indices = torch.topk(pred, 3)
pred_print = []
for i in indices.squeeze(0):
year, month = i2ym(i.item())
pred_print.append(year+"-"+month)
st.write(", ".join(pred_print))