Spaces:
Runtime error
Runtime error
from transformers import PreTrainedTokenizerFast | |
from tokenizers import SentencePieceBPETokenizer | |
from transformers import BartForConditionalGeneration | |
import streamlit as st | |
import torch | |
def tokenizer(): | |
tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-summarization') | |
return tokenizer | |
def get_model(): | |
model = BartForConditionalGeneration.from_pretrained('gogamza/kobart-summarization') | |
model.eval() | |
return model | |
default_text = '''์ง๋ณ๊ด๋ฆฌ์ฒญ์ 23์ผ ์ง๋ฐฉ์์น๋จ์ฒด๊ฐ ๋ณด๊ฑด๋น๊ตญ๊ณผ ํ์ ์์ด ๋จ๋ ์ผ๋ก ์ธํ๋ฃจ์์(๋ ๊ฐ) ๋ฐฑ์ ์ ์ข ์ค๋จ์ ๊ฒฐ์ ํด์๋ ์ ๋๋ค๋ ์ ์ฅ์ ๋ฐํ๋ค. | |
์ง๋ณ์ฒญ์ ์ด๋ ์ฐธ๊ณ ์๋ฃ๋ฅผ ๋ฐฐํฌํ๊ณ โํฅํ ์ ์ฒด ๊ตญ๊ฐ ์๋ฐฉ์ ์ข ์ฌ์ ์ด ์ฐจ์ง ์์ด ์งํ๋๋๋ก ์ง์์ฒด๊ฐ ์์ฒด์ ์ผ๋ก ์ ์ข ์ ๋ณด ์ฌ๋ถ๋ฅผ ๊ฒฐ์ ํ์ง ์๋๋ก ์๋ด๋ฅผ ํ๋คโ๊ณ ์ค๋ช ํ๋ค. | |
๋ ๊ฐ๋ฐฑ์ ์ ์ ์ข ํ ํ ๊ณ ๋ น์ธต์ ์ค์ฌ์ผ๋ก ์ ๊ตญ์์ ์ฌ๋ง์๊ฐ ์๋ฐ๋ฅด์ ์์ธ ์๋ฑํฌ๊ตฌ๋ณด๊ฑด์๋ ์ ๋ , ๊ฒฝ๋ถ ํฌํญ์๋ ์ด๋ ๊ด๋ด ์๋ฃ๊ธฐ๊ด์ ์ ์ข ์ ๋ณด๋ฅํด๋ฌ๋ผ๋ ๊ณต๋ฌธ์ ๋ด๋ ค๋ณด๋๋ค. ์ด๋ ์๋ฐฉ์ ์ข ๊ณผ ์ฌ๋ง ๊ฐ ์ง์ ์ ์ฐ๊ด์ฑ์ด ๋ฎ์ ์ ์ข ์ ์ค๋จํ ์ํฉ์ ์๋๋ผ๋ ์ง๋ณ์ฒญ์ ํ๋จ๊ณผ๋ ๋ค๋ฅธ ๊ฒ์ด๋ค. | |
์ง๋ณ์ฒญ์ ์ง๋ 21์ผ ์ ๋ฌธ๊ฐ ๋ฑ์ด ์ฐธ์ฌํ โ์๋ฐฉ์ ์ข ํผํด์กฐ์ฌ๋ฐโ์ ๋ถ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํ์ผ๋ก ๋ ๊ฐ ์๋ฐฉ์ ์ข ์ฌ์ ์ ์ผ์ ๋๋ก ์งํํ๊ธฐ๋ก ํ๋ค. ํนํ ๊ณ ๋ น ์ด๋ฅด์ ๊ณผ ์ด๋ฆฐ์ด, ์์ ๋ถ ๋ฑ ๋ ๊ฐ ๊ณ ์ํ๊ตฐ์ ๋ฐฑ์ ์ ์ ์ข ํ์ง ์์์ ๋ ํฉ๋ณ์ฆ ํผํด๊ฐ ํด ์ ์๋ค๋ฉด์ ์ ์ข ์ ๋ ๋ คํ๋ค. ํ์ง๋ง ์ ์ข ์ฌ์ ์ ์ง ๋ฐํ ์ดํ์๋ ์ฌ๋ง ๋ณด๊ณ ๊ฐ ์๋ฐ๋ฅด์ ์ง๋ณ์ฒญ์ ์ด๋ โ์๋ฐฉ์ ์ข ํผํด์กฐ์ฌ๋ฐ ํ์โ์ โ์๋ฐฉ์ ์ข ์ ๋ฌธ์์ํโ๋ฅผ ๊ฐ์ตํด ๋ ๊ฐ๋ฐฑ์ ๊ณผ ์ฌ๋ง ๊ฐ ๊ด๋ จ์ฑ, ์ ์ข ์ฌ์ ์ ์ง ์ฌ๋ถ ๋ฑ์ ๋ํด ๋ค์ ๊ฒฐ๋ก ๋ด๋ฆฌ๊ธฐ๋ก ํ๋ค. ํ์ ๊ฒฐ๊ณผ๋ ์ด๋ ์คํ 7์ ๋์ด ๋ฐํ๋ ์์ ์ด๋ค. | |
''' | |
model = get_model() | |
tokenizer = tokenizer() | |
st.title("Summarization Model Test") | |
text = st.text_area("Input news :", value=default_text) | |
st.markdown("## Original News Data") | |
st.write(text) | |
if text: | |
st.markdown("## Predict Summary") | |
with st.spinner('processing..'): | |
raw_input_ids = tokenizer.encode(text) | |
input_ids = [tokenizer.bos_token_id] + \ | |
raw_input_ids + [tokenizer.eos_token_id] | |
summary_ids = model.generate(torch.tensor([input_ids]), | |
max_length=256, | |
early_stopping=True, | |
repetition_penalty=2.0) | |
summ = tokenizer.decode(summary_ids.squeeze().tolist(), skip_special_tokens=True) | |
st.write(summ) | |