sohamsh commited on
Commit
4d04a00
1 Parent(s): 9c03bcf

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -0
app.py CHANGED
@@ -4,6 +4,33 @@ words= st.text_input('Enter some words')
4
  num_words= st.slider('How long should the output be?', 0, 100, 5)
5
  button = st.button('Submit')
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  if button:
 
 
 
 
 
 
 
 
 
8
  st.write('Clicked!')
9
  st.write(words, num_words)
 
4
  num_words= st.slider('How long should the output be?', 0, 100, 5)
5
  button = st.button('Submit')
6
 
7
+ @st.cache # only run the function once
8
+ def download_transformer():
9
+ #for reproducability
10
+ #SEED = 12
11
+
12
+ from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
13
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
14
+ GPT2 = TFGPT2LMHeadModel.from_pretrained("gpt2-medium", pad_token_id=tokenizer.eos_token_id)
15
+
16
+ return tokenizer, GPT2
17
+
18
+
19
+ tokenizer, GPT2 = download_transformer()
20
+
21
+ def input_seq(input_words):
22
+ import tensorflow as tf
23
+ return tokenizer.encode(input_words, return_tensors='tf')
24
+
25
  if button:
26
+
27
+ sample_output = GPT2.generate(
28
+ input_seq(words),
29
+ do_sample = True,
30
+ max_length = num_words,
31
+ top_p = 0.8,
32
+ top_k = 0)
33
+
34
+
35
  st.write('Clicked!')
36
  st.write(words, num_words)