juliekallini commited on
Commit
ad1f315
·
verified ·
1 Parent(s): 58a57f0

Upload model

Browse files
config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/nlp/scr3/nlp/llms-in-llms/babylm_models/babylm_shuffle_even_odd_100M_randinit_no_positional_encodings/babylm_shuffle_even_odd_100M_randinit_no_positional_encodings_seed0/runs/babylm_shuffle_even_odd_100M_randinit_no_positional_encodings_seed0/checkpoint-3000",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "GPT2NoPositionalEncodingLMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.1,
8
+ "auto_map": {
9
+ "AutoModelForCausalLM": "modeling_gpt2_no_pos.GPT2NoPositionalEncodingLMHeadModel"
10
+ },
11
+ "bos_token_id": 50256,
12
+ "embd_pdrop": 0.1,
13
+ "eos_token_id": 50256,
14
+ "initializer_range": 0.02,
15
+ "layer_norm_epsilon": 1e-05,
16
+ "model_type": "gpt2",
17
+ "n_ctx": 1024,
18
+ "n_embd": 768,
19
+ "n_head": 12,
20
+ "n_inner": null,
21
+ "n_layer": 12,
22
+ "n_positions": 1024,
23
+ "reorder_and_upcast_attn": true,
24
+ "resid_pdrop": 0.1,
25
+ "scale_attn_by_inverse_layer_idx": true,
26
+ "scale_attn_weights": true,
27
+ "summary_activation": null,
28
+ "summary_first_dropout": 0.2,
29
+ "summary_proj_to_labels": true,
30
+ "summary_type": "cls_index",
31
+ "summary_use_proj": true,
32
+ "task_specific_params": {
33
+ "text-generation": {
34
+ "do_sample": true,
35
+ "max_length": 1024
36
+ }
37
+ },
38
+ "torch_dtype": "float32",
39
+ "transformers_version": "4.35.2",
40
+ "use_cache": false,
41
+ "vocab_size": 50257
42
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.35.2",
6
+ "use_cache": false
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72e5ca054a8467d11edd3b1216dc9a8d412a1aa56ad1dc96824e45d1e07eb36e
3
+ size 494628384
modeling_gpt2_no_pos.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_gpt2_no_pos.py
2
+ # Adapted from Huggingface's transformers library
3
+
4
+ import torch
5
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2PreTrainedModel
6
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
7
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss
10
+ from typing import Optional, Tuple, Union
11
+
12
+ class GPT2NoPositionalEncodingModel(GPT2PreTrainedModel):
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+
16
+ self.embed_dim = config.hidden_size
17
+
18
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
19
+
20
+ self.drop = nn.Dropout(config.embd_pdrop)
21
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
22
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
23
+
24
+ # Model parallel
25
+ self.model_parallel = False
26
+ self.device_map = None
27
+ self.gradient_checkpointing = False
28
+
29
+ # Initialize weights and apply final processing
30
+ self.post_init()
31
+
32
+ def parallelize(self, device_map=None):
33
+ # Check validity of device_map
34
+ self.device_map = (
35
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
36
+ )
37
+ assert_device_map(self.device_map, len(self.h))
38
+ self.model_parallel = True
39
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
40
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
41
+ self.wte = self.wte.to(self.first_device)
42
+ # Load onto devices
43
+ for k, v in self.device_map.items():
44
+ for block in v:
45
+ cuda_device = "cuda:" + str(k)
46
+ self.h[block] = self.h[block].to(cuda_device)
47
+ # ln_f to last
48
+ self.ln_f = self.ln_f.to(self.last_device)
49
+
50
+ def deparallelize(self):
51
+ self.model_parallel = False
52
+ self.device_map = None
53
+ self.first_device = "cpu"
54
+ self.last_device = "cpu"
55
+ self.wte = self.wte.to("cpu")
56
+ for index in range(len(self.h)):
57
+ self.h[index] = self.h[index].to("cpu")
58
+ self.ln_f = self.ln_f.to("cpu")
59
+ torch.cuda.empty_cache()
60
+
61
+ def get_input_embeddings(self):
62
+ return self.wte
63
+
64
+ def set_input_embeddings(self, new_embeddings):
65
+ self.wte = new_embeddings
66
+
67
+ def _prune_heads(self, heads_to_prune):
68
+ """
69
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
70
+ """
71
+ for layer, heads in heads_to_prune.items():
72
+ self.h[layer].attn.prune_heads(heads)
73
+
74
+ def forward(
75
+ self,
76
+ input_ids: Optional[torch.LongTensor] = None,
77
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
78
+ attention_mask: Optional[torch.FloatTensor] = None,
79
+ token_type_ids: Optional[torch.LongTensor] = None,
80
+ position_ids: Optional[torch.LongTensor] = None,
81
+ head_mask: Optional[torch.FloatTensor] = None,
82
+ inputs_embeds: Optional[torch.FloatTensor] = None,
83
+ encoder_hidden_states: Optional[torch.Tensor] = None,
84
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
85
+ use_cache: Optional[bool] = None,
86
+ output_attentions: Optional[bool] = None,
87
+ output_hidden_states: Optional[bool] = None,
88
+ return_dict: Optional[bool] = None,
89
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
90
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
91
+ output_hidden_states = (
92
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
93
+ )
94
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
95
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
96
+
97
+ if input_ids is not None and inputs_embeds is not None:
98
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
99
+ elif input_ids is not None:
100
+ input_shape = input_ids.size()
101
+ input_ids = input_ids.view(-1, input_shape[-1])
102
+ batch_size = input_ids.shape[0]
103
+ elif inputs_embeds is not None:
104
+ input_shape = inputs_embeds.size()[:-1]
105
+ batch_size = inputs_embeds.shape[0]
106
+ else:
107
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
108
+
109
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
110
+
111
+ if token_type_ids is not None:
112
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
113
+
114
+ if past_key_values is None:
115
+ past_length = 0
116
+ past_key_values = tuple([None] * len(self.h))
117
+ else:
118
+ past_length = past_key_values[0][0].size(-2)
119
+ if position_ids is None:
120
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
121
+ position_ids = position_ids.unsqueeze(0)
122
+
123
+ # GPT2Attention mask.
124
+ if attention_mask is not None:
125
+ if batch_size <= 0:
126
+ raise ValueError("batch_size has to be defined and > 0")
127
+ attention_mask = attention_mask.view(batch_size, -1)
128
+ # We create a 3D attention mask from a 2D tensor mask.
129
+ # Sizes are [batch_size, 1, 1, to_seq_length]
130
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
131
+ # this attention mask is more simple than the triangular masking of causal attention
132
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
133
+ attention_mask = attention_mask[:, None, None, :]
134
+
135
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
136
+ # masked positions, this operation will create a tensor which is 0.0 for
137
+ # positions we want to attend and the dtype's smallest value for masked positions.
138
+ # Since we are adding it to the raw scores before the softmax, this is
139
+ # effectively the same as removing these entirely.
140
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
141
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
142
+
143
+ # If a 2D or 3D attention mask is provided for the cross-attention
144
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
145
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
146
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
147
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
148
+ if encoder_attention_mask is None:
149
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
150
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
151
+ else:
152
+ encoder_attention_mask = None
153
+
154
+ # Prepare head mask if needed
155
+ # 1.0 in head_mask indicate we keep the head
156
+ # attention_probs has shape bsz x n_heads x N x N
157
+ # head_mask has shape n_layer x batch x n_heads x N x N
158
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
159
+
160
+ if inputs_embeds is None:
161
+ inputs_embeds = self.wte(input_ids)
162
+ hidden_states = inputs_embeds
163
+
164
+ if token_type_ids is not None:
165
+ token_type_embeds = self.wte(token_type_ids)
166
+ hidden_states = hidden_states + token_type_embeds
167
+
168
+ hidden_states = self.drop(hidden_states)
169
+
170
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
171
+
172
+ if self.gradient_checkpointing and self.training:
173
+ if use_cache:
174
+ use_cache = False
175
+
176
+ presents = () if use_cache else None
177
+ all_self_attentions = () if output_attentions else None
178
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
179
+ all_hidden_states = () if output_hidden_states else None
180
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
181
+ # Model parallel
182
+ if self.model_parallel:
183
+ torch.cuda.set_device(hidden_states.device)
184
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
185
+ if layer_past is not None:
186
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
187
+ # Ensure that attention_mask is always on the same device as hidden_states
188
+ if attention_mask is not None:
189
+ attention_mask = attention_mask.to(hidden_states.device)
190
+ if isinstance(head_mask, torch.Tensor):
191
+ head_mask = head_mask.to(hidden_states.device)
192
+ if output_hidden_states:
193
+ all_hidden_states = all_hidden_states + (hidden_states,)
194
+
195
+ if self.gradient_checkpointing and self.training:
196
+ outputs = self._gradient_checkpointing_func(
197
+ block.__call__,
198
+ hidden_states,
199
+ None,
200
+ attention_mask,
201
+ head_mask[i],
202
+ encoder_hidden_states,
203
+ encoder_attention_mask,
204
+ use_cache,
205
+ output_attentions,
206
+ )
207
+ else:
208
+ outputs = block(
209
+ hidden_states,
210
+ layer_past=layer_past,
211
+ attention_mask=attention_mask,
212
+ head_mask=head_mask[i],
213
+ encoder_hidden_states=encoder_hidden_states,
214
+ encoder_attention_mask=encoder_attention_mask,
215
+ use_cache=use_cache,
216
+ output_attentions=output_attentions,
217
+ )
218
+
219
+ hidden_states = outputs[0]
220
+ if use_cache is True:
221
+ presents = presents + (outputs[1],)
222
+
223
+ if output_attentions:
224
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
225
+ if self.config.add_cross_attention:
226
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
227
+
228
+ # Model Parallel: If it's the last layer for that device, put things on the next device
229
+ if self.model_parallel:
230
+ for k, v in self.device_map.items():
231
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
232
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
233
+
234
+ hidden_states = self.ln_f(hidden_states)
235
+
236
+ hidden_states = hidden_states.view(output_shape)
237
+ # Add last hidden state
238
+ if output_hidden_states:
239
+ all_hidden_states = all_hidden_states + (hidden_states,)
240
+
241
+ if not return_dict:
242
+ return tuple(
243
+ v
244
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
245
+ if v is not None
246
+ )
247
+
248
+ return BaseModelOutputWithPastAndCrossAttentions(
249
+ last_hidden_state=hidden_states,
250
+ past_key_values=presents,
251
+ hidden_states=all_hidden_states,
252
+ attentions=all_self_attentions,
253
+ cross_attentions=all_cross_attentions,
254
+ )
255
+
256
+ class GPT2NoPositionalEncodingLMHeadModel(GPT2PreTrainedModel):
257
+ _tied_weights_keys = ["lm_head.weight"]
258
+
259
+ def __init__(self, config):
260
+ super().__init__(config)
261
+ self.transformer = GPT2NoPositionalEncodingModel(config)
262
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
263
+
264
+ # Model parallel
265
+ self.model_parallel = False
266
+ self.device_map = None
267
+
268
+ # Initialize weights and apply final processing
269
+ self.post_init()
270
+
271
+ def parallelize(self, device_map=None):
272
+ self.device_map = (
273
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
274
+ if device_map is None
275
+ else device_map
276
+ )
277
+ assert_device_map(self.device_map, len(self.transformer.h))
278
+ self.transformer.parallelize(self.device_map)
279
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
280
+ self.model_parallel = True
281
+
282
+ def deparallelize(self):
283
+ self.transformer.deparallelize()
284
+ self.transformer = self.transformer.to("cpu")
285
+ self.lm_head = self.lm_head.to("cpu")
286
+ self.model_parallel = False
287
+ torch.cuda.empty_cache()
288
+
289
+ def get_output_embeddings(self):
290
+ return self.lm_head
291
+
292
+ def set_output_embeddings(self, new_embeddings):
293
+ self.lm_head = new_embeddings
294
+
295
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
296
+ token_type_ids = kwargs.get("token_type_ids", None)
297
+ # Omit tokens covered by past_key_values
298
+ if past_key_values:
299
+ past_length = past_key_values[0][0].shape[2]
300
+
301
+ # Some generation methods already pass only the last input ID
302
+ if input_ids.shape[1] > past_length:
303
+ remove_prefix_length = past_length
304
+ else:
305
+ # Default to old behavior: keep only final ID
306
+ remove_prefix_length = input_ids.shape[1] - 1
307
+
308
+ input_ids = input_ids[:, remove_prefix_length:]
309
+ if token_type_ids is not None:
310
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
311
+
312
+ attention_mask = kwargs.get("attention_mask", None)
313
+ position_ids = kwargs.get("position_ids", None)
314
+
315
+ if attention_mask is not None and position_ids is None:
316
+ # create position_ids on the fly for batch generation
317
+ position_ids = attention_mask.long().cumsum(-1) - 1
318
+ position_ids.masked_fill_(attention_mask == 0, 1)
319
+ if past_key_values:
320
+ position_ids = position_ids[:, -input_ids.shape[1] :]
321
+ else:
322
+ position_ids = None
323
+
324
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
325
+ if inputs_embeds is not None and past_key_values is None:
326
+ model_inputs = {"inputs_embeds": inputs_embeds}
327
+ else:
328
+ model_inputs = {"input_ids": input_ids}
329
+
330
+ model_inputs.update(
331
+ {
332
+ "past_key_values": past_key_values,
333
+ "use_cache": kwargs.get("use_cache"),
334
+ "position_ids": position_ids,
335
+ "attention_mask": attention_mask,
336
+ "token_type_ids": token_type_ids,
337
+ }
338
+ )
339
+
340
+ return model_inputs
341
+
342
+ def forward(
343
+ self,
344
+ input_ids: Optional[torch.LongTensor] = None,
345
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
346
+ attention_mask: Optional[torch.FloatTensor] = None,
347
+ token_type_ids: Optional[torch.LongTensor] = None,
348
+ position_ids: Optional[torch.LongTensor] = None,
349
+ head_mask: Optional[torch.FloatTensor] = None,
350
+ inputs_embeds: Optional[torch.FloatTensor] = None,
351
+ encoder_hidden_states: Optional[torch.Tensor] = None,
352
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
353
+ labels: Optional[torch.LongTensor] = None,
354
+ use_cache: Optional[bool] = None,
355
+ output_attentions: Optional[bool] = None,
356
+ output_hidden_states: Optional[bool] = None,
357
+ return_dict: Optional[bool] = None,
358
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
359
+ r"""
360
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
361
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
362
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
363
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
364
+ """
365
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
366
+
367
+ transformer_outputs = self.transformer(
368
+ input_ids,
369
+ past_key_values=past_key_values,
370
+ attention_mask=attention_mask,
371
+ token_type_ids=token_type_ids,
372
+ position_ids=position_ids,
373
+ head_mask=head_mask,
374
+ inputs_embeds=inputs_embeds,
375
+ encoder_hidden_states=encoder_hidden_states,
376
+ encoder_attention_mask=encoder_attention_mask,
377
+ use_cache=use_cache,
378
+ output_attentions=output_attentions,
379
+ output_hidden_states=output_hidden_states,
380
+ return_dict=return_dict,
381
+ )
382
+ hidden_states = transformer_outputs[0]
383
+
384
+ # Set device for model parallelism
385
+ if self.model_parallel:
386
+ torch.cuda.set_device(self.transformer.first_device)
387
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
388
+
389
+ lm_logits = self.lm_head(hidden_states)
390
+
391
+ loss = None
392
+ if labels is not None:
393
+ # move labels to correct device to enable model parallelism
394
+ labels = labels.to(lm_logits.device)
395
+ # Shift so that tokens < n predict n
396
+ shift_logits = lm_logits[..., :-1, :].contiguous()
397
+ shift_labels = labels[..., 1:].contiguous()
398
+ # Flatten the tokens
399
+ loss_fct = CrossEntropyLoss()
400
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
401
+
402
+ if not return_dict:
403
+ output = (lm_logits,) + transformer_outputs[1:]
404
+ return ((loss,) + output) if loss is not None else output
405
+
406
+ return CausalLMOutputWithCrossAttentions(
407
+ loss=loss,
408
+ logits=lm_logits,
409
+ past_key_values=transformer_outputs.past_key_values,
410
+ hidden_states=transformer_outputs.hidden_states,
411
+ attentions=transformer_outputs.attentions,
412
+ cross_attentions=transformer_outputs.cross_attentions,
413
+ )
414
+
415
+ @staticmethod
416
+ def _reorder_cache(
417
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
418
+ ) -> Tuple[Tuple[torch.Tensor]]:
419
+ """
420
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
421
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
422
+ beam_idx at every generation step.
423
+ """
424
+ return tuple(
425
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
426
+ for layer_past in past_key_values
427
+ )