Gadsby / app.py
Hellisotherpeople's picture
Update app.py
37a4549
raw
history blame
5.44 kB
import re
from unittest import result
import streamlit as st
import torch
from torch.nn import functional as F
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoTokenizer,
GPT2Tokenizer, LogitsProcessor, LogitsProcessorList,
pipeline, top_k_top_p_filtering)
st.set_page_config(page_title="Gadsby")
st.title("Gadsby - Constrained Text Generation with Transformers")
st.caption("By Allen Roush")
st.caption("Find me on Linkedin: https://www.linkedin.com/in/allen-roush-27721011b/")
st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg")
st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby_(novel)")
form = st.sidebar.form("choose_settings")
form.header("Main Settings")
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "gpt2")
form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
mode = form.selectbox("What kind of constrained generation are we doing?", ["lipogram", "reverse_lipogram", "length_constrained", "greater_than_length"])
form.caption("Lipograms mean that a letter (or substring) is not allowed in the generated string, reverse lipograms force a letter to be in the generated string")
if mode == "lipogram":
naughty_strings_list = st.text_area("Enter the list of strings that you don't want in each word seperated by a space", value = "E e")
naughty_strings = naughty_strings_list.split(" ")
elif mode == "reverse_lipogram":
nice_strings_list = st.text_area("Enter the list of strings that you DO want in each word seperated by a space", value = "t T")
nice_strings = nice_strings_list.split(" ")
else:
length_constraint = form.number_input("Enter the length should each word be restricted to (or greater/less than)", value = 5) + 1
length = form.number_input("Select how long you want the generated text to be", value = 100)
number_of_tokens_to_sample = form.number_input("Select how many tokens we want to search through when we do the filtering", value = 1000)
form.caption("Settings this to higher numbers will improve the experience but will cause generating to slow. Low numbers may cause lots of blank or failed generations")
temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 1.05, min_value = 0.0)
form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words")
form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate")
sequence = st.text_area("Enter a custom prompt", value = "I do ")
form.form_submit_button("Generate some Constrained Text!")
with st.spinner("Please wait while the model loads:"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def isPalindrome(s):
return s == s[::-1]
def get_next_word_without_e(input_sequence):
input_ids = tokenizer.encode(sequence, return_tensors="pt")
# get logits of last hidden state
next_token_candidates_logits = model(input_ids)[0][:, -1, :]
if temperature != 1.0:
next_token_candidates_logits = next_token_candidates_logits / temperature
# filter
filtered_next_token_candidates_logits = top_k_top_p_filtering(next_token_candidates_logits, top_k=number_of_tokens_to_sample, top_p=number_of_tokens_to_sample)
# sample and get a probability distribution
probs = F.softmax(filtered_next_token_candidates_logits, dim=-1)
next_token_candidates = torch.multinomial(probs, num_samples=number_of_tokens_to_sample) ## 10000 random samples
word_list = []
for candidate_string in next_token_candidates:
for candidate in candidate_string:
resulting_string = tokenizer.decode(candidate) #skip_special_tokens=True, clean_up_tokenization_spaces=True)
###Constrained text generation starts HERE
##Lipogram - No naughty strings used
if mode == "lipogram":
if all(nauty_string not in resulting_string for nauty_string in naughty_strings): ## This returns at the first naughty strings
return resulting_string
##Reverse-Lipogram - Must use things in nice_strings
elif mode == "reverse_lipogram":
if any(nice_string in resulting_string for nice_string in nice_strings):
return resulting_string
##Length constraints
elif mode == "length_constrained":
##Seems reliable if length is greater than 4
if len(resulting_string) == length_constraint:
return resulting_string
elif mode == "greater_than_length":
##Only sort of works
if len(resulting_string) >= length_constraint:
return resulting_string
return " "
i = length
while i > 0:
new_word = get_next_word_without_e(input_sequence= sequence)
sequence = sequence + new_word
i = i-1
st.write("GENERATED SEQUENCE: ")
st.write(sequence)