ydshieh
commited on
Commit
•
54ece9e
1
Parent(s):
7b292b7
Add project_encoder and related layers
Browse files
vit_gpt2/configuration_vit_gpt2.py
CHANGED
@@ -16,7 +16,10 @@ class ViTGPT2Config(PretrainedConfig):
|
|
16 |
|
17 |
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
|
18 |
super().__init__(
|
19 |
-
|
|
|
|
|
|
|
20 |
|
21 |
if vision_config_dict is None:
|
22 |
vision_config_dict = {}
|
@@ -41,6 +44,19 @@ class ViTGPT2Config(PretrainedConfig):
|
|
41 |
self.decoder_start_token_id = self.text_config.bos_token_id
|
42 |
self.forced_eos_token_id = self.text_config.eos_token_id
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
@classmethod
|
45 |
def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
|
46 |
|
|
|
16 |
|
17 |
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
|
18 |
super().__init__(
|
19 |
+
vision_config_dict=vision_config_dict, text_config_dict=text_config_dict, **kwargs
|
20 |
+
)
|
21 |
+
|
22 |
+
project_encoder = kwargs.pop("project_encoder", None)
|
23 |
|
24 |
if vision_config_dict is None:
|
25 |
vision_config_dict = {}
|
|
|
44 |
self.decoder_start_token_id = self.text_config.bos_token_id
|
45 |
self.forced_eos_token_id = self.text_config.eos_token_id
|
46 |
|
47 |
+
_project_encoder = getattr(self.text_config, "project_encoder", None)
|
48 |
+
if project_encoder is not None and _project_encoder is not None:
|
49 |
+
assert project_encoder == _project_encoder
|
50 |
+
elif project_encoder:
|
51 |
+
_project_encoder = project_encoder
|
52 |
+
elif _project_encoder:
|
53 |
+
project_encoder = _project_encoder
|
54 |
+
else:
|
55 |
+
project_encoder = False
|
56 |
+
|
57 |
+
self.config.project_encoder = project_encoder
|
58 |
+
self.text_config.project_encoder = project_encoder
|
59 |
+
|
60 |
@classmethod
|
61 |
def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
|
62 |
|
vit_gpt2/modeling_flax_gpt2.py
CHANGED
@@ -308,9 +308,13 @@ class FlaxGPT2Block(nn.Module):
|
|
308 |
self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
|
309 |
|
310 |
if not self.only_self_attn:
|
311 |
-
self.
|
312 |
# [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
|
313 |
-
self.
|
|
|
|
|
|
|
|
|
314 |
|
315 |
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
316 |
self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
|
@@ -348,10 +352,19 @@ class FlaxGPT2Block(nn.Module):
|
|
348 |
cross_attn_weights = None
|
349 |
if encoder_hidden_states is not None:
|
350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
residual = hidden_states
|
352 |
-
hidden_states = self.
|
353 |
|
354 |
-
cross_attn_outputs = self.
|
355 |
hidden_states=hidden_states,
|
356 |
key_value_states=encoder_hidden_states,
|
357 |
attention_mask=encoder_attention_mask,
|
|
|
308 |
self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
|
309 |
|
310 |
if not self.only_self_attn:
|
311 |
+
self.cross_attn_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
312 |
# [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
|
313 |
+
self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
|
314 |
+
|
315 |
+
if self.config.project_encoder:
|
316 |
+
self.encoder_projection_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
317 |
+
self.encoder_projection_mlp = FlaxGPT2MLP(self.config, self.config.hidden_size, dtype=self.dtype)
|
318 |
|
319 |
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
320 |
self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
|
|
|
352 |
cross_attn_weights = None
|
353 |
if encoder_hidden_states is not None:
|
354 |
|
355 |
+
if self.project_encoder:
|
356 |
+
residual = encoder_hidden_states
|
357 |
+
encoder_hidden_states = self.encoder_projection_ln(encoder_hidden_states)
|
358 |
+
feed_forward_hidden_states = self.encoder_projection_mlp(
|
359 |
+
encoder_hidden_states, deterministic=deterministic
|
360 |
+
)
|
361 |
+
# residual connection
|
362 |
+
encoder_hidden_states = residual + feed_forward_hidden_states
|
363 |
+
|
364 |
residual = hidden_states
|
365 |
+
hidden_states = self.cross_attn_ln(hidden_states)
|
366 |
|
367 |
+
cross_attn_outputs = self.cross_attn(
|
368 |
hidden_states=hidden_states,
|
369 |
key_value_states=encoder_hidden_states,
|
370 |
attention_mask=encoder_attention_mask,
|
vit_gpt2/modeling_flax_vit_gpt2_lm.py
CHANGED
@@ -541,6 +541,7 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
|
541 |
|
542 |
if "config" not in text_kwargs:
|
543 |
text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
|
|
|
544 |
text_kwargs["config"] = text_config
|
545 |
|
546 |
text_kwargs["config"].add_cross_attention = True
|
|
|
541 |
|
542 |
if "config" not in text_kwargs:
|
543 |
text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
|
544 |
+
text_config.project_encoder = text_kwargs.pop("project_encoder", None)
|
545 |
text_kwargs["config"] = text_config
|
546 |
|
547 |
text_kwargs["config"].add_cross_attention = True
|