Spaces:
Sleeping
Sleeping
import re | |
from unittest import result | |
import streamlit as st | |
import torch | |
from torch.nn import functional as F | |
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering, | |
AutoModelForSeq2SeqLM, | |
AutoModelForSequenceClassification, AutoTokenizer, | |
GPT2Tokenizer, LogitsProcessor, LogitsProcessorList, | |
pipeline, top_k_top_p_filtering) | |
st.set_page_config(page_title="Gadsby") | |
st.title("Gadsby - Constrained Text Generation with Transformers") | |
st.caption("By Allen Roush") | |
st.caption("Find me on Linkedin: https://www.linkedin.com/in/allen-roush-27721011b/") | |
st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg") | |
st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby_(novel)") | |
form = st.sidebar.form("choose_settings") | |
form.header("Main Settings") | |
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "gpt2") | |
form.caption("This will download a new model, so it may take awhile or even break if the model is too large") | |
mode = form.selectbox("What kind of constrained generation are we doing?", ["lipogram", "reverse_lipogram", "length_constrained", "greater_than_length"]) | |
form.caption("Lipograms mean that a letter (or substring) is not allowed in the generated string, reverse lipograms force a letter to be in the generated string") | |
if mode == "lipogram": | |
naughty_strings_list = st.text_area("Enter the list of strings that you don't want in each word seperated by a space", value = "E e") | |
naughty_strings = naughty_strings_list.split(" ") | |
elif mode == "reverse_lipogram": | |
nice_strings_list = st.text_area("Enter the list of strings that you DO want in each word seperated by a space", value = "t T") | |
nice_strings = nice_strings_list.split(" ") | |
else: | |
length_constraint = form.number_input("Enter the length should each word be restricted to (or greater/less than)", value = 5) + 1 | |
length = form.number_input("Select how long you want the generated text to be", value = 100) | |
number_of_tokens_to_sample = form.number_input("Select how many tokens we want to search through when we do the filtering", value = 1000) | |
form.caption("Settings this to higher numbers will improve the experience but will cause generating to slow. Low numbers may cause lots of blank or failed generations") | |
temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 1.05, min_value = 0.0) | |
form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words") | |
form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate") | |
sequence = st.text_area("Enter a custom prompt", value = "I do ") | |
form.form_submit_button("Generate some Constrained Text!") | |
with st.spinner("Please wait while the model loads:"): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
def isPalindrome(s): | |
return s == s[::-1] | |
def get_next_word_without_e(input_sequence): | |
input_ids = tokenizer.encode(sequence, return_tensors="pt") | |
# get logits of last hidden state | |
next_token_candidates_logits = model(input_ids)[0][:, -1, :] | |
if temperature != 1.0: | |
next_token_candidates_logits = next_token_candidates_logits / temperature | |
# filter | |
filtered_next_token_candidates_logits = top_k_top_p_filtering(next_token_candidates_logits, top_k=number_of_tokens_to_sample, top_p=number_of_tokens_to_sample) | |
# sample and get a probability distribution | |
probs = F.softmax(filtered_next_token_candidates_logits, dim=-1) | |
next_token_candidates = torch.multinomial(probs, num_samples=number_of_tokens_to_sample) ## 10000 random samples | |
word_list = [] | |
for candidate_string in next_token_candidates: | |
for candidate in candidate_string: | |
resulting_string = tokenizer.decode(candidate) #skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
###Constrained text generation starts HERE | |
##Lipogram - No naughty strings used | |
if mode == "lipogram": | |
if all(nauty_string not in resulting_string for nauty_string in naughty_strings): ## This returns at the first naughty strings | |
return resulting_string | |
##Reverse-Lipogram - Must use things in nice_strings | |
elif mode == "reverse_lipogram": | |
if any(nice_string in resulting_string for nice_string in nice_strings): | |
return resulting_string | |
##Length constraints | |
elif mode == "length_constrained": | |
##Seems reliable if length is greater than 4 | |
if len(resulting_string) == length_constraint: | |
return resulting_string | |
elif mode == "greater_than_length": | |
##Only sort of works | |
if len(resulting_string) >= length_constraint: | |
return resulting_string | |
return " " | |
i = length | |
while i > 0: | |
new_word = get_next_word_without_e(input_sequence= sequence) | |
sequence = sequence + new_word | |
i = i-1 | |
st.write("GENERATED SEQUENCE: ") | |
st.write(sequence) | |