File size: 1,004 Bytes
75d6a4a
 
4336167
0266d30
2b77521
75d6a4a
4d04a00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75d6a4a
4d04a00
 
 
 
 
 
 
 
 
0266d30
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import streamlit as st

words= st.text_input('Enter some words')
num_words= st.slider('How long should the output be?', 0, 100, 5)
button = st.button('Submit')

@st.cache # only run the function once
def download_transformer():
  #for reproducability
  #SEED = 12
  
  from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
  GPT2 = TFGPT2LMHeadModel.from_pretrained("gpt2-medium", pad_token_id=tokenizer.eos_token_id)
  
  return tokenizer, GPT2
  
   
tokenizer, GPT2 = download_transformer()

def input_seq(input_words):
  import tensorflow as tf
  return tokenizer.encode(input_words, return_tensors='tf')

if button:

  sample_output = GPT2.generate(
                             input_seq(words), 
                             do_sample = True, 
                             max_length = num_words, 
                             top_p = 0.8, 
                             top_k = 0)


  st.write('Clicked!')
  st.write(words, num_words)