|
import streamlit as st |
|
import csv |
|
|
|
|
|
st.set_page_config(page_title="Reassuring Parables") |
|
|
|
st.title("Reassuring Parables generator - by Allen Roush") |
|
st.caption("Find me on Linkedin: https://www.linkedin.com/in/allen-roush-27721011b/") |
|
|
|
st.image("https://imgs.xkcd.com/comics/reassuring.png") |
|
st.caption("From https://xkcd.com/1263/") |
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
source_text = ["Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never", |
|
"Computers will never",] |
|
|
|
|
|
|
|
|
|
target_text = ["Computers will never understand a sonnet", |
|
"Computers will never enjoy a salad", |
|
"Computers will never know how to love", |
|
"Computers will never know how to smell", |
|
"Computers will never have a sense of being", |
|
"Computers will never feel", |
|
"Computers will never appreciate art", |
|
"Computers will never have good manners", |
|
"Computers will never understand god", |
|
"Computers will never solve the halting problem", |
|
"Computers will never be conscious", |
|
"Computers will never prove that they aren't P-zombies", |
|
"Computers will never replace the human brain", |
|
"Computers will never write better reassuring parables than humans" |
|
"Computers will never replace humans"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(): |
|
model.train(train_df=train_df, |
|
eval_df=eval_df, |
|
source_max_token_len = 512, |
|
target_max_token_len = 128, |
|
batch_size = 1, |
|
max_epochs = 4, |
|
use_gpu = True, |
|
outputdir = "/home/lain/lain/CX_DB8/outputs", |
|
early_stopping_patience_epochs = 0, |
|
precision = 32 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.spinner("Please wait while the model loads:"): |
|
tokenizer = AutoTokenizer.from_pretrained("Hellisotherpeople/T5_Reassuring_Parables") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("Hellisotherpeople/T5_Reassuring_Parables") |
|
|
|
form = st.sidebar.form("choose_settings") |
|
|
|
form.header("Main Settings") |
|
|
|
number_of_parables = form.number_input("Select how many reassuring parables you want to generate", value = 20, max_value = 1000) |
|
max_length_of_parable = form.number_input("What's the max length of the parable?", value = 20, max_value = 128) |
|
min_length_of_parable = form.number_input("What's the min length of the parable?", value = 0, max_value = max_length_of_parable) |
|
top_k = form.number_input("What value of K should we use for Top-K sampling? Set to zero to disable", value = 50) |
|
form.caption("In Top-K sampling, the K most likely next words are filtered and the probability mass is redistributed among only those K next words. ") |
|
top_p = form.number_input("What value of P should we use for Top-p sampling? Set to zero to disable", value = 0.95, max_value = 1.0, min_value = 0.0) |
|
form.caption("Top-p sampling chooses from the smallest possible set of words whose cumulative probability exceeds the probability p. The probability mass is then redistributed among this set of words.") |
|
temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 1.05, min_value = 0.0) |
|
form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words") |
|
form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate") |
|
form.form_submit_button("Generate some Reassuring Parables!") |
|
|
|
|
|
|
|
|
|
|
|
with st.spinner("Generating Reassuring Parables"): |
|
input_ids = tokenizer.encode("Computers will never", return_tensors='pt') |
|
|
|
sample_outputs = model.generate( |
|
input_ids, |
|
do_sample=True, |
|
max_length=max_length_of_parable, |
|
min_length=min_length_of_parable, |
|
top_k=top_k, |
|
top_p=top_p, |
|
num_return_sequences=number_of_parables, |
|
temperature=temperature |
|
) |
|
|
|
|
|
list_of_parables = [] |
|
for i, sample_output in enumerate(sample_outputs): |
|
list_of_parables.append(tokenizer.decode(sample_output, skip_special_tokens=True)) |
|
st.write(list_of_parables) |
|
|
|
|