deep_words / app.py
sohamsh's picture
updated app.py
4d04a00
raw
history blame
1 kB
import streamlit as st
words= st.text_input('Enter some words')
num_words= st.slider('How long should the output be?', 0, 100, 5)
button = st.button('Submit')
@st.cache # only run the function once
def download_transformer():
#for reproducability
#SEED = 12
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
GPT2 = TFGPT2LMHeadModel.from_pretrained("gpt2-medium", pad_token_id=tokenizer.eos_token_id)
return tokenizer, GPT2
tokenizer, GPT2 = download_transformer()
def input_seq(input_words):
import tensorflow as tf
return tokenizer.encode(input_words, return_tensors='tf')
if button:
sample_output = GPT2.generate(
input_seq(words),
do_sample = True,
max_length = num_words,
top_p = 0.8,
top_k = 0)
st.write('Clicked!')
st.write(words, num_words)