Uday commited on
Commit
c8c8629
·
1 Parent(s): 80dd9c4

Add HF training integration and fix binary file tracking

Browse files
Files changed (6) hide show
  1. .gitignore +3 -1
  2. Dockerfile +33 -0
  3. GUIDE_HF.md +103 -0
  4. pixi.lock +34 -0
  5. pixi.toml +1 -0
  6. tasks/image_classification/train_energy.py +146 -110
.gitignore CHANGED
@@ -26,4 +26,6 @@ utils/hugging_face/
26
  # pixi environments
27
  .pixi/*
28
  !.pixi/config.toml
29
- changes.md
 
 
 
26
  # pixi environments
27
  .pixi/*
28
  !.pixi/config.toml
29
+ changes.md
30
+ assets/activations.gif
31
+ examples/goldfish.jpg
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ghcr.io/prefix-dev/pixi:0.39.0 AS builder
2
+
3
+ # Copy source code
4
+ COPY . /app
5
+ WORKDIR /app
6
+
7
+ # Install dependencies
8
+ RUN pixi install
9
+
10
+ # Create a shell script to run the training
11
+ # We need to activate the environment
12
+ RUN echo '#!/bin/bash' > /app/entrypoint.sh && \
13
+ echo 'pixi run python tasks/image_classification/train_energy.py "$@"' >> /app/entrypoint.sh && \
14
+ chmod +x /app/entrypoint.sh
15
+
16
+ # Runtime image (optional, but good for size)
17
+ # For simplicity, we'll just use the builder image for now as it has everything.
18
+ # But HF Spaces might need specific permissions.
19
+
20
+ # Set up user for HF Spaces (optional but recommended)
21
+ # RUN useradd -m -u 1000 user
22
+ # USER user
23
+ # ENV HOME=/home/user \
24
+ # PATH=/home/user/.local/bin:$PATH
25
+
26
+ # ENTRYPOINT ["/app/entrypoint.sh"]
27
+ # CMD ["--help"]
28
+
29
+ # Let's try a simpler approach compatible with standard HF Spaces
30
+ # They often just run the CMD.
31
+
32
+ ENTRYPOINT ["pixi", "run", "python", "tasks/image_classification/train_energy.py"]
33
+ CMD ["--energy_head_enabled", "--loss_type", "energy_contrastive", "--push_to_hub", "--hub_model_id", "Uday/ctm-energy-based-halting", "--hub_token", "$HF_TOKEN"]
GUIDE_HF.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training on Hugging Face with GPUs
2
+
3
+ This guide explains how to train the Energy Halting experiment on Hugging Face infrastructure, including local GPU training with `accelerate` and deployment to Hugging Face Spaces.
4
+
5
+ ## Prerequisites
6
+
7
+ 1. **Hugging Face Account**: Create one at [huggingface.co](https://huggingface.co).
8
+ 2. **Access Token**: Get a write token from [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
9
+ 3. **Pixi**: Installed locally.
10
+
11
+ ## 1. Local Training with Accelerate
12
+
13
+ We use Hugging Face `accelerate` for robust multi-GPU and mixed-precision training.
14
+
15
+ ### Setup
16
+
17
+ Ensure dependencies are installed:
18
+
19
+ ```bash
20
+ pixi install
21
+ ```
22
+
23
+ ### Configure Accelerate
24
+
25
+ Run the configuration wizard to set up your GPU environment (e.g., number of GPUs, mixed precision):
26
+
27
+ ```bash
28
+ pixi run accelerate config
29
+ ```
30
+
31
+ ### Run Training
32
+
33
+ Use `accelerate launch` to start training. This handles device placement automatically.
34
+
35
+ ```bash
36
+ pixi run accelerate launch tasks/image_classification/train_energy.py \
37
+ --energy_head_enabled \
38
+ --loss_type energy_contrastive \
39
+ --dataset cifar10 \
40
+ --batch_size 32 \
41
+ --use_amp \
42
+ --push_to_hub \
43
+ --hub_model_id <your-username>/ctm-energy-cifar10 \
44
+ --hub_token <your-token>
45
+ ```
46
+
47
+ ## 2. Deploying to Hugging Face Spaces (GPU)
48
+
49
+ You can run this training job on a Hugging Face Space with a GPU.
50
+
51
+ ### Create a Space
52
+
53
+ 1. Go to [huggingface.co/new-space](https://huggingface.co/new-space).
54
+ 2. Name: `ctm-energy-training` (or similar).
55
+ 3. SDK: **Docker**.
56
+ 4. Hardware: Choose a **GPU** instance (e.g., Nvidia T4, A10G).
57
+
58
+ ### Deploy Code
59
+
60
+ You can deploy by pushing your code to the Space's repository.
61
+
62
+ 1. **Clone the Space**:
63
+
64
+ ```bash
65
+ git clone https://huggingface.co/spaces/<your-username>/ctm-energy-training
66
+ cd ctm-energy-training
67
+ ```
68
+
69
+ 2. **Copy Files**:
70
+ Copy your project files into this directory (excluding `.git`, `.pixi`, `data`, `logs`).
71
+ _Crucially, ensure `Dockerfile`, `pixi.toml`, `pixi.lock`, `tasks/`, `models/`, `utils/`, and `configs/` are present._
72
+
73
+ 3. **Push**:
74
+ ```bash
75
+ git add .
76
+ git commit -m "Deploy training job"
77
+ git push
78
+ ```
79
+
80
+ ### Environment Variables
81
+
82
+ To allow the Space to push the trained model back to the Hub, you need to set your HF token as a secret.
83
+
84
+ 1. Go to your Space's **Settings**.
85
+ 2. Scroll to **Variables and secrets**.
86
+ 3. Add a New Secret:
87
+ - Name: `HF_TOKEN`
88
+ - Value: Your write token.
89
+
90
+ ### Update Dockerfile CMD (Optional)
91
+
92
+ The default `Dockerfile` CMD prints help. To run training immediately upon deployment, modify the `CMD` in the `Dockerfile` before pushing:
93
+
94
+ ```dockerfile
95
+ CMD ["--energy_head_enabled", "--loss_type", "energy_contrastive", "--push_to_hub", "--hub_model_id", "<your-username>/ctm-energy-cifar10", "--hub_token", "$HF_TOKEN"]
96
+ ```
97
+
98
+ _Note: You'll need to pass the token via env var or arg._
99
+
100
+ ## 3. Monitoring
101
+
102
+ - **Local**: Check the `logs/` directory or WandB if enabled (`--wandb`).
103
+ - **Spaces**: Check the **Logs** tab in your Space.
pixi.lock CHANGED
@@ -9,6 +9,7 @@ environments:
9
  packages:
10
  osx-arm64:
11
  - conda: https://conda.anaconda.org/conda-forge/noarch/_python_abi3_support-1.0-hd8ed1ab_2.conda
 
12
  - conda: https://conda.anaconda.org/conda-forge/noarch/aiohappyeyeballs-2.6.1-pyhd8ed1ab_0.conda
13
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/aiohttp-3.13.2-py312he52fbff_0.conda
14
  - conda: https://conda.anaconda.org/conda-forge/noarch/aiosignal-1.4.0-pyhd8ed1ab_0.conda
@@ -221,6 +222,7 @@ environments:
221
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pixman-0.46.4-h81086ad_1.conda
222
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/prometheus-cpp-1.3.0-h0967b3e_0.conda
223
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/propcache-0.3.1-py312h998013c_0.conda
 
224
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda
225
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pugixml-1.15-hd3d436d_0.conda
226
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/py-opencv-4.12.0-qt6_py312he92a2c1_607.conda
@@ -302,6 +304,24 @@ packages:
302
  purls: []
303
  size: 8191
304
  timestamp: 1744137672556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  - conda: https://conda.anaconda.org/conda-forge/noarch/aiohappyeyeballs-2.6.1-pyhd8ed1ab_0.conda
306
  sha256: 7842ddc678e77868ba7b92a726b437575b23aaec293bca0d40826f1026d90e27
307
  md5: 18fd895e0e775622906cdabfc3cf0fb4
@@ -3253,6 +3273,20 @@ packages:
3253
  - pkg:pypi/propcache?source=hash-mapping
3254
  size: 51972
3255
  timestamp: 1744525285336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3256
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda
3257
  sha256: 8ed65e17fbb0ca944bfb8093b60086e3f9dd678c3448b5de212017394c247ee3
3258
  md5: 415816daf82e0b23a736a069a75e9da7
 
9
  packages:
10
  osx-arm64:
11
  - conda: https://conda.anaconda.org/conda-forge/noarch/_python_abi3_support-1.0-hd8ed1ab_2.conda
12
+ - conda: https://conda.anaconda.org/conda-forge/noarch/accelerate-1.12.0-pyhcf101f3_0.conda
13
  - conda: https://conda.anaconda.org/conda-forge/noarch/aiohappyeyeballs-2.6.1-pyhd8ed1ab_0.conda
14
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/aiohttp-3.13.2-py312he52fbff_0.conda
15
  - conda: https://conda.anaconda.org/conda-forge/noarch/aiosignal-1.4.0-pyhd8ed1ab_0.conda
 
222
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pixman-0.46.4-h81086ad_1.conda
223
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/prometheus-cpp-1.3.0-h0967b3e_0.conda
224
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/propcache-0.3.1-py312h998013c_0.conda
225
+ - conda: https://conda.anaconda.org/conda-forge/osx-arm64/psutil-7.1.3-py312h37e1c23_0.conda
226
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda
227
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pugixml-1.15-hd3d436d_0.conda
228
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/py-opencv-4.12.0-qt6_py312he92a2c1_607.conda
 
304
  purls: []
305
  size: 8191
306
  timestamp: 1744137672556
307
+ - conda: https://conda.anaconda.org/conda-forge/noarch/accelerate-1.12.0-pyhcf101f3_0.conda
308
+ sha256: 7351587f4771eb96b5858902d34efb4c67c1e579e745d955bc7052e204b029a6
309
+ md5: e02f90d5f2ee4dd409884c49839bf64c
310
+ depends:
311
+ - python >=3.10
312
+ - numpy >=1.17
313
+ - packaging >=20.0
314
+ - psutil
315
+ - pyyaml
316
+ - pytorch >=2.0.0
317
+ - huggingface_hub >=0.21.0
318
+ - safetensors >=0.4.3
319
+ - python
320
+ license: Apache-2.0
321
+ purls:
322
+ - pkg:pypi/accelerate?source=hash-mapping
323
+ size: 272809
324
+ timestamp: 1763737594988
325
  - conda: https://conda.anaconda.org/conda-forge/noarch/aiohappyeyeballs-2.6.1-pyhd8ed1ab_0.conda
326
  sha256: 7842ddc678e77868ba7b92a726b437575b23aaec293bca0d40826f1026d90e27
327
  md5: 18fd895e0e775622906cdabfc3cf0fb4
 
3273
  - pkg:pypi/propcache?source=hash-mapping
3274
  size: 51972
3275
  timestamp: 1744525285336
3276
+ - conda: https://conda.anaconda.org/conda-forge/osx-arm64/psutil-7.1.3-py312h37e1c23_0.conda
3277
+ sha256: cd831dfe655fdb581e1c2c71fa072d2fce38538474a36cbde3ae2dd910a2ae76
3278
+ md5: d0b2f83de57eafaa6d7700b589c66096
3279
+ depends:
3280
+ - python
3281
+ - __osx >=11.0
3282
+ - python 3.12.* *_cpython
3283
+ - python_abi 3.12.* *_cp312
3284
+ license: BSD-3-Clause
3285
+ license_family: BSD
3286
+ purls:
3287
+ - pkg:pypi/psutil?source=hash-mapping
3288
+ size: 508014
3289
+ timestamp: 1762093047823
3290
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda
3291
  sha256: 8ed65e17fbb0ca944bfb8093b60086e3f9dd678c3448b5de212017394c247ee3
3292
  md5: 415816daf82e0b23a736a069a75e9da7
pixi.toml CHANGED
@@ -23,6 +23,7 @@ datasets = "*"
23
  huggingface_hub = "*"
24
  safetensors = "*"
25
  ffmpeg = "*"
 
26
 
27
  [pypi-dependencies]
28
  autoclip = "*"
 
23
  huggingface_hub = "*"
24
  safetensors = "*"
25
  ffmpeg = "*"
26
+ accelerate = ">=1.12.0,<2"
27
 
28
  [pypi-dependencies]
29
  autoclip = "*"
tasks/image_classification/train_energy.py CHANGED
@@ -33,6 +33,9 @@ from utils.housekeeping import set_seed, zip_python_code
33
  from utils.losses import image_classification_loss, EnergyContrastiveLoss # Used by CTM, LSTM
34
  from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
35
 
 
 
 
36
  from autoclip.torch import QuantileClip
37
 
38
  import gc
@@ -127,14 +130,20 @@ def parse_args():
127
  parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
128
  parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
129
  parser.add_argument('--seed', type=int, default=412, help='Random seed.')
130
- parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
131
  parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
132
  parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
133
  parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
134
  parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
135
  parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
136
  parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
137
-
 
 
 
 
 
 
 
138
 
139
  args = parser.parse_args()
140
  return args
@@ -208,8 +217,23 @@ if __name__=='__main__':
208
  # Hosuekeeping
209
  args = parse_args()
210
 
211
- set_seed(args.seed, False)
212
- if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
215
 
@@ -229,12 +253,6 @@ if __name__=='__main__':
229
  print(args, file=f)
230
 
231
  # Configure device string (support MPS on macOS)
232
- if args.device[0] != -1:
233
- device = f'cuda:{args.device[0]}'
234
- elif torch.backends.mps.is_available():
235
- device = 'mps'
236
- else:
237
- device = 'cpu'
238
  print(f'Running model {args.model} on {device}')
239
 
240
  # Build model conditionally
@@ -265,34 +283,31 @@ if __name__=='__main__':
265
  ).to(device)
266
  elif args.model == 'lstm':
267
  model = LSTMBaseline(
268
- num_layers=args.num_layers,
269
- iterations=args.iterations,
270
- d_model=args.d_model,
271
  d_input=args.d_input,
272
- heads=args.heads,
273
- backbone_type=args.backbone_type,
274
- positional_embedding_type=args.positional_embedding_type,
275
  out_dims=args.out_dims,
276
- prediction_reshaper=prediction_reshaper,
277
  dropout=args.dropout,
278
- ).to(device)
279
  elif args.model == 'ff':
280
  model = FFBaseline(
281
- d_model=args.d_model,
282
- backbone_type=args.backbone_type,
283
  out_dims=args.out_dims,
284
  dropout=args.dropout,
285
- ).to(device)
286
  else:
287
- raise ValueError(f"Unknown model type: {args.model}")
 
 
 
 
 
288
 
289
 
290
  # For lazy modules so that we can get param count
291
  pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
292
  model(pseudo_inputs)
293
-
294
- model.train()
295
-
296
 
297
  print(f'Total params: {sum(p.numel() for p in model.parameters())}')
298
  decay_params = []
@@ -332,6 +347,11 @@ if __name__=='__main__':
332
  else:
333
  raise NotImplementedError
334
 
 
 
 
 
 
335
 
336
  # Metrics tracking
337
  start_iter = 0
@@ -344,9 +364,13 @@ if __name__=='__main__':
344
  train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
345
  test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
346
 
 
 
 
347
  # scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
348
  # Fallback for older torch versions or specific builds
349
- scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
 
350
 
351
  # Reloading logic
352
  if args.reload:
@@ -355,14 +379,14 @@ if __name__=='__main__':
355
  print(f'Reloading from: {checkpoint_path}')
356
  checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
357
  if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
358
- load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
359
  print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
360
 
361
  if not args.reload_model_only:
362
  print('Reloading optimizer etc.')
363
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
364
  scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
365
- scaler.load_state_dict(checkpoint['scaler_state_dict'])
366
  start_iter = checkpoint['iteration']
367
  # Load common metrics
368
  train_losses = checkpoint['train_losses']
@@ -414,59 +438,58 @@ if __name__=='__main__':
414
  iterator = iter(trainloader)
415
  inputs, targets = next(iterator)
416
 
417
- inputs = inputs.to(device)
418
- targets = targets.to(device)
419
 
420
  loss = None
421
  accuracy = None
422
  # Model-specific forward and loss calculation
423
- with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp):
424
- if args.do_compile: # CUDAGraph marking for clean compile
425
- torch.compiler.cudagraph_mark_step_begin()
426
-
427
- if args.model == 'ctm':
428
- if args.energy_head_enabled:
429
- predictions, certainties, energies = model(inputs)
430
- if args.loss_type == 'energy_contrastive':
431
- criterion = EnergyContrastiveLoss(margin=args.energy_margin, energy_scale=args.energy_scale)
432
- loss, stats = criterion(predictions, energies, targets)
433
- # Use standard accuracy metric for now
434
- where_most_certain = certainties[:,1].argmax(-1)
435
- accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
436
- pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Avg Energy={stats["avg_energy"]:0.3f}'
437
- else:
438
- # Fallback to standard loss even if energy head is enabled (but unused)
439
- loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
440
- accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
441
- pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}'
442
- else:
443
- predictions, certainties, synchronisation = model(inputs)
444
- loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
445
  accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
446
- pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
447
-
448
- elif args.model == 'lstm':
 
 
 
 
449
  predictions, certainties, synchronisation = model(inputs)
450
  loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
451
- # LSTM where_most_certain will just be -1 because use_most_certain is False owing to stability issues with LSTM training
452
  accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
453
- pbar_desc = f'LSTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
454
-
455
- elif args.model == 'ff':
456
- predictions = model(inputs)
457
- loss = nn.CrossEntropyLoss()(predictions, targets)
458
- accuracy = (predictions.argmax(1) == targets).float().mean().item()
459
- pbar_desc = f'FF Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}'
460
-
461
- scaler.scale(loss).backward()
462
-
463
- if args.gradient_clipping!=-1:
464
- scaler.unscale_(optimizer)
465
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
466
-
467
- scaler.step(optimizer)
468
- scaler.update()
469
- optimizer.zero_grad(set_to_none=True)
 
 
 
 
 
 
470
  scheduler.step()
471
 
472
  pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
@@ -493,16 +516,16 @@ if __name__=='__main__':
493
 
494
  pbar.set_description('Tracking: Computing TRAIN metrics')
495
  with torch.no_grad(): # Should use inference_mode? CTM/LSTM scripts used no_grad
496
- loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
497
  all_targets_list = []
498
  all_predictions_list = [] # List to store raw predictions (B, C, T) or (B, C)
499
  all_predictions_most_certain_list = [] # Only for CTM/LSTM
500
  all_losses = []
501
 
502
- with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
503
- for inferi, (inputs, targets) in enumerate(loader):
504
- inputs = inputs.to(device)
505
- targets = targets.to(device)
506
  all_targets_list.append(targets.detach().cpu().numpy())
507
 
508
  # Model-specific forward and loss for evaluation
@@ -552,16 +575,16 @@ if __name__=='__main__':
552
  model.eval()
553
  pbar.set_description('Tracking: Computing TEST metrics')
554
  with torch.inference_mode(): # Use inference_mode for test eval
555
- loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
556
  all_targets_list = []
557
  all_predictions_list = []
558
  all_predictions_most_certain_list = [] # Only for CTM/LSTM
559
  all_losses = []
560
 
561
- with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
562
- for inferi, (inputs, targets) in enumerate(loader):
563
- inputs = inputs.to(device)
564
- targets = targets.to(device)
565
  all_targets_list.append(targets.detach().cpu().numpy())
566
 
567
  # Model-specific forward and loss for evaluation
@@ -655,13 +678,13 @@ if __name__=='__main__':
655
  if args.model in ['ctm', 'lstm']:
656
  try: # For safety
657
  inputs_viz, targets_viz = next(iter(testloader)) # Get a fresh batch
658
- inputs_viz = inputs_viz.to(device)
659
- targets_viz = targets_viz.to(device)
660
 
661
  pbar.set_description('Tracking: Processing test data for viz')
662
  predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
663
 
664
- att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])
665
  attention_tracking_viz = attention_tracking_viz.reshape(
666
  attention_tracking_viz.shape[0],
667
  attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
@@ -694,32 +717,45 @@ if __name__=='__main__':
694
  model.train() # Switch back to train mode
695
 
696
 
 
697
  # Save model checkpoint (conditional metrics)
698
  if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
699
- pbar.set_description('Saving model checkpoint...')
700
- checkpoint_data = {
701
- 'model_state_dict': model.state_dict(),
702
- 'optimizer_state_dict': optimizer.state_dict(),
703
- 'scheduler_state_dict': scheduler.state_dict(),
704
- 'scaler_state_dict': scaler.state_dict(),
705
- 'iteration': bi,
706
- # Always save these
707
- 'train_losses': train_losses,
708
- 'test_losses': test_losses,
709
- 'train_accuracies': train_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
710
- 'test_accuracies': test_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
711
- 'iters': iters,
712
- 'args': args, # Save args used for this run
713
- # RNG states
714
- 'torch_rng_state': torch.get_rng_state(),
715
- 'numpy_rng_state': np.random.get_state(),
716
- 'random_rng_state': random.getstate(),
717
- }
718
- # Conditionally add metrics specific to CTM/LSTM
719
- if args.model in ['ctm', 'lstm']:
720
- checkpoint_data['train_accuracies_most_certain'] = train_accuracies_most_certain
721
- checkpoint_data['test_accuracies_most_certain'] = test_accuracies_most_certain
722
 
723
- torch.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
 
725
  pbar.update(1)
 
33
  from utils.losses import image_classification_loss, EnergyContrastiveLoss # Used by CTM, LSTM
34
  from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
35
 
36
+ from accelerate import Accelerator
37
+ from huggingface_hub import upload_folder
38
+
39
  from autoclip.torch import QuantileClip
40
 
41
  import gc
 
130
  parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
131
  parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
132
  parser.add_argument('--seed', type=int, default=412, help='Random seed.')
 
133
  parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
134
  parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
135
  parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
136
  parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
137
  parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
138
  parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
139
+ parser.add_argument('--reload', type=str, default=None, help='Path to checkpoint to reload from.')
140
+ parser.add_argument('--wandb', action=argparse.BooleanOptionalAction, default=False, help='Log to WandB.')
141
+
142
+ # HF Hub
143
+ parser.add_argument('--push_to_hub', action=argparse.BooleanOptionalAction, default=False, help='Push model to HF Hub.')
144
+ parser.add_argument('--hub_model_id', type=str, default=None, help='HF Hub model ID (e.g., username/repo).')
145
+ parser.add_argument('--hub_token', type=str, default=None, help='HF Hub token.')
146
+ parser.add_argument('--hub_private', action=argparse.BooleanOptionalAction, default=False, help='Make HF Hub repo private.')
147
 
148
  args = parser.parse_args()
149
  return args
 
217
  # Hosuekeeping
218
  args = parse_args()
219
 
220
+ set_seed(args.seed)
221
+
222
+ # Initialize Accelerator
223
+ accelerator = Accelerator(log_with="wandb" if args.wandb else None)
224
+ device = accelerator.device
225
+
226
+ # Setup Logging
227
+ if accelerator.is_main_process:
228
+ if not os.path.exists(args.log_dir):
229
+ os.makedirs(args.log_dir)
230
+ print(f"Logging to {args.log_dir}")
231
+ if args.wandb:
232
+ accelerator.init_trackers(
233
+ project_name="continuous-thought-machines",
234
+ config=vars(args),
235
+ init_kwargs={"wandb": {"name": args.log_dir.split('/')[-1]}}
236
+ )
237
 
238
  assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
239
 
 
253
  print(args, file=f)
254
 
255
  # Configure device string (support MPS on macOS)
 
 
 
 
 
 
256
  print(f'Running model {args.model} on {device}')
257
 
258
  # Build model conditionally
 
283
  ).to(device)
284
  elif args.model == 'lstm':
285
  model = LSTMBaseline(
286
+ d_model=args.d_model,
 
 
287
  d_input=args.d_input,
288
+ num_layers=args.num_layers,
 
 
289
  out_dims=args.out_dims,
 
290
  dropout=args.dropout,
291
+ )
292
  elif args.model == 'ff':
293
  model = FFBaseline(
294
+ d_model=args.d_model,
295
+ d_input=args.d_input,
296
  out_dims=args.out_dims,
297
  dropout=args.dropout,
298
+ )
299
  else:
300
+ raise NotImplementedError
301
+
302
+ model.train()
303
+
304
+
305
+ # Param counting moved after initialization
306
 
307
 
308
  # For lazy modules so that we can get param count
309
  pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
310
  model(pseudo_inputs)
 
 
 
311
 
312
  print(f'Total params: {sum(p.numel() for p in model.parameters())}')
313
  decay_params = []
 
347
  else:
348
  raise NotImplementedError
349
 
350
+ # Prepare with Accelerator
351
+ # Note: Accelerate handles device placement
352
+ model, optimizer, trainloader, testloader, scheduler = accelerator.prepare(
353
+ model, optimizer, trainloader, testloader, scheduler
354
+ )
355
 
356
  # Metrics tracking
357
  start_iter = 0
 
364
  train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
365
  test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
366
 
367
+ train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
368
+ test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
369
+
370
  # scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
371
  # Fallback for older torch versions or specific builds
372
+ # scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
373
+ # Accelerate handles mixed precision automatically
374
 
375
  # Reloading logic
376
  if args.reload:
 
379
  print(f'Reloading from: {checkpoint_path}')
380
  checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
381
  if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
382
+ load_result = accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
383
  print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
384
 
385
  if not args.reload_model_only:
386
  print('Reloading optimizer etc.')
387
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
388
  scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
389
+ # scaler.load_state_dict(checkpoint['scaler_state_dict']) # Scaler is handled by accelerator
390
  start_iter = checkpoint['iteration']
391
  # Load common metrics
392
  train_losses = checkpoint['train_losses']
 
438
  iterator = iter(trainloader)
439
  inputs, targets = next(iterator)
440
 
441
+ # inputs = inputs.to(device) # Handled by accelerator.prepare
442
+ # targets = targets.to(device) # Handled by accelerator.prepare
443
 
444
  loss = None
445
  accuracy = None
446
  # Model-specific forward and loss calculation
447
+ # with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp): # Handled by accelerator
448
+ if args.do_compile: # CUDAGraph marking for clean compile
449
+ torch.compiler.cudagraph_mark_step_begin()
450
+
451
+ if args.model == 'ctm':
452
+ if args.energy_head_enabled:
453
+ predictions, certainties, energies = model(inputs)
454
+ if args.loss_type == 'energy_contrastive':
455
+ criterion = EnergyContrastiveLoss(margin=args.energy_margin, energy_scale=args.energy_scale)
456
+ loss, stats = criterion(predictions, energies, targets)
457
+ # Use standard accuracy metric for now
458
+ where_most_certain = certainties[:,1].argmax(-1)
 
 
 
 
 
 
 
 
 
 
459
  accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
460
+ pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Avg Energy={stats["avg_energy"]:0.3f}'
461
+ else:
462
+ # Fallback to standard loss even if energy head is enabled (but unused)
463
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
464
+ accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
465
+ pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}'
466
+ else:
467
  predictions, certainties, synchronisation = model(inputs)
468
  loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
 
469
  accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
470
+ pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
471
+
472
+ elif args.model == 'lstm':
473
+ predictions, certainties, synchronisation = model(inputs)
474
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
475
+ # LSTM where_most_certain will just be -1 because use_most_certain is False owing to stability issues with LSTM training
476
+ accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
477
+ pbar_desc = f'LSTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
478
+
479
+ elif args.model == 'ff':
480
+ predictions = model(inputs)
481
+ loss = nn.CrossEntropyLoss()(predictions, targets)
482
+ accuracy = (predictions.argmax(1) == targets).float().mean().item()
483
+ pbar_desc = f'FF Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}'
484
+
485
+ # Backward pass with Accelerate
486
+ accelerator.backward(loss)
487
+
488
+ if args.gradient_clipping > 0:
489
+ accelerator.clip_grad_norm_(model.parameters(), args.gradient_clipping)
490
+
491
+ optimizer.step()
492
+ optimizer.zero_grad()
493
  scheduler.step()
494
 
495
  pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
 
516
 
517
  pbar.set_description('Tracking: Computing TRAIN metrics')
518
  with torch.no_grad(): # Should use inference_mode? CTM/LSTM scripts used no_grad
519
+ # loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test) # Use prepared loader
520
  all_targets_list = []
521
  all_predictions_list = [] # List to store raw predictions (B, C, T) or (B, C)
522
  all_predictions_most_certain_list = [] # Only for CTM/LSTM
523
  all_losses = []
524
 
525
+ with tqdm(total=len(trainloader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
526
+ for inferi, (inputs, targets) in enumerate(trainloader):
527
+ # inputs = inputs.to(device) # Handled by accelerator.prepare
528
+ # targets = targets.to(device) # Handled by accelerator.prepare
529
  all_targets_list.append(targets.detach().cpu().numpy())
530
 
531
  # Model-specific forward and loss for evaluation
 
575
  model.eval()
576
  pbar.set_description('Tracking: Computing TEST metrics')
577
  with torch.inference_mode(): # Use inference_mode for test eval
578
+ # loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test) # Use prepared loader
579
  all_targets_list = []
580
  all_predictions_list = []
581
  all_predictions_most_certain_list = [] # Only for CTM/LSTM
582
  all_losses = []
583
 
584
+ with tqdm(total=len(testloader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
585
+ for inferi, (inputs, targets) in enumerate(testloader):
586
+ # inputs = inputs.to(device) # Handled by accelerator.prepare
587
+ # targets = targets.to(device) # Handled by accelerator.prepare
588
  all_targets_list.append(targets.detach().cpu().numpy())
589
 
590
  # Model-specific forward and loss for evaluation
 
678
  if args.model in ['ctm', 'lstm']:
679
  try: # For safety
680
  inputs_viz, targets_viz = next(iter(testloader)) # Get a fresh batch
681
+ # inputs_viz = inputs_viz.to(device) # Handled by accelerator.prepare
682
+ # targets_viz = targets_viz.to(device) # Handled by accelerator.prepare
683
 
684
  pbar.set_description('Tracking: Processing test data for viz')
685
  predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
686
 
687
+ att_shape = (accelerator.unwrap_model(model).kv_features.shape[2], accelerator.unwrap_model(model).kv_features.shape[3])
688
  attention_tracking_viz = attention_tracking_viz.reshape(
689
  attention_tracking_viz.shape[0],
690
  attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
 
717
  model.train() # Switch back to train mode
718
 
719
 
720
+ # Save model checkpoint (conditional metrics)
721
  # Save model checkpoint (conditional metrics)
722
  if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
723
+ if accelerator.is_main_process:
724
+ pbar.set_description('Saving model checkpoint...')
725
+ checkpoint_data = {
726
+ 'model_state_dict': accelerator.unwrap_model(model).state_dict(),
727
+ 'optimizer_state_dict': optimizer.state_dict(),
728
+ 'scheduler_state_dict': scheduler.state_dict(),
729
+ 'iteration': bi,
730
+ 'train_losses': train_losses,
731
+ 'test_losses': test_losses,
732
+ 'train_accuracies': train_accuracies,
733
+ 'test_accuracies': test_accuracies,
734
+ 'iters': iters,
735
+ 'args': args,
736
+ 'torch_rng_state': torch.get_rng_state(),
737
+ 'numpy_rng_state': np.random.get_state(),
738
+ 'random_rng_state': random.getstate(),
739
+ }
740
+
741
+ if args.model in ['ctm', 'lstm']:
742
+ checkpoint_data['train_accuracies_most_certain'] = train_accuracies_most_certain
743
+ checkpoint_data['test_accuracies_most_certain'] = test_accuracies_most_certain
 
 
744
 
745
+ accelerator.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
746
+
747
+ # Push to Hub
748
+ if args.push_to_hub and args.hub_model_id:
749
+ if bi % (args.save_every * 5) == 0: # Upload less frequently
750
+ try:
751
+ upload_folder(
752
+ folder_path=args.log_dir,
753
+ repo_id=args.hub_model_id,
754
+ token=args.hub_token,
755
+ commit_message=f"Training checkpoint {bi}",
756
+ ignore_patterns=["*.pt"],
757
+ )
758
+ except Exception as e:
759
+ print(f"Failed to upload to hub: {e}")
760
 
761
  pbar.update(1)