WCNegentropy commited on
Commit
3c27aeb
ยท
verified ยท
1 Parent(s): 35c1128

๐Ÿš€ Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. scripts/training/basic_training.py +180 -0
scripts/training/basic_training.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Basic BitTransformerLM Training Script
4
+ =====================================
5
+
6
+ A simple working training script that follows the ACTUAL BitTransformerLM
7
+ model implementation exactly as it exists in the codebase.
8
+ """
9
+
10
+ import sys
11
+ import os
12
+ import logging
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ # Add paths for imports
18
+ sys.path.append('/data')
19
+ sys.path.append('/data/BitTransformerLM')
20
+
21
+ from bit_transformer import BitTransformerLM, text_to_bits
22
+ from BTLM_Extensions import configure_adafactor_optimizer
23
+
24
+ # Setup logging
25
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
26
+ logger = logging.getLogger(__name__)
27
+
28
+ def create_simple_dataset():
29
+ """Create a simple bit dataset for testing."""
30
+ logger.info("Creating simple bit dataset...")
31
+
32
+ # Use some simple text examples
33
+ texts = [
34
+ "Hello world! This is a test.",
35
+ "BitTransformerLM processes bits natively.",
36
+ "Training on binary sequences is interesting.",
37
+ "Each character becomes 9 bits with parity.",
38
+ "The model learns bit patterns directly.",
39
+ ]
40
+
41
+ # Convert to bits
42
+ bit_sequences = []
43
+ for text in texts:
44
+ bits = text_to_bits(text)
45
+ bit_sequences.append(bits)
46
+
47
+ # Pad to same length and create training data
48
+ max_len = min(64, max(len(bits) for bits in bit_sequences)) # Keep it small for testing
49
+
50
+ training_data = []
51
+ for bits in bit_sequences:
52
+ if len(bits) >= max_len:
53
+ # Take chunks of max_len
54
+ for i in range(0, len(bits) - max_len + 1, max_len // 2):
55
+ chunk = bits[i:i + max_len]
56
+ if len(chunk) == max_len:
57
+ training_data.append(chunk)
58
+
59
+ # Convert to tensor
60
+ data_tensor = torch.tensor(training_data, dtype=torch.long)
61
+ logger.info(f"Created dataset: {data_tensor.shape}")
62
+
63
+ return data_tensor
64
+
65
+ def create_model():
66
+ """Create a small BitTransformerLM model for testing."""
67
+ logger.info("Creating BitTransformerLM model...")
68
+
69
+ # Small model configuration for basic testing
70
+ model = BitTransformerLM(
71
+ d_model=128,
72
+ nhead=8,
73
+ num_layers=2,
74
+ dim_feedforward=256,
75
+ max_seq_len=64,
76
+ lambda_K=0.1,
77
+ lambda_C=0.1,
78
+ lambda_S=0.1,
79
+ use_checkpoint=False, # Disable for simplicity
80
+ use_autocast=False, # Disable for simplicity
81
+ use_act=False # Disable for simplicity
82
+ )
83
+
84
+ total_params = sum(p.numel() for p in model.parameters())
85
+ logger.info(f"Model created: {total_params:,} parameters")
86
+
87
+ return model
88
+
89
+ def train_basic():
90
+ """Basic training loop following the example_training_step pattern."""
91
+ logger.info("Starting basic BitTransformerLM training...")
92
+
93
+ # Create model and data
94
+ model = create_model()
95
+ data = create_simple_dataset()
96
+
97
+ # Calculate total steps
98
+ batch_size = 2
99
+ epochs = 5
100
+ total_steps = (len(data) // batch_size) * epochs
101
+
102
+ # Configure optimizer using Fixed LR Adafactor (breakthrough config)
103
+ logger.info("Configuring Fixed RL Adafactor optimizer...")
104
+ optimizer, scheduler = configure_adafactor_optimizer(
105
+ model,
106
+ lr=1e-3, # FIXED learning rate - key to breakthrough!
107
+ weight_decay=0.01,
108
+ total_steps=total_steps
109
+ )
110
+
111
+ logger.info("Starting training loop...")
112
+
113
+ # Training configuration
114
+
115
+ model.train()
116
+
117
+ for epoch in range(epochs):
118
+ epoch_losses = []
119
+
120
+ # Simple batching
121
+ for i in range(0, len(data), batch_size):
122
+ batch = data[i:i + batch_size]
123
+ if len(batch) < batch_size:
124
+ continue # Skip incomplete batches
125
+
126
+ # Zero gradients
127
+ optimizer.zero_grad()
128
+
129
+ # Forward pass - EXACTLY like example_training_step
130
+ logits, telemetry = model(batch)
131
+
132
+ # Loss calculation - EXACTLY like example_training_step
133
+ pred = logits[:, :-1, :].reshape(-1, 2)
134
+ target = batch[:, 1:].reshape(-1)
135
+ loss = F.cross_entropy(pred, target)
136
+
137
+ # Backward pass
138
+ loss.backward()
139
+
140
+ # Gradient clipping
141
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
142
+
143
+ # Optimizer step
144
+ optimizer.step()
145
+ if scheduler:
146
+ scheduler.step()
147
+
148
+ epoch_losses.append(loss.item())
149
+
150
+ # Log epoch results
151
+ avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('inf')
152
+ logger.info(f"Epoch {epoch + 1}/{epochs}: Average Loss = {avg_loss:.6f}")
153
+
154
+ # Log telemetry if available
155
+ if telemetry:
156
+ for key, value in telemetry.items():
157
+ if torch.is_tensor(value):
158
+ logger.info(f" {key}: {value.mean().item():.4f}")
159
+
160
+ logger.info("Basic training completed successfully!")
161
+ return model
162
+
163
+ def main():
164
+ """Main function."""
165
+ logger.info("๐Ÿš€ Starting basic BitTransformerLM training test")
166
+
167
+ try:
168
+ trained_model = train_basic()
169
+ logger.info("โœ… Basic training test PASSED!")
170
+
171
+ # Save the model
172
+ torch.save(trained_model.state_dict(), '/data/BitTransformerLM/basic_model.pt')
173
+ logger.info("Model saved to basic_model.pt")
174
+
175
+ except Exception as e:
176
+ logger.error(f"โŒ Training failed: {e}")
177
+ raise
178
+
179
+ if __name__ == "__main__":
180
+ main()