ainz commited on
Commit
43ea3b4
·
verified ·
1 Parent(s): 85e4754

Fix model architecture and generation compatibility

Browse files
Files changed (1) hide show
  1. modeling_tiny_recursive.py +55 -73
modeling_tiny_recursive.py CHANGED
@@ -1,11 +1,11 @@
1
 
2
- import torch
3
- import torch.nn as nn
4
- from transformers import PreTrainedModel, PretrainedConfig, GPT2TokenizerFast, Trainer, TrainingArguments, DataCollatorForLanguageModeling
5
  from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
6
- from transformers.generation import GenerationMixin # <--- FIXED: Import this explicitly
7
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
- from datasets import load_dataset
 
 
9
  class TRMConfig(PretrainedConfig):
10
  model_type = "recursive_gpt"
11
 
@@ -14,9 +14,9 @@ class TRMConfig(PretrainedConfig):
14
  vocab_size=50257,
15
  n_positions=1024,
16
  n_embd=512,
 
 
17
  n_head=8,
18
- n_physical_layers=2,
19
- n_loops=6,
20
  activation_function="gelu_new",
21
  resid_pdrop=0.1,
22
  embd_pdrop=0.1,
@@ -28,13 +28,12 @@ class TRMConfig(PretrainedConfig):
28
  **kwargs,
29
  ):
30
  super().__init__(**kwargs)
31
- # Standard config
32
  self.vocab_size = vocab_size
33
  self.n_positions = n_positions
34
  self.n_embd = n_embd
35
- self.n_head = n_head
36
  self.n_physical_layers = n_physical_layers
37
  self.n_loops = n_loops
 
38
  self.activation_function = activation_function
39
  self.resid_pdrop = resid_pdrop
40
  self.embd_pdrop = embd_pdrop
@@ -44,17 +43,16 @@ class TRMConfig(PretrainedConfig):
44
  self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
45
  self.reorder_and_upcast_attn = reorder_and_upcast_attn
46
 
47
- # --- CRITICAL FIXES FOR COMPATIBILITY ---
48
- # These map your custom names to what GPT2Attention expects
49
- self.max_position_embeddings = n_positions
50
  self.hidden_size = n_embd
51
- self.num_attention_heads = n_head # <--- FIXED: The missing attribute
52
  self.num_hidden_layers = n_physical_layers
53
- self.n_inner = None # Defaults to 4*hidden_size
 
54
 
55
  class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
56
  config_class = TRMConfig
57
- _tied_weights_keys = ["lm_head.weight"] # <-- Add this line
58
 
59
  def __init__(self, config):
60
  super().__init__(config)
@@ -65,90 +63,74 @@ class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
65
  self.wpe = nn.Embedding(config.n_positions, config.n_embd)
66
  self.drop = nn.Dropout(config.embd_pdrop)
67
 
68
- # 2. The Logic Core (The "7M" part)
69
  self.physical_blocks = nn.ModuleList([
70
- RecursiveBlock(config, layer_idx=i) for i in range(config.n_physical_layers)
 
 
 
 
 
71
  ])
72
 
 
73
  self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
 
 
74
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
75
 
76
- # Weight tying
77
- self.lm_head.weight = self.wte.weight
78
  self.post_init()
79
 
80
- def forward( self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs):
81
- # Default to True if not specified, required for generation
82
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
 
 
84
  device = input_ids.device
85
- b, t = input_ids.size()
86
-
87
- # Positions & Embeddings
88
- pos = torch.arange(0, t, dtype=torch.long, device=device)
89
- tok_emb = self.wte(input_ids)
90
- pos_emb = self.wpe(pos)
91
- hidden_states = self.drop(tok_emb + pos_emb)
92
-
93
- # Attention Mask Handling
94
- if attention_mask is None:
95
- attention_mask = torch.ones((b, t), device=device)
96
-
97
- # Broadcast mask to (batch, head, seq, seq)
98
- # We preserve the original mask for the loss calculation later if needed,
99
- # but for the blocks we need the 4D version.
100
- extended_attention_mask = attention_mask.view(b, 1, 1, t)
101
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
102
-
103
- # =========================================================
104
- # THE RECURSIVE LOOP
105
- # =========================================================
106
- for loop_i in range(self.config.n_loops):
107
- for block in self.physical_blocks:
108
- hidden_states = block(hidden_states, attention_mask=extended_attention_mask)
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  hidden_states = self.ln_f(hidden_states)
111
  logits = self.lm_head(hidden_states)
112
 
113
  loss = None
114
  if labels is not None:
115
- loss_fct = nn.CrossEntropyLoss()
116
  shift_logits = logits[..., :-1, :].contiguous()
117
  shift_labels = labels[..., 1:].contiguous()
 
118
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
119
 
120
- # <--- CRITICAL FIX: Return CausalLMOutputWithCrossAttentions
121
- if not return_dict:
122
- output = (logits,)
123
- return ((loss,) + output) if loss is not None else output
124
-
125
  return CausalLMOutputWithCrossAttentions(
126
  loss=loss,
127
  logits=logits,
128
- past_key_values=None, # We are not using KV-cache for simplicity in this recursive setup
129
- hidden_states=None,
130
  attentions=None,
 
131
  )
132
 
133
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
134
  return {"input_ids": input_ids}
135
 
136
- class RecursiveBlock(nn.Module):
137
- def __init__(self, config, layer_idx):
138
- super().__init__()
139
- self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
140
- self.attn = GPT2Attention(config, layer_idx=layer_idx)
141
- self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
142
- self.mlp = GPT2MLP(config.n_embd, config)
143
-
144
- def forward(self, x, layer_past=None, attention_mask=None):
145
- residual = x
146
- x = self.ln_1(x)
147
- # We disable caching (use_cache=False) to simplify the recursion loop
148
- attn_outputs = self.attn(x, layer_past=layer_past, attention_mask=attention_mask, use_cache=False)
149
- x = residual + attn_outputs[0]
150
-
151
- residual = x
152
- x = self.ln_2(x)
153
- x = residual + self.mlp(x)
154
- return x
 
1
 
2
+ from transformers import PreTrainedModel, PretrainedConfig
 
 
3
  from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
4
+ from transformers.generation import GenerationMixin
5
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
6
+ import torch
7
+ import torch.nn as nn
8
+
9
  class TRMConfig(PretrainedConfig):
10
  model_type = "recursive_gpt"
11
 
 
14
  vocab_size=50257,
15
  n_positions=1024,
16
  n_embd=512,
17
+ n_physical_layers=3,
18
+ n_loops=8,
19
  n_head=8,
 
 
20
  activation_function="gelu_new",
21
  resid_pdrop=0.1,
22
  embd_pdrop=0.1,
 
28
  **kwargs,
29
  ):
30
  super().__init__(**kwargs)
 
31
  self.vocab_size = vocab_size
32
  self.n_positions = n_positions
33
  self.n_embd = n_embd
 
34
  self.n_physical_layers = n_physical_layers
35
  self.n_loops = n_loops
36
+ self.n_head = n_head
37
  self.activation_function = activation_function
38
  self.resid_pdrop = resid_pdrop
39
  self.embd_pdrop = embd_pdrop
 
43
  self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
44
  self.reorder_and_upcast_attn = reorder_and_upcast_attn
45
 
46
+ # Required for transformers compatibility
 
 
47
  self.hidden_size = n_embd
48
+ self.num_attention_heads = n_head
49
  self.num_hidden_layers = n_physical_layers
50
+ self.n_inner = None
51
+ self.is_encoder_decoder = False
52
 
53
  class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
54
  config_class = TRMConfig
55
+ _tied_weights_keys = ["lm_head.weight"]
56
 
57
  def __init__(self, config):
58
  super().__init__(config)
 
63
  self.wpe = nn.Embedding(config.n_positions, config.n_embd)
64
  self.drop = nn.Dropout(config.embd_pdrop)
65
 
66
+ # 2. Physical blocks - matching your saved model structure
67
  self.physical_blocks = nn.ModuleList([
68
+ nn.ModuleDict({
69
+ "ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
70
+ "attn": GPT2Attention(config, layer_idx=i),
71
+ "ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
72
+ "mlp": GPT2MLP(4 * config.n_embd, config)
73
+ }) for i in range(config.n_physical_layers)
74
  ])
75
 
76
+ # 3. Final layer norm
77
  self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
78
+
79
+ # 4. Language modeling head
80
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
81
 
82
+ # Initialize weights
 
83
  self.post_init()
84
 
85
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
86
+ if input_ids is None:
87
+ return None
88
 
89
+ batch_size, seq_len = input_ids.shape
90
  device = input_ids.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Get embeddings
93
+ token_embeds = self.wte(input_ids)
94
+ pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
95
+ pos_embeds = self.wpe(pos_ids)
96
+ hidden_states = self.drop(token_embeds + pos_embeds)
97
+
98
+ # Apply recursive loops through physical blocks
99
+ for loop in range(self.config.n_loops):
100
+ block_idx = loop % self.config.n_physical_layers
101
+ block = self.physical_blocks[block_idx]
102
+
103
+ # Attention
104
+ ln_output = block["ln_1"](hidden_states)
105
+ attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0]
106
+ hidden_states = hidden_states + attn_output
107
+
108
+ # MLP
109
+ ln_output = block["ln_2"](hidden_states)
110
+ mlp_output = block["mlp"](ln_output)
111
+ hidden_states = hidden_states + mlp_output
112
+
113
+ # Final layer norm and projection
114
  hidden_states = self.ln_f(hidden_states)
115
  logits = self.lm_head(hidden_states)
116
 
117
  loss = None
118
  if labels is not None:
 
119
  shift_logits = logits[..., :-1, :].contiguous()
120
  shift_labels = labels[..., 1:].contiguous()
121
+ loss_fct = nn.CrossEntropyLoss()
122
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
123
 
 
 
 
 
 
124
  return CausalLMOutputWithCrossAttentions(
125
  loss=loss,
126
  logits=logits,
127
+ hidden_states=hidden_states,
 
128
  attentions=None,
129
+ cross_attentions=None
130
  )
131
 
132
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
133
  return {"input_ids": input_ids}
134
 
135
+ def _reorder_cache(self, past, beam_idx):
136
+ return past