yuchenxie commited on
Commit
e0eae5b
1 Parent(s): 2a6dccf

Update modeling_arlow_gpt.py

Browse files
Files changed (1) hide show
  1. modeling_arlow_gpt.py +26 -54
modeling_arlow_gpt.py CHANGED
@@ -1,15 +1,14 @@
1
  # modeling_arlow_gpt.py
2
- import torch
3
- import torch.nn as nn
4
- from transformers import PreTrainedModel, CLIPModel, GPT2Model
5
  from typing import Optional, Union, Dict, Tuple
6
  from .configuration_arlow_gpt import ArlowGPTConfig
7
 
8
  class ArlowGPTPreTrainedModel(PreTrainedModel):
9
- """Base class for ArlowGPT model."""
10
  config_class = ArlowGPTConfig
11
  base_model_prefix = "arlow_gpt"
12
  supports_gradient_checkpointing = True
 
13
 
14
  def _init_weights(self, module):
15
  if isinstance(module, nn.Linear):
@@ -18,59 +17,32 @@ class ArlowGPTPreTrainedModel(PreTrainedModel):
18
  module.bias.data.zero_()
19
 
20
  class ArlowGPTModel(ArlowGPTPreTrainedModel):
 
 
 
21
  def __init__(self, config: ArlowGPTConfig):
22
  super().__init__(config)
23
-
24
- self.clip = CLIPModel.from_pretrained(config.clip_model_name)
25
- self.gpt2 = GPT2Model.from_pretrained(config.gpt2_model_name)
26
-
27
- self.feature_projection = nn.Linear(
28
- self.clip.vision_model.config.hidden_size + self.gpt2.config.hidden_size,
29
- config.projection_dim
30
- )
31
 
32
  # Initialize weights and apply final processing
33
  self.post_init()
34
 
35
- def forward(
36
- self,
37
- input_ids: torch.Tensor,
38
- attention_mask: torch.Tensor,
39
- pixel_values: torch.Tensor,
40
- return_dict: bool = True,
41
- ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
42
- vision_outputs = self.clip.get_image_features(pixel_values=pixel_values)
43
- text_outputs = self.gpt2(
44
- input_ids=input_ids,
45
- attention_mask=attention_mask
46
- ).last_hidden_state
47
-
48
- batch_size = text_outputs.shape[0]
49
- seq_length = text_outputs.shape[1]
50
-
51
- vision_features = vision_outputs.unsqueeze(1).expand(
52
- batch_size, seq_length, -1
53
- )
54
-
55
- combined_features = torch.cat(
56
- [vision_features, text_outputs],
57
- dim=-1
58
- )
59
-
60
- hidden_states = self.feature_projection(combined_features)
61
 
62
- if return_dict:
63
- return {"hidden_states": hidden_states}
64
- return hidden_states
 
 
 
 
65
 
66
- class ArlowGPTForCausalLM(ArlowGPTPreTrainedModel):
67
- def __init__(self, config: ArlowGPTConfig):
68
- super().__init__(config)
69
- self.arlow_gpt = ArlowGPTModel(config)
70
- self.output_projection = nn.Linear(config.projection_dim, config.vocab_size)
71
-
72
- # Initialize weights and apply final processing
73
- self.post_init()
74
 
75
  def forward(
76
  self,
@@ -88,7 +60,7 @@ class ArlowGPTForCausalLM(ArlowGPTPreTrainedModel):
88
  )
89
 
90
  hidden_states = outputs["hidden_states"]
91
- logits = self.output_projection(hidden_states)
92
 
93
  loss = None
94
  if labels is not None:
@@ -96,10 +68,10 @@ class ArlowGPTForCausalLM(ArlowGPTPreTrainedModel):
96
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
97
 
98
  if return_dict:
99
- return {
100
- "loss": loss,
101
- "logits": logits
102
- }
103
  return (loss, logits) if loss is not None else logits
104
 
105
  def prepare_inputs_for_generation(
 
1
  # modeling_arlow_gpt.py
2
+ from transformers import PreTrainedModel, PreTrainedModel, CLIPModel, GPT2Model
3
+ from transformers.modeling_outputs import Seq2SeqLMOutput
 
4
  from typing import Optional, Union, Dict, Tuple
5
  from .configuration_arlow_gpt import ArlowGPTConfig
6
 
7
  class ArlowGPTPreTrainedModel(PreTrainedModel):
 
8
  config_class = ArlowGPTConfig
9
  base_model_prefix = "arlow_gpt"
10
  supports_gradient_checkpointing = True
11
+ _keys_to_ignore_on_load_missing = [r"clip", r"gpt2"]
12
 
13
  def _init_weights(self, module):
14
  if isinstance(module, nn.Linear):
 
17
  module.bias.data.zero_()
18
 
19
  class ArlowGPTModel(ArlowGPTPreTrainedModel):
20
+ # Same as before
21
+
22
+ class ArlowGPTForImageTextToText(ArlowGPTPreTrainedModel):
23
  def __init__(self, config: ArlowGPTConfig):
24
  super().__init__(config)
25
+ self.arlow_gpt = ArlowGPTModel(config)
26
+ self.lm_head = nn.Linear(config.projection_dim, config.vocab_size)
 
 
 
 
 
 
27
 
28
  # Initialize weights and apply final processing
29
  self.post_init()
30
 
31
+ def save_pretrained(self, save_directory, **kwargs):
32
+ """Override save_pretrained to save all components"""
33
+ super().save_pretrained(save_directory, **kwargs)
34
+ self.arlow_gpt.save_pretrained(save_directory)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ @classmethod
37
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
38
+ """Override from_pretrained to handle custom loading logic"""
39
+ config = kwargs.get("config", None)
40
+ if config is None:
41
+ config = ArlowGPTConfig.from_pretrained(pretrained_model_name_or_path)
42
+ kwargs["config"] = config
43
 
44
+ config._name_or_path = pretrained_model_name_or_path
45
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
 
 
 
 
 
 
46
 
47
  def forward(
48
  self,
 
60
  )
61
 
62
  hidden_states = outputs["hidden_states"]
63
+ logits = self.lm_head(hidden_states)
64
 
65
  loss = None
66
  if labels is not None:
 
68
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
69
 
70
  if return_dict:
71
+ return Seq2SeqLMOutput(
72
+ loss=loss,
73
+ logits=logits,
74
+ )
75
  return (loss, logits) if loss is not None else logits
76
 
77
  def prepare_inputs_for_generation(