File size: 9,745 Bytes
0a604e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53a8589
0a604e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53a8589
0a604e2
 
53a8589
0a604e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import re
from unittest import result
import string
import streamlit as st
import torch
from torch.nn import functional as F
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
                          AutoModelForSeq2SeqLM,
                          AutoModelForSequenceClassification, AutoTokenizer,
                          GPT2Tokenizer, LogitsProcessor, LogitsProcessorList,
                          pipeline, top_k_top_p_filtering)



st.set_page_config(page_title="Gadsby")
st.title("Gadsby - Constrained Text G̶e̶n̶e̶r̶a̶t̶i̶o̶n̶  to Text with Transformers")
st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg")
st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby_(novel)")



form = st.sidebar.form("choose_settings")
form.header("Main Settings")

model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text-to-Text", value = "google/pegasus-cnn_dailymail")
form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
mode = form.selectbox("What kind of constrained generation are we doing?", ["lipogram", "reverse_lipogram", "e-prime", "rhopalism", "length_constrained", "greater_than_length", "Pangram", "rhopalism-lipogram"])
form.caption("Lipograms mean that a letter (or substring) is not allowed in the generated string, reverse lipograms force a letter to be in the generated string")

if mode == "lipogram":
    naughty_strings_list = st.text_area("Enter the list of strings that you don't want in each word seperated by a space", value = "E e")
    naughty_strings = naughty_strings_list.split(" ")
elif mode == "e-prime":
    e_prime_string = """be being been am is isn't are aren't was wasn't were weren't i'm you're we're they're he's she's it's there's here's where's how's what's who's that's aint isnt arent wasnt werent im youre were theyre hes shes its theres heres wheres hows whats whos thats aint Be Being Been Am Is Isn't Are Aren't Was Wasn't Were Weren't I'm You're We're They're He's She's It's There's Here's Where's How's What's Who's That's Aint Isnt Arent Wasnt Werent Im Youre Were Theyre Hes Shes Its Theres Heres Wheres Hows Whats Whos Thats Aint BE BEING BEEN AM IS ISN'T ARE AREN'T WAS WASN'T WERE WEREN'T I'M YOU'RE WE'RE THEY'RE HE'S SHE'S IT'S THERE'S HERE'S WHERE'S HOW'S WHAT'S WHO'S THAT'S AINT ISNT ARENT WASNT WERENT IM YOURE WERE THEYRE HES SHES ITS THERES HERES WHERES HOWS WHATS WHOS THATS AINT"""
    st.caption("The default word list is the list needed to enforce the language model to generate english without usage of the verb to be")
    naughty_strings_list = st.text_area("Enter the list of strings that you don't want to be generated (exact match)", value = e_prime_string)
    naughty_strings = naughty_strings_list.split(" ")    
elif mode == "reverse_lipogram":
    nice_strings_list = st.text_area("Enter the list of strings that you DO want in each word seperated by a space", value = "t T")
    nice_strings = nice_strings_list.split(" ")
elif mode == "rhopalism":
    length_constraint = form.number_input("Enter the length that the Rhopalism shoud start with", value = 1)
    st.caption("Rhopalisms are usually reliable but sometimes you need to try generating two or three times for a perfect one")
elif mode == "rhopalism-lipogram":
    naughty_strings_list = st.text_area("Enter the list of strings that you don't want in each word seperated by a space", value = "E e")
    naughty_strings = naughty_strings_list.split(" ")
    length_constraint = form.number_input("Enter the length that the Rhopalism shoud start with", value = 1)
    st.caption("Rhopalisms are usually reliable but sometimes you need to try generating two or three times for a perfect one")
else:
    length_constraint = form.number_input("Enter the length should each word be restricted to (or greater/less than)", value = 5) + 1


length = form.number_input("Select how long you want the generated text to be", value = 100)
number_of_tokens_to_sample = form.number_input("Select how many tokens we want to search through when we do the filtering", value = 25000)
form.caption("Settings this to higher numbers will improve the experience but will cause generating to slow. Low numbers may cause lots of blank or failed generations")
temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 0.10, min_value = 0.0)
form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words")
form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate")


sequence = st.text_area("Enter a custom prompt", value = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.")
decoded_sequence = ""

form.form_submit_button("Generate some Constrained Text!")


with st.spinner("Please wait while the model loads:"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    model.config.pad_token_id = model.config.eos_token_id

def isPalindrome(s):
    return s == s[::-1]


if mode == "rhopalism" or mode == "rhopalism-lipogram":
    rhopalism_len = length_constraint



nice_strings_pangram = list(string.ascii_lowercase)

decoder_input_ids = tokenizer.encode("<pad>", add_special_tokens=False, return_tensors="pt")

def get_next_word_without_e():
    input_ids = tokenizer.encode(sequence, return_tensors="pt")
    # get logits of last hidden state

    next_token_candidates_logits = model(input_ids = input_ids, decoder_input_ids = decoder_input_ids)[0][:, -1, :]
    if temperature != 1.0:
        next_token_candidates_logits = next_token_candidates_logits / temperature
    # filter
    filtered_next_token_candidates_logits = top_k_top_p_filtering(next_token_candidates_logits, top_k=int(number_of_tokens_to_sample), top_p=int(number_of_tokens_to_sample))
    # sample and get a probability distribution
    probs = F.softmax(filtered_next_token_candidates_logits, dim=-1)
    next_token_candidates = torch.multinomial(probs, num_samples=int(number_of_tokens_to_sample)) ## 10000 random samples
    word_list = []
    for candidate_string in next_token_candidates:
        for candidate in candidate_string:
            resulting_string = tokenizer.decode(candidate, skip_special_tokens=True)# clean_up_tokenization_spaces=True)
            ###Constrained text generation starts HERE
            ##Lipogram - No naughty strings used
            if mode == "lipogram" or mode == "e-prime":
                if all(nauty_string not in resulting_string for nauty_string in naughty_strings): ## This returns at the first naughty strings
                    return resulting_string, candidate
            ##Reverse-Lipogram - Must use things in nice_strings
            elif mode == "reverse_lipogram":
                if any(nice_string in resulting_string for nice_string in nice_strings):
                    return resulting_string, candidate
            ##Length constraints
            elif mode == "length_constrained":
                ##Seems reliable if length is greater than 4
                if len(resulting_string) == length_constraint:
                    return resulting_string, candidate
            elif mode == "greater_than_length":
                ##Only sort of works 
                if len(resulting_string) >= length_constraint:
                    return resulting_string, candidate
            elif mode == "rhopalism":
                ##Mostly works
                if len(resulting_string) == rhopalism_len:
                    return resulting_string, candidate
            elif mode == "Pangram":
                if any(c in nice_strings_pangram for c in resulting_string):
                    return resulting_string, candidate
            elif mode == "rhopalism-lipogram":
                if len(resulting_string) == rhopalism_len:
                    if all(nauty_string not in resulting_string for nauty_string in naughty_strings):
                        return resulting_string, candidate


                
    return " "


new_sequence = ""

j = 0
i = length
while i > 0:
    new_word, new_candidate = get_next_word_without_e()
    decoder_input_ids = torch.cat([decoder_input_ids, new_candidate.view(1, -1)], axis=-1)
    if new_word.endswith(" "):
        new_sequence = new_sequence + new_word
    else:
        new_sequence = new_sequence + new_word + " "
    if mode == "rhopalism" or mode == "rhopalism-lipogram":
        rhopalism_len += 1
    i = i-1
    if mode == "Pangram":
        for character in sequence:
            if character in nice_strings_pangram:
                nice_strings_pangram.remove(character)
    j += 1 

st.write("GENERATED SEQUENCE: ")
#st.write(new_sequence)
st.write(tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True))
#st.write(nice_strings_pangram)