haven-jeon commited on
Commit
96484a5
β€’
1 Parent(s): 23b2f5d

fix tokenizer

Browse files
Files changed (1) hide show
  1. app.py +10 -17
app.py CHANGED
@@ -1,10 +1,7 @@
1
  import torch
2
  import string
3
  import streamlit as st
4
- from transformers import GPT2LMHeadModel
5
- from tokenizers import Tokenizer
6
-
7
-
8
 
9
 
10
  @st.cache
@@ -13,7 +10,13 @@ def get_model():
13
  model.eval()
14
  return model
15
 
16
- tokenizer = Tokenizer.from_file('skt/kogpt2-base-v2')
 
 
 
 
 
 
17
 
18
  default_text = "ν˜„λŒ€μΈλ“€μ€ μ™œ 항상 λΆˆμ•ˆν•΄ ν• κΉŒ?"
19
 
@@ -39,26 +42,16 @@ st.markdown("""
39
 
40
  text = st.text_area("Input Text:", value=default_text)
41
  st.write(text)
42
- st.markdown("""
43
- > *ν˜„μž¬ 2core μΈμŠ€ν„΄μŠ€μ—μ„œ 예츑이 μ§„ν–‰λ˜μ–΄ λ‹€μ†Œ 느릴 수 있음*
44
- """)
45
  punct = ('!', '?', '.')
46
 
47
  if text:
48
  st.markdown("## Predict")
49
  with st.spinner('processing..'):
50
  print(f'input > {text}')
51
- input_ids = tokenizer.encode(text).ids
52
  gen_ids = model.generate(torch.tensor([input_ids]),
53
  max_length=128,
54
- repetition_penalty=2.0,
55
- # num_beams=2,
56
- # length_penalty=1.0,
57
- use_cache=True,
58
- pad_token_id=tokenizer.token_to_id('<pad>'),
59
- eos_token_id=tokenizer.token_to_id('</s>'),
60
- bos_token_id=tokenizer.token_to_id('</s>'),
61
- bad_words_ids=[[tokenizer.token_to_id('<unk>')] ])
62
  generated = tokenizer.decode(gen_ids[0,:].tolist()).strip()
63
  if generated != '' and generated[-1] not in punct:
64
  for i in reversed(range(len(generated))):
1
  import torch
2
  import string
3
  import streamlit as st
4
+ from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast
 
 
 
5
 
6
 
7
  @st.cache
10
  model.eval()
11
  return model
12
 
13
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("skt/kogpt2-base-v2",
14
+ bos_token='</s>',
15
+ eos_token='</s>',
16
+ unk_token='<unk>',
17
+ pad_token='<pad>',
18
+ mask_token='<mask>')
19
+
20
 
21
  default_text = "ν˜„λŒ€μΈλ“€μ€ μ™œ 항상 λΆˆμ•ˆν•΄ ν• κΉŒ?"
22
 
42
 
43
  text = st.text_area("Input Text:", value=default_text)
44
  st.write(text)
 
 
 
45
  punct = ('!', '?', '.')
46
 
47
  if text:
48
  st.markdown("## Predict")
49
  with st.spinner('processing..'):
50
  print(f'input > {text}')
51
+ input_ids = tokenizer(text)['input_ids']
52
  gen_ids = model.generate(torch.tensor([input_ids]),
53
  max_length=128,
54
+ repetition_penalty=2.0)
 
 
 
 
 
 
 
55
  generated = tokenizer.decode(gen_ids[0,:].tolist()).strip()
56
  if generated != '' and generated[-1] not in punct:
57
  for i in reversed(range(len(generated))):