arjunpatel commited on
Commit
61e66c3
1 Parent(s): ef0bdc3

History up and running

Browse files
Files changed (1) hide show
  1. gradio_demo.py +51 -41
gradio_demo.py CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer
3
  from transformers import pipeline
4
  from utils import format_moves
5
  import pandas as pd
 
6
  model_checkpoint = "distilgpt2"
7
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
@@ -13,9 +14,11 @@ generate = pipeline("text-generation",
13
  # load in the model
14
  seed_text = "This move is called "
15
  import tensorflow as tf
 
16
  tf.random.set_seed(0)
17
 
18
- #need a function to sanitize imputs
 
19
  # - remove extra spaces
20
  # - make sure each word is capitalized
21
  # - format the moves such that it's clearer when each move is listed
@@ -25,61 +28,67 @@ def update_history(df, move_name, move_desc, generation, parameters):
25
  # needs to format each move description with new lines to cut down on width
26
 
27
  new_row = [{"Move Name": move_name,
28
- "Move Description": move_desc,
29
- "Generation Type": generation,
30
- "Parameters": parameters}]
31
  return pd.concat([df, pd.DataFrame(new_row)])
32
 
 
33
  def create_move(move, history):
34
  generated_move = format_moves(generate(seed_text + move, num_return_sequences=1))
35
  return generated_move, update_history(history, move, generated_move,
36
- "baseline", "None")
37
 
38
 
39
- def create_greedy_search_move(move):
40
- generated_move = generate(seed_text + move, do_sample=False)
41
- return format_moves(generated_move)
 
42
 
43
 
44
- def create_beam_search_move(move, num_beams=2):
45
- generated_move = generate(seed_text + move, num_beams=num_beams,
46
- num_return_sequences=1,
47
- do_sample=False, early_stopping=True)
48
- return format_moves(generated_move)
 
49
 
50
 
51
- def create_sampling_search_move(move, do_sample=True, temperature=1):
52
- generated_move = generate(seed_text + move, do_sample=do_sample, temperature= float(temperature),
53
- num_return_sequences=1, topk=0)
54
- return format_moves(generated_move)
 
 
55
 
56
 
57
- def create_top_search_move(move, topk=0, topp=0.90):
58
- generated_move = generate(
59
  seed_text + move,
60
  do_sample=True,
61
  num_return_sequences=1,
62
  top_k=topk,
63
  top_p=topp,
64
- force_word_ids=tokenizer.encode("The user", return_tensors='tf'))
65
- return format_moves(generated_move)
66
-
67
-
68
 
69
 
70
  demo = gr.Blocks()
71
 
72
  with demo:
73
  gr.Markdown("<h1><center>What's that Pokemon Move?</center></h1>")
74
- gr.Markdown("This Gradio demo is a small GPT-2 model fine-tuned on a dataset of Pokemon moves! It'll generate a move description given a name.")
 
75
  gr.Markdown("Enter a two to three word Pokemon Move name of your imagination below!")
76
  with gr.Tabs():
77
  with gr.TabItem("Standard Generation"):
78
  with gr.Row():
79
- text_input_baseline = gr.Textbox(label = "Move",
80
- placeholder = "Type a two or three word move name here! Try \"Wonder Shield\"!")
81
- text_output_baseline = gr.Textbox(label = "Move Description",
82
- placeholder= "Leave this blank!")
83
  text_button_baseline = gr.Button("Create my move!")
84
  with gr.TabItem("Greedy Search"):
85
  gr.Markdown("This tab lets you learn about using greedy search!")
@@ -100,14 +109,14 @@ with demo:
100
  with gr.Row():
101
  temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1,
102
  label="Temperature")
103
- sample_boolean = gr.Checkbox(label = "Enable Sampling?")
104
  text_input_temp = gr.Textbox(label="Move")
105
  text_output_temp = gr.Textbox(label="Move Description")
106
  text_button_temp = gr.Button("Create my move!")
107
  with gr.TabItem("Top K and Top P Sampling"):
108
  gr.Markdown("This tab lets you learn about Top K and Top P Sampling")
109
  with gr.Row():
110
- topk = gr.Slider(minimum=10, maximum=100, value=50, step=5,
111
  label="Top K")
112
  topp = gr.Slider(minimum=0.10, maximum=0.95, value=1, step=0.05,
113
  label="Top P")
@@ -116,16 +125,17 @@ with demo:
116
  text_button_top = gr.Button("Create my move!")
117
  with gr.Box():
118
  # Displays a dataframe with the history of moves generated, with parameters
119
- history = gr.Dataframe(headers= ["Move Name", "Move Description", "Generation Type", "Parameters"])
120
-
121
-
122
- text_button_baseline.click(create_move, inputs=[text_input_baseline, history], outputs=[text_output_baseline, history])
123
- text_button_greedy.click(create_greedy_search_move, inputs=text_input_greedy, outputs=text_output_greedy)
124
- text_button_temp.click(create_sampling_search_move, inputs=[text_input_temp, sample_boolean, temperature],
125
- outputs=text_output_temp)
126
- text_button_beam.click(create_beam_search_move, inputs=[text_input_beam, num_beams], outputs=text_output_beam)
127
- text_button_top.click(create_top_search_move, inputs=[text_input_top, topk, topp], outputs=text_output_top)
 
 
 
128
 
129
- #Whenever any of the output boxes updates, take that output box and add it to the History dataframe
130
- #text_output_baseline.change(update_history, inputs = [history, text_input_baseline, text_output_baseline], outputs = history)
131
  demo.launch(share=True)
 
3
  from transformers import pipeline
4
  from utils import format_moves
5
  import pandas as pd
6
+
7
  model_checkpoint = "distilgpt2"
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
 
14
  # load in the model
15
  seed_text = "This move is called "
16
  import tensorflow as tf
17
+
18
  tf.random.set_seed(0)
19
 
20
+
21
+ # need a function to sanitize imputs
22
  # - remove extra spaces
23
  # - make sure each word is capitalized
24
  # - format the moves such that it's clearer when each move is listed
 
28
  # needs to format each move description with new lines to cut down on width
29
 
30
  new_row = [{"Move Name": move_name,
31
+ "Move Description": move_desc,
32
+ "Generation Type": generation,
33
+ "Parameters": parameters}]
34
  return pd.concat([df, pd.DataFrame(new_row)])
35
 
36
+
37
  def create_move(move, history):
38
  generated_move = format_moves(generate(seed_text + move, num_return_sequences=1))
39
  return generated_move, update_history(history, move, generated_move,
40
+ "baseline", "None")
41
 
42
 
43
+ def create_greedy_search_move(move, history):
44
+ generated_move = format_moves(generate(seed_text + move, do_sample=False))
45
+ return generated_move, update_history(history, move, generated_move,
46
+ "greedy", "None")
47
 
48
 
49
+ def create_beam_search_move(move, num_beams, history):
50
+ generated_move = format_moves(generate(seed_text + move, num_beams=num_beams,
51
+ num_return_sequences=1,
52
+ do_sample=False, early_stopping=True))
53
+ return generated_move, update_history(history, move, generated_move,
54
+ "beam", {"num_beams": 2})
55
 
56
 
57
+ def create_sampling_search_move(move, do_sample, temperature, history):
58
+ generated_move = format_moves(generate(seed_text + move, do_sample=do_sample, temperature=float(temperature),
59
+ num_return_sequences=1, topk=0))
60
+ return generated_move, update_history(history, move, generated_move,
61
+ "temperature", {"do_sample": do_sample,
62
+ "temperature": temperature})
63
 
64
 
65
+ def create_top_search_move(move, topk, topp, history):
66
+ generated_move = format_moves(generate(
67
  seed_text + move,
68
  do_sample=True,
69
  num_return_sequences=1,
70
  top_k=topk,
71
  top_p=topp,
72
+ force_word_ids=tokenizer.encode("The user", return_tensors='tf')))
73
+ return generated_move, update_history(history, move, generated_move,
74
+ "top", {"top k": topk,
75
+ "top p": topp})
76
 
77
 
78
  demo = gr.Blocks()
79
 
80
  with demo:
81
  gr.Markdown("<h1><center>What's that Pokemon Move?</center></h1>")
82
+ gr.Markdown(
83
+ "This Gradio demo is a small GPT-2 model fine-tuned on a dataset of Pokemon moves! It'll generate a move description given a name.")
84
  gr.Markdown("Enter a two to three word Pokemon Move name of your imagination below!")
85
  with gr.Tabs():
86
  with gr.TabItem("Standard Generation"):
87
  with gr.Row():
88
+ text_input_baseline = gr.Textbox(label="Move",
89
+ placeholder="Type a two or three word move name here! Try \"Wonder Shield\"!")
90
+ text_output_baseline = gr.Textbox(label="Move Description",
91
+ placeholder="Leave this blank!")
92
  text_button_baseline = gr.Button("Create my move!")
93
  with gr.TabItem("Greedy Search"):
94
  gr.Markdown("This tab lets you learn about using greedy search!")
 
109
  with gr.Row():
110
  temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1,
111
  label="Temperature")
112
+ sample_boolean = gr.Checkbox(label="Enable Sampling?")
113
  text_input_temp = gr.Textbox(label="Move")
114
  text_output_temp = gr.Textbox(label="Move Description")
115
  text_button_temp = gr.Button("Create my move!")
116
  with gr.TabItem("Top K and Top P Sampling"):
117
  gr.Markdown("This tab lets you learn about Top K and Top P Sampling")
118
  with gr.Row():
119
+ topk = gr.Slider(minimum=10, maximum=100, value=0, step=5,
120
  label="Top K")
121
  topp = gr.Slider(minimum=0.10, maximum=0.95, value=1, step=0.05,
122
  label="Top P")
 
125
  text_button_top = gr.Button("Create my move!")
126
  with gr.Box():
127
  # Displays a dataframe with the history of moves generated, with parameters
128
+ history = gr.Dataframe(headers=["Move Name", "Move Description", "Generation Type", "Parameters"])
129
+
130
+ text_button_baseline.click(create_move, inputs=[text_input_baseline, history],
131
+ outputs=[text_output_baseline, history])
132
+ text_button_greedy.click(create_greedy_search_move, inputs=[text_input_greedy, history],
133
+ outputs=[text_output_greedy, history])
134
+ text_button_temp.click(create_sampling_search_move, inputs=[text_input_temp, sample_boolean, temperature, history],
135
+ outputs=[text_output_temp, history])
136
+ text_button_beam.click(create_beam_search_move, inputs=[text_input_beam, num_beams, history],
137
+ outputs=[text_output_beam, history])
138
+ text_button_top.click(create_top_search_move, inputs=[text_input_top, topk, topp, history],
139
+ outputs=[text_output_top, history])
140
 
 
 
141
  demo.launch(share=True)