Spaces:
Runtime error
Runtime error
updated app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,33 @@ words= st.text_input('Enter some words')
|
|
4 |
num_words= st.slider('How long should the output be?', 0, 100, 5)
|
5 |
button = st.button('Submit')
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
if button:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
st.write('Clicked!')
|
9 |
st.write(words, num_words)
|
|
|
4 |
num_words= st.slider('How long should the output be?', 0, 100, 5)
|
5 |
button = st.button('Submit')
|
6 |
|
7 |
+
@st.cache # only run the function once
|
8 |
+
def download_transformer():
|
9 |
+
#for reproducability
|
10 |
+
#SEED = 12
|
11 |
+
|
12 |
+
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
|
13 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
|
14 |
+
GPT2 = TFGPT2LMHeadModel.from_pretrained("gpt2-medium", pad_token_id=tokenizer.eos_token_id)
|
15 |
+
|
16 |
+
return tokenizer, GPT2
|
17 |
+
|
18 |
+
|
19 |
+
tokenizer, GPT2 = download_transformer()
|
20 |
+
|
21 |
+
def input_seq(input_words):
|
22 |
+
import tensorflow as tf
|
23 |
+
return tokenizer.encode(input_words, return_tensors='tf')
|
24 |
+
|
25 |
if button:
|
26 |
+
|
27 |
+
sample_output = GPT2.generate(
|
28 |
+
input_seq(words),
|
29 |
+
do_sample = True,
|
30 |
+
max_length = num_words,
|
31 |
+
top_p = 0.8,
|
32 |
+
top_k = 0)
|
33 |
+
|
34 |
+
|
35 |
st.write('Clicked!')
|
36 |
st.write(words, num_words)
|