ydshieh commited on
Commit
7f266a7
1 Parent(s): 06bcf58

revert to original flax gpt2

Browse files
Files changed (1) hide show
  1. vit_gpt2/modeling_flax_gpt2.py +36 -185
vit_gpt2/modeling_flax_gpt2.py CHANGED
@@ -23,11 +23,11 @@ from flax.linen import combine_masks, make_causal_mask
23
  from flax.linen.attention import dot_product_attention_weights
24
  from jax import lax
25
 
26
- from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
27
- from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxCausalLMOutputWithCrossAttentions
28
- from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
29
- from transformers.utils import logging
30
- from transformers.models.gpt2.configuration_gpt2 import GPT2Config
31
 
32
 
33
  logger = logging.get_logger(__name__)
@@ -117,8 +117,6 @@ class FlaxConv1D(nn.Module):
117
  class FlaxGPT2Attention(nn.Module):
118
  config: GPT2Config
119
  dtype: jnp.dtype = jnp.float32
120
- causal: bool = True
121
- self_attn: bool = True
122
 
123
  def setup(self):
124
  config = self.config
@@ -126,18 +124,10 @@ class FlaxGPT2Attention(nn.Module):
126
  self.num_heads = config.num_attention_heads
127
  self.head_dim = self.embed_dim // self.num_heads
128
 
129
- factor = 3 if self.self_attn else 2
130
- self.c_attn = FlaxConv1D(features=factor * self.embed_dim, dtype=self.dtype)
131
  self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
132
-
133
- if not self.self_attn:
134
- self.c_query_attn = FlaxConv1D(features=1 * self.embed_dim, dtype=self.dtype)
135
-
136
  self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
137
- if self.causal:
138
- self.causal_mask = make_causal_mask(
139
- jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool"
140
- )
141
 
142
  def _split_heads(self, hidden_states):
143
  return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
@@ -180,30 +170,13 @@ class FlaxGPT2Attention(nn.Module):
180
  def __call__(
181
  self,
182
  hidden_states,
183
- key_value_states: Optional[jnp.ndarray] = None,
184
  attention_mask=None,
185
  deterministic: bool = True,
186
  init_cache: bool = False,
187
  output_attentions: bool = False,
188
  ):
189
-
190
- # if key_value_states are provided this layer is used as a cross-attention layer
191
- # for the decoder
192
- is_cross_attention = key_value_states is not None
193
-
194
- if not is_cross_attention:
195
- # self_attention
196
- assert self.self_attn
197
- qkv_out = self.c_attn(hidden_states)
198
- query, key, value = jnp.split(qkv_out, 3, axis=2)
199
- else:
200
- # cross_attentions
201
- assert not self.self_attn
202
- assert not self.causal
203
- q_out = self.c_query_attn(hidden_states)
204
- (query,) = jnp.split(q_out, 1, axis=2)
205
- kv_out = self.c_attn(key_value_states)
206
- key, value = jnp.split(kv_out, 2, axis=2)
207
 
208
  query = self._split_heads(query)
209
  key = self._split_heads(key)
@@ -211,27 +184,20 @@ class FlaxGPT2Attention(nn.Module):
211
 
212
  query_length, key_length = query.shape[1], key.shape[1]
213
 
214
- if self.causal:
215
- if self.has_variable("cache", "cached_key"):
216
- mask_shift = self.variables["cache"]["cache_index"]
217
- max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
218
- causal_mask = lax.dynamic_slice(
219
- self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
220
- )
221
- else:
222
- causal_mask = self.causal_mask[:, :, :query_length, :key_length]
223
-
224
- batch_size = hidden_states.shape[0]
225
- causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
226
-
227
- # combine masks if needed
228
- if attention_mask is not None and self.causal:
229
- attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
230
- attention_mask = combine_masks(attention_mask, causal_mask)
231
- elif self.causal:
232
- attention_mask = causal_mask
233
- elif attention_mask is not None:
234
- attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
235
 
236
  dropout_rng = None
237
  if not deterministic and self.config.attn_pdrop > 0.0:
@@ -239,18 +205,15 @@ class FlaxGPT2Attention(nn.Module):
239
 
240
  # During fast autoregressive decoding, we feed one position at a time,
241
  # and cache the keys and values step by step.
242
- if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
243
  key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
244
 
245
  # transform boolean mask into float mask
246
- if attention_mask is not None:
247
- attention_bias = lax.select(
248
- attention_mask > 0,
249
- jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
250
- jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
251
- )
252
- else:
253
- attention_bias = None
254
 
255
  # usual dot product attention
256
  attn_weights = dot_product_attention_weights(
@@ -298,23 +261,11 @@ class FlaxGPT2Block(nn.Module):
298
  dtype: jnp.dtype = jnp.float32
299
 
300
  def setup(self):
301
-
302
  hidden_size = self.config.hidden_size
303
  inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
304
 
305
  self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
306
  self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
307
-
308
- if self.config.add_cross_attention:
309
- self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
310
- # [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
311
- self.crossattention = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
312
-
313
- project_encoder = getattr(self.config, "project_encoder", None)
314
- if project_encoder:
315
- self.encoder_projection_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
316
- self.encoder_projection_mlp = FlaxGPT2MLP(self.config, self.config.hidden_size, dtype=self.dtype)
317
-
318
  self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
319
  self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
320
 
@@ -322,8 +273,6 @@ class FlaxGPT2Block(nn.Module):
322
  self,
323
  hidden_states,
324
  attention_mask=None,
325
- encoder_hidden_states: Optional[jnp.ndarray] = None,
326
- encoder_attention_mask: Optional[jnp.ndarray] = None,
327
  deterministic: bool = True,
328
  init_cache: bool = False,
329
  output_attentions: bool = False,
@@ -341,61 +290,13 @@ class FlaxGPT2Block(nn.Module):
341
  attn_output = outputs[0]
342
  hidden_states = attn_output + residual
343
 
344
- # Cross-Attention Block
345
- cross_attn_weights = None
346
- if encoder_hidden_states is not None:
347
-
348
- # add one self-attention block for cross-attention
349
- if not hasattr(self, "crossattention"):
350
- raise ValueError(
351
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
352
- "cross-attention layers by setting `config.add_cross_attention=True`"
353
- )
354
-
355
- project_encoder = getattr(self.config, "project_encoder", None)
356
- if project_encoder:
357
- residual = encoder_hidden_states
358
- encoder_hidden_states = self.encoder_projection_ln(encoder_hidden_states)
359
- feed_forward_hidden_states = self.encoder_projection_mlp(
360
- encoder_hidden_states, deterministic=deterministic
361
- )
362
- # residual connection
363
- encoder_hidden_states = residual + feed_forward_hidden_states
364
-
365
- residual = hidden_states
366
- hidden_states = self.ln_cross_attn(hidden_states)
367
-
368
- cross_attn_outputs = self.crossattention(
369
- hidden_states=hidden_states,
370
- key_value_states=encoder_hidden_states,
371
- attention_mask=encoder_attention_mask,
372
- deterministic=deterministic,
373
- # `init_cache` is only for decoder's `self_attn`
374
- init_cache=False,
375
- output_attentions=output_attentions,
376
- )
377
- # residual connection
378
- cross_attn_output = cross_attn_outputs[0]
379
- hidden_states = cross_attn_output + residual
380
-
381
- if output_attentions:
382
- cross_attn_weights = cross_attn_outputs[1]
383
-
384
  residual = hidden_states
385
  hidden_states = self.ln_2(hidden_states)
386
  feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
387
  # residual connection
388
  hidden_states = residual + feed_forward_hidden_states
389
 
390
- outputs = (hidden_states,)
391
-
392
- if output_attentions:
393
- self_attn_weights = attn_output[1]
394
- outputs += (self_attn_weights,)
395
- if cross_attn_weights is not None:
396
- outputs += (cross_attn_weights,)
397
-
398
- return outputs
399
 
400
 
401
  class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
@@ -427,24 +328,7 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
427
  params_rng, dropout_rng = jax.random.split(rng)
428
  rngs = {"params": params_rng, "dropout": dropout_rng}
429
 
430
- if self.config.add_cross_attention:
431
- encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
432
- encoder_attention_mask = attention_mask
433
- module_init_outputs = self.module.init(
434
- rngs, input_ids, attention_mask, position_ids,
435
- encoder_hidden_states, encoder_attention_mask, return_dict=False
436
- )
437
- else:
438
- module_init_outputs = self.module.init(
439
- rngs, input_ids, attention_mask, position_ids, return_dict=False
440
- )
441
-
442
- return module_init_outputs["params"]
443
-
444
- # TODO: Remove if OK
445
- # @classmethod
446
- # def _from_config(cls, config, **kwargs):
447
- # return super()._from_config(config, **kwargs)
448
 
449
  def init_cache(self, batch_size, max_length):
450
  r"""
@@ -471,8 +355,6 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
471
  input_ids,
472
  attention_mask=None,
473
  position_ids=None,
474
- encoder_hidden_states: Optional[jnp.ndarray] = None,
475
- encoder_attention_mask: Optional[jnp.ndarray] = None,
476
  params: dict = None,
477
  past_key_values: dict = None,
478
  dropout_rng: jax.random.PRNGKey = None,
@@ -487,10 +369,6 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
487
  )
488
  return_dict = return_dict if return_dict is not None else self.config.return_dict
489
 
490
- if encoder_hidden_states is not None and encoder_attention_mask is None:
491
- batch_size, sequence_length = encoder_hidden_states.shape[:2]
492
- encoder_attention_mask = jnp.ones((batch_size, sequence_length))
493
-
494
  batch_size, sequence_length = input_ids.shape
495
 
496
  if position_ids is None:
@@ -521,8 +399,6 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
521
  jnp.array(input_ids, dtype="i4"),
522
  jnp.array(attention_mask, dtype="i4"),
523
  jnp.array(position_ids, dtype="i4"),
524
- encoder_hidden_states,
525
- encoder_attention_mask,
526
  not train,
527
  False,
528
  output_attentions,
@@ -557,8 +433,6 @@ class FlaxGPT2BlockCollection(nn.Module):
557
  self,
558
  hidden_states,
559
  attention_mask=None,
560
- encoder_hidden_states: Optional[jnp.ndarray] = None,
561
- encoder_attention_mask: Optional[jnp.ndarray] = None,
562
  deterministic: bool = True,
563
  init_cache: bool = False,
564
  output_attentions: bool = False,
@@ -567,7 +441,6 @@ class FlaxGPT2BlockCollection(nn.Module):
567
  ):
568
  all_attentions = () if output_attentions else None
569
  all_hidden_states = () if output_hidden_states else None
570
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
571
 
572
  for block in self.blocks:
573
  if output_hidden_states:
@@ -576,8 +449,6 @@ class FlaxGPT2BlockCollection(nn.Module):
576
  layer_outputs = block(
577
  hidden_states,
578
  attention_mask,
579
- encoder_hidden_states=encoder_hidden_states,
580
- encoder_attention_mask=encoder_attention_mask,
581
  deterministic=deterministic,
582
  init_cache=init_cache,
583
  output_attentions=output_attentions,
@@ -587,25 +458,19 @@ class FlaxGPT2BlockCollection(nn.Module):
587
  if output_attentions:
588
  all_attentions += (layer_outputs[1],)
589
 
590
- if encoder_hidden_states is not None:
591
- all_cross_attentions += (layer_outputs[2],)
592
-
593
  if output_hidden_states:
594
  all_hidden_states += (hidden_states,)
595
 
596
- # In Flax, `past_key_values` is not contained in modules' outputs.
597
- outputs = [hidden_states, all_hidden_states, all_attentions, all_cross_attentions]
598
 
599
  if not return_dict:
600
  return tuple(v for v in outputs if v is not None)
601
 
602
- # with cross_attn
603
- return FlaxBaseModelOutputWithPastAndCrossAttentions(
604
  last_hidden_state=hidden_states,
605
  past_key_values=None,
606
  hidden_states=all_hidden_states,
607
  attentions=all_attentions,
608
- cross_attentions=all_cross_attentions,
609
  )
610
 
611
 
@@ -637,8 +502,6 @@ class FlaxGPT2Module(nn.Module):
637
  input_ids,
638
  attention_mask,
639
  position_ids,
640
- encoder_hidden_states: Optional[jnp.ndarray] = None,
641
- encoder_attention_mask: Optional[jnp.ndarray] = None,
642
  deterministic=True,
643
  init_cache: bool = False,
644
  output_attentions: bool = False,
@@ -654,8 +517,6 @@ class FlaxGPT2Module(nn.Module):
654
  outputs = self.h(
655
  hidden_states,
656
  attention_mask,
657
- encoder_hidden_states,
658
- encoder_attention_mask,
659
  deterministic=deterministic,
660
  init_cache=init_cache,
661
  output_attentions=output_attentions,
@@ -669,11 +530,10 @@ class FlaxGPT2Module(nn.Module):
669
  if not return_dict:
670
  return (hidden_states,) + outputs[1:]
671
 
672
- return FlaxBaseModelOutputWithPastAndCrossAttentions(
673
  last_hidden_state=hidden_states,
674
  hidden_states=outputs.hidden_states,
675
  attentions=outputs.attentions,
676
- cross_attentions=outputs.cross_attentions,
677
  )
678
 
679
 
@@ -708,8 +568,6 @@ class FlaxGPT2LMHeadModule(nn.Module):
708
  input_ids,
709
  attention_mask,
710
  position_ids,
711
- encoder_hidden_states: Optional[jnp.ndarray] = None,
712
- encoder_attention_mask: Optional[jnp.ndarray] = None,
713
  deterministic: bool = True,
714
  init_cache: bool = False,
715
  output_attentions: bool = False,
@@ -720,8 +578,6 @@ class FlaxGPT2LMHeadModule(nn.Module):
720
  input_ids,
721
  attention_mask,
722
  position_ids,
723
- encoder_hidden_states,
724
- encoder_attention_mask,
725
  deterministic=deterministic,
726
  init_cache=init_cache,
727
  output_attentions=output_attentions,
@@ -740,13 +596,8 @@ class FlaxGPT2LMHeadModule(nn.Module):
740
  if not return_dict:
741
  return (lm_logits,) + outputs[1:]
742
 
743
- return FlaxCausalLMOutputWithCrossAttentions(
744
- logits=lm_logits,
745
- past_key_values=None,
746
- hidden_states=outputs.hidden_states,
747
- attentions=outputs.attentions,
748
- cross_attentions=outputs.cross_attentions
749
- )
750
 
751
  @add_start_docstrings(
752
  """
 
23
  from flax.linen.attention import dot_product_attention_weights
24
  from jax import lax
25
 
26
+ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
27
+ from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput
28
+ from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
29
+ from ...utils import logging
30
+ from .configuration_gpt2 import GPT2Config
31
 
32
 
33
  logger = logging.get_logger(__name__)
 
117
  class FlaxGPT2Attention(nn.Module):
118
  config: GPT2Config
119
  dtype: jnp.dtype = jnp.float32
 
 
120
 
121
  def setup(self):
122
  config = self.config
 
124
  self.num_heads = config.num_attention_heads
125
  self.head_dim = self.embed_dim // self.num_heads
126
 
127
+ self.c_attn = FlaxConv1D(features=3 * self.embed_dim, dtype=self.dtype)
 
128
  self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
 
 
 
 
129
  self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
130
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
 
 
 
131
 
132
  def _split_heads(self, hidden_states):
133
  return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
 
170
  def __call__(
171
  self,
172
  hidden_states,
 
173
  attention_mask=None,
174
  deterministic: bool = True,
175
  init_cache: bool = False,
176
  output_attentions: bool = False,
177
  ):
178
+ qkv_out = self.c_attn(hidden_states)
179
+ query, key, value = jnp.split(qkv_out, 3, axis=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  query = self._split_heads(query)
182
  key = self._split_heads(key)
 
184
 
185
  query_length, key_length = query.shape[1], key.shape[1]
186
 
187
+ if self.has_variable("cache", "cached_key"):
188
+ mask_shift = self.variables["cache"]["cache_index"]
189
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
190
+ causal_mask = lax.dynamic_slice(
191
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
192
+ )
193
+ else:
194
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
195
+
196
+ batch_size = hidden_states.shape[0]
197
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
198
+
199
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
200
+ attention_mask = combine_masks(attention_mask, causal_mask)
 
 
 
 
 
 
 
201
 
202
  dropout_rng = None
203
  if not deterministic and self.config.attn_pdrop > 0.0:
 
205
 
206
  # During fast autoregressive decoding, we feed one position at a time,
207
  # and cache the keys and values step by step.
208
+ if self.has_variable("cache", "cached_key") or init_cache:
209
  key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
210
 
211
  # transform boolean mask into float mask
212
+ attention_bias = lax.select(
213
+ attention_mask > 0,
214
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
215
+ jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
216
+ )
 
 
 
217
 
218
  # usual dot product attention
219
  attn_weights = dot_product_attention_weights(
 
261
  dtype: jnp.dtype = jnp.float32
262
 
263
  def setup(self):
 
264
  hidden_size = self.config.hidden_size
265
  inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
266
 
267
  self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
268
  self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
 
 
 
 
 
 
 
 
 
 
 
269
  self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
270
  self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
271
 
 
273
  self,
274
  hidden_states,
275
  attention_mask=None,
 
 
276
  deterministic: bool = True,
277
  init_cache: bool = False,
278
  output_attentions: bool = False,
 
290
  attn_output = outputs[0]
291
  hidden_states = attn_output + residual
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  residual = hidden_states
294
  hidden_states = self.ln_2(hidden_states)
295
  feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
296
  # residual connection
297
  hidden_states = residual + feed_forward_hidden_states
298
 
299
+ return (hidden_states,) + outputs[1:]
 
 
 
 
 
 
 
 
300
 
301
 
302
  class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
 
328
  params_rng, dropout_rng = jax.random.split(rng)
329
  rngs = {"params": params_rng, "dropout": dropout_rng}
330
 
331
+ return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
  def init_cache(self, batch_size, max_length):
334
  r"""
 
355
  input_ids,
356
  attention_mask=None,
357
  position_ids=None,
 
 
358
  params: dict = None,
359
  past_key_values: dict = None,
360
  dropout_rng: jax.random.PRNGKey = None,
 
369
  )
370
  return_dict = return_dict if return_dict is not None else self.config.return_dict
371
 
 
 
 
 
372
  batch_size, sequence_length = input_ids.shape
373
 
374
  if position_ids is None:
 
399
  jnp.array(input_ids, dtype="i4"),
400
  jnp.array(attention_mask, dtype="i4"),
401
  jnp.array(position_ids, dtype="i4"),
 
 
402
  not train,
403
  False,
404
  output_attentions,
 
433
  self,
434
  hidden_states,
435
  attention_mask=None,
 
 
436
  deterministic: bool = True,
437
  init_cache: bool = False,
438
  output_attentions: bool = False,
 
441
  ):
442
  all_attentions = () if output_attentions else None
443
  all_hidden_states = () if output_hidden_states else None
 
444
 
445
  for block in self.blocks:
446
  if output_hidden_states:
 
449
  layer_outputs = block(
450
  hidden_states,
451
  attention_mask,
 
 
452
  deterministic=deterministic,
453
  init_cache=init_cache,
454
  output_attentions=output_attentions,
 
458
  if output_attentions:
459
  all_attentions += (layer_outputs[1],)
460
 
 
 
 
461
  if output_hidden_states:
462
  all_hidden_states += (hidden_states,)
463
 
464
+ outputs = (hidden_states,)
 
465
 
466
  if not return_dict:
467
  return tuple(v for v in outputs if v is not None)
468
 
469
+ return FlaxBaseModelOutputWithPast(
 
470
  last_hidden_state=hidden_states,
471
  past_key_values=None,
472
  hidden_states=all_hidden_states,
473
  attentions=all_attentions,
 
474
  )
475
 
476
 
 
502
  input_ids,
503
  attention_mask,
504
  position_ids,
 
 
505
  deterministic=True,
506
  init_cache: bool = False,
507
  output_attentions: bool = False,
 
517
  outputs = self.h(
518
  hidden_states,
519
  attention_mask,
 
 
520
  deterministic=deterministic,
521
  init_cache=init_cache,
522
  output_attentions=output_attentions,
 
530
  if not return_dict:
531
  return (hidden_states,) + outputs[1:]
532
 
533
+ return FlaxBaseModelOutput(
534
  last_hidden_state=hidden_states,
535
  hidden_states=outputs.hidden_states,
536
  attentions=outputs.attentions,
 
537
  )
538
 
539
 
 
568
  input_ids,
569
  attention_mask,
570
  position_ids,
 
 
571
  deterministic: bool = True,
572
  init_cache: bool = False,
573
  output_attentions: bool = False,
 
578
  input_ids,
579
  attention_mask,
580
  position_ids,
 
 
581
  deterministic=deterministic,
582
  init_cache=init_cache,
583
  output_attentions=output_attentions,
 
596
  if not return_dict:
597
  return (lm_logits,) + outputs[1:]
598
 
599
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
600
+
 
 
 
 
 
601
 
602
  @add_start_docstrings(
603
  """