arjunpatel commited on
Commit
ef0bdc3
1 Parent(s): 658b022

Working demo with history

Browse files
Files changed (1) hide show
  1. gradio_demo.py +104 -33
gradio_demo.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer
3
  from transformers import pipeline
4
-
 
5
  model_checkpoint = "distilgpt2"
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
@@ -9,52 +10,122 @@ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
9
  generate = pipeline("text-generation",
10
  model="arjunpatel/distilgpt2-finetuned-pokemon-moves",
11
  tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
 
 
13
 
14
- def filter_text(generated_move):
15
- # removes any moves that follow after the genrated move
16
- print(generated_move)
17
- sentences = generated_move.split(".")
18
- if len(sentences) > 2:
19
- ret_set = " ".join(sentences[0:1])
20
- else:
21
- ret_set = generated_move
22
- return ret_set
23
 
24
- def create_move(move):
25
- seed_text = "This move is called "
26
- generated_move = generate(seed_text + move, num_return_sequences=2,
27
- no_repeat_ngram_size=4)[0]["generated_text"]
28
- return generated_move
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
- # # demo = gr.Interface(fn=greet, inputs = "text", outputs="text")
32
- #
33
- # gr.Interface(fn=create_move,
34
- # inputs="text", outputs="text").launch()
35
- # # demo.launch()
36
 
37
- def filler_move(test_move, temperature):
38
- return test_move + " with temperature " + str(temperature)
39
 
40
  demo = gr.Blocks()
41
 
42
  with demo:
43
- gr.Markdown("What's that Pokemon Move?")
 
 
44
  with gr.Tabs():
45
  with gr.TabItem("Standard Generation"):
46
  with gr.Row():
47
- text_input_baseline = gr.Textbox()
48
- text_output_baseline = gr.Textbox()
 
 
49
  text_button_baseline = gr.Button("Create my move!")
50
- with gr.TabItem("Temperature Search"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  with gr.Row():
52
- temperature = gr.Slider(minimum = 0.3, maximum = 4, value = 1, step = 0.1,
53
- label = "Temperature")
54
- text_input_temp = gr.Textbox(label="Move Name")
55
- text_output_temp = gr.Textbox(label = "Move Description")
 
56
  text_button_temp = gr.Button("Create my move!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- #text_button_baseline.click(filler_move, inputs=[text_input_baseline, 0], outputs=text_output_baseline)
59
- text_button_temp.click(filler_move, inputs=[text_input_temp, temperature], outputs=text_output_temp)
60
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer
3
  from transformers import pipeline
4
+ from utils import format_moves
5
+ import pandas as pd
6
  model_checkpoint = "distilgpt2"
7
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
 
10
  generate = pipeline("text-generation",
11
  model="arjunpatel/distilgpt2-finetuned-pokemon-moves",
12
  tokenizer=tokenizer)
13
+ # load in the model
14
+ seed_text = "This move is called "
15
+ import tensorflow as tf
16
+ tf.random.set_seed(0)
17
+
18
+ #need a function to sanitize imputs
19
+ # - remove extra spaces
20
+ # - make sure each word is capitalized
21
+ # - format the moves such that it's clearer when each move is listed
22
+ # - play with the max length parameter abit, and try to remove sentences that don't end in periods.
23
+
24
+ def update_history(df, move_name, move_desc, generation, parameters):
25
+ # needs to format each move description with new lines to cut down on width
26
+
27
+ new_row = [{"Move Name": move_name,
28
+ "Move Description": move_desc,
29
+ "Generation Type": generation,
30
+ "Parameters": parameters}]
31
+ return pd.concat([df, pd.DataFrame(new_row)])
32
+
33
+ def create_move(move, history):
34
+ generated_move = format_moves(generate(seed_text + move, num_return_sequences=1))
35
+ return generated_move, update_history(history, move, generated_move,
36
+ "baseline", "None")
37
+
38
+
39
+ def create_greedy_search_move(move):
40
+ generated_move = generate(seed_text + move, do_sample=False)
41
+ return format_moves(generated_move)
42
+
43
 
44
+ def create_beam_search_move(move, num_beams=2):
45
+ generated_move = generate(seed_text + move, num_beams=num_beams,
46
+ num_return_sequences=1,
47
+ do_sample=False, early_stopping=True)
48
+ return format_moves(generated_move)
49
 
 
 
 
 
 
 
 
 
 
50
 
51
+ def create_sampling_search_move(move, do_sample=True, temperature=1):
52
+ generated_move = generate(seed_text + move, do_sample=do_sample, temperature= float(temperature),
53
+ num_return_sequences=1, topk=0)
54
+ return format_moves(generated_move)
55
+
56
+
57
+ def create_top_search_move(move, topk=0, topp=0.90):
58
+ generated_move = generate(
59
+ seed_text + move,
60
+ do_sample=True,
61
+ num_return_sequences=1,
62
+ top_k=topk,
63
+ top_p=topp,
64
+ force_word_ids=tokenizer.encode("The user", return_tensors='tf'))
65
+ return format_moves(generated_move)
66
 
67
 
 
 
 
 
 
68
 
 
 
69
 
70
  demo = gr.Blocks()
71
 
72
  with demo:
73
+ gr.Markdown("<h1><center>What's that Pokemon Move?</center></h1>")
74
+ gr.Markdown("This Gradio demo is a small GPT-2 model fine-tuned on a dataset of Pokemon moves! It'll generate a move description given a name.")
75
+ gr.Markdown("Enter a two to three word Pokemon Move name of your imagination below!")
76
  with gr.Tabs():
77
  with gr.TabItem("Standard Generation"):
78
  with gr.Row():
79
+ text_input_baseline = gr.Textbox(label = "Move",
80
+ placeholder = "Type a two or three word move name here! Try \"Wonder Shield\"!")
81
+ text_output_baseline = gr.Textbox(label = "Move Description",
82
+ placeholder= "Leave this blank!")
83
  text_button_baseline = gr.Button("Create my move!")
84
+ with gr.TabItem("Greedy Search"):
85
+ gr.Markdown("This tab lets you learn about using greedy search!")
86
+ with gr.Row():
87
+ text_input_greedy = gr.Textbox(label="Move")
88
+ text_output_greedy = gr.Textbox(label="Move Description")
89
+ text_button_greedy = gr.Button("Create my move!")
90
+ with gr.TabItem("Beam Search"):
91
+ gr.Markdown("This tab lets you learn about using beam search!")
92
+ with gr.Row():
93
+ num_beams = gr.Slider(minimum=2, maximum=10, value=2, step=1,
94
+ label="Number of Beams")
95
+ text_input_beam = gr.Textbox(label="Move")
96
+ text_output_beam = gr.Textbox(label="Move Description")
97
+ text_button_beam = gr.Button("Create my move!")
98
+ with gr.TabItem("Sampling and Temperature Search"):
99
+ gr.Markdown("This tab lets you experiment with adjusting the temperature of the generator")
100
  with gr.Row():
101
+ temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1,
102
+ label="Temperature")
103
+ sample_boolean = gr.Checkbox(label = "Enable Sampling?")
104
+ text_input_temp = gr.Textbox(label="Move")
105
+ text_output_temp = gr.Textbox(label="Move Description")
106
  text_button_temp = gr.Button("Create my move!")
107
+ with gr.TabItem("Top K and Top P Sampling"):
108
+ gr.Markdown("This tab lets you learn about Top K and Top P Sampling")
109
+ with gr.Row():
110
+ topk = gr.Slider(minimum=10, maximum=100, value=50, step=5,
111
+ label="Top K")
112
+ topp = gr.Slider(minimum=0.10, maximum=0.95, value=1, step=0.05,
113
+ label="Top P")
114
+ text_input_top = gr.Textbox(label="Move")
115
+ text_output_top = gr.Textbox(label="Move Description")
116
+ text_button_top = gr.Button("Create my move!")
117
+ with gr.Box():
118
+ # Displays a dataframe with the history of moves generated, with parameters
119
+ history = gr.Dataframe(headers= ["Move Name", "Move Description", "Generation Type", "Parameters"])
120
+
121
+
122
+ text_button_baseline.click(create_move, inputs=[text_input_baseline, history], outputs=[text_output_baseline, history])
123
+ text_button_greedy.click(create_greedy_search_move, inputs=text_input_greedy, outputs=text_output_greedy)
124
+ text_button_temp.click(create_sampling_search_move, inputs=[text_input_temp, sample_boolean, temperature],
125
+ outputs=text_output_temp)
126
+ text_button_beam.click(create_beam_search_move, inputs=[text_input_beam, num_beams], outputs=text_output_beam)
127
+ text_button_top.click(create_top_search_move, inputs=[text_input_top, topk, topp], outputs=text_output_top)
128
 
129
+ #Whenever any of the output boxes updates, take that output box and add it to the History dataframe
130
+ #text_output_baseline.change(update_history, inputs = [history, text_input_baseline, text_output_baseline], outputs = history)
131
+ demo.launch(share=True)