Rainsilves commited on
Commit
013fb26
β€’
1 Parent(s): 944943b

added Gadsby

Browse files
Files changed (2) hide show
  1. app.py +109 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from unittest import result
3
+
4
+ import streamlit as st
5
+ import torch
6
+ from torch.nn import functional as F
7
+ from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
8
+ AutoModelForSeq2SeqLM,
9
+ AutoModelForSequenceClassification, AutoTokenizer,
10
+ GPT2Tokenizer, LogitsProcessor, LogitsProcessorList,
11
+ pipeline, top_k_top_p_filtering)
12
+
13
+
14
+ st.set_page_config(page_title="Gadsby")
15
+ st.title("Gadsby - Constrained Text Generation with Transformers")
16
+ st.caption("By Allen Roush")
17
+ st.caption("Find me on Linkedin: https://www.linkedin.com/in/allen-roush-27721011b/")
18
+ st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg")
19
+ st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby_(novel)")
20
+
21
+
22
+
23
+ form = st.sidebar.form("choose_settings")
24
+ form.header("Main Settings")
25
+
26
+ model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "gpt2")
27
+ form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
28
+ mode = form.selectbox("What kind of constrained generation are we doing?", ["lipogram", "reverse_lipogram", "length_constrained", "greater_than_length"])
29
+ form.caption("Lipograms mean that a letter (or substring) is not allowed in the generated string, reverse lipograms force a letter to be in the generated string")
30
+
31
+ if mode == "lipogram":
32
+ naughty_strings_list = st.text_area("Enter the list of strings that you don't want in each word seperated by a space", value = "E e")
33
+ naughty_strings = naughty_strings_list.split(" ")
34
+ elif mode == "reverse_lipogram":
35
+ nice_strings_list = st.text_area("Enter the list of strings that you DO want in each word seperated by a space", value = "t T")
36
+ nice_strings = nice_strings_list.split(" ")
37
+ else:
38
+ length_constraint = form.number_input("Enter the length should each word be restricted to (or greater/less than)", value = 5) + 1
39
+
40
+
41
+ length = form.number_input("Select how long you want the generated text to be", value = 100)
42
+ number_of_tokens_to_sample = form.number_input("Select how many tokens we want to search through when we do the filtering", value = 1000)
43
+ form.caption("Settings this to higher numbers will improve the experience but will cause generating to slow. Low numbers may cause lots of blank or failed generations")
44
+ temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 1.05, min_value = 0.0)
45
+ form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words")
46
+ form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate")
47
+
48
+
49
+ sequence = st.text_area("Enter a custom prompt", value = "I don't want")
50
+
51
+ form.form_submit_button("Generate some Constrained Text!")
52
+
53
+
54
+ with st.spinner("Please wait while the model loads:"):
55
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
56
+ model = AutoModelForCausalLM.from_pretrained(model_name)
57
+
58
+
59
+ def isPalindrome(s):
60
+ return s == s[::-1]
61
+
62
+
63
+ def get_next_word_without_e(input_sequence):
64
+ input_ids = tokenizer.encode(sequence, return_tensors="pt")
65
+ # get logits of last hidden state
66
+ next_token_candidates_logits = model(input_ids)[0][:, -1, :]
67
+ if temperature != 1.0:
68
+ next_token_candidates_logits = next_token_candidates_logits / temperature
69
+ # filter
70
+ filtered_next_token_candidates_logits = top_k_top_p_filtering(next_token_candidates_logits, top_k=number_of_tokens_to_sample, top_p=number_of_tokens_to_sample)
71
+ # sample and get a probability distribution
72
+ probs = F.softmax(filtered_next_token_candidates_logits, dim=-1)
73
+ next_token_candidates = torch.multinomial(probs, num_samples=number_of_tokens_to_sample) ## 10000 random samples
74
+ word_list = []
75
+ for candidate_string in next_token_candidates:
76
+ for candidate in candidate_string:
77
+ resulting_string = tokenizer.decode(candidate) #skip_special_tokens=True, clean_up_tokenization_spaces=True)
78
+ ###Constrained text generation starts HERE
79
+ ##Lipogram - No naughty strings used
80
+ if mode == "lipogram":
81
+ if all(nauty_string not in resulting_string for nauty_string in naughty_strings): ## This returns at the first naughty strings
82
+ return resulting_string
83
+ ##Reverse-Lipogram - Must use things in nice_strings
84
+ elif mode == "reverse_lipogram":
85
+ if any(nice_string in resulting_string for nice_string in nice_strings):
86
+ return resulting_string
87
+ ##Length constraints
88
+ elif mode == "length_constrained":
89
+ ##Seems reliable if length is greater than 4
90
+ if len(resulting_string) == length_constraint:
91
+ return resulting_string
92
+ elif mode == "greater_than_length":
93
+ ##Only sort of works
94
+ if len(resulting_string) >= length_constraint:
95
+ return resulting_string
96
+ return " "
97
+
98
+
99
+
100
+ i = length
101
+ while i > 0:
102
+ new_word = get_next_word_without_e(input_sequence= sequence)
103
+ sequence = sequence + new_word
104
+ i = i-1
105
+
106
+ st.write("GENERATED SEQUENCE: ")
107
+ st.write(sequence)
108
+
109
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers