ydshieh
commited on
Commit
•
b8c22f0
1
Parent(s):
e30ab96
Fix project_encoder
Browse files
vit_gpt2/configuration_vit_gpt2.py
CHANGED
@@ -47,9 +47,9 @@ class ViTGPT2Config(PretrainedConfig):
|
|
47 |
_project_encoder = getattr(self.text_config, "project_encoder", None)
|
48 |
if project_encoder is not None and _project_encoder is not None:
|
49 |
assert project_encoder == _project_encoder
|
50 |
-
elif project_encoder:
|
51 |
_project_encoder = project_encoder
|
52 |
-
elif _project_encoder:
|
53 |
project_encoder = _project_encoder
|
54 |
else:
|
55 |
project_encoder = False
|
|
|
47 |
_project_encoder = getattr(self.text_config, "project_encoder", None)
|
48 |
if project_encoder is not None and _project_encoder is not None:
|
49 |
assert project_encoder == _project_encoder
|
50 |
+
elif project_encoder is not None:
|
51 |
_project_encoder = project_encoder
|
52 |
+
elif _project_encoder is not None:
|
53 |
project_encoder = _project_encoder
|
54 |
else:
|
55 |
project_encoder = False
|
vit_gpt2/modeling_flax_vit_gpt2_lm.py
CHANGED
@@ -534,6 +534,7 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
|
534 |
vision_pretrained_model_name_or_path, *vision_model_args, **vision_kwargs
|
535 |
)
|
536 |
|
|
|
537 |
if text_model is None:
|
538 |
assert (
|
539 |
text_pretrained_model_name_or_path is not None
|
@@ -542,6 +543,8 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
|
542 |
if "config" not in text_kwargs:
|
543 |
text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
|
544 |
text_config.project_encoder = text_kwargs.pop("project_encoder", None)
|
|
|
|
|
545 |
text_kwargs["config"] = text_config
|
546 |
|
547 |
text_kwargs["config"].add_cross_attention = True
|
@@ -553,8 +556,6 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
|
553 |
|
554 |
# instantiate config with corresponding kwargs
|
555 |
dtype = kwargs.pop("dtype", jnp.float32)
|
556 |
-
project_encoder = kwargs.pop("project_encoder", None)
|
557 |
-
|
558 |
config = ViTGPT2Config.from_vision_text_configs(
|
559 |
vision_model.config, text_model.config, project_encoder=project_encoder, **kwargs
|
560 |
)
|
|
|
534 |
vision_pretrained_model_name_or_path, *vision_model_args, **vision_kwargs
|
535 |
)
|
536 |
|
537 |
+
project_encoder = kwargs.pop("project_encoder", None)
|
538 |
if text_model is None:
|
539 |
assert (
|
540 |
text_pretrained_model_name_or_path is not None
|
|
|
543 |
if "config" not in text_kwargs:
|
544 |
text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
|
545 |
text_config.project_encoder = text_kwargs.pop("project_encoder", None)
|
546 |
+
if project_encoder is not None:
|
547 |
+
text_config.project_encoder = project_encoder
|
548 |
text_kwargs["config"] = text_config
|
549 |
|
550 |
text_kwargs["config"].add_cross_attention = True
|
|
|
556 |
|
557 |
# instantiate config with corresponding kwargs
|
558 |
dtype = kwargs.pop("dtype", jnp.float32)
|
|
|
|
|
559 |
config = ViTGPT2Config.from_vision_text_configs(
|
560 |
vision_model.config, text_model.config, project_encoder=project_encoder, **kwargs
|
561 |
)
|