ydshieh commited on
Commit
e30ab96
1 Parent(s): 5081c5d

Fix project_encoder

Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py CHANGED
@@ -553,8 +553,10 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
553
 
554
  # instantiate config with corresponding kwargs
555
  dtype = kwargs.pop("dtype", jnp.float32)
 
 
556
  config = ViTGPT2Config.from_vision_text_configs(
557
- vision_model.config, text_model.config, **kwargs
558
  )
559
 
560
  # init model
 
553
 
554
  # instantiate config with corresponding kwargs
555
  dtype = kwargs.pop("dtype", jnp.float32)
556
+ project_encoder = kwargs.pop("project_encoder", None)
557
+
558
  config = ViTGPT2Config.from_vision_text_configs(
559
+ vision_model.config, text_model.config, project_encoder=project_encoder, **kwargs
560
  )
561
 
562
  # init model