ydshieh commited on
Commit
b8c22f0
1 Parent(s): e30ab96

Fix project_encoder

Browse files
vit_gpt2/configuration_vit_gpt2.py CHANGED
@@ -47,9 +47,9 @@ class ViTGPT2Config(PretrainedConfig):
47
  _project_encoder = getattr(self.text_config, "project_encoder", None)
48
  if project_encoder is not None and _project_encoder is not None:
49
  assert project_encoder == _project_encoder
50
- elif project_encoder:
51
  _project_encoder = project_encoder
52
- elif _project_encoder:
53
  project_encoder = _project_encoder
54
  else:
55
  project_encoder = False
 
47
  _project_encoder = getattr(self.text_config, "project_encoder", None)
48
  if project_encoder is not None and _project_encoder is not None:
49
  assert project_encoder == _project_encoder
50
+ elif project_encoder is not None:
51
  _project_encoder = project_encoder
52
+ elif _project_encoder is not None:
53
  project_encoder = _project_encoder
54
  else:
55
  project_encoder = False
vit_gpt2/modeling_flax_vit_gpt2_lm.py CHANGED
@@ -534,6 +534,7 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
534
  vision_pretrained_model_name_or_path, *vision_model_args, **vision_kwargs
535
  )
536
 
 
537
  if text_model is None:
538
  assert (
539
  text_pretrained_model_name_or_path is not None
@@ -542,6 +543,8 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
542
  if "config" not in text_kwargs:
543
  text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
544
  text_config.project_encoder = text_kwargs.pop("project_encoder", None)
 
 
545
  text_kwargs["config"] = text_config
546
 
547
  text_kwargs["config"].add_cross_attention = True
@@ -553,8 +556,6 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
553
 
554
  # instantiate config with corresponding kwargs
555
  dtype = kwargs.pop("dtype", jnp.float32)
556
- project_encoder = kwargs.pop("project_encoder", None)
557
-
558
  config = ViTGPT2Config.from_vision_text_configs(
559
  vision_model.config, text_model.config, project_encoder=project_encoder, **kwargs
560
  )
 
534
  vision_pretrained_model_name_or_path, *vision_model_args, **vision_kwargs
535
  )
536
 
537
+ project_encoder = kwargs.pop("project_encoder", None)
538
  if text_model is None:
539
  assert (
540
  text_pretrained_model_name_or_path is not None
 
543
  if "config" not in text_kwargs:
544
  text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
545
  text_config.project_encoder = text_kwargs.pop("project_encoder", None)
546
+ if project_encoder is not None:
547
+ text_config.project_encoder = project_encoder
548
  text_kwargs["config"] = text_config
549
 
550
  text_kwargs["config"].add_cross_attention = True
 
556
 
557
  # instantiate config with corresponding kwargs
558
  dtype = kwargs.pop("dtype", jnp.float32)
 
 
559
  config = ViTGPT2Config.from_vision_text_configs(
560
  vision_model.config, text_model.config, project_encoder=project_encoder, **kwargs
561
  )