Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import tensorflow as tf | |
config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=3, | |
inter_op_parallelism_threads=2, | |
allow_soft_placement=True, | |
device_count = {'GPU':1, 'CPU':4}) | |
session = tf.compat.v1.Session(config=config) | |
#for reproducability | |
SEED = 64 | |
#maximum number of words in output text | |
# MAX_LEN = 30 | |
title = st.text_input('Enter the seed words', ' ') | |
input_sequence = title | |
number = st.number_input('Insert how many words', 1) | |
MAX_LEN = number | |
if st.button('Submit'): | |
#get transformers | |
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("ml6team/gpt-2-medium-conditional-quote-generator") | |
GPT2 = model = AutoModelForCausalLM.from_pretrained("ml6team/gpt-2-medium-conditional-quote-generator") | |
import tensorflow as tf | |
tf.random.set_seed(SEED) | |
input_ids = tokenizer.encode(input_sequence, return_tensors='tf') | |
input_ids = tokenizer.encode(input_sequence, return_tensors='tf') | |
# generate text until the output length (which includes the context length) reaches 50 | |
greedy_output = GPT2.generate(input_ids, max_length = MAX_LEN) | |
print("Output:\n" + 100 * '-') | |
print(tokenizer.decode(greedy_output[0], skip_special_tokens = True)) | |
else: | |
st.write(' ') | |
# print("Output:\n" + 100 * '-') | |
# print(tokenizer.decode(sample_output[0], skip_special_tokens = True), '...') |