Spaces:
Runtime error
Runtime error
from transformers import PreTrainedTokenizerFast | |
from tokenizers import SentencePieceBPETokenizer | |
from transformers import BartForConditionalGeneration | |
import streamlit as st | |
import torch | |
def tokenizer(): | |
tokenizer = PreTrainedTokenizerFast.from_pretrained('dnrso/koBART_Sum_Review_finetuning') | |
return tokenizer | |
def get_model(): | |
model = BartForConditionalGeneration.from_pretrained('dnrso/koBART_Sum_Review_finetuning') | |
model.eval() | |
return model | |
default_text = '''게임을 하면서 사용하기 좋아요 음질도 괜찮고 착용감도 좋고 이어컵측면에 불빛도 이뻐요 가성비 정말 좋은 제품입니다 | |
''' | |
model = get_model() | |
tokenizer = tokenizer() | |
st.title("Review Summarization Test") | |
text = st.text_area("Input:", value=default_text) | |
st.markdown("Review Data") | |
st.write(text) | |
if text: | |
st.markdown("## Predict Summary") | |
with st.spinner('processing..'): | |
raw_input_ids = tokenizer.encode(text) | |
input_ids = [tokenizer.bos_token_id] + \ | |
raw_input_ids + [tokenizer.eos_token_id] | |
summary_ids = model.generate(torch.tensor([input_ids]), | |
max_length=256, | |
early_stopping=True, | |
repetition_penalty=2.0) | |
summ = tokenizer.decode(summary_ids.squeeze().tolist(), skip_special_tokens=True) | |
st.write(summ) | |