WCNegentropy commited on
Commit
8ef2120
·
verified ·
1 Parent(s): e260177

Remove massive_scale_simple.py - cleanup for OS launch

Browse files
Files changed (1) hide show
  1. massive_scale_simple.py +0 -395
massive_scale_simple.py DELETED
@@ -1,395 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- BitTransformerLM Massive Scale Training - SIMPLIFIED & OPTIMIZED
4
- =================================================================
5
-
6
- Fixed version that properly initializes 680M parameter model with all optimizations!
7
- Uses DataParallel for multi-GPU instead of FSDP to avoid initialization issues.
8
- """
9
-
10
- import os
11
- import sys
12
- import time
13
- import json
14
- import logging
15
- from datetime import datetime
16
- from typing import Dict, Any, Optional
17
-
18
- import torch
19
- import torch.nn as nn
20
- import torch.nn.functional as F
21
- from torch.utils.data import DataLoader
22
- import datasets
23
- from datasets import load_dataset
24
- import numpy as np
25
-
26
- # BitTransformerLM imports
27
- from bit_transformer.model import BitTransformerLM
28
- from bit_transformer.bit_io import text_to_bits, bits_to_text
29
- from bit_transformer.utils import set_dropout
30
-
31
- # Configure logging
32
- logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
33
- logger = logging.getLogger(__name__)
34
-
35
-
36
- class OptimizedConfig:
37
- """Optimized 680M parameter configuration with ALL BitTransformerLM features enabled."""
38
-
39
- # Model Architecture (680M parameters - CONFIRMED)
40
- D_MODEL = 1536
41
- NUM_LAYERS = 24
42
- NUM_HEADS = 24
43
- DIM_FEEDFORWARD = 6144
44
- MAX_SEQ_LEN = 2048
45
-
46
- # Training Configuration
47
- BATCH_SIZE_PER_GPU = 1 # Ultra conservative for 680M model
48
- NUM_GPUS = 4
49
- TOTAL_BATCH_SIZE = BATCH_SIZE_PER_GPU * NUM_GPUS # 4
50
- GRADIENT_ACCUMULATION_STEPS = 32 # Effective batch size = 128
51
-
52
- LEARNING_RATE = 3e-4 # Optimal for 680M model
53
- WEIGHT_DECAY = 0.01
54
- MAX_STEPS = 10000
55
- WARMUP_STEPS = 500
56
-
57
- # BitTransformerLM Optimizations - ALL ENABLED!
58
- USE_REVERSIBLE = True # 50% memory savings
59
- USE_GRADIENT_CHECKPOINTING = True # Additional memory savings
60
- USE_MIXED_PRECISION = True # FP16 training
61
- USE_AUTOCAST = True # CPU mixed precision when needed
62
- CHUNK_SIZE = None # Full attention (no chunking)
63
- FULL_ATTN_LOGGING = False # Memory optimization
64
-
65
- # Safety & Telemetry
66
- LAMBDA_K = 1.0
67
- LAMBDA_C = 1.0
68
- LAMBDA_S = 1.0
69
- NEGENTROPY_THRESHOLD = 0.2
70
- LZ_COMPLEXITY_THRESHOLD = 0.3
71
- SYMBIOSIS_THRESHOLD = 0.5
72
-
73
- @classmethod
74
- def get_model_config(cls) -> Dict[str, Any]:
75
- """Get optimized model configuration."""
76
- return {
77
- "d_model": cls.D_MODEL,
78
- "nhead": cls.NUM_HEADS,
79
- "num_layers": cls.NUM_LAYERS,
80
- "dim_feedforward": cls.DIM_FEEDFORWARD,
81
- "max_seq_len": cls.MAX_SEQ_LEN,
82
- "lambda_K": cls.LAMBDA_K,
83
- "lambda_C": cls.LAMBDA_C,
84
- "lambda_S": cls.LAMBDA_S,
85
- "reversible": cls.USE_REVERSIBLE,
86
- "use_checkpoint": cls.USE_GRADIENT_CHECKPOINTING,
87
- "use_autocast": cls.USE_AUTOCAST,
88
- "chunk_size": cls.CHUNK_SIZE,
89
- "full_attn_logging": cls.FULL_ATTN_LOGGING,
90
- }
91
-
92
-
93
- class SimpleWikiTextDataset(torch.utils.data.Dataset):
94
- """Simplified WikiText dataset for bit-level training."""
95
-
96
- def __init__(self, split: str = "train", max_samples: int = 1000, max_length: int = 2048):
97
- self.max_length = max_length
98
-
99
- logger.info(f"Loading WikiText-103 {split} split (max {max_samples} samples)...")
100
- dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split=split)
101
-
102
- # Filter and limit samples
103
- texts = [item['text'] for item in dataset if len(item['text'].strip()) > 100][:max_samples]
104
- self.texts = texts
105
-
106
- logger.info(f"Loaded {len(self.texts)} text samples from {split}")
107
-
108
- def __len__(self) -> int:
109
- return len(self.texts)
110
-
111
- def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
112
- text = self.texts[idx]
113
-
114
- try:
115
- # Convert text to bits
116
- bits = text_to_bits(text)
117
-
118
- # Truncate or pad to max_length
119
- if len(bits) > self.max_length:
120
- bits = bits[:self.max_length]
121
- elif len(bits) < self.max_length:
122
- bits = bits + [0] * (self.max_length - len(bits))
123
-
124
- # Convert to tensor
125
- input_bits = torch.tensor(bits[:-1], dtype=torch.long)
126
- target_bits = torch.tensor(bits[1:], dtype=torch.long)
127
-
128
- return {
129
- 'input_ids': input_bits,
130
- 'labels': target_bits,
131
- 'attention_mask': torch.ones_like(input_bits)
132
- }
133
-
134
- except Exception as e:
135
- logger.warning(f"Error processing text at index {idx}: {e}")
136
- # Fallback
137
- fallback_bits = [0, 1] * (self.max_length // 2)
138
- input_bits = torch.tensor(fallback_bits[:-1], dtype=torch.long)
139
- target_bits = torch.tensor(fallback_bits[1:], dtype=torch.long)
140
-
141
- return {
142
- 'input_ids': input_bits,
143
- 'labels': target_bits,
144
- 'attention_mask': torch.ones_like(input_bits)
145
- }
146
-
147
-
148
- def create_optimized_model(config: OptimizedConfig) -> nn.Module:
149
- """Create properly optimized BitTransformerLM model."""
150
-
151
- # Create model on CPU first
152
- logger.info("🏗️ Creating optimized BitTransformerLM model...")
153
- model_config = config.get_model_config()
154
-
155
- logger.info("Model configuration:")
156
- for k, v in model_config.items():
157
- logger.info(f" {k}: {v}")
158
-
159
- model = BitTransformerLM(**model_config)
160
-
161
- # Count parameters
162
- params = sum(p.numel() for p in model.parameters() if p.requires_grad)
163
- logger.info(f"✅ Model created: {params:,} parameters ({params/1e6:.1f}M)")
164
-
165
- # Move to GPU and setup DataParallel
166
- if torch.cuda.is_available() and torch.cuda.device_count() >= config.NUM_GPUS:
167
- logger.info(f"🚀 Setting up multi-GPU training on {config.NUM_GPUS} GPUs...")
168
-
169
- # Move model to GPU 0
170
- model = model.cuda()
171
-
172
- # Wrap with DataParallel for multi-GPU
173
- if config.NUM_GPUS > 1:
174
- model = nn.DataParallel(model, device_ids=list(range(config.NUM_GPUS)))
175
- logger.info(f"✅ DataParallel setup complete across GPUs: {list(range(config.NUM_GPUS))}")
176
-
177
- else:
178
- logger.warning("⚠️ Limited GPU availability - using single GPU or CPU")
179
- if torch.cuda.is_available():
180
- model = model.cuda()
181
-
182
- return model
183
-
184
-
185
- def train_step(model: nn.Module, batch: Dict[str, torch.Tensor],
186
- optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler,
187
- config: OptimizedConfig) -> tuple:
188
- """Optimized training step with all BitTransformerLM features."""
189
-
190
- model.train()
191
- set_dropout(model, 0.1) # Enable dropout for training
192
-
193
- # Move batch to GPU
194
- input_ids = batch['input_ids'].cuda(non_blocking=True)
195
- labels = batch['labels'].cuda(non_blocking=True)
196
-
197
- # Forward pass with mixed precision
198
- with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION):
199
- outputs = model(input_ids)
200
-
201
- if isinstance(outputs, tuple):
202
- logits, telemetry = outputs
203
- else:
204
- logits, telemetry = outputs, {}
205
-
206
- # Compute loss
207
- loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction='mean')
208
-
209
- # Add safety penalties if enabled
210
- safety_penalty = 0.0
211
- if telemetry:
212
- negentropy = telemetry.get('negentropy', 1.0)
213
- lz_complexity = telemetry.get('lz_complexity', 1.0)
214
- symbiosis = telemetry.get('symbiosis', 1.0)
215
-
216
- if (negentropy < config.NEGENTROPY_THRESHOLD or
217
- lz_complexity < config.LZ_COMPLEXITY_THRESHOLD or
218
- symbiosis < config.SYMBIOSIS_THRESHOLD):
219
- safety_penalty = 0.1
220
- loss = loss + safety_penalty
221
-
222
- # Scale for gradient accumulation
223
- loss = loss / config.GRADIENT_ACCUMULATION_STEPS
224
-
225
- # Backward pass
226
- scaler.scale(loss).backward()
227
-
228
- return loss.item() * config.GRADIENT_ACCUMULATION_STEPS, telemetry, safety_penalty
229
-
230
-
231
- def main():
232
- """Main training function."""
233
-
234
- logger.info("🚀 OPTIMIZED MASSIVE SCALE BITTRANSFORMERLM TRAINING!")
235
- logger.info("=" * 60)
236
-
237
- config = OptimizedConfig()
238
-
239
- # Check CUDA
240
- if not torch.cuda.is_available():
241
- logger.error("❌ CUDA not available!")
242
- return
243
-
244
- logger.info(f"🔥 Hardware: {torch.cuda.device_count()}x GPUs detected")
245
- for i in range(torch.cuda.device_count()):
246
- props = torch.cuda.get_device_properties(i)
247
- logger.info(f" GPU {i}: {props.name} ({props.total_memory / 1024**3:.1f}GB)")
248
-
249
- # Create model
250
- model = create_optimized_model(config)
251
-
252
- # Create datasets
253
- logger.info("📚 Loading datasets...")
254
- train_dataset = SimpleWikiTextDataset("train", max_samples=2000, max_length=config.MAX_SEQ_LEN)
255
- val_dataset = SimpleWikiTextDataset("validation", max_samples=100, max_length=config.MAX_SEQ_LEN)
256
-
257
- # Create dataloaders
258
- train_loader = DataLoader(
259
- train_dataset,
260
- batch_size=config.BATCH_SIZE_PER_GPU,
261
- shuffle=True,
262
- num_workers=2,
263
- pin_memory=True
264
- )
265
-
266
- val_loader = DataLoader(
267
- val_dataset,
268
- batch_size=config.BATCH_SIZE_PER_GPU,
269
- shuffle=False,
270
- num_workers=1,
271
- pin_memory=True
272
- )
273
-
274
- # Setup optimizer and scheduler
275
- logger.info("⚙️ Setting up optimizer...")
276
- optimizer = torch.optim.AdamW(
277
- model.parameters(),
278
- lr=config.LEARNING_RATE,
279
- weight_decay=config.WEIGHT_DECAY,
280
- betas=(0.9, 0.95)
281
- )
282
-
283
- scheduler = torch.optim.lr_scheduler.OneCycleLR(
284
- optimizer,
285
- max_lr=config.LEARNING_RATE,
286
- total_steps=config.MAX_STEPS,
287
- pct_start=config.WARMUP_STEPS / config.MAX_STEPS,
288
- )
289
-
290
- scaler = torch.cuda.amp.GradScaler(enabled=config.USE_MIXED_PRECISION)
291
-
292
- # Training loop
293
- logger.info("🎯 Starting training...")
294
- logger.info(f"Target steps: {config.MAX_STEPS}")
295
- logger.info(f"Effective batch size: {config.TOTAL_BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS}")
296
-
297
- step = 0
298
- running_loss = 0.0
299
- start_time = time.time()
300
-
301
- for epoch in range(100): # Large number
302
- for batch_idx, batch in enumerate(train_loader):
303
- # Training step
304
- loss, telemetry, safety_penalty = train_step(
305
- model, batch, optimizer, scaler, config
306
- )
307
- running_loss += loss
308
-
309
- # Gradient accumulation
310
- if (batch_idx + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0:
311
- # Gradient clipping
312
- scaler.unscale_(optimizer)
313
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
314
-
315
- # Optimizer step
316
- scaler.step(optimizer)
317
- scaler.update()
318
- scheduler.step()
319
- optimizer.zero_grad()
320
-
321
- step += 1
322
-
323
- # Logging
324
- if step % 10 == 0:
325
- avg_loss = running_loss / 10
326
- elapsed = time.time() - start_time
327
- samples_per_sec = (config.TOTAL_BATCH_SIZE * 10) / elapsed
328
- memory_used = torch.cuda.max_memory_allocated() / (1024**3)
329
-
330
- logger.info(
331
- f"Step {step:4d} | "
332
- f"Loss: {avg_loss:.4f} | "
333
- f"K: {telemetry.get('negentropy', 0):.3f} | "
334
- f"C: {telemetry.get('lz_complexity', 0):.3f} | "
335
- f"S: {telemetry.get('symbiosis', 0):.3f} | "
336
- f"LR: {scheduler.get_last_lr()[0]:.2e} | "
337
- f"Speed: {samples_per_sec:.1f} samp/s | "
338
- f"Mem: {memory_used:.1f}GB"
339
- + (f" | Safety: {safety_penalty:.3f}" if safety_penalty > 0 else "")
340
- )
341
-
342
- running_loss = 0.0
343
- start_time = time.time()
344
-
345
- # Validation
346
- if step % 100 == 0:
347
- model.eval()
348
- set_dropout(model, 0.0)
349
- val_loss = 0
350
-
351
- with torch.no_grad():
352
- for val_batch in val_loader:
353
- val_input_ids = val_batch['input_ids'].cuda()
354
- val_labels = val_batch['labels'].cuda()
355
-
356
- with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION):
357
- val_outputs = model(val_input_ids)
358
- if isinstance(val_outputs, tuple):
359
- val_logits, _ = val_outputs
360
- else:
361
- val_logits = val_outputs
362
-
363
- val_loss += F.cross_entropy(
364
- val_logits.view(-1, 2),
365
- val_labels.view(-1)
366
- ).item()
367
-
368
- val_loss /= len(val_loader)
369
- logger.info(f"📊 Validation Loss: {val_loss:.4f}")
370
-
371
- # Save checkpoint
372
- if step % 500 == 0:
373
- checkpoint_dir = f"/data/checkpoints/massive_simple_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
374
- os.makedirs(checkpoint_dir, exist_ok=True)
375
-
376
- torch.save({
377
- 'step': step,
378
- 'model_state_dict': model.state_dict(),
379
- 'optimizer_state_dict': optimizer.state_dict(),
380
- 'scheduler_state_dict': scheduler.state_dict(),
381
- 'config': config.get_model_config(),
382
- }, f"{checkpoint_dir}/checkpoint_step_{step:06d}.pt")
383
-
384
- logger.info(f"💾 Checkpoint saved: step {step}")
385
-
386
- if step >= config.MAX_STEPS:
387
- logger.info("🏁 Training completed!")
388
- return
389
-
390
- if step >= config.MAX_STEPS:
391
- break
392
-
393
-
394
- if __name__ == "__main__":
395
- main()