ydshieh commited on
Commit
54ece9e
1 Parent(s): 7b292b7

Add project_encoder and related layers

Browse files
vit_gpt2/configuration_vit_gpt2.py CHANGED
@@ -16,7 +16,10 @@ class ViTGPT2Config(PretrainedConfig):
16
 
17
  def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
18
  super().__init__(
19
- text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
 
 
 
20
 
21
  if vision_config_dict is None:
22
  vision_config_dict = {}
@@ -41,6 +44,19 @@ class ViTGPT2Config(PretrainedConfig):
41
  self.decoder_start_token_id = self.text_config.bos_token_id
42
  self.forced_eos_token_id = self.text_config.eos_token_id
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  @classmethod
45
  def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
46
 
 
16
 
17
  def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
18
  super().__init__(
19
+ vision_config_dict=vision_config_dict, text_config_dict=text_config_dict, **kwargs
20
+ )
21
+
22
+ project_encoder = kwargs.pop("project_encoder", None)
23
 
24
  if vision_config_dict is None:
25
  vision_config_dict = {}
 
44
  self.decoder_start_token_id = self.text_config.bos_token_id
45
  self.forced_eos_token_id = self.text_config.eos_token_id
46
 
47
+ _project_encoder = getattr(self.text_config, "project_encoder", None)
48
+ if project_encoder is not None and _project_encoder is not None:
49
+ assert project_encoder == _project_encoder
50
+ elif project_encoder:
51
+ _project_encoder = project_encoder
52
+ elif _project_encoder:
53
+ project_encoder = _project_encoder
54
+ else:
55
+ project_encoder = False
56
+
57
+ self.config.project_encoder = project_encoder
58
+ self.text_config.project_encoder = project_encoder
59
+
60
  @classmethod
61
  def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
62
 
vit_gpt2/modeling_flax_gpt2.py CHANGED
@@ -308,9 +308,13 @@ class FlaxGPT2Block(nn.Module):
308
  self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
309
 
310
  if not self.only_self_attn:
311
- self.encoder_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
312
  # [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
313
- self.encoder_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
 
 
 
 
314
 
315
  self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
316
  self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
@@ -348,10 +352,19 @@ class FlaxGPT2Block(nn.Module):
348
  cross_attn_weights = None
349
  if encoder_hidden_states is not None:
350
 
 
 
 
 
 
 
 
 
 
351
  residual = hidden_states
352
- hidden_states = self.encoder_ln(hidden_states)
353
 
354
- cross_attn_outputs = self.encoder_attn(
355
  hidden_states=hidden_states,
356
  key_value_states=encoder_hidden_states,
357
  attention_mask=encoder_attention_mask,
 
308
  self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
309
 
310
  if not self.only_self_attn:
311
+ self.cross_attn_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
312
  # [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
313
+ self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
314
+
315
+ if self.config.project_encoder:
316
+ self.encoder_projection_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
317
+ self.encoder_projection_mlp = FlaxGPT2MLP(self.config, self.config.hidden_size, dtype=self.dtype)
318
 
319
  self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
320
  self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
 
352
  cross_attn_weights = None
353
  if encoder_hidden_states is not None:
354
 
355
+ if self.project_encoder:
356
+ residual = encoder_hidden_states
357
+ encoder_hidden_states = self.encoder_projection_ln(encoder_hidden_states)
358
+ feed_forward_hidden_states = self.encoder_projection_mlp(
359
+ encoder_hidden_states, deterministic=deterministic
360
+ )
361
+ # residual connection
362
+ encoder_hidden_states = residual + feed_forward_hidden_states
363
+
364
  residual = hidden_states
365
+ hidden_states = self.cross_attn_ln(hidden_states)
366
 
367
+ cross_attn_outputs = self.cross_attn(
368
  hidden_states=hidden_states,
369
  key_value_states=encoder_hidden_states,
370
  attention_mask=encoder_attention_mask,
vit_gpt2/modeling_flax_vit_gpt2_lm.py CHANGED
@@ -541,6 +541,7 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
541
 
542
  if "config" not in text_kwargs:
543
  text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
 
544
  text_kwargs["config"] = text_config
545
 
546
  text_kwargs["config"].add_cross_attention = True
 
541
 
542
  if "config" not in text_kwargs:
543
  text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
544
+ text_config.project_encoder = text_kwargs.pop("project_encoder", None)
545
  text_kwargs["config"] = text_config
546
 
547
  text_kwargs["config"].add_cross_attention = True