adamcasson commited on
Commit
a891744
·
1 Parent(s): 89ab732

add more viz

Browse files
Files changed (1) hide show
  1. app.py +48 -18
app.py CHANGED
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer
7
 
8
  tokenizer = AutoTokenizer.from_pretrained("adamcasson/ul2-tinystories")
9
 
 
10
  def mask_spans(
11
  tokens,
12
  mu,
@@ -19,7 +20,9 @@ def mask_spans(
19
  masked_tokens = tokens[:]
20
 
21
  encoder_inputs = [prepend_id] if prepend_id is not None else []
 
22
  targets = []
 
23
 
24
  # Original T5 code reused tokens at the end of vocab for sentinels
25
  # https://github.com/google-research/text-to-text-transfer-transformer/blob/258fd30687e6c60d18b7204d009dc5c753142987/t5/data/preprocessors.py#L3106C6-L3106C6
@@ -32,7 +35,9 @@ def mask_spans(
32
  0, len(tokens) - random.randint(1, int(2 * mu))
33
  ) # max to handle start < 0
34
  encoder_inputs += tokens[:start] + [sentinel_id]
35
- targets += tokens[start:]
 
 
36
  for i in range(start, len(tokens)):
37
  masked_tokens[i] = -1
38
 
@@ -49,28 +54,30 @@ def mask_spans(
49
  # randomly decide if span should be masked
50
  if np.random.binomial(1, p=r):
51
  encoder_inputs.append(sentinel_id)
 
52
  targets += tokens[start:end]
 
53
  for i in range(start, end):
54
  masked_tokens[i] = -1
55
  prev_span_unmasked = False
56
  sentinel_id -= 1
57
  else:
58
  encoder_inputs += tokens[start:end]
 
59
  # if previous span was also unmasked we don't need to keep adding the sentinel token
60
  if not prev_span_unmasked:
61
  targets.append(sentinel_id)
 
62
  prev_span_unmasked = True
63
  start = end
64
 
65
- encoder_inputs.append(eos_id)
66
  targets.append(eos_id)
67
- decoder_inputs = (
68
- [prepend_id] + targets[:-1]
69
- if prepend_id is not None
70
- else [eos_id] + targets[:-1]
71
- )
72
 
73
- return encoder_inputs, decoder_inputs, targets, masked_tokens
74
 
75
  # Create mixture-of-denoisers
76
  denoiser_map = {
@@ -137,15 +144,23 @@ def mask_viz(denoiser, text):
137
  seq = tokenizer.encode(text)
138
  tokens = tokenizer.tokenize(text)
139
 
140
- out = denoiser_map[denoiser](seq)
141
 
142
- mask = out[-1]
143
-
144
  highlight_tok = []
145
  for tok, tok_mask in zip(tokens, mask):
146
  highlight_tok.append((tok.replace("Ġ", " ").replace("Ċ", "\n"), "masked" if tok_mask == -1 else "unmasked"))
147
 
148
- return highlight_tok
 
 
 
 
 
 
 
 
 
 
149
 
150
  iface = gr.Interface(
151
  fn=mask_viz,
@@ -164,14 +179,29 @@ iface = gr.Interface(
164
  value="R (µ = 3, r = 0.15)",
165
  ),
166
  gr.Textbox(
167
- value='Once upon a time, there was a clever little dog named Max. Max loved to run and play with his friends in the park. One day, Max was running very fast when he fell and hurt his knee. Max went to his friend, the wise old owl, and said, "Owl, my knee hurts. What can I do?" The owl thought for a moment and said, "Max, you should test your knee. Try to walk slowly and see if it still hurts." So Max tested his knee by walking slowly. At first, it hurt a little, but soon Max felt better. He said, "Thank you, Owl, for your help. Now I can play with my friends again." Max was so happy that he could play with his friends without pain. He learned that sometimes, it was good to slow down and listen to his body. And Max and his friends played happily in the park ever after.'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  ),
169
  ],
170
- outputs=gr.HighlightedText(
171
- combine_adjacent=True,
172
- show_legend=True,
173
- color_map={"unmasked": "green", "masked": "red"}
174
- )
175
  )
176
 
177
  iface.launch()
 
7
 
8
  tokenizer = AutoTokenizer.from_pretrained("adamcasson/ul2-tinystories")
9
 
10
+
11
  def mask_spans(
12
  tokens,
13
  mu,
 
20
  masked_tokens = tokens[:]
21
 
22
  encoder_inputs = [prepend_id] if prepend_id is not None else []
23
+ encoder_mask = [1] if prepend_id is not None else []
24
  targets = []
25
+ targets_mask = []
26
 
27
  # Original T5 code reused tokens at the end of vocab for sentinels
28
  # https://github.com/google-research/text-to-text-transfer-transformer/blob/258fd30687e6c60d18b7204d009dc5c753142987/t5/data/preprocessors.py#L3106C6-L3106C6
 
35
  0, len(tokens) - random.randint(1, int(2 * mu))
36
  ) # max to handle start < 0
37
  encoder_inputs += tokens[:start] + [sentinel_id]
38
+ encoder_mask += ([1] * len(tokens[:start])) + [0]
39
+ targets += [sentinel_id] + tokens[start:]
40
+ targets_mask += [0] + ([1] * len(tokens[start:]))
41
  for i in range(start, len(tokens)):
42
  masked_tokens[i] = -1
43
 
 
54
  # randomly decide if span should be masked
55
  if np.random.binomial(1, p=r):
56
  encoder_inputs.append(sentinel_id)
57
+ encoder_mask.append(0)
58
  targets += tokens[start:end]
59
+ targets_mask += ([1] * len(tokens[start:end]))
60
  for i in range(start, end):
61
  masked_tokens[i] = -1
62
  prev_span_unmasked = False
63
  sentinel_id -= 1
64
  else:
65
  encoder_inputs += tokens[start:end]
66
+ encoder_mask += ([1] * len(tokens[start:end]))
67
  # if previous span was also unmasked we don't need to keep adding the sentinel token
68
  if not prev_span_unmasked:
69
  targets.append(sentinel_id)
70
+ targets_mask.append(0)
71
  prev_span_unmasked = True
72
  start = end
73
 
 
74
  targets.append(eos_id)
75
+ targets_mask.append(1)
76
+ decoder_inputs = [eos_id] + targets[:-1]
77
+ decoder_mask = [1] + targets_mask[:-1]
78
+
79
+ return encoder_inputs, encoder_mask, decoder_inputs, decoder_mask, targets, targets_mask, masked_tokens
80
 
 
81
 
82
  # Create mixture-of-denoisers
83
  denoiser_map = {
 
144
  seq = tokenizer.encode(text)
145
  tokens = tokenizer.tokenize(text)
146
 
147
+ enc_in, enc_mask, dec_in, dec_mask, targets, targets_mask, mask = denoiser_map[denoiser](seq)
148
 
 
 
149
  highlight_tok = []
150
  for tok, tok_mask in zip(tokens, mask):
151
  highlight_tok.append((tok.replace("Ġ", " ").replace("Ċ", "\n"), "masked" if tok_mask == -1 else "unmasked"))
152
 
153
+ highlight_enc = []
154
+ enc_tok = tokenizer.convert_ids_to_tokens(enc_in)
155
+ for id, tok, tok_mask in zip(enc_in, enc_tok, enc_mask):
156
+ highlight_enc.append((tok.replace("Ġ", " ").replace("Ċ", "\n") if tok_mask == 1 else str(id), "masked" if tok_mask == 0 else "unmasked"))
157
+
158
+ highlight_dec = []
159
+ dec_tok = tokenizer.convert_ids_to_tokens(dec_in)
160
+ for id, tok, tok_mask in zip(dec_in, dec_tok, dec_mask):
161
+ highlight_dec.append((tok.replace("Ġ", " ").replace("Ċ", "\n") if tok_mask == 1 else str(id), "masked" if tok_mask == 0 else "unmasked"))
162
+
163
+ return highlight_tok, highlight_enc, highlight_dec
164
 
165
  iface = gr.Interface(
166
  fn=mask_viz,
 
179
  value="R (µ = 3, r = 0.15)",
180
  ),
181
  gr.Textbox(
182
+ value='Once upon a time, there was a family with a little boy. His name was Jack.\nOne day, Jack had a thought. He wanted to go to the park and play. His parents were worried because it was getting dark and the park was far away.\n"Mom, I want to play in the park," Jack said.\nHis mother thought for a moment. "It\'s too late to go to the park now. We\'d better stay at home," she said. \nJack was sad, but he understood why his parents were worried. Together they decided to play games at home instead. \nJack was so happy to get to play games with his family. He thought it was the best time ever.'
183
+ ),
184
+ ],
185
+ outputs=[
186
+ gr.HighlightedText(
187
+ label="Corrupted spans",
188
+ combine_adjacent=True,
189
+ show_legend=True,
190
+ color_map={"unmasked": "green", "masked": "red"}
191
+ ),
192
+ gr.HighlightedText(
193
+ label="Encoder input",
194
+ combine_adjacent=True,
195
+ show_legend=True,
196
+ color_map={"unmasked": "green", "masked": "red"}
197
+ ),
198
+ gr.HighlightedText(
199
+ label="Decoder input",
200
+ combine_adjacent=True,
201
+ show_legend=True,
202
+ color_map={"unmasked": "green", "masked": "red"}
203
  ),
204
  ],
 
 
 
 
 
205
  )
206
 
207
  iface.launch()