File size: 12,628 Bytes
658b022
 
 
ef0bdc3
 
b4075da
61e66c3
658b022
 
 
 
 
 
 
ef0bdc3
 
61e66c3
ef0bdc3
 
61e66c3
ef0bdc3
 
61e66c3
 
 
ef0bdc3
 
61e66c3
ef0bdc3
 
 
61e66c3
ef0bdc3
 
61e66c3
 
 
 
ef0bdc3
658b022
61e66c3
 
 
 
 
 
658b022
 
61e66c3
 
 
 
 
 
ef0bdc3
 
61e66c3
 
ef0bdc3
 
 
 
 
61e66c3
 
 
 
658b022
 
 
 
 
ef0bdc3
61e66c3
b4075da
 
 
 
ef09c1c
 
b4075da
 
 
 
 
 
658b022
 
b4075da
 
 
658b022
61e66c3
b4075da
 
61e66c3
 
658b022
b4075da
 
 
 
 
 
 
 
 
 
 
 
 
ef0bdc3
 
 
 
 
b4075da
 
 
 
ef09c1c
 
b4075da
ef09c1c
 
 
b4075da
 
ef09c1c
ef0bdc3
 
 
 
 
 
 
b4075da
ef09c1c
 
b4075da
 
ef09c1c
 
 
 
b4075da
 
ef09c1c
 
 
b4075da
ef09c1c
 
b4075da
 
 
658b022
ef0bdc3
 
 
b4075da
 
ef0bdc3
658b022
ef0bdc3
b4075da
ef09c1c
 
b4075da
 
ef09c1c
 
b4075da
 
 
 
 
 
 
 
 
 
ef0bdc3
b4075da
ef0bdc3
b4075da
ef0bdc3
b4075da
 
 
ef0bdc3
 
 
b4075da
ef0bdc3
61e66c3
ef09c1c
19c6000
 
ef09c1c
19c6000
 
ef09c1c
 
19c6000
 
ef09c1c
 
19c6000
 
ef09c1c
 
 
 
 
 
61e66c3
 
 
 
 
 
 
 
 
 
658b022
ef0bdc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import gradio as gr
from transformers import AutoTokenizer
from transformers import pipeline
from utils import format_moves
import pandas as pd
import tensorflow as tf

model_checkpoint = "distilgpt2"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

generate = pipeline("text-generation",
                    model="arjunpatel/distilgpt2-finetuned-pokemon-moves",
                    tokenizer=tokenizer)
# load in the model
seed_text = "This move is called "

tf.random.set_seed(0)


def update_history(df, move_name, move_desc, generation, parameters):
    new_row = [{"Move Name": move_name,
                "Move Description": move_desc,
                "Generation Type": generation,
                "Parameters": parameters}]
    return pd.concat([df, pd.DataFrame(new_row)])


def create_move(move, history):
    generated_move = format_moves(generate(seed_text + move, num_return_sequences=1))
    return generated_move, update_history(history, move, generated_move,
                                          "baseline", "None")


def create_greedy_search_move(move, history):
    generated_move = format_moves(generate(seed_text + move, do_sample=False))
    return generated_move, update_history(history, move, generated_move,
                                          "greedy", "None")


def create_beam_search_move(move, num_beams, history):
    generated_move = format_moves(generate(seed_text + move, num_beams=num_beams,
                                           num_return_sequences=1,
                                           do_sample=False, early_stopping=True))
    return generated_move, update_history(history, move, generated_move,
                                          "beam", {"num_beams": 2})


def create_sampling_search_move(move, do_sample, temperature, history):
    generated_move = format_moves(generate(seed_text + move, do_sample=do_sample, temperature=float(temperature),
                                           num_return_sequences=1, topk=0))
    return generated_move, update_history(history, move, generated_move,
                                          "temperature", {"do_sample": do_sample,
                                                          "temperature": temperature})


def create_top_search_move(move, topk, topp, history):
    generated_move = format_moves(generate(
        seed_text + move,
        do_sample=True,
        num_return_sequences=1,
        top_k=topk,
        top_p=topp,
        force_word_ids=tokenizer.encode("The user", return_tensors='tf')))
    return generated_move, update_history(history, move, generated_move,
                                          "top", {"top k": topk,
                                                  "top p": topp})


demo = gr.Blocks()

with demo:
    gr.Markdown("<h1><center>What's that Pokemon Move?</center></h1>")
    gr.Markdown(
        """This Gradio demo allows you to generate Pokemon Move descriptions given a name, and learn more about text 
        decoding methods in the process! Each tab aims to explain each generation methodology available for the 
        model. The dataframe below allows you to keep track of each move generated, to compare!""")
    gr.Markdown("<h3> How does text generation work? <h3>")
    gr.Markdown("""Roughly, text generation models accept an input sequence of words (or parts of words, 
    known as tokens. 
                These models then output a corresponding set of words or tokens. Given the input, the model
                estimates the probability of another possible word or token appearing right after the given sequence. In
                other words, the model estimates conditional probabilities and ranks them in order to generate sequences
                . """)
    gr.Markdown("Enter a two to three word Pokemon Move name of your imagination below, with each word capitalized!")
    gr.Markdown("<h3> Move Generation <h3>")
    with gr.Tabs():
        with gr.TabItem("Standard Generation"):
            gr.Markdown(
                """The default parameters for distilgpt2 work well to generate moves. Use this tab to have fun and as 
                a baseline for your experiments.""")
            with gr.Row():
                text_input_baseline = gr.Textbox(label="Move",
                                                 placeholder="Type a two or three word move name here! Try \"Wonder "
                                                             "Shield\"!")
                text_output_baseline = gr.Textbox(label="Move Description",
                                                  placeholder="Leave this blank!")
            text_button_baseline = gr.Button("Create my move!")
        with gr.TabItem("Greedy Search Decoding"):
            gr.Markdown("""
            
            Greedy search is a decoding method that relies on finding words that has the highest estimated 
            probability of following the sequence thus far. 
            
            Therefore, the model \"greedily\" grabs the highest 
            probability word and continues generating the sentence. 
            
            This has the side effect of finding sequences that are reasonable, but avoids sequences that are 
            less probable but way more interesting. 
            Try the other decoding methods to get sentences with more variety!
            """)
            with gr.Row():
                text_input_greedy = gr.Textbox(label="Move")
                text_output_greedy = gr.Textbox(label="Move Description")
            text_button_greedy = gr.Button("Create my move!")
        with gr.TabItem("Beam Search"):
            gr.Markdown("""Beam search is an improvement on Greedy Search. Instead of directly grabbing the word that 
            maximizes probability, we conduct a search with B number of candidates. We then try to find the next word 
            that would most likely follow each beam, and we grab the top B candidates of that search. This may 
            eliminate one of the original beams we started with, and that's okay! That is how the algorithm decides 
            on an optimal candidate. Eventually, the beam sequence terminate or are eliminated due to being too 
            improbable. 
            
            Increasing the number of beams will increase model generation time, but also result in a more thorough 
            search. Decreasing the number of beams will decrease decoding time, but it may not find an optimal 
            sentence. 
            
            Play around with the num_beams parameter to experiment! """
                        )
            with gr.Row():
                num_beams = gr.Slider(minimum=2, maximum=10, value=2, step=1,
                                      label="Number of Beams")
                text_input_beam = gr.Textbox(label="Move")
                text_output_beam = gr.Textbox(label="Move Description")
            text_button_beam = gr.Button("Create my move!")
        with gr.TabItem("Sampling and Temperature Search"):
            gr.Markdown(
                """Greedy Search and Beam Search were both good at finding sequences that are likely to follow our 
                input text, but when generating cool move descriptions, we want some more variety! 
                
                Instead of choosing the word or token that is most likely to follow a given sequence, we can instead
                ask the model to sample across the probability distribution of likely words. 
                
                It's kind of like walking into the tall grass and finding a Pokemon encounter. 
                There are different encounter rates, which allow
                for the most common mons to appear (looking at you, Zubat), but also account for surprise, like shinys!
                
                We might even want to go further, though. We can rescale the probability distributions directly 
                instead, allowing for rare words to temporarily become more frequently. We do this using the 
                temperature parameter. 
                
                Turn the temperature up, and rare tokens become very likely! Cool down, and we approach more sensible 
                output. 
                
                Experiment with turning sampling on and off, and by varying temperature below!.  
                """)
            with gr.Row():
                temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1,
                                        label="Temperature")
                text_input_temp = gr.Textbox(label="Move")
            with gr.Row():
                sample_boolean = gr.Checkbox(label="Enable Sampling?")
                text_output_temp = gr.Textbox(label="Move Description")
            text_button_temp = gr.Button("Create my move!")
        with gr.TabItem("Top K and Top P Sampling"):
            gr.Markdown(
                """When we want more control over the words we get to sample from, we turn to Top K and Top P 
                decoding methods! 
                
                
                The Top K sampling method selects the K most probable words given a sequence, and then samples from 
                that subset, rather than the whole vocabulary. This effectively cuts out low probability words. 
                
                
                Top P also reduces the available vocabulary to sample from, but instead of choosing the number of 
                words or tokens in advance, we sort the vocabulary from most to least likely word, and we 
                grab the smallest set of words that sum to P. This allows for the number of words we look at to 
                change while sampling, instead of being fixed. 
                
                We can even use both methods at the same time! To disable Top K, set it to 0 using the slider. 
                To disable Top P, set it to 1""")

            with gr.Row():
                topk = gr.Slider(minimum=0, maximum=200, value=0, step=5,
                                 label="Top K")

                text_input_top = gr.Textbox(label="Move")
            with gr.Row():
                topp = gr.Slider(minimum=0.10, maximum=1, value=1, step=0.05,
                                 label="Top P")
                text_output_top = gr.Textbox(label="Move Description")
            text_button_top = gr.Button("Create my move!")
    with gr.Box():
        gr.Markdown("<h3> Generation History <h3>")
        # Displays a dataframe with the history of moves generated, with parameters
        history = gr.Dataframe(headers=["Move Name", "Move Description", "Generation Type", "Parameters"])
    with gr.Box():
        gr.Markdown("<h3>How did you make this?<h3>")
        gr.Markdown("""
        I collected the dataset from [Serebii] (https://www.serebii.net) , a news source and aggregator of Pokemon info.
        
        
        I then added a seed phrase  "This move is called" just before each move in order to assist the model in 
        generation. 
        
        
        I then followed HuggingFace's handy language_modeling.ipynb for fine-tuning distillgpt2 on this tiny dataset, 
        and it surprisingly worked! 
        
        
        I learned all about text generation using the book [Natural Language Processing with Transformers] (
        https://www.oreilly.com/library/view/natural-language-processing/9781098103231/) by Lewis Tunstall, 
        Leandro von Werra and Thomas Wolf, as well as [this fantastic article] (
        https://huggingface.co/blog/how-to-generate) by Patrick von Platen. Thanks to all of these folks for creating 
        these learning materials, and thanks to the Hugging Face team for developing this product! """)

    text_button_baseline.click(create_move, inputs=[text_input_baseline, history],
                               outputs=[text_output_baseline, history])
    text_button_greedy.click(create_greedy_search_move, inputs=[text_input_greedy, history],
                             outputs=[text_output_greedy, history])
    text_button_temp.click(create_sampling_search_move, inputs=[text_input_temp, sample_boolean, temperature, history],
                           outputs=[text_output_temp, history])
    text_button_beam.click(create_beam_search_move, inputs=[text_input_beam, num_beams, history],
                           outputs=[text_output_beam, history])
    text_button_top.click(create_top_search_move, inputs=[text_input_top, topk, topp, history],
                          outputs=[text_output_top, history])

demo.launch(share=True)