import torch import streamlit as st from transformers.models.bart import BartForConditionalGeneration from transformers import PreTrainedTokenizerFast #@st.cache #@st.cache_data(allow_output_mutation=True) @st.cache_data() def load_model(): #model = BartForConditionalGeneration.from_pretrained('logs/model_chp/epoch-6') model = BartForConditionalGeneration.from_pretrained('LeeJang/news-summarization-v2') # tokenizer = get_kobart_tokenizer() return model model = load_model() tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1') st.title("2문장 뉴스 요약기") text = st.text_area("뉴스 입력:") st.markdown("## 뉴스 원문") st.write(text) #''' if text: text = text.replace('\n', ' ') text = text.strip() arr = text.split(' ') if len(arr) > 501: #print('!!!') arr = arr[:501] text = ' '.join(arr) st.markdown("## 요약 결과") with st.spinner('processing..'): input_ids = tokenizer.encode(text) input_ids = torch.tensor(input_ids) input_ids = input_ids.unsqueeze(0) output = model.generate(input_ids, eos_token_id=1, max_length=512, num_beams=5) output = tokenizer.decode(output[0], skip_special_tokens=True) st.write(output) #'''