ydshieh commited on
Commit
7d3b1a0
1 Parent(s): e5b3a97

Change a parameter in decode(): deterministic -> train

Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py CHANGED
@@ -398,21 +398,6 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
398
  )
399
 
400
 
401
- # @add_start_docstrings(
402
- # "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.",
403
- # BART_START_DOCSTRING,
404
- # )
405
- # class FlaxViTGPT2LMModel(FlaxViTGPT2LMPreTrainedModel):
406
- # config: BartConfig
407
- # dtype: jnp.dtype = jnp.float32 # the dtype of the computation
408
- # module_class = FlaxViTGPT2LMModule
409
- #
410
- #
411
- # append_call_sample_docstring(
412
- # FlaxBartModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
413
- # )
414
-
415
-
416
  class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
417
  module_class = FlaxViTGPT2LMForConditionalGenerationModule
418
  dtype: jnp.dtype = jnp.float32
@@ -428,7 +413,7 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
428
  output_attentions: Optional[bool] = None,
429
  output_hidden_states: Optional[bool] = None,
430
  return_dict: Optional[bool] = None,
431
- deterministic: bool = True,
432
  params: dict = None,
433
  dropout_rng: PRNGKey = None,
434
  ):
@@ -443,7 +428,7 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
443
  output_attentions,
444
  output_hidden_states,
445
  return_dict,
446
- not deterministic,
447
  params,
448
  dropout_rng,
449
  )
 
398
  )
399
 
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
402
  module_class = FlaxViTGPT2LMForConditionalGenerationModule
403
  dtype: jnp.dtype = jnp.float32
 
413
  output_attentions: Optional[bool] = None,
414
  output_hidden_states: Optional[bool] = None,
415
  return_dict: Optional[bool] = None,
416
+ train: bool = False,
417
  params: dict = None,
418
  dropout_rng: PRNGKey = None,
419
  ):
 
428
  output_attentions,
429
  output_hidden_states,
430
  return_dict,
431
+ train,
432
  params,
433
  dropout_rng,
434
  )