deepwords / app.py
yash161101's picture
Update app.py
f29b441
raw
history blame
1.54 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import tensorflow as tf
config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=3,
inter_op_parallelism_threads=2,
allow_soft_placement=True,
device_count = {'GPU':1, 'CPU':4})
session = tf.compat.v1.Session(config=config)
#for reproducability
SEED = 64
#maximum number of words in output text
# MAX_LEN = 30
title = st.text_input('Enter the seed words', ' ')
input_sequence = title
number = st.number_input('Insert how many words', 1)
MAX_LEN = number
if st.button('Submit'):
#get transformers
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
tokenizer = AutoTokenizer.from_pretrained("ml6team/gpt-2-medium-conditional-quote-generator")
GPT2 = model = AutoModelForCausalLM.from_pretrained("ml6team/gpt-2-medium-conditional-quote-generator")
import tensorflow as tf
tf.random.set_seed(SEED)
input_ids = tokenizer.encode(input_sequence, return_tensors='tf')
input_ids = tokenizer.encode(input_sequence, return_tensors='tf')
# generate text until the output length (which includes the context length) reaches 50
greedy_output = GPT2.generate(input_ids, max_length = MAX_LEN)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens = True))
else:
st.write(' ')
# print("Output:\n" + 100 * '-')
# print(tokenizer.decode(sample_output[0], skip_special_tokens = True), '...')