Gadsby / Text-Generation.py
Hellisotherpeople's picture
Update Text-Generation.py
979a861
raw
history blame
6.93 kB
import re
from unittest import result
import string
import streamlit as st
import torch
from torch.nn import functional as F
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoTokenizer,
GPT2Tokenizer, LogitsProcessor, LogitsProcessorList,
pipeline, top_k_top_p_filtering, PhrasalConstraint, DisjunctiveConstraint)
import ast
class ModifyLogitsProcessor(LogitsProcessor):
### Anything with the letter "e" in it
def __init__(self, tokenizer, chars_to_modify, filter_mode=True):
super().__init__()
self.tokenizer = tokenizer
self.filter_mode = filter_mode
self.chars_to_modify = chars_to_modify
# Compute the tokens to modify at initialization
self.tokens_to_modify = {}
for char, factor in chars_to_modify.items():
mod_tokens = [token_id for token_id, token in enumerate(self.tokenizer.get_vocab()) if char in token]
self.tokens_to_modify[char] = mod_tokens
def __call__(self, input_ids, scores):
for char, tokens in self.tokens_to_modify.items():
if self.filter_mode:
scores[:, tokens] = -float('inf')
else:
# Fetch the corresponding factor from chars_to_modify dictionary
factor = self.chars_to_modify[char]
scores[:, tokens] += factor
return scores
st.set_page_config(page_title="Gadsby")
st.title("Gadsby - Constrained Text Generation with Transformers")
st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg")
st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby_(novel)")
form = st.sidebar.form("choose_settings")
form.header("Model Settings")
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "facebook/opt-1.3b")
form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
percision = form.selectbox("What percision are we loading the model with?", ["8bit", "16bit", "32bit"], )
form.caption("The lower the percision, the less ram the model takes and the faster it runs, but the quality is reduced")
form.header("Token Level Constraint Settings")
form.subheader("Lipogram Constraint")
form.caption("Lipograms are compositions where a certain letter or certain letters of the alphabet are omitted or discouraged")
filter_mode = form.checkbox("Filter Mode?", value=False)
form.caption("Enabling filter mode sets all selected tokens probabilities to negative infinity")
naughty_strings_list = form.text_input('Enter letters or words to filter or modify the probabilities of (comma separated):', value = "that,e")
factor_input = form.text_input('Enter corresponding factors to add to the logits (comma separated, ignored if in filter mode):', value = "5,-99")
form.header("Sequence Level Constraint Settings")
form.header("Phrasal Constraint")
force_word = form.text_input("Enter a word or sentence that is guaranteed to appear in the output", value = "lipogram")
form.header("Disjunctive Constraint")
force_flexible_input = form.text_input('Enter a list of words or sentences that the model must include at least one item from (in Python list format)', '["constraint", "banana"]')
if force_flexible_input:
try:
force_flexible = ast.literal_eval(force_flexible_input)
except Exception as e:
st.write('Failed to parse the list. Please check your input.')
st.write('Error:', e)
force_flexible = []
else:
pass
if naughty_strings_list:
chars = naughty_strings_list.split(',')
factors = list(map(float, factor_input.split(',')))
chars_to_modify = dict(zip(chars, factors))
else:
chars = ""
factors = []
chars_to_modify = {}
generate_args = st.text_input('model.generate() arguments (in python dictionary format) ', '{"max_new_tokens": 50, "min_new_tokens": 50, "temperature": 2.0, "num_return_sequences": 1, "do_sample": False, "num_beams": 2, "repetition_penalty": 3.0}')
st.caption("For more details on what these settings mean and a complete list of all settings, see here: https://huggingface.co/blog/how-to-generate and https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig and https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationMixin.generate")
sequence = st.text_area("Enter a custom prompt", value = "Tell me about ")
form.form_submit_button("Generate some Constrained Text!")
def parse_generate_args(args_str):
args_list = args_str.split(',')
args_dict = {arg.split(':')[0]: int(arg.split(':')[1]) for arg in args_list if len(arg.split(':')) == 2}
return args_dict
@st.cache_resource
def load_the_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = False)
return tokenizer
@st.cache_resource
def load_the_model(percision):
if percision == "32bit":
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False)
elif percision =="16bit":
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False, torch_dtype=torch.float16)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=True)
return model
if len(chars) != len(factors):
st.write("Please ensure that the number of characters matches the number of factors.")
else:
model = load_the_model(percision)
tokenizer = load_the_tokenizer()
constraints = []
if force_word:
constraints.append(PhrasalConstraint(
tokenizer(force_word, add_special_tokens=False).input_ids
))
if force_flexible_input:
constraints.append(DisjunctiveConstraint(
tokenizer(force_flexible, add_special_tokens=False).input_ids
))
if filter_mode:
logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=True)])
else:
logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=False)])
input_ids = tokenizer.encode(sequence, return_tensors="pt").to('cuda')
generate_kwargs = ast.literal_eval(generate_args)
if constraints:
output_ids = model.generate(input_ids, constraints=constraints, logits_processor=logits_processor, **generate_kwargs)
else:
output_ids = model.generate(input_ids, logits_processor=logits_processor, **generate_kwargs)
st.write("GENERATED SEQUENCE(s): ")
for output in output_ids:
st.write(tokenizer.decode(output, skip_special_tokens = True, clean_up_tokenization_spaces = True))