LeeJang commited on
Commit
4ca7e91
1 Parent(s): 2f7dce2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ from transformers.models.bart import BartForConditionalGeneration
4
+ from transformers import PreTrainedTokenizerFast
5
+
6
+ #@st.cache
7
+ #@st.cache_data(allow_output_mutation=True)
8
+ @st.cache_data()
9
+ def load_model():
10
+ #model = BartForConditionalGeneration.from_pretrained('logs/model_chp/epoch-6')
11
+ model = BartForConditionalGeneration.from_pretrained('LeeJang/news-summarization-v2')
12
+ # tokenizer = get_kobart_tokenizer()
13
+ return model
14
+
15
+ model = load_model()
16
+ tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')
17
+ st.title("2문장 뉴스 요약기")
18
+ text = st.text_area("뉴스 입력:")
19
+
20
+ st.markdown("## 뉴스 원문")
21
+ st.write(text)
22
+
23
+ #'''
24
+ if text:
25
+ text = text.replace('\n', ' ')
26
+ text = text.strip()
27
+
28
+ arr = text.split(' ')
29
+
30
+ if len(arr) > 501:
31
+ #print('!!!')
32
+ arr = arr[:501]
33
+ text = ' '.join(arr)
34
+
35
+ st.markdown("## 요약 결과")
36
+ with st.spinner('processing..'):
37
+ input_ids = tokenizer.encode(text)
38
+ input_ids = torch.tensor(input_ids)
39
+ input_ids = input_ids.unsqueeze(0)
40
+ output = model.generate(input_ids, eos_token_id=1, max_length=512, num_beams=5)
41
+ output = tokenizer.decode(output[0], skip_special_tokens=True)
42
+ st.write(output)
43
+ #'''