Spaces:
Sleeping
Sleeping
import re | |
from unittest import result | |
import string | |
import streamlit as st | |
import torch | |
from torch.nn import functional as F | |
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering, | |
AutoModelForSeq2SeqLM, | |
AutoModelForSequenceClassification, AutoTokenizer, | |
GPT2Tokenizer, LogitsProcessor, LogitsProcessorList, | |
pipeline, top_k_top_p_filtering, PhrasalConstraint, DisjunctiveConstraint) | |
import ast | |
class ModifyLogitsProcessor(LogitsProcessor): | |
### Anything with the letter "e" in it | |
def __init__(self, tokenizer, chars_to_modify, filter_mode=True): | |
super().__init__() | |
self.tokenizer = tokenizer | |
self.filter_mode = filter_mode | |
self.chars_to_modify = chars_to_modify | |
# Compute the tokens to modify at initialization | |
self.tokens_to_modify = {} | |
for char, factor in chars_to_modify.items(): | |
mod_tokens = [token_id for token_id, token in enumerate(self.tokenizer.get_vocab()) if char in token] | |
self.tokens_to_modify[char] = mod_tokens | |
def __call__(self, input_ids, scores): | |
for char, tokens in self.tokens_to_modify.items(): | |
if self.filter_mode: | |
scores[:, tokens] = -float('inf') | |
else: | |
# Fetch the corresponding factor from chars_to_modify dictionary | |
factor = self.chars_to_modify[char] | |
scores[:, tokens] += factor | |
return scores | |
st.set_page_config(page_title="Gadsby") | |
st.title("Gadsby - Constrained Text Generation with Transformers") | |
st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg") | |
st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby_(novel)") | |
form = st.sidebar.form("choose_settings") | |
form.header("Model Settings") | |
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "facebook/opt-1.3b") | |
form.caption("This will download a new model, so it may take awhile or even break if the model is too large") | |
percision = form.selectbox("What percision are we loading the model with?", ["8bit", "16bit", "32bit"], ) | |
form.caption("The lower the percision, the less ram the model takes and the faster it runs, but the quality is reduced") | |
form.header("Token Level Constraint Settings") | |
form.subheader("Lipogram Constraint") | |
form.caption("Lipograms are compositions where a certain letter or certain letters of the alphabet are omitted or discouraged") | |
filter_mode = form.checkbox("Filter Mode?", value=False) | |
form.caption("Enabling filter mode sets all selected tokens probabilities to negative infinity") | |
naughty_strings_list = form.text_input('Enter letters or words to filter or modify the probabilities of (comma separated):', value = "that,e") | |
factor_input = form.text_input('Enter corresponding factors to add to the logits (comma separated, ignored if in filter mode):', value = "5,-99") | |
form.header("Sequence Level Constraint Settings") | |
form.header("Phrasal Constraint") | |
force_word = form.text_input("Enter a word or sentence that is guaranteed to appear in the output", value = "lipogram") | |
form.header("Disjunctive Constraint") | |
force_flexible_input = form.text_input('Enter a list of words or sentences that the model must include at least one item from (in Python list format)', '["constraint", "banana"]') | |
if force_flexible_input: | |
try: | |
force_flexible = ast.literal_eval(force_flexible_input) | |
except Exception as e: | |
st.write('Failed to parse the list. Please check your input.') | |
st.write('Error:', e) | |
force_flexible = [] | |
else: | |
pass | |
if naughty_strings_list: | |
chars = naughty_strings_list.split(',') | |
factors = list(map(float, factor_input.split(','))) | |
chars_to_modify = dict(zip(chars, factors)) | |
else: | |
chars = "" | |
factors = [] | |
chars_to_modify = {} | |
generate_args = st.text_input('model.generate() arguments (in python dictionary format) ', '{"max_new_tokens": 50, "min_new_tokens": 50, "temperature": 2.0, "num_return_sequences": 1, "do_sample": False, "num_beams": 2, "repetition_penalty": 3.0}') | |
st.caption("For more details on what these settings mean and a complete list of all settings, see here: https://huggingface.co/blog/how-to-generate and https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig and https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationMixin.generate") | |
sequence = st.text_area("Enter a custom prompt", value = "Tell me about ") | |
form.form_submit_button("Generate some Constrained Text!") | |
def parse_generate_args(args_str): | |
args_list = args_str.split(',') | |
args_dict = {arg.split(':')[0]: int(arg.split(':')[1]) for arg in args_list if len(arg.split(':')) == 2} | |
return args_dict | |
def load_the_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = False) | |
return tokenizer | |
def load_the_model(percision): | |
if percision == "32bit": | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False) | |
elif percision =="16bit": | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False, torch_dtype=torch.float16) | |
else: | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=True) | |
return model | |
if len(chars) != len(factors): | |
st.write("Please ensure that the number of characters matches the number of factors.") | |
else: | |
model = load_the_model(percision) | |
tokenizer = load_the_tokenizer() | |
constraints = [] | |
if force_word: | |
constraints.append(PhrasalConstraint( | |
tokenizer(force_word, add_special_tokens=False).input_ids | |
)) | |
if force_flexible_input: | |
constraints.append(DisjunctiveConstraint( | |
tokenizer(force_flexible, add_special_tokens=False).input_ids | |
)) | |
if filter_mode: | |
logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=True)]) | |
else: | |
logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=False)]) | |
input_ids = tokenizer.encode(sequence, return_tensors="pt").to('cuda') | |
generate_kwargs = ast.literal_eval(generate_args) | |
if constraints: | |
output_ids = model.generate(input_ids, constraints=constraints, logits_processor=logits_processor, **generate_kwargs) | |
else: | |
output_ids = model.generate(input_ids, logits_processor=logits_processor, **generate_kwargs) | |
st.write("GENERATED SEQUENCE(s): ") | |
for output in output_ids: | |
st.write(tokenizer.decode(output, skip_special_tokens = True, clean_up_tokenization_spaces = True)) | |