haven-jeon
initial commit
bca6d91
raw history blame
No virus
2.86 kB
from transformers import PreTrainedTokenizerFast
from tokenizers import SentencePieceBPETokenizer
from transformers import BartForConditionalGeneration
import streamlit as st
import torch
def tokenizer():
tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-summarization')
return tokenizer
@st.cache(allow_output_mutation=True)
def get_model():
model = BartForConditionalGeneration.from_pretrained('gogamza/kobart-summarization')
model.eval()
return model
default_text = '''์งˆ๋ณ‘๊ด€๋ฆฌ์ฒญ์€ 23์ผ ์ง€๋ฐฉ์ž์น˜๋‹จ์ฒด๊ฐ€ ๋ณด๊ฑด๋‹น๊ตญ๊ณผ ํ˜‘์˜ ์—†์ด ๋‹จ๋…์œผ๋กœ ์ธํ”Œ๋ฃจ์—”์ž(๋…๊ฐ) ๋ฐฑ์‹  ์ ‘์ข… ์ค‘๋‹จ์„ ๊ฒฐ์ •ํ•ด์„œ๋Š” ์•ˆ ๋œ๋‹ค๋Š” ์ž…์žฅ์„ ๋ฐํ˜”๋‹ค.
์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  ์ฐธ๊ณ ์ž๋ฃŒ๋ฅผ ๋ฐฐํฌํ•˜๊ณ  โ€œํ–ฅํ›„ ์ „์ฒด ๊ตญ๊ฐ€ ์˜ˆ๋ฐฉ์ ‘์ข…์‚ฌ์—…์ด ์ฐจ์งˆ ์—†์ด ์ง„ํ–‰๋˜๋„๋ก ์ง€์ž์ฒด๊ฐ€ ์ž์ฒด์ ์œผ๋กœ ์ ‘์ข… ์œ ๋ณด ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•˜์ง€ ์•Š๋„๋ก ์•ˆ๋‚ด๋ฅผ ํ–ˆ๋‹คโ€๊ณ  ์„ค๋ช…ํ–ˆ๋‹ค.
๋…๊ฐ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•œ ํ›„ ๊ณ ๋ น์ธต์„ ์ค‘์‹ฌ์œผ๋กœ ์ „๊ตญ์—์„œ ์‚ฌ๋ง์ž๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์„œ์šธ ์˜๋“ฑํฌ๊ตฌ๋ณด๊ฑด์†Œ๋Š” ์ „๋‚ , ๊ฒฝ๋ถ ํฌํ•ญ์‹œ๋Š” ์ด๋‚  ๊ด€๋‚ด ์˜๋ฃŒ๊ธฐ๊ด€์— ์ ‘์ข…์„ ๋ณด๋ฅ˜ํ•ด๋‹ฌ๋ผ๋Š” ๊ณต๋ฌธ์„ ๋‚ด๋ ค๋ณด๋ƒˆ๋‹ค. ์ด๋Š” ์˜ˆ๋ฐฉ์ ‘์ข…๊ณผ ์‚ฌ๋ง ๊ฐ„ ์ง์ ‘์  ์—ฐ๊ด€์„ฑ์ด ๋‚ฎ์•„ ์ ‘์ข…์„ ์ค‘๋‹จํ•  ์ƒํ™ฉ์€ ์•„๋‹ˆ๋ผ๋Š” ์งˆ๋ณ‘์ฒญ์˜ ํŒ๋‹จ๊ณผ๋Š” ๋‹ค๋ฅธ ๊ฒƒ์ด๋‹ค.
์งˆ๋ณ‘์ฒญ์€ ์ง€๋‚œ 21์ผ ์ „๋ฌธ๊ฐ€ ๋“ฑ์ด ์ฐธ์—ฌํ•œ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜โ€™์˜ ๋ถ„์„ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋…๊ฐ ์˜ˆ๋ฐฉ์ ‘์ข… ์‚ฌ์—…์„ ์ผ์ •๋Œ€๋กœ ์ง„ํ–‰ํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค. ํŠนํžˆ ๊ณ ๋ น ์–ด๋ฅด์‹ ๊ณผ ์–ด๋ฆฐ์ด, ์ž„์‹ ๋ถ€ ๋“ฑ ๋…๊ฐ ๊ณ ์œ„ํ—˜๊ตฐ์€ ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•˜์ง€ ์•Š์•˜์„ ๋•Œ ํ•ฉ๋ณ‘์ฆ ํ”ผํ•ด๊ฐ€ ํด ์ˆ˜ ์žˆ๋‹ค๋ฉด์„œ ์ ‘์ข…์„ ๋…๋ คํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ๋ฐœํ‘œ ์ดํ›„์—๋„ ์‚ฌ๋ง ๋ณด๊ณ ๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜ ํšŒ์˜โ€™์™€ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ์ „๋ฌธ์œ„์›ํšŒโ€™๋ฅผ ๊ฐœ์ตœํ•ด ๋…๊ฐ๋ฐฑ์‹ ๊ณผ ์‚ฌ๋ง ๊ฐ„ ๊ด€๋ จ์„ฑ, ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ์—ฌ๋ถ€ ๋“ฑ์— ๋Œ€ํ•ด ๋‹ค์‹œ ๊ฒฐ๋ก  ๋‚ด๋ฆฌ๊ธฐ๋กœ ํ–ˆ๋‹ค. ํšŒ์˜ ๊ฒฐ๊ณผ๋Š” ์ด๋‚  ์˜คํ›„ 7์‹œ ๋„˜์–ด ๋ฐœํ‘œ๋  ์˜ˆ์ •์ด๋‹ค.
'''
model = get_model()
tokenizer = tokenizer()
st.title("Summarization Model Test")
text = st.text_area("Input news :", value=default_text)
st.markdown("## Original News Data")
st.write(text)
if text:
st.markdown("## Predict Summary")
with st.spinner('processing..'):
raw_input_ids = tokenizer.encode(text)
input_ids = [tokenizer.bos_token_id] + \
raw_input_ids + [tokenizer.eos_token_id]
summary_ids = model.generate(torch.tensor([input_ids]),
max_length=256,
early_stopping=True,
repetition_penalty=2.0)
summ = tokenizer.decode(summary_ids.squeeze().tolist(), skip_special_tokens=True)
st.write(summ)