ainz commited on
Commit
85e4754
·
verified ·
1 Parent(s): 4b236c7

Update modeling file with complete recursive implementation

Browse files
Files changed (1) hide show
  1. modeling_tiny_recursive.py +72 -45
modeling_tiny_recursive.py CHANGED
@@ -1,11 +1,11 @@
1
 
2
- from transformers import PreTrainedModel, PretrainedConfig
3
- from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
4
- from transformers.generation import GenerationMixin
5
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
6
  import torch
7
  import torch.nn as nn
8
-
 
 
 
 
9
  class TRMConfig(PretrainedConfig):
10
  model_type = "recursive_gpt"
11
 
@@ -14,9 +14,9 @@ class TRMConfig(PretrainedConfig):
14
  vocab_size=50257,
15
  n_positions=1024,
16
  n_embd=512,
17
- n_physical_layers=3,
18
- n_loops=8,
19
  n_head=8,
 
 
20
  activation_function="gelu_new",
21
  resid_pdrop=0.1,
22
  embd_pdrop=0.1,
@@ -28,12 +28,13 @@ class TRMConfig(PretrainedConfig):
28
  **kwargs,
29
  ):
30
  super().__init__(**kwargs)
 
31
  self.vocab_size = vocab_size
32
  self.n_positions = n_positions
33
  self.n_embd = n_embd
 
34
  self.n_physical_layers = n_physical_layers
35
  self.n_loops = n_loops
36
- self.n_head = n_head
37
  self.activation_function = activation_function
38
  self.resid_pdrop = resid_pdrop
39
  self.embd_pdrop = embd_pdrop
@@ -43,15 +44,17 @@ class TRMConfig(PretrainedConfig):
43
  self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
44
  self.reorder_and_upcast_attn = reorder_and_upcast_attn
45
 
46
- # Required for transformers compatibility
 
 
47
  self.hidden_size = n_embd
48
- self.num_attention_heads = n_head
49
  self.num_hidden_layers = n_physical_layers
50
- self.n_inner = None
51
 
52
  class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
53
  config_class = TRMConfig
54
- _tied_weights_keys = ["lm_head.weight"]
55
 
56
  def __init__(self, config):
57
  super().__init__(config)
@@ -62,66 +65,90 @@ class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
62
  self.wpe = nn.Embedding(config.n_positions, config.n_embd)
63
  self.drop = nn.Dropout(config.embd_pdrop)
64
 
65
- # 2. The Logic Core - Physical transformer blocks
66
  self.physical_blocks = nn.ModuleList([
67
- nn.ModuleDict({
68
- "ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
69
- "attn": GPT2Attention(config, layer_idx=i),
70
- "ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
71
- "mlp": GPT2MLP(4 * config.n_embd, config)
72
- }) for i in range(config.n_physical_layers)
73
  ])
74
 
75
- # 3. Final layer norm
76
  self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
77
-
78
- # 4. Language modeling head
79
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
80
 
81
- # Initialize weights
 
82
  self.post_init()
83
 
84
- def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
85
- batch_size, seq_len = input_ids.shape
 
 
 
 
86
 
87
- # Get embeddings
88
- token_embeds = self.wte(input_ids)
89
- pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=input_ids.device)
90
- pos_embeds = self.wpe(pos_ids)
91
- hidden_states = self.drop(token_embeds + pos_embeds)
92
 
93
- # Apply recursive loops through physical blocks
94
- for loop in range(self.config.n_loops):
95
- block_idx = loop % self.config.n_physical_layers
96
- block = self.physical_blocks[block_idx]
97
 
98
- # Attention
99
- ln_output = block["ln_1"](hidden_states)
100
- attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0]
101
- hidden_states = hidden_states + attn_output
 
102
 
103
- # MLP
104
- ln_output = block["ln_2"](hidden_states)
105
- mlp_output = block["mlp"](ln_output)
106
- hidden_states = hidden_states + mlp_output
 
 
107
 
108
- # Final layer norm and projection
109
  hidden_states = self.ln_f(hidden_states)
110
  logits = self.lm_head(hidden_states)
111
 
112
  loss = None
113
  if labels is not None:
 
114
  shift_logits = logits[..., :-1, :].contiguous()
115
  shift_labels = labels[..., 1:].contiguous()
116
- loss_fct = nn.CrossEntropyLoss()
117
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
118
 
 
 
 
 
 
119
  return CausalLMOutputWithCrossAttentions(
120
  loss=loss,
121
  logits=logits,
 
122
  hidden_states=None,
123
- attentions=None
124
  )
125
 
126
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
127
  return {"input_ids": input_ids}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
+ from transformers import PreTrainedModel, PretrainedConfig, GPT2TokenizerFast, Trainer, TrainingArguments, DataCollatorForLanguageModeling
5
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
6
+ from transformers.generation import GenerationMixin # <--- FIXED: Import this explicitly
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from datasets import load_dataset
9
  class TRMConfig(PretrainedConfig):
10
  model_type = "recursive_gpt"
11
 
 
14
  vocab_size=50257,
15
  n_positions=1024,
16
  n_embd=512,
 
 
17
  n_head=8,
18
+ n_physical_layers=2,
19
+ n_loops=6,
20
  activation_function="gelu_new",
21
  resid_pdrop=0.1,
22
  embd_pdrop=0.1,
 
28
  **kwargs,
29
  ):
30
  super().__init__(**kwargs)
31
+ # Standard config
32
  self.vocab_size = vocab_size
33
  self.n_positions = n_positions
34
  self.n_embd = n_embd
35
+ self.n_head = n_head
36
  self.n_physical_layers = n_physical_layers
37
  self.n_loops = n_loops
 
38
  self.activation_function = activation_function
39
  self.resid_pdrop = resid_pdrop
40
  self.embd_pdrop = embd_pdrop
 
44
  self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
45
  self.reorder_and_upcast_attn = reorder_and_upcast_attn
46
 
47
+ # --- CRITICAL FIXES FOR COMPATIBILITY ---
48
+ # These map your custom names to what GPT2Attention expects
49
+ self.max_position_embeddings = n_positions
50
  self.hidden_size = n_embd
51
+ self.num_attention_heads = n_head # <--- FIXED: The missing attribute
52
  self.num_hidden_layers = n_physical_layers
53
+ self.n_inner = None # Defaults to 4*hidden_size
54
 
55
  class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
56
  config_class = TRMConfig
57
+ _tied_weights_keys = ["lm_head.weight"] # <-- Add this line
58
 
59
  def __init__(self, config):
60
  super().__init__(config)
 
65
  self.wpe = nn.Embedding(config.n_positions, config.n_embd)
66
  self.drop = nn.Dropout(config.embd_pdrop)
67
 
68
+ # 2. The Logic Core (The "7M" part)
69
  self.physical_blocks = nn.ModuleList([
70
+ RecursiveBlock(config, layer_idx=i) for i in range(config.n_physical_layers)
 
 
 
 
 
71
  ])
72
 
 
73
  self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
 
 
74
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
75
 
76
+ # Weight tying
77
+ self.lm_head.weight = self.wte.weight
78
  self.post_init()
79
 
80
+ def forward( self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs):
81
+ # Default to True if not specified, required for generation
82
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
+
84
+ device = input_ids.device
85
+ b, t = input_ids.size()
86
 
87
+ # Positions & Embeddings
88
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
89
+ tok_emb = self.wte(input_ids)
90
+ pos_emb = self.wpe(pos)
91
+ hidden_states = self.drop(tok_emb + pos_emb)
92
 
93
+ # Attention Mask Handling
94
+ if attention_mask is None:
95
+ attention_mask = torch.ones((b, t), device=device)
 
96
 
97
+ # Broadcast mask to (batch, head, seq, seq)
98
+ # We preserve the original mask for the loss calculation later if needed,
99
+ # but for the blocks we need the 4D version.
100
+ extended_attention_mask = attention_mask.view(b, 1, 1, t)
101
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
102
 
103
+ # =========================================================
104
+ # THE RECURSIVE LOOP
105
+ # =========================================================
106
+ for loop_i in range(self.config.n_loops):
107
+ for block in self.physical_blocks:
108
+ hidden_states = block(hidden_states, attention_mask=extended_attention_mask)
109
 
 
110
  hidden_states = self.ln_f(hidden_states)
111
  logits = self.lm_head(hidden_states)
112
 
113
  loss = None
114
  if labels is not None:
115
+ loss_fct = nn.CrossEntropyLoss()
116
  shift_logits = logits[..., :-1, :].contiguous()
117
  shift_labels = labels[..., 1:].contiguous()
 
118
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
119
 
120
+ # <--- CRITICAL FIX: Return CausalLMOutputWithCrossAttentions
121
+ if not return_dict:
122
+ output = (logits,)
123
+ return ((loss,) + output) if loss is not None else output
124
+
125
  return CausalLMOutputWithCrossAttentions(
126
  loss=loss,
127
  logits=logits,
128
+ past_key_values=None, # We are not using KV-cache for simplicity in this recursive setup
129
  hidden_states=None,
130
+ attentions=None,
131
  )
132
 
133
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
134
  return {"input_ids": input_ids}
135
+
136
+ class RecursiveBlock(nn.Module):
137
+ def __init__(self, config, layer_idx):
138
+ super().__init__()
139
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
140
+ self.attn = GPT2Attention(config, layer_idx=layer_idx)
141
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
142
+ self.mlp = GPT2MLP(config.n_embd, config)
143
+
144
+ def forward(self, x, layer_past=None, attention_mask=None):
145
+ residual = x
146
+ x = self.ln_1(x)
147
+ # We disable caching (use_cache=False) to simplify the recursion loop
148
+ attn_outputs = self.attn(x, layer_past=layer_past, attention_mask=attention_mask, use_cache=False)
149
+ x = residual + attn_outputs[0]
150
+
151
+ residual = x
152
+ x = self.ln_2(x)
153
+ x = residual + self.mlp(x)
154
+ return x