Fix model architecture and generation compatibility
Browse files- modeling_tiny_recursive.py +55 -73
modeling_tiny_recursive.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
|
| 2 |
-
import
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
from transformers import PreTrainedModel, PretrainedConfig, GPT2TokenizerFast, Trainer, TrainingArguments, DataCollatorForLanguageModeling
|
| 5 |
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
|
| 6 |
-
from transformers.generation import GenerationMixin
|
| 7 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
class TRMConfig(PretrainedConfig):
|
| 10 |
model_type = "recursive_gpt"
|
| 11 |
|
|
@@ -14,9 +14,9 @@ class TRMConfig(PretrainedConfig):
|
|
| 14 |
vocab_size=50257,
|
| 15 |
n_positions=1024,
|
| 16 |
n_embd=512,
|
|
|
|
|
|
|
| 17 |
n_head=8,
|
| 18 |
-
n_physical_layers=2,
|
| 19 |
-
n_loops=6,
|
| 20 |
activation_function="gelu_new",
|
| 21 |
resid_pdrop=0.1,
|
| 22 |
embd_pdrop=0.1,
|
|
@@ -28,13 +28,12 @@ class TRMConfig(PretrainedConfig):
|
|
| 28 |
**kwargs,
|
| 29 |
):
|
| 30 |
super().__init__(**kwargs)
|
| 31 |
-
# Standard config
|
| 32 |
self.vocab_size = vocab_size
|
| 33 |
self.n_positions = n_positions
|
| 34 |
self.n_embd = n_embd
|
| 35 |
-
self.n_head = n_head
|
| 36 |
self.n_physical_layers = n_physical_layers
|
| 37 |
self.n_loops = n_loops
|
|
|
|
| 38 |
self.activation_function = activation_function
|
| 39 |
self.resid_pdrop = resid_pdrop
|
| 40 |
self.embd_pdrop = embd_pdrop
|
|
@@ -44,17 +43,16 @@ class TRMConfig(PretrainedConfig):
|
|
| 44 |
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
|
| 45 |
self.reorder_and_upcast_attn = reorder_and_upcast_attn
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
# These map your custom names to what GPT2Attention expects
|
| 49 |
-
self.max_position_embeddings = n_positions
|
| 50 |
self.hidden_size = n_embd
|
| 51 |
-
self.num_attention_heads = n_head
|
| 52 |
self.num_hidden_layers = n_physical_layers
|
| 53 |
-
self.n_inner = None
|
|
|
|
| 54 |
|
| 55 |
class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
|
| 56 |
config_class = TRMConfig
|
| 57 |
-
_tied_weights_keys = ["lm_head.weight"]
|
| 58 |
|
| 59 |
def __init__(self, config):
|
| 60 |
super().__init__(config)
|
|
@@ -65,90 +63,74 @@ class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
|
|
| 65 |
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
| 66 |
self.drop = nn.Dropout(config.embd_pdrop)
|
| 67 |
|
| 68 |
-
# 2.
|
| 69 |
self.physical_blocks = nn.ModuleList([
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
])
|
| 72 |
|
|
|
|
| 73 |
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
|
| 74 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 75 |
|
| 76 |
-
#
|
| 77 |
-
self.lm_head.weight = self.wte.weight
|
| 78 |
self.post_init()
|
| 79 |
|
| 80 |
-
def forward(
|
| 81 |
-
|
| 82 |
-
|
| 83 |
|
|
|
|
| 84 |
device = input_ids.device
|
| 85 |
-
b, t = input_ids.size()
|
| 86 |
-
|
| 87 |
-
# Positions & Embeddings
|
| 88 |
-
pos = torch.arange(0, t, dtype=torch.long, device=device)
|
| 89 |
-
tok_emb = self.wte(input_ids)
|
| 90 |
-
pos_emb = self.wpe(pos)
|
| 91 |
-
hidden_states = self.drop(tok_emb + pos_emb)
|
| 92 |
-
|
| 93 |
-
# Attention Mask Handling
|
| 94 |
-
if attention_mask is None:
|
| 95 |
-
attention_mask = torch.ones((b, t), device=device)
|
| 96 |
-
|
| 97 |
-
# Broadcast mask to (batch, head, seq, seq)
|
| 98 |
-
# We preserve the original mask for the loss calculation later if needed,
|
| 99 |
-
# but for the blocks we need the 4D version.
|
| 100 |
-
extended_attention_mask = attention_mask.view(b, 1, 1, t)
|
| 101 |
-
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 102 |
-
|
| 103 |
-
# =========================================================
|
| 104 |
-
# THE RECURSIVE LOOP
|
| 105 |
-
# =========================================================
|
| 106 |
-
for loop_i in range(self.config.n_loops):
|
| 107 |
-
for block in self.physical_blocks:
|
| 108 |
-
hidden_states = block(hidden_states, attention_mask=extended_attention_mask)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
hidden_states = self.ln_f(hidden_states)
|
| 111 |
logits = self.lm_head(hidden_states)
|
| 112 |
|
| 113 |
loss = None
|
| 114 |
if labels is not None:
|
| 115 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 116 |
shift_logits = logits[..., :-1, :].contiguous()
|
| 117 |
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
| 118 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 119 |
|
| 120 |
-
# <--- CRITICAL FIX: Return CausalLMOutputWithCrossAttentions
|
| 121 |
-
if not return_dict:
|
| 122 |
-
output = (logits,)
|
| 123 |
-
return ((loss,) + output) if loss is not None else output
|
| 124 |
-
|
| 125 |
return CausalLMOutputWithCrossAttentions(
|
| 126 |
loss=loss,
|
| 127 |
logits=logits,
|
| 128 |
-
|
| 129 |
-
hidden_states=None,
|
| 130 |
attentions=None,
|
|
|
|
| 131 |
)
|
| 132 |
|
| 133 |
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 134 |
return {"input_ids": input_ids}
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
super().__init__()
|
| 139 |
-
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 140 |
-
self.attn = GPT2Attention(config, layer_idx=layer_idx)
|
| 141 |
-
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 142 |
-
self.mlp = GPT2MLP(config.n_embd, config)
|
| 143 |
-
|
| 144 |
-
def forward(self, x, layer_past=None, attention_mask=None):
|
| 145 |
-
residual = x
|
| 146 |
-
x = self.ln_1(x)
|
| 147 |
-
# We disable caching (use_cache=False) to simplify the recursion loop
|
| 148 |
-
attn_outputs = self.attn(x, layer_past=layer_past, attention_mask=attention_mask, use_cache=False)
|
| 149 |
-
x = residual + attn_outputs[0]
|
| 150 |
-
|
| 151 |
-
residual = x
|
| 152 |
-
x = self.ln_2(x)
|
| 153 |
-
x = residual + self.mlp(x)
|
| 154 |
-
return x
|
|
|
|
| 1 |
|
| 2 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
|
|
|
| 3 |
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
|
| 4 |
+
from transformers.generation import GenerationMixin
|
| 5 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
class TRMConfig(PretrainedConfig):
|
| 10 |
model_type = "recursive_gpt"
|
| 11 |
|
|
|
|
| 14 |
vocab_size=50257,
|
| 15 |
n_positions=1024,
|
| 16 |
n_embd=512,
|
| 17 |
+
n_physical_layers=3,
|
| 18 |
+
n_loops=8,
|
| 19 |
n_head=8,
|
|
|
|
|
|
|
| 20 |
activation_function="gelu_new",
|
| 21 |
resid_pdrop=0.1,
|
| 22 |
embd_pdrop=0.1,
|
|
|
|
| 28 |
**kwargs,
|
| 29 |
):
|
| 30 |
super().__init__(**kwargs)
|
|
|
|
| 31 |
self.vocab_size = vocab_size
|
| 32 |
self.n_positions = n_positions
|
| 33 |
self.n_embd = n_embd
|
|
|
|
| 34 |
self.n_physical_layers = n_physical_layers
|
| 35 |
self.n_loops = n_loops
|
| 36 |
+
self.n_head = n_head
|
| 37 |
self.activation_function = activation_function
|
| 38 |
self.resid_pdrop = resid_pdrop
|
| 39 |
self.embd_pdrop = embd_pdrop
|
|
|
|
| 43 |
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
|
| 44 |
self.reorder_and_upcast_attn = reorder_and_upcast_attn
|
| 45 |
|
| 46 |
+
# Required for transformers compatibility
|
|
|
|
|
|
|
| 47 |
self.hidden_size = n_embd
|
| 48 |
+
self.num_attention_heads = n_head
|
| 49 |
self.num_hidden_layers = n_physical_layers
|
| 50 |
+
self.n_inner = None
|
| 51 |
+
self.is_encoder_decoder = False
|
| 52 |
|
| 53 |
class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
|
| 54 |
config_class = TRMConfig
|
| 55 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 56 |
|
| 57 |
def __init__(self, config):
|
| 58 |
super().__init__(config)
|
|
|
|
| 63 |
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
| 64 |
self.drop = nn.Dropout(config.embd_pdrop)
|
| 65 |
|
| 66 |
+
# 2. Physical blocks - matching your saved model structure
|
| 67 |
self.physical_blocks = nn.ModuleList([
|
| 68 |
+
nn.ModuleDict({
|
| 69 |
+
"ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
|
| 70 |
+
"attn": GPT2Attention(config, layer_idx=i),
|
| 71 |
+
"ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
|
| 72 |
+
"mlp": GPT2MLP(4 * config.n_embd, config)
|
| 73 |
+
}) for i in range(config.n_physical_layers)
|
| 74 |
])
|
| 75 |
|
| 76 |
+
# 3. Final layer norm
|
| 77 |
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 78 |
+
|
| 79 |
+
# 4. Language modeling head
|
| 80 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 81 |
|
| 82 |
+
# Initialize weights
|
|
|
|
| 83 |
self.post_init()
|
| 84 |
|
| 85 |
+
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
|
| 86 |
+
if input_ids is None:
|
| 87 |
+
return None
|
| 88 |
|
| 89 |
+
batch_size, seq_len = input_ids.shape
|
| 90 |
device = input_ids.device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
# Get embeddings
|
| 93 |
+
token_embeds = self.wte(input_ids)
|
| 94 |
+
pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
|
| 95 |
+
pos_embeds = self.wpe(pos_ids)
|
| 96 |
+
hidden_states = self.drop(token_embeds + pos_embeds)
|
| 97 |
+
|
| 98 |
+
# Apply recursive loops through physical blocks
|
| 99 |
+
for loop in range(self.config.n_loops):
|
| 100 |
+
block_idx = loop % self.config.n_physical_layers
|
| 101 |
+
block = self.physical_blocks[block_idx]
|
| 102 |
+
|
| 103 |
+
# Attention
|
| 104 |
+
ln_output = block["ln_1"](hidden_states)
|
| 105 |
+
attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0]
|
| 106 |
+
hidden_states = hidden_states + attn_output
|
| 107 |
+
|
| 108 |
+
# MLP
|
| 109 |
+
ln_output = block["ln_2"](hidden_states)
|
| 110 |
+
mlp_output = block["mlp"](ln_output)
|
| 111 |
+
hidden_states = hidden_states + mlp_output
|
| 112 |
+
|
| 113 |
+
# Final layer norm and projection
|
| 114 |
hidden_states = self.ln_f(hidden_states)
|
| 115 |
logits = self.lm_head(hidden_states)
|
| 116 |
|
| 117 |
loss = None
|
| 118 |
if labels is not None:
|
|
|
|
| 119 |
shift_logits = logits[..., :-1, :].contiguous()
|
| 120 |
shift_labels = labels[..., 1:].contiguous()
|
| 121 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 122 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
return CausalLMOutputWithCrossAttentions(
|
| 125 |
loss=loss,
|
| 126 |
logits=logits,
|
| 127 |
+
hidden_states=hidden_states,
|
|
|
|
| 128 |
attentions=None,
|
| 129 |
+
cross_attentions=None
|
| 130 |
)
|
| 131 |
|
| 132 |
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 133 |
return {"input_ids": input_ids}
|
| 134 |
|
| 135 |
+
def _reorder_cache(self, past, beam_idx):
|
| 136 |
+
return past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|