soyoung97 commited on
Commit
22a4fe2
β€’
1 Parent(s): 81a0216

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +44 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedTokenizerFast
2
+ from tokenizers import SentencePieceBPETokenizer
3
+ from transformers import BartForConditionalGeneration
4
+ import streamlit as st
5
+ import torch
6
+
7
+
8
+
9
+ def tokenizer():
10
+ tokenizer = PreTrainedTokenizerFast.from_pretrained('Soyoung97/gec_kr')
11
+ return tokenizer
12
+
13
+
14
+ @st.cache(allow_output_mutation=True)
15
+ def get_model():
16
+ model = BartForConditionalGeneration.from_pretrained('Soyoung97/gec_kr')
17
+ model.eval()
18
+ return model
19
+
20
+
21
+ default_text = 'λ‚˜λŠ 였늘 지베 κ°€μ¨μš”'
22
+
23
+ model = get_model()
24
+ tokenizer = tokenizer()
25
+ st.title("GEC_KR Model Test")
26
+ text = st.text_area("Input corrputed sentence :", value=default_text)
27
+
28
+ st.markdown("## Original sentence:")
29
+ st.write(text)
30
+
31
+ if text:
32
+ st.markdown("## Corrected output")
33
+ with st.spinner('processing..'):
34
+ raw_input_ids = tokenizer.encode(text)
35
+ input_ids = [tokenizer.bos_token_id] + \
36
+ raw_input_ids + [tokenizer.eos_token_id]
37
+ corrected_ids = model.generate(torch.tensor([input_ids]),
38
+ max_length=256,
39
+ eos_token_id=1,
40
+ num_beams=4,
41
+ early_stopping=True,
42
+ repetition_penalty=2.0)
43
+ summ = tokenizer.decode(corrected_ids.squeeze().tolist(), skip_special_tokens=True)
44
+ st.write(summ)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ streamlit