Hellisotherpeople commited on
Commit
2dce12e
·
1 Parent(s): efa9e89

Upload Text-Generation.py

Browse files
Files changed (1) hide show
  1. Text-Generation.py +113 -130
Text-Generation.py CHANGED
@@ -8,10 +8,37 @@ from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
8
  AutoModelForSeq2SeqLM,
9
  AutoModelForSequenceClassification, AutoTokenizer,
10
  GPT2Tokenizer, LogitsProcessor, LogitsProcessorList,
11
- pipeline, top_k_top_p_filtering)
 
12
 
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  st.set_page_config(page_title="Gadsby")
16
  st.title("Gadsby - Constrained Text Generation with Transformers")
17
  st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg")
@@ -20,146 +47,102 @@ st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby
20
 
21
 
22
  form = st.sidebar.form("choose_settings")
23
- form.header("Main Settings")
24
 
25
- model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "gpt2")
26
  form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
27
- mode = form.selectbox("What kind of constrained generation are we doing?", ["lipogram", "reverse_lipogram", "e-prime", "rhopalism", "length_constrained", "greater_than_length", "Pangram", "rhopalism-lipogram"])
28
- form.caption("Lipograms mean that a letter (or substring) is not allowed in the generated string, reverse lipograms force a letter to be in the generated string")
29
-
30
- if mode == "lipogram":
31
- naughty_strings_list = st.text_area("Enter the list of strings that you don't want in each word seperated by a space", value = "E e")
32
- naughty_strings = naughty_strings_list.split(" ")
33
- elif mode == "e-prime":
34
- e_prime_string = """be being been am is isn't are aren't was wasn't were weren't i'm you're we're they're he's she's it's there's here's where's how's what's who's that's aint isnt arent wasnt werent im youre were theyre hes shes its theres heres wheres hows whats whos thats aint Be Being Been Am Is Isn't Are Aren't Was Wasn't Were Weren't I'm You're We're They're He's She's It's There's Here's Where's How's What's Who's That's Aint Isnt Arent Wasnt Werent Im Youre Were Theyre Hes Shes Its Theres Heres Wheres Hows Whats Whos Thats Aint BE BEING BEEN AM IS ISN'T ARE AREN'T WAS WASN'T WERE WEREN'T I'M YOU'RE WE'RE THEY'RE HE'S SHE'S IT'S THERE'S HERE'S WHERE'S HOW'S WHAT'S WHO'S THAT'S AINT ISNT ARENT WASNT WERENT IM YOURE WERE THEYRE HES SHES ITS THERES HERES WHERES HOWS WHATS WHOS THATS AINT"""
35
- st.caption("The default word list is the list needed to enforce the language model to generate english without usage of the verb to be")
36
- naughty_strings_list = st.text_area("Enter the list of strings that you don't want to be generated (exact match)", value = e_prime_string)
37
- naughty_strings = naughty_strings_list.split(" ")
38
- elif mode == "reverse_lipogram":
39
- nice_strings_list = st.text_area("Enter the list of strings that you DO want in each word seperated by a space", value = "t T")
40
- nice_strings = nice_strings_list.split(" ")
41
- elif mode == "rhopalism":
42
- length_constraint = form.number_input("Enter the length that the Rhopalism shoud start with", value = 1)
43
- st.caption("Rhopalisms are usually reliable but sometimes you need to try generating two or three times for a perfect one")
44
- elif mode == "rhopalism-lipogram":
45
- naughty_strings_list = st.text_area("Enter the list of strings that you don't want in each word seperated by a space", value = "E e")
46
- naughty_strings = naughty_strings_list.split(" ")
47
- length_constraint = form.number_input("Enter the length that the Rhopalism shoud start with", value = 1)
48
- st.caption("Rhopalisms are usually reliable but sometimes you need to try generating two or three times for a perfect one")
 
 
49
  else:
50
- length_constraint = form.number_input("Enter the length should each word be restricted to (or greater/less than)", value = 5) + 1
51
 
52
 
53
- length = form.number_input("Select how long you want the generated text to be", value = 100)
54
- number_of_tokens_to_sample = form.number_input("Select how many tokens we want to search through when we do the filtering", value = 25000)
55
- form.caption("Settings this to higher numbers will improve the experience but will cause generating to slow. Low numbers may cause lots of blank or failed generations")
56
- temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 0.90, min_value = 0.0)
57
- form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words")
58
- form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate")
 
 
 
 
 
59
 
60
 
61
- sequence = st.text_area("Enter a custom prompt", value = "I do ")
62
 
63
  form.form_submit_button("Generate some Constrained Text!")
64
 
 
 
 
 
65
 
66
- @st.cache(allow_output_mutation=True)
67
  def load_the_tokenizer():
68
- tokenizer = AutoTokenizer.from_pretrained(model_name)
69
- return tokenizer
70
-
71
- @st.cache(allow_output_mutation=True)
72
- def load_the_model():
73
- model = AutoModelForCausalLM.from_pretrained(model_name)
74
- return model
75
-
76
-
77
- model = load_the_model()
78
- tokenizer = load_the_tokenizer()
79
-
80
-
81
-
82
- def isPalindrome(s):
83
- return s == s[::-1]
84
-
85
-
86
- if mode == "rhopalism" or mode == "rhopalism-lipogram":
87
- rhopalism_len = length_constraint
88
-
89
-
90
-
91
- nice_strings_pangram = list(string.ascii_lowercase)
92
-
93
-
94
-
95
- def get_next_word_without_e(input_sequence):
96
- input_ids = tokenizer.encode(sequence, return_tensors="pt")
97
- # get logits of last hidden state
98
- next_token_candidates_logits = model(input_ids)[0][:, -1, :]
99
- if temperature != 1.0:
100
- next_token_candidates_logits = next_token_candidates_logits / temperature
101
- # filter
102
- filtered_next_token_candidates_logits = top_k_top_p_filtering(next_token_candidates_logits, top_k=int(number_of_tokens_to_sample), top_p=int(number_of_tokens_to_sample))
103
- # sample and get a probability distribution
104
- probs = F.softmax(filtered_next_token_candidates_logits, dim=-1)
105
- next_token_candidates = torch.multinomial(probs, num_samples=int(number_of_tokens_to_sample)) ## 10000 random samples
106
- word_list = []
107
- for candidate_string in next_token_candidates:
108
- for candidate in candidate_string:
109
- resulting_string = tokenizer.decode(candidate) #skip_special_tokens=True, clean_up_tokenization_spaces=True)
110
- ###Constrained text generation starts HERE
111
- ##Lipogram - No naughty strings used
112
- if mode == "lipogram" or mode == "e-prime":
113
- if all(nauty_string not in resulting_string for nauty_string in naughty_strings): ## This returns at the first naughty strings
114
- return resulting_string
115
- ##Reverse-Lipogram - Must use things in nice_strings
116
- elif mode == "reverse_lipogram":
117
- if any(nice_string in resulting_string for nice_string in nice_strings):
118
- return resulting_string
119
- ##Length constraints
120
- elif mode == "length_constrained":
121
- ##Seems reliable if length is greater than 4
122
- if len(resulting_string) == length_constraint:
123
- return resulting_string
124
- elif mode == "greater_than_length":
125
- ##Only sort of works
126
- if len(resulting_string) >= length_constraint:
127
- return resulting_string
128
- elif mode == "rhopalism":
129
- ##Mostly works
130
- if len(resulting_string) == rhopalism_len:
131
- return resulting_string
132
- elif mode == "Pangram":
133
- if any(c in nice_strings_pangram for c in resulting_string):
134
- return resulting_string
135
- elif mode == "rhopalism-lipogram":
136
- if len(resulting_string) == rhopalism_len:
137
- if all(nauty_string not in resulting_string for nauty_string in naughty_strings):
138
- return resulting_string
139
-
140
-
141
-
142
- return " "
143
-
144
-
145
-
146
-
147
- j = 0
148
- i = length
149
- while i > 0:
150
- new_word = get_next_word_without_e(input_sequence= sequence)
151
- sequence = sequence + new_word
152
- if mode == "rhopalism" or mode == "rhopalism-lipogram":
153
- rhopalism_len += 1
154
- i = i-1
155
- if mode == "Pangram":
156
- for character in sequence:
157
- if character in nice_strings_pangram:
158
- nice_strings_pangram.remove(character)
159
- j += 1
160
-
161
- st.write("GENERATED SEQUENCE: ")
162
- st.write(sequence)
163
- #st.write(nice_strings_pangram)
164
 
165
 
 
8
  AutoModelForSeq2SeqLM,
9
  AutoModelForSequenceClassification, AutoTokenizer,
10
  GPT2Tokenizer, LogitsProcessor, LogitsProcessorList,
11
+ pipeline, top_k_top_p_filtering, PhrasalConstraint, DisjunctiveConstraint)
12
+ import ast
13
 
14
 
15
 
16
+
17
+ class ModifyLogitsProcessor(LogitsProcessor):
18
+ ### Anything with the letter "e" in it
19
+ def __init__(self, tokenizer, chars_to_modify, filter_mode=True):
20
+ super().__init__()
21
+ self.tokenizer = tokenizer
22
+ self.filter_mode = filter_mode
23
+ self.chars_to_modify = chars_to_modify
24
+
25
+ # Compute the tokens to modify at initialization
26
+ self.tokens_to_modify = {}
27
+ for char, factor in chars_to_modify.items():
28
+ mod_tokens = [token_id for token_id, token in enumerate(self.tokenizer.get_vocab()) if char in token]
29
+ self.tokens_to_modify[char] = mod_tokens
30
+
31
+ def __call__(self, input_ids, scores):
32
+ for char, tokens in self.tokens_to_modify.items():
33
+ if self.filter_mode:
34
+ scores[:, tokens] = -float('inf')
35
+ else:
36
+ # Fetch the corresponding factor from chars_to_modify dictionary
37
+ factor = self.chars_to_modify[char]
38
+ scores[:, tokens] += factor
39
+ return scores
40
+
41
+
42
  st.set_page_config(page_title="Gadsby")
43
  st.title("Gadsby - Constrained Text Generation with Transformers")
44
  st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg")
 
47
 
48
 
49
  form = st.sidebar.form("choose_settings")
50
+ form.header("Model Settings")
51
 
52
+ model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "eachadea/vicuna-7b-1.1")
53
  form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
54
+ percision = form.selectbox("What percision are we loading the model with?", ["8bit", "16bit", "32bit"], )
55
+ form.caption("The lower the percision, the less ram the model takes and the faster it runs, but the quality is reduced")
56
+
57
+ form.header("Token Level Constraint Settings")
58
+ form.subheader("Lipogram Constraint")
59
+ form.caption("Lipograms are compositions where a certain letter or certain letters of the alphabet are omitted or discouraged")
60
+ filter_mode = form.checkbox("Filter Mode?", value=False)
61
+ form.caption("Enabling filter mode sets all selected tokens probabilities to negative infinity")
62
+ naughty_strings_list = form.text_input('Enter letters or words to filter or modify the probabilities of (comma separated):', value = "that,e")
63
+ factor_input = form.text_input('Enter corresponding factors to add to the logits (comma separated, ignored if in filter mode):', value = "5,-99")
64
+
65
+ form.header("Sequence Level Constraint Settings")
66
+ form.header("Phrasal Constraint")
67
+ force_word = form.text_input("Enter a word or sentence that is guaranteed to appear in the output", value = "lipogram")
68
+ form.header("Disjunctive Constraint")
69
+ force_flexible_input = form.text_input('Enter a list of words or sentences that the model must include at least one item from (in Python list format)', '["constraint", "banana"]')
70
+
71
+ if force_flexible_input:
72
+ try:
73
+ force_flexible = ast.literal_eval(force_flexible_input)
74
+ except Exception as e:
75
+ st.write('Failed to parse the list. Please check your input.')
76
+ st.write('Error:', e)
77
+ force_flexible = []
78
  else:
79
+ pass
80
 
81
 
82
+ if naughty_strings_list:
83
+ chars = naughty_strings_list.split(',')
84
+ factors = list(map(float, factor_input.split(',')))
85
+ chars_to_modify = dict(zip(chars, factors))
86
+ else:
87
+ chars = ""
88
+ factors = []
89
+ chars_to_modify = {}
90
+
91
+ generate_args = st.text_input('model.generate() arguments (in python dictionary format) ', '{"max_length": 50, "min_length": 50, "temperature": 2.0, "num_return_sequences": 1, "do_sample": False, "num_beams": 2, "repetition_penalty": 3.0}')
92
+ st.caption("For more details on what these settings mean and a complete list of all settings, see here: https://huggingface.co/blog/how-to-generate and https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig and https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationMixin.generate")
93
 
94
 
95
+ sequence = st.text_area("Enter a custom prompt", value = "Tell me about ")
96
 
97
  form.form_submit_button("Generate some Constrained Text!")
98
 
99
+ def parse_generate_args(args_str):
100
+ args_list = args_str.split(',')
101
+ args_dict = {arg.split(':')[0]: int(arg.split(':')[1]) for arg in args_list if len(arg.split(':')) == 2}
102
+ return args_dict
103
 
104
+ @st.cache_resource
105
  def load_the_tokenizer():
106
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = False)
107
+ return tokenizer
108
+
109
+ @st.cache_resource
110
+ def load_the_model(percision):
111
+ if percision == "32bit":
112
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False)
113
+ elif percision =="16bit":
114
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False, torch_dtype=torch.float16)
115
+ else:
116
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=True)
117
+ return model
118
+
119
+ if len(chars) != len(factors):
120
+ st.write("Please ensure that the number of characters matches the number of factors.")
121
+ else:
122
+ model = load_the_model(percision)
123
+ tokenizer = load_the_tokenizer()
124
+ constraints = []
125
+ if force_word:
126
+ constraints.append(PhrasalConstraint(
127
+ tokenizer(force_word, add_special_tokens=False).input_ids
128
+ ))
129
+ if force_flexible_input:
130
+ constraints.append(DisjunctiveConstraint(
131
+ tokenizer(force_flexible, add_special_tokens=False).input_ids
132
+ ))
133
+ if filter_mode:
134
+ logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=True)])
135
+ else:
136
+ logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=False)])
137
+ input_ids = tokenizer.encode(sequence, return_tensors="pt").to('cuda')
138
+ generate_kwargs = ast.literal_eval(generate_args)
139
+ if constraints:
140
+ output_ids = model.generate(input_ids, constraints=constraints, logits_processor=logits_processor, **generate_kwargs)
141
+ else:
142
+ output_ids = model.generate(input_ids, logits_processor=logits_processor, **generate_kwargs)
143
+ st.write("GENERATED SEQUENCE(s): ")
144
+ for output in output_ids:
145
+ st.write(tokenizer.decode(output, skip_special_tokens = True, clean_up_tokenization_spaces = True))
146
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148