ydshieh commited on
Commit
3ed2a5d
1 Parent(s): 155e823

Clean Flax ViT + GPT2-LM script

Browse files
Files changed (1) hide show
  1. vit_gpt2/modeling_flax_vit_gpt2_lm.py +141 -269
vit_gpt2/modeling_flax_vit_gpt2_lm.py CHANGED
@@ -6,39 +6,27 @@ import jax.numpy as jnp
6
  from flax.core.frozen_dict import FrozenDict, unfreeze
7
  from jax import lax
8
  from jax.random import PRNGKey
9
- from transformers import GPT2Config, FlaxViTModel, ViTConfig
10
  from transformers.modeling_flax_outputs import (
11
  FlaxCausalLMOutputWithCrossAttentions,
12
  FlaxSeq2SeqLMOutput,
13
  FlaxSeq2SeqModelOutput,
14
  )
15
- from transformers.models.bart.modeling_flax_bart import (
16
- shift_tokens_right,
17
- )
 
 
18
  from .modeling_flax_gpt2 import (
 
19
  FlaxGPT2Module,
20
  FlaxGPT2Model,
21
  FlaxGPT2LMHeadModule,
22
  FlaxGPT2LMHeadModel,
23
- FlaxPreTrainedModel
24
  )
25
- from transformers.models.vit.modeling_flax_vit import FlaxViTModule
26
-
27
- from .configuration_vit_gpt2 import ViTGPT2Config
28
 
29
 
30
- def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
31
- """
32
- Shift input ids one token to the right.
33
- """
34
- shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
35
- shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
36
- # replace possible -100 values in labels by `pad_token_id`
37
- shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
38
-
39
- return shifted_input_ids
40
-
41
  class FlaxViTGPT2LMModule(nn.Module):
 
42
  config: ViTGPT2Config
43
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
44
 
@@ -54,16 +42,16 @@ class FlaxViTGPT2LMModule(nn.Module):
54
  return self.decoder
55
 
56
  def __call__(
57
- self,
58
- pixel_values,
59
- input_ids,
60
- attention_mask,
61
- position_ids,
62
- encoder_attention_mask: Optional[jnp.ndarray] = None,
63
- output_attentions: bool = False,
64
- output_hidden_states: bool = False,
65
- return_dict: bool = True,
66
- deterministic: bool = True,
67
  ):
68
  encoder_outputs = self.encoder(
69
  pixel_values=pixel_values,
@@ -74,11 +62,11 @@ class FlaxViTGPT2LMModule(nn.Module):
74
  )
75
 
76
  decoder_outputs = self.decoder(
77
- input_ids=input_ids,
78
- attention_mask=attention_mask,
79
- position_ids=position_ids,
80
  encoder_hidden_states=encoder_outputs[0],
81
- encoder_attention_mask=encoder_attention_mask,
82
  deterministic=deterministic,
83
  output_attentions=output_attentions,
84
  output_hidden_states=output_hidden_states,
@@ -98,10 +86,14 @@ class FlaxViTGPT2LMModule(nn.Module):
98
  encoder_attentions=encoder_outputs.attentions,
99
  )
100
 
 
101
  class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module):
 
 
 
 
102
  config: ViTGPT2Config
103
  dtype: jnp.dtype = jnp.float32
104
- bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
105
 
106
  def setup(self):
107
  self.model = FlaxViTGPT2LMModule(config=self.config, dtype=self.dtype)
@@ -115,10 +107,10 @@ class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module):
115
  def __call__(
116
  self,
117
  pixel_values,
118
- input_ids,
119
  attention_mask,
120
- position_ids,
121
- encoder_attention_mask: Optional[jnp.ndarray] = None,
 
122
  output_attentions: bool = False,
123
  output_hidden_states: bool = False,
124
  return_dict: bool = True,
@@ -126,10 +118,10 @@ class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module):
126
  ):
127
  outputs = self.model(
128
  pixel_values=pixel_values,
129
- input_ids=input_ids,
130
  attention_mask=attention_mask,
131
- position_ids=position_ids,
132
- encoder_attention_mask=encoder_attention_mask,
 
133
  output_attentions=output_attentions,
134
  output_hidden_states=output_hidden_states,
135
  return_dict=return_dict,
@@ -140,6 +132,7 @@ class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module):
140
 
141
 
142
  class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
 
143
  config_class = ViTGPT2Config
144
  base_model_prefix: str = "model"
145
  module_class: nn.Module = None
@@ -159,23 +152,23 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
159
  )
160
 
161
  module = self.module_class(config=config, dtype=dtype, **kwargs)
162
- super().__init__(
163
- config, module, input_shape=input_shape, seed=seed, dtype=dtype
164
- )
165
 
166
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
167
- # init input tensors
168
- pixel_values = jax.random.normal(rng, input_shape[0])
169
- # # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
170
- # input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
171
 
172
- input_ids = jnp.zeros(input_shape[1], dtype="i4")
173
- attention_mask = jnp.ones_like(input_ids)
174
 
175
- batch_size, sequence_length = input_ids.shape
176
- position_ids = jnp.broadcast_to(
177
- jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
178
- )
 
 
 
 
 
 
179
 
180
  params_rng, dropout_rng = jax.random.split(rng)
181
  rngs = {"params": params_rng, "dropout": dropout_rng}
@@ -183,40 +176,34 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
183
  return self.module.init(
184
  rngs,
185
  pixel_values,
186
- input_ids,
187
  attention_mask,
188
- position_ids,
 
 
189
  )["params"]
190
 
191
  def init_cache(self, batch_size, max_length, encoder_outputs):
192
-
193
- input_ids = jnp.ones((batch_size, max_length), dtype="i4")
194
- attention_mask = jnp.ones_like(input_ids)
195
- position_ids = jnp.broadcast_to(
196
- jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
197
- input_ids.shape,
198
  )
199
 
200
- def _decoder_forward(
201
- module,
202
- input_ids,
203
- attention_mask,
204
- position_ids,
205
- **kwargs,
206
- ):
207
  decoder_module = module._get_decoder_module()
208
  return decoder_module(
209
- input_ids,
210
- attention_mask,
211
- position_ids,
212
  **kwargs,
213
  )
214
 
215
  init_variables = self.module.init(
216
  jax.random.PRNGKey(0),
217
- input_ids=input_ids,
218
- attention_mask=attention_mask,
219
- position_ids=position_ids,
220
  encoder_hidden_states=encoder_outputs[0],
221
  init_cache=True,
222
  method=_decoder_forward, # we only need to call the decoder to init the cache
@@ -234,20 +221,13 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
234
  params: dict = None,
235
  dropout_rng: PRNGKey = None,
236
  ):
237
- output_attentions = (
238
- output_attentions
239
- if output_attentions is not None
240
- else self.config.output_attentions
241
- )
242
  output_hidden_states = (
243
- output_hidden_states
244
- if output_hidden_states is not None
245
- else self.config.output_hidden_states
246
- )
247
- return_dict = (
248
- return_dict if return_dict is not None else self.config.return_dict
249
  )
 
250
 
 
251
  pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
252
 
253
  # Handle any PRNG if needed
@@ -272,11 +252,11 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
272
 
273
  def decode(
274
  self,
275
- input_ids,
276
  encoder_outputs,
277
  encoder_attention_mask: Optional[jnp.ndarray] = None,
278
- attention_mask: Optional[jnp.ndarray] = None,
279
- position_ids: Optional[jnp.ndarray] = None,
280
  past_key_values: dict = None,
281
  output_attentions: Optional[bool] = None,
282
  output_hidden_states: Optional[bool] = None,
@@ -287,29 +267,23 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
287
  ):
288
 
289
  output_attentions = (
290
- output_attentions
291
- if output_attentions is not None
292
- else self.config.output_attentions
293
  )
294
  output_hidden_states = (
295
- output_hidden_states
296
- if output_hidden_states is not None
297
- else self.config.output_hidden_states
298
- )
299
- return_dict = (
300
- return_dict if return_dict is not None else self.config.return_dict
301
  )
 
302
 
303
  encoder_hidden_states = encoder_outputs[0]
304
  if encoder_attention_mask is None:
305
  batch_size, sequence_length = encoder_hidden_states.shape[:2]
306
  encoder_attention_mask = jnp.ones((batch_size, sequence_length))
307
 
308
- batch_size, sequence_length = input_ids.shape
309
- if attention_mask is None:
310
- attention_mask = jnp.ones((batch_size, sequence_length))
311
 
312
- if position_ids is None:
313
  if past_key_values is not None:
314
  raise ValueError(
315
  "Make sure to provide `position_ids` when passing `past_key_values`."
@@ -335,26 +309,20 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
335
  else:
336
  mutable = False
337
 
338
- def _decoder_forward(
339
- module,
340
- input_ids,
341
- attention_mask,
342
- position_ids,
343
- **kwargs,
344
- ):
345
  decoder_module = module._get_decoder_module()
346
  return decoder_module(
347
- input_ids,
348
- attention_mask,
349
- position_ids,
350
  **kwargs,
351
  )
352
 
353
  outputs = self.module.apply(
354
  inputs,
355
- input_ids=jnp.array(input_ids, dtype="i4"),
356
- attention_mask=jnp.array(attention_mask, dtype="i4"),
357
- position_ids=jnp.array(position_ids, dtype="i4"),
358
  encoder_hidden_states=encoder_hidden_states,
359
  encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
360
  output_attentions=output_attentions,
@@ -380,9 +348,10 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
380
  def __call__(
381
  self,
382
  pixel_values: jnp.ndarray,
383
- input_ids: Optional[jnp.ndarray] = None,
384
  attention_mask: Optional[jnp.ndarray] = None,
385
- position_ids: Optional[jnp.ndarray] = None,
 
 
386
  output_attentions: Optional[bool] = None,
387
  output_hidden_states: Optional[bool] = None,
388
  return_dict: Optional[bool] = None,
@@ -390,41 +359,24 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
390
  params: dict = None,
391
  dropout_rng: PRNGKey = None,
392
  ):
393
- output_attentions = (
394
- output_attentions
395
- if output_attentions is not None
396
- else self.config.output_attentions
397
- )
398
  output_hidden_states = (
399
- output_hidden_states
400
- if output_hidden_states is not None
401
- else self.config.output_hidden_states
402
- )
403
- return_dict = (
404
- return_dict if return_dict is not None else self.config.return_dict
405
  )
 
406
 
 
407
  pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
408
 
409
- # # prepare encoder inputs
410
- # if encoder_attention_mask is None:
411
- # encoder_attention_mask = jnp.ones_like(input_ids)
412
-
413
- # if position_ids is None:
414
- # batch_size, sequence_length = input_ids.shape
415
- # position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
416
-
417
  # prepare decoder inputs
418
- # if decoder_input_ids is None:
419
- # decoder_input_ids = shift_tokens_right(
420
- # input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
421
- # ) # TODO: Check how to use this
422
-
423
- if attention_mask is None:
424
- attention_mask = jnp.ones_like(input_ids)
425
- if position_ids is None:
426
- batch_size, sequence_length = input_ids.shape
427
- position_ids = jnp.broadcast_to(
428
  jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
429
  )
430
 
@@ -434,9 +386,9 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
434
  return self.module.apply(
435
  {"params": params or self.params},
436
  pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
437
- input_ids=jnp.array(input_ids, dtype="i4"),
438
- attention_mask=jnp.array(attention_mask, dtype="i4"),
439
- position_ids=jnp.array(position_ids, dtype="i4"),
440
  output_attentions=output_attentions,
441
  output_hidden_states=output_hidden_states,
442
  return_dict=return_dict,
@@ -445,17 +397,32 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
445
  )
446
 
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
449
  module_class = FlaxViTGPT2LMForConditionalGenerationModule
450
  dtype: jnp.dtype = jnp.float32
451
 
452
  def decode(
453
  self,
454
- input_ids,
455
  encoder_outputs,
456
  encoder_attention_mask: Optional[jnp.ndarray] = None,
457
- attention_mask: Optional[jnp.ndarray] = None,
458
- position_ids: Optional[jnp.ndarray] = None,
459
  past_key_values: dict = None,
460
  output_attentions: Optional[bool] = None,
461
  output_hidden_states: Optional[bool] = None,
@@ -464,135 +431,42 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
464
  params: dict = None,
465
  dropout_rng: PRNGKey = None,
466
  ):
467
- output_attentions = (
468
- output_attentions
469
- if output_attentions is not None
470
- else self.config.output_attentions
471
- )
472
- output_hidden_states = (
473
- output_hidden_states
474
- if output_hidden_states is not None
475
- else self.config.output_hidden_states
476
- )
477
- return_dict = (
478
- return_dict if return_dict is not None else self.config.return_dict
479
- )
480
-
481
- encoder_hidden_states = encoder_outputs[0]
482
- if encoder_attention_mask is None:
483
- batch_size, sequence_length = encoder_hidden_states.shape[:2]
484
- encoder_attention_mask = jnp.ones((batch_size, sequence_length))
485
 
486
- batch_size, sequence_length = input_ids.shape
487
- if attention_mask is None:
488
- attention_mask = jnp.ones((batch_size, sequence_length))
489
-
490
- if position_ids is None:
491
- if past_key_values is not None:
492
- raise ValueError(
493
- "Make sure to provide `position_ids` when passing `past_key_values`."
494
- )
495
-
496
- position_ids = jnp.broadcast_to(
497
- jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
498
- )
499
-
500
- # Handle any PRNG if needed
501
- rngs = {}
502
- if dropout_rng is not None:
503
- rngs["dropout"] = dropout_rng
504
-
505
- inputs = {"params": params or self.params}
506
-
507
- # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
508
- # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
509
- # it can be changed by FlaxGPT2Attention module
510
- if past_key_values:
511
- inputs["cache"] = past_key_values
512
- mutable = ["cache"]
513
- else:
514
- mutable = False
515
-
516
- def _decoder_forward(
517
- module,
518
- input_ids,
519
- attention_mask,
520
- position_ids,
521
- **kwargs,
522
- ):
523
- decoder_module = module._get_decoder_module()
524
- outputs = decoder_module(
525
- input_ids,
526
- attention_mask,
527
- position_ids,
528
- **kwargs,
529
- )
530
- lm_logits = outputs[0]
531
-
532
- return lm_logits, outputs
533
-
534
- outputs = self.module.apply(
535
- inputs,
536
- input_ids=jnp.array(input_ids, dtype="i4"),
537
- attention_mask=jnp.array(attention_mask, dtype="i4"),
538
- position_ids=jnp.array(position_ids, dtype="i4"),
539
- encoder_hidden_states=encoder_hidden_states,
540
- encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
541
- output_attentions=output_attentions,
542
- output_hidden_states=output_hidden_states,
543
- return_dict=return_dict,
544
- deterministic=deterministic,
545
- rngs=rngs,
546
- mutable=mutable,
547
- method=_decoder_forward,
548
  )
549
 
550
- if past_key_values is None:
551
- lm_logits, outputs = outputs
552
- else:
553
- (lm_logits, outputs), past = outputs
554
-
555
- if return_dict:
556
- outputs = FlaxCausalLMOutputWithCrossAttentions(
557
- logits=lm_logits,
558
- hidden_states=outputs.decoder_hidden_states,
559
- attentions=outputs.decoder_attentions,
560
- cross_attentions=outputs.cross_attentions,
561
- )
562
- else:
563
- outputs = (lm_logits,) + outputs[1:]
564
-
565
- # add updated cache to model output
566
- if past_key_values is not None and return_dict:
567
- outputs["past_key_values"] = unfreeze(past["cache"])
568
- return outputs
569
- elif past_key_values is not None and not return_dict:
570
- outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
571
-
572
- return outputs
573
-
574
  def prepare_inputs_for_generation(
575
  self,
576
- input_ids,
577
  max_length,
578
- encoder_attention_mask: Optional[jnp.DeviceArray] = None,
579
  attention_mask: Optional[jnp.DeviceArray] = None,
 
580
  encoder_outputs=None,
581
  **kwargs,
582
  ):
583
  # initializing the cache
584
- batch_size, seq_length = input_ids.shape
585
 
586
  past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
587
  # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
588
  # But since the decoder uses a causal mask, those positions are masked anyways.
589
  # Thus we can create a single static attention_mask here, which is more efficient for compilation
590
  extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
591
- if attention_mask is not None:
592
- position_ids = attention_mask.cumsum(axis=-1) - 1
593
- extended_attention_mask = lax.dynamic_update_slice(
594
- extended_attention_mask, attention_mask, (0, 0)
595
- )
596
  else:
597
  position_ids = jnp.broadcast_to(
598
  jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
@@ -601,16 +475,14 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
601
  return {
602
  "past_key_values": past_key_values,
603
  "encoder_outputs": encoder_outputs,
604
- "encoder_attention_mask": encoder_attention_mask,
605
- "attention_mask": extended_attention_mask,
606
- "position_ids": position_ids,
607
  }
608
 
609
  def update_inputs_for_generation(self, model_outputs, model_kwargs):
610
  model_kwargs["past_key_values"] = model_outputs.past_key_values
611
- model_kwargs["position_ids"] = (
612
- model_kwargs["position_ids"][:, -1:] + 1
613
- )
614
  return model_kwargs
615
 
616
  @classmethod
 
6
  from flax.core.frozen_dict import FrozenDict, unfreeze
7
  from jax import lax
8
  from jax.random import PRNGKey
 
9
  from transformers.modeling_flax_outputs import (
10
  FlaxCausalLMOutputWithCrossAttentions,
11
  FlaxSeq2SeqLMOutput,
12
  FlaxSeq2SeqModelOutput,
13
  )
14
+ from .configuration_vit_gpt2 import ViTGPT2Config
15
+ from transformers import ViTConfig, GPT2Config
16
+ ### TODO: check FlaxPreTrainedModel
17
+ from transformers import FlaxPreTrainedModel, FlaxViTModel
18
+ from transformers.models.vit.modeling_flax_vit import FlaxViTModule
19
  from .modeling_flax_gpt2 import (
20
+ FlaxGPT2PreTrainedModel,
21
  FlaxGPT2Module,
22
  FlaxGPT2Model,
23
  FlaxGPT2LMHeadModule,
24
  FlaxGPT2LMHeadModel,
 
25
  )
 
 
 
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
28
  class FlaxViTGPT2LMModule(nn.Module):
29
+ """Play the same role as ``FlaxBartModule`` but with the decoder equipped with a LM head."""
30
  config: ViTGPT2Config
31
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
32
 
 
42
  return self.decoder
43
 
44
  def __call__(
45
+ self,
46
+ pixel_values,
47
+ attention_mask,
48
+ decoder_input_ids,
49
+ decoder_attention_mask,
50
+ decoder_position_ids,
51
+ output_attentions: bool = False,
52
+ output_hidden_states: bool = False,
53
+ return_dict: bool = True,
54
+ deterministic: bool = True,
55
  ):
56
  encoder_outputs = self.encoder(
57
  pixel_values=pixel_values,
 
62
  )
63
 
64
  decoder_outputs = self.decoder(
65
+ input_ids=decoder_input_ids,
66
+ attention_mask=decoder_attention_mask,
67
+ position_ids=decoder_position_ids,
68
  encoder_hidden_states=encoder_outputs[0],
69
+ encoder_attention_mask=attention_mask,
70
  deterministic=deterministic,
71
  output_attentions=output_attentions,
72
  output_hidden_states=output_hidden_states,
 
86
  encoder_attentions=encoder_outputs.attentions,
87
  )
88
 
89
+
90
  class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module):
91
+ """Play the same role as ``FlaxBartForConditionalGenerationModule`` but with the decoder equipped with a LM head.
92
+
93
+ Actually, it is identical to ``FlaxBartForConditionalGenerationModule`` with a different name.
94
+ """
95
  config: ViTGPT2Config
96
  dtype: jnp.dtype = jnp.float32
 
97
 
98
  def setup(self):
99
  self.model = FlaxViTGPT2LMModule(config=self.config, dtype=self.dtype)
 
107
  def __call__(
108
  self,
109
  pixel_values,
 
110
  attention_mask,
111
+ decoder_input_ids,
112
+ decoder_attention_mask,
113
+ decoder_position_ids,
114
  output_attentions: bool = False,
115
  output_hidden_states: bool = False,
116
  return_dict: bool = True,
 
118
  ):
119
  outputs = self.model(
120
  pixel_values=pixel_values,
 
121
  attention_mask=attention_mask,
122
+ decoder_input_ids=decoder_input_ids,
123
+ decoder_attention_mask=decoder_attention_mask,
124
+ decoder_position_ids=decoder_position_ids,
125
  output_attentions=output_attentions,
126
  output_hidden_states=output_hidden_states,
127
  return_dict=return_dict,
 
132
 
133
 
134
  class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
135
+ """Play the same role as ``FlaxBartPretrainedModel``"""
136
  config_class = ViTGPT2Config
137
  base_model_prefix: str = "model"
138
  module_class: nn.Module = None
 
152
  )
153
 
154
  module = self.module_class(config=config, dtype=dtype, **kwargs)
155
+ # This will use ``self.init_weights``.
156
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
 
157
 
158
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
 
 
 
 
159
 
160
+ encoder_input_shape, decoder_input_shape = input_shape
 
161
 
162
+ # init input tensors
163
+ pixel_values = jax.random.normal(rng, encoder_input_shape)
164
+ attention_mask = None
165
+ decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
166
+ # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
167
+ decoder_input_ids = jax.ops.index_update(decoder_input_ids, (..., -1), self.config.eos_token_id)
168
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
169
+
170
+ batch_size, sequence_length = decoder_input_ids.shape
171
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
172
 
173
  params_rng, dropout_rng = jax.random.split(rng)
174
  rngs = {"params": params_rng, "dropout": dropout_rng}
 
176
  return self.module.init(
177
  rngs,
178
  pixel_values,
 
179
  attention_mask,
180
+ decoder_input_ids,
181
+ decoder_attention_mask,
182
+ decoder_position_ids,
183
  )["params"]
184
 
185
  def init_cache(self, batch_size, max_length, encoder_outputs):
186
+ # init input variables to retrieve cache
187
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
188
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
189
+ decoder_position_ids = jnp.broadcast_to(
190
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape,
 
191
  )
192
 
193
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
 
 
 
 
 
 
194
  decoder_module = module._get_decoder_module()
195
  return decoder_module(
196
+ input_ids=decoder_input_ids,
197
+ attention_mask=decoder_attention_mask,
198
+ position_ids=decoder_position_ids,
199
  **kwargs,
200
  )
201
 
202
  init_variables = self.module.init(
203
  jax.random.PRNGKey(0),
204
+ decoder_input_ids=decoder_input_ids,
205
+ decoder_attention_mask=decoder_attention_mask,
206
+ decoder_position_ids=decoder_position_ids,
207
  encoder_hidden_states=encoder_outputs[0],
208
  init_cache=True,
209
  method=_decoder_forward, # we only need to call the decoder to init the cache
 
221
  params: dict = None,
222
  dropout_rng: PRNGKey = None,
223
  ):
224
+ output_attentions = (output_attentions if output_attentions is not None else self.config.vit_config.output_attentions)
 
 
 
 
225
  output_hidden_states = (
226
+ output_hidden_states if output_hidden_states is not None else self.config.vit_config.output_hidden_states
 
 
 
 
 
227
  )
228
+ return_dict = return_dict if return_dict is not None else self.config.vit_config.return_dict
229
 
230
+ # (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.)
231
  pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
232
 
233
  # Handle any PRNG if needed
 
252
 
253
  def decode(
254
  self,
255
+ decoder_input_ids,
256
  encoder_outputs,
257
  encoder_attention_mask: Optional[jnp.ndarray] = None,
258
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
259
+ decoder_position_ids: Optional[jnp.ndarray] = None,
260
  past_key_values: dict = None,
261
  output_attentions: Optional[bool] = None,
262
  output_hidden_states: Optional[bool] = None,
 
267
  ):
268
 
269
  output_attentions = (
270
+ output_attentions if output_attentions is not None else self.config.gpt2_config.output_attentions
 
 
271
  )
272
  output_hidden_states = (
273
+ output_hidden_states if output_hidden_states is not None else self.config.gpt2_config.output_hidden_states
 
 
 
 
 
274
  )
275
+ return_dict = return_dict if return_dict is not None else self.config.gpt2_config.return_dict
276
 
277
  encoder_hidden_states = encoder_outputs[0]
278
  if encoder_attention_mask is None:
279
  batch_size, sequence_length = encoder_hidden_states.shape[:2]
280
  encoder_attention_mask = jnp.ones((batch_size, sequence_length))
281
 
282
+ batch_size, sequence_length = decoder_input_ids.shape
283
+ if decoder_attention_mask is None:
284
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
285
 
286
+ if decoder_position_ids is None:
287
  if past_key_values is not None:
288
  raise ValueError(
289
  "Make sure to provide `position_ids` when passing `past_key_values`."
 
309
  else:
310
  mutable = False
311
 
312
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
 
 
 
 
 
 
313
  decoder_module = module._get_decoder_module()
314
  return decoder_module(
315
+ decoder_input_ids,
316
+ decoder_attention_mask,
317
+ decoder_position_ids,
318
  **kwargs,
319
  )
320
 
321
  outputs = self.module.apply(
322
  inputs,
323
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
324
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
325
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
326
  encoder_hidden_states=encoder_hidden_states,
327
  encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
328
  output_attentions=output_attentions,
 
348
  def __call__(
349
  self,
350
  pixel_values: jnp.ndarray,
 
351
  attention_mask: Optional[jnp.ndarray] = None,
352
+ decoder_input_ids: Optional[jnp.ndarray] = None,
353
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
354
+ decoder_position_ids: Optional[jnp.ndarray] = None,
355
  output_attentions: Optional[bool] = None,
356
  output_hidden_states: Optional[bool] = None,
357
  return_dict: Optional[bool] = None,
 
359
  params: dict = None,
360
  dropout_rng: PRNGKey = None,
361
  ):
362
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
363
+
 
 
 
364
  output_hidden_states = (
365
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
366
  )
367
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
368
 
369
+ # prepare encoder inputs (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.)
370
  pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
371
 
 
 
 
 
 
 
 
 
372
  # prepare decoder inputs
373
+ if decoder_input_ids is None:
374
+ decoder_input_ids = self.config.decoder_start_token_id * jnp.ones((pixel_values.shape[0], 1))
375
+ if decoder_attention_mask is None:
376
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
377
+ if decoder_position_ids is None:
378
+ batch_size, sequence_length = decoder_input_ids.shape
379
+ decoder_position_ids = jnp.broadcast_to(
 
 
 
380
  jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
381
  )
382
 
 
386
  return self.module.apply(
387
  {"params": params or self.params},
388
  pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
389
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
390
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
391
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
392
  output_attentions=output_attentions,
393
  output_hidden_states=output_hidden_states,
394
  return_dict=return_dict,
 
397
  )
398
 
399
 
400
+ # @add_start_docstrings(
401
+ # "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.",
402
+ # BART_START_DOCSTRING,
403
+ # )
404
+ # class FlaxViTGPT2LMModel(FlaxViTGPT2LMPreTrainedModel):
405
+ # config: BartConfig
406
+ # dtype: jnp.dtype = jnp.float32 # the dtype of the computation
407
+ # module_class = FlaxViTGPT2LMModule
408
+ #
409
+ #
410
+ # append_call_sample_docstring(
411
+ # FlaxBartModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
412
+ # )
413
+
414
+
415
  class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
416
  module_class = FlaxViTGPT2LMForConditionalGenerationModule
417
  dtype: jnp.dtype = jnp.float32
418
 
419
  def decode(
420
  self,
421
+ decoder_input_ids,
422
  encoder_outputs,
423
  encoder_attention_mask: Optional[jnp.ndarray] = None,
424
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
425
+ decoder_position_ids: Optional[jnp.ndarray] = None,
426
  past_key_values: dict = None,
427
  output_attentions: Optional[bool] = None,
428
  output_hidden_states: Optional[bool] = None,
 
431
  params: dict = None,
432
  dropout_rng: PRNGKey = None,
433
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
+ return super().decode(
436
+ decoder_input_ids,
437
+ encoder_outputs,
438
+ encoder_attention_mask,
439
+ decoder_attention_mask,
440
+ decoder_position_ids,
441
+ past_key_values,
442
+ output_attentions,
443
+ output_hidden_states,
444
+ return_dict,
445
+ not deterministic,
446
+ params,
447
+ dropout_rng,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  )
449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  def prepare_inputs_for_generation(
451
  self,
452
+ decoder_input_ids,
453
  max_length,
 
454
  attention_mask: Optional[jnp.DeviceArray] = None,
455
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
456
  encoder_outputs=None,
457
  **kwargs,
458
  ):
459
  # initializing the cache
460
+ batch_size, seq_length = decoder_input_ids.shape
461
 
462
  past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
463
  # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
464
  # But since the decoder uses a causal mask, those positions are masked anyways.
465
  # Thus we can create a single static attention_mask here, which is more efficient for compilation
466
  extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
467
+ if decoder_attention_mask is not None:
468
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
469
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
 
 
470
  else:
471
  position_ids = jnp.broadcast_to(
472
  jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
 
475
  return {
476
  "past_key_values": past_key_values,
477
  "encoder_outputs": encoder_outputs,
478
+ "encoder_attention_mask": attention_mask,
479
+ "decoder_attention_mask": extended_attention_mask,
480
+ "decoder_position_ids": position_ids,
481
  }
482
 
483
  def update_inputs_for_generation(self, model_outputs, model_kwargs):
484
  model_kwargs["past_key_values"] = model_outputs.past_key_values
485
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
 
 
486
  return model_kwargs
487
 
488
  @classmethod