File size: 6,934 Bytes
013fb26
 
f42affe
013fb26
 
 
 
 
 
 
2dce12e
 
013fb26
 
ef797fb
2dce12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
013fb26
 
 
3b7d642
013fb26
 
 
 
2dce12e
013fb26
979a861
013fb26
2dce12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
013fb26
2dce12e
013fb26
 
2dce12e
 
 
 
 
 
 
 
 
979a861
2dce12e
013fb26
 
2dce12e
013fb26
 
 
2dce12e
 
 
 
013fb26
2dce12e
d55c571
2dce12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
013fb26
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import re
from unittest import result
import string
import streamlit as st
import torch
from torch.nn import functional as F
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
                          AutoModelForSeq2SeqLM,
                          AutoModelForSequenceClassification, AutoTokenizer,
                          GPT2Tokenizer, LogitsProcessor, LogitsProcessorList,
                          pipeline, top_k_top_p_filtering, PhrasalConstraint, DisjunctiveConstraint)
import ast




class ModifyLogitsProcessor(LogitsProcessor):
    ### Anything with the letter "e" in it
    def __init__(self, tokenizer, chars_to_modify, filter_mode=True):
        super().__init__()
        self.tokenizer = tokenizer
        self.filter_mode = filter_mode
        self.chars_to_modify = chars_to_modify

        # Compute the tokens to modify at initialization
        self.tokens_to_modify = {}
        for char, factor in chars_to_modify.items():
            mod_tokens = [token_id for token_id, token in enumerate(self.tokenizer.get_vocab()) if char in token]
            self.tokens_to_modify[char] = mod_tokens

    def __call__(self, input_ids, scores):
        for char, tokens in self.tokens_to_modify.items():
            if self.filter_mode:
                scores[:, tokens] = -float('inf')
            else:
                # Fetch the corresponding factor from chars_to_modify dictionary
                factor = self.chars_to_modify[char]
                scores[:, tokens] += factor
        return scores


st.set_page_config(page_title="Gadsby")
st.title("Gadsby - Constrained Text Generation with Transformers")
st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg")
st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby_(novel)")



form = st.sidebar.form("choose_settings")
form.header("Model Settings")

model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "facebook/opt-1.3b")
form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
percision = form.selectbox("What percision are we loading the model with?", ["8bit", "16bit", "32bit"], )
form.caption("The lower the percision, the less ram the model takes and the faster it runs, but the quality is reduced")

form.header("Token Level Constraint Settings")
form.subheader("Lipogram Constraint")
form.caption("Lipograms are compositions where a certain letter or certain letters of the alphabet are omitted or discouraged")
filter_mode = form.checkbox("Filter Mode?", value=False)
form.caption("Enabling filter mode sets all selected tokens probabilities to negative infinity")
naughty_strings_list = form.text_input('Enter letters or words to filter or modify the probabilities of (comma separated):', value = "that,e")
factor_input = form.text_input('Enter corresponding factors to add to the logits (comma separated, ignored if in filter mode):', value = "5,-99")

form.header("Sequence Level Constraint Settings")
form.header("Phrasal Constraint")
force_word = form.text_input("Enter a word or sentence that is guaranteed to appear in the output", value = "lipogram")
form.header("Disjunctive Constraint")
force_flexible_input = form.text_input('Enter a list of words or sentences that the model must include at least one item from (in Python list format)', '["constraint", "banana"]')

if force_flexible_input:
    try:
        force_flexible = ast.literal_eval(force_flexible_input)
    except Exception as e:
        st.write('Failed to parse the list. Please check your input.')
        st.write('Error:', e)
        force_flexible = []
else:
    pass


if naughty_strings_list:
    chars = naughty_strings_list.split(',')
    factors = list(map(float, factor_input.split(',')))
    chars_to_modify = dict(zip(chars, factors))
else:
    chars = ""
    factors = []
    chars_to_modify = {}

generate_args = st.text_input('model.generate() arguments (in python dictionary format) ', '{"max_new_tokens": 50, "min_new_tokens": 50, "temperature": 2.0, "num_return_sequences": 1, "do_sample": False, "num_beams": 2, "repetition_penalty": 3.0}')
st.caption("For more details on what these settings mean and a complete list of all settings, see here: https://huggingface.co/blog/how-to-generate and https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig and https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationMixin.generate")


sequence = st.text_area("Enter a custom prompt", value = "Tell me about ")

form.form_submit_button("Generate some Constrained Text!")

def parse_generate_args(args_str):
    args_list = args_str.split(',')
    args_dict = {arg.split(':')[0]: int(arg.split(':')[1]) for arg in args_list if len(arg.split(':')) == 2}
    return args_dict

@st.cache_resource
def load_the_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = False)
    return tokenizer

@st.cache_resource
def load_the_model(percision):
    if percision == "32bit":
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False)
    elif percision =="16bit":
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False, torch_dtype=torch.float16)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=True)
    return model

if len(chars) != len(factors):
    st.write("Please ensure that the number of characters matches the number of factors.")
else:
    model = load_the_model(percision)
    tokenizer = load_the_tokenizer()
    constraints = []
    if force_word:
        constraints.append(PhrasalConstraint(
        tokenizer(force_word, add_special_tokens=False).input_ids
    ))
    if force_flexible_input:
        constraints.append(DisjunctiveConstraint(
        tokenizer(force_flexible, add_special_tokens=False).input_ids
    ))
    if filter_mode:
        logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=True)])
    else:
        logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=False)])
    input_ids = tokenizer.encode(sequence, return_tensors="pt").to('cuda')
    generate_kwargs = ast.literal_eval(generate_args)
    if constraints:
        output_ids = model.generate(input_ids, constraints=constraints, logits_processor=logits_processor, **generate_kwargs)
    else:
        output_ids = model.generate(input_ids, logits_processor=logits_processor, **generate_kwargs)
    st.write("GENERATED SEQUENCE(s): ")
    for output in output_ids:
        st.write(tokenizer.decode(output, skip_special_tokens = True, clean_up_tokenization_spaces = True))