AbstractPhil commited on
Commit
3aac63e
·
verified ·
1 Parent(s): e30e73e

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +705 -135
trainer.py CHANGED
@@ -1,26 +1,46 @@
1
  """
2
- MobiusNet - CIFAR-100 (Dynamic Stages)
3
- ======================================
4
- Properly handles variable stage counts.
5
-
6
- Author: AbstractPhil
7
- https://huggingface.co/AbstractPhil/mobiusnet
8
- License: Apache 2.0
9
  """
10
 
 
 
 
11
  import math
 
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
  from torch import Tensor
16
- from typing import Tuple
17
  from torchvision import datasets, transforms
18
  from torch.utils.data import DataLoader
 
19
  from tqdm.auto import tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
  print(f"Device: {device}")
23
 
 
 
 
 
 
24
 
25
  # ============================================================================
26
  # MÖBIUS LENS
@@ -36,6 +56,9 @@ class MobiusLens(nn.Module):
36
  ):
37
  super().__init__()
38
 
 
 
 
39
  self.t = layer_idx / max(total_layers - 1, 1)
40
 
41
  scale_span = scale_range[1] - scale_range[0]
@@ -45,12 +68,10 @@ class MobiusLens(nn.Module):
45
 
46
  self.register_buffer('scales', torch.tensor([scale_low, scale_high]))
47
 
48
- # TWIST IN
49
  self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi))
50
  self.twist_in_proj = nn.Linear(dim, dim, bias=False)
51
  nn.init.orthogonal_(self.twist_in_proj.weight)
52
 
53
- # CENTER LENS
54
  self.omega = nn.Parameter(torch.tensor(math.pi))
55
  self.alpha = nn.Parameter(torch.tensor(1.5))
56
 
@@ -64,7 +85,8 @@ class MobiusLens(nn.Module):
64
  self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4]))
65
  self.xor_weight = nn.Parameter(torch.tensor(0.7))
66
 
67
- # TWIST OUT
 
68
  self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi))
69
  self.twist_out_proj = nn.Linear(dim, dim, bias=False)
70
  nn.init.orthogonal_(self.twist_out_proj.weight)
@@ -99,9 +121,9 @@ class MobiusLens(nn.Module):
99
 
100
  gate = w[0] * L + w[1] * M + w[2] * R
101
  gate = gate * (0.5 + 0.5 * lr)
102
- gate = gate / (gate.mean() + 1e-6) * 0.5
103
 
104
- return x * gate.clamp(0, 1)
105
 
106
  def _twist_out(self, x: Tensor) -> Tensor:
107
  cos_t = torch.cos(self.twist_out_angle)
@@ -110,6 +132,19 @@ class MobiusLens(nn.Module):
110
 
111
  def forward(self, x: Tensor) -> Tensor:
112
  return self._twist_out(self._center_lens(self._twist_in(x)))
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  # ============================================================================
@@ -157,25 +192,24 @@ class MobiusConvBlock(nn.Module):
157
 
158
  rw = torch.sigmoid(self.residual_weight)
159
  return rw * identity + (1 - rw) * h
 
 
 
160
 
161
 
162
  # ============================================================================
163
- # MÖBIUS NET - DYNAMIC STAGES
164
  # ============================================================================
165
 
166
  class MobiusNet(nn.Module):
167
- """
168
- Pure conv with Möbius topology.
169
- Dynamic number of stages based on len(depths).
170
- """
171
-
172
  def __init__(
173
  self,
174
  in_chans: int = 3,
175
- num_classes: int = 100,
176
- channels: Tuple[int, ...] = (64, 64, 128, 128),
177
- depths: Tuple[int, ...] = (8, 4, 2),
178
  scale_range: Tuple[float, float] = (0.5, 2.5),
 
179
  ):
180
  super().__init__()
181
 
@@ -184,22 +218,22 @@ class MobiusNet(nn.Module):
184
 
185
  self.total_layers = total_layers
186
  self.scale_range = scale_range
187
- self.channels = channels
188
- self.depths = depths
189
  self.num_stages = num_stages
 
 
 
190
 
191
- # Ensure we have enough channel specs
192
  channels = list(channels)
193
  while len(channels) < num_stages:
194
  channels.append(channels[-1])
195
 
196
- # Stem
197
  self.stem = nn.Sequential(
198
- nn.Conv2d(in_chans, channels[0], 3, padding=1, bias=False),
199
  nn.BatchNorm2d(channels[0]),
200
  )
201
 
202
- # Build stages dynamically
203
  layer_idx = 0
204
  self.stages = nn.ModuleList()
205
  self.downsamples = nn.ModuleList()
@@ -207,16 +241,12 @@ class MobiusNet(nn.Module):
207
  for stage_idx in range(num_stages):
208
  ch = channels[stage_idx]
209
 
210
- # Stage blocks
211
  stage = nn.ModuleList()
212
  for _ in range(depths[stage_idx]):
213
- stage.append(MobiusConvBlock(
214
- ch, layer_idx, total_layers, scale_range
215
- ))
216
  layer_idx += 1
217
  self.stages.append(stage)
218
 
219
- # Downsample between stages (not after last)
220
  if stage_idx < num_stages - 1:
221
  ch_next = channels[stage_idx + 1]
222
  self.downsamples.append(nn.Sequential(
@@ -224,9 +254,18 @@ class MobiusNet(nn.Module):
224
  nn.BatchNorm2d(ch_next),
225
  ))
226
 
227
- # Head
 
 
 
 
 
 
 
 
 
228
  self.pool = nn.AdaptiveAvgPool2d(1)
229
- self.head = nn.Linear(channels[num_stages - 1], num_classes)
230
 
231
  def forward(self, x: Tensor) -> Tensor:
232
  x = self.stem(x)
@@ -237,40 +276,104 @@ class MobiusNet(nn.Module):
237
  if i < len(self.downsamples):
238
  x = self.downsamples[i](x)
239
 
 
240
  return self.head(self.pool(x).flatten(1))
241
 
242
- def get_info(self) -> str:
243
- return (
244
- f"MobiusNet: channels={self.channels}, depths={self.depths}, "
245
- f"total_layers={self.total_layers}, scale_range={self.scale_range}"
246
- )
 
 
 
 
 
 
 
247
 
248
- def get_topology_info(self) -> str:
249
- lines = ["Möbius Ribbon Topology:"]
250
- lines.append("=" * 60)
251
-
252
- scale_span = self.scale_range[1] - self.scale_range[0]
253
  layer_idx = 0
254
-
255
- for stage_idx, depth in enumerate(self.depths):
256
- ch = self.channels[stage_idx] if stage_idx < len(self.channels) else self.channels[-1]
257
- for local_idx in range(depth):
258
- t = layer_idx / max(self.total_layers - 1, 1)
259
- scale_low = self.scale_range[0] + t * scale_span
260
- scale_high = scale_low + scale_span / self.total_layers
261
-
262
- lines.append(
263
- f"Layer {layer_idx:2d} (Stage {stage_idx+1}, ch={ch:3d}): "
264
- f"t={t:.3f}, scales=[{scale_low:.3f}, {scale_high:.3f}]"
265
- )
266
  layer_idx += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
- if stage_idx < self.num_stages - 1:
269
- ch_next = self.channels[stage_idx + 1] if stage_idx + 1 < len(self.channels) else self.channels[-1]
270
- lines.append(f" ↓ Downsample {ch} → {ch_next}")
271
-
272
- lines.append("=" * 60)
273
- return "\n".join(lines)
 
 
 
 
 
 
 
 
 
274
 
275
 
276
  # ============================================================================
@@ -278,112 +381,520 @@ class MobiusNet(nn.Module):
278
  # ============================================================================
279
 
280
  PRESETS = {
281
- 'mobius_xs': {
282
- 'channels': (64, 64, 128),
283
- 'depths': (4, 2, 2),
284
  'scale_range': (0.5, 2.5),
285
  },
286
- 'mobius_stretched': {
287
- 'channels': (32, 64, 96, 128, 192, 256, 320, 384, 448),
288
- 'depths': (4, 4, 4, 3, 3, 3, 2, 2, 2),
289
- 'scale_range': (0.2915, 2.85),
290
  },
291
- 'mobius_m': {
292
- 'channels': (64, 128, 256, 256),
293
- 'depths': (8, 4, 2),
294
- 'scale_range': (0.5, 3.0),
295
- },
296
- 'mobius_deep': {
297
- 'channels': (64, 64, 128, 128),
298
- 'depths': (12, 6, 4),
299
  'scale_range': (0.5, 3.5),
300
  },
301
- 'mobius_wide': {
302
- 'channels': (96, 96, 192, 192),
303
- 'depths': (8, 4, 2),
304
- 'scale_range': (0.5, 2.5),
305
  },
306
  }
307
 
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  # ============================================================================
310
  # TRAINING
311
  # ============================================================================
312
 
313
- def train_mobius_cifar100(
314
- preset: str = 'mobius_s',
315
  epochs: int = 100,
316
  lr: float = 1e-3,
317
  batch_size: int = 128,
318
- use_autoaugment: bool = True,
 
 
 
 
 
 
 
 
319
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  config = PRESETS[preset]
 
321
 
322
  print("=" * 70)
323
- print(f"MÖBIUS NET - {preset.upper()} - CIFAR-100")
324
  print("=" * 70)
325
  print(f"Device: {device}")
326
  print(f"Channels: {config['channels']}")
327
  print(f"Depths: {config['depths']}")
328
  print(f"Scale range: {config['scale_range']}")
329
- print(f"AutoAugment: {use_autoaugment}")
 
 
330
  print()
331
 
332
- # CIFAR-100 normalization
333
- mean = (0.5071, 0.4867, 0.4408)
334
- std = (0.2675, 0.2565, 0.2761)
 
 
 
335
 
336
- train_transforms = [
337
- transforms.RandomCrop(32, padding=4),
338
- transforms.RandomHorizontalFlip(),
339
- ]
340
- if use_autoaugment:
341
- train_transforms.append(transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10))
342
- train_transforms.extend([
343
- transforms.ToTensor(),
344
- transforms.Normalize(mean, std),
345
- ])
346
-
347
- train_tf = transforms.Compose(train_transforms)
348
- test_tf = transforms.Compose([
349
- transforms.ToTensor(),
350
- transforms.Normalize(mean, std),
351
- ])
352
-
353
- train_ds = datasets.CIFAR100('./data', train=True, download=True, transform=train_tf)
354
- test_ds = datasets.CIFAR100('./data', train=False, download=True, transform=test_tf)
355
-
356
- train_loader = DataLoader(
357
- train_ds, batch_size=batch_size, shuffle=True,
358
- num_workers=8, pin_memory=True, persistent_workers=True
359
- )
360
- test_loader = DataLoader(
361
- test_ds, batch_size=256, num_workers=2, pin_memory=True, persistent_workers=True,
362
  )
363
 
 
 
 
 
364
  model = MobiusNet(
365
  in_chans=3,
366
- num_classes=100,
 
367
  **config
368
  ).to(device)
369
 
370
- print(model.get_info())
371
- print()
372
- print(model.get_topology_info())
373
- print()
374
-
375
- model.compile(mode='reduce-overhead')
376
-
377
  total_params = sum(p.numel() for p in model.parameters())
378
  print(f"Total params: {total_params:,}")
379
  print()
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
382
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
383
 
 
 
384
  best_acc = 0.0
385
 
386
- for epoch in range(1, epochs + 1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  model.train()
388
  train_loss, train_correct, train_total = 0, 0, 0
389
 
@@ -406,30 +917,76 @@ def train_mobius_cifar100(
406
 
407
  scheduler.step()
408
 
 
409
  model.eval()
410
  val_correct, val_total = 0, 0
411
  with torch.no_grad():
412
- for x, y in test_loader:
413
  x, y = x.to(device), y.to(device)
414
  logits = model(x)
415
  val_correct += (logits.argmax(1) == y).sum().item()
416
  val_total += x.size(0)
417
 
 
418
  train_acc = train_correct / train_total
419
  val_acc = val_correct / val_total
420
- best_acc = max(best_acc, val_acc)
421
- marker = " ★" if val_acc >= best_acc else ""
422
 
423
- print(f"Epoch {epoch:3d} | Loss: {train_loss/train_total:.4f} | "
 
 
 
 
 
424
  f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
  print()
427
  print("=" * 70)
428
  print("FINAL RESULTS")
429
  print("=" * 70)
430
- print(model.get_info())
431
  print(f"Best accuracy: {best_acc:.4f}")
432
  print(f"Total params: {total_params:,}")
 
433
  print("=" * 70)
434
 
435
  return model, best_acc
@@ -440,9 +997,22 @@ def train_mobius_cifar100(
440
  # ============================================================================
441
 
442
  if __name__ == '__main__':
443
- model, best_acc = train_mobius_cifar100(
444
- preset='mobius_stretched', # channels=(64, 64, 128, 128), depths=(8, 4, 2)
445
- epochs=100,
446
- lr=1e-3,
447
- use_autoaugment=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  )
 
1
  """
2
+ MobiusNet Trainer with TensorBoard, SafeTensors, and HuggingFace Upload
3
+ =======================================================================
 
 
 
 
 
4
  """
5
 
6
+ import os
7
+ import re
8
+ import json
9
  import math
10
+ import shutil
11
  import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
14
  from torch import Tensor
15
+ from typing import Tuple, Optional, Dict, Any
16
  from torchvision import datasets, transforms
17
  from torch.utils.data import DataLoader
18
+ from torch.utils.tensorboard import SummaryWriter
19
  from tqdm.auto import tqdm
20
+ from datetime import datetime
21
+ from pathlib import Path
22
+ from safetensors.torch import save_file as save_safetensors, load_file as load_safetensors
23
+ from huggingface_hub import HfApi, login
24
+
25
+ # Colab HF login
26
+ try:
27
+ from google.colab import userdata
28
+ token = userdata.get('HF_TOKEN')
29
+ os.environ['HF_TOKEN'] = token
30
+ login(token=token)
31
+ print("Logged in to HuggingFace via Colab")
32
+ except:
33
+ # Not in Colab or token not set
34
+ pass
35
 
36
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
  print(f"Device: {device}")
38
 
39
+ # Enable TF32 for faster computation on Ampere+ GPUs
40
+ torch.backends.cuda.matmul.allow_tf32 = True
41
+ torch.backends.cudnn.allow_tf32 = True
42
+ torch.set_float32_matmul_precision('high')
43
+
44
 
45
  # ============================================================================
46
  # MÖBIUS LENS
 
56
  ):
57
  super().__init__()
58
 
59
+ self.dim = dim
60
+ self.layer_idx = layer_idx
61
+ self.total_layers = total_layers
62
  self.t = layer_idx / max(total_layers - 1, 1)
63
 
64
  scale_span = scale_range[1] - scale_range[0]
 
68
 
69
  self.register_buffer('scales', torch.tensor([scale_low, scale_high]))
70
 
 
71
  self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi))
72
  self.twist_in_proj = nn.Linear(dim, dim, bias=False)
73
  nn.init.orthogonal_(self.twist_in_proj.weight)
74
 
 
75
  self.omega = nn.Parameter(torch.tensor(math.pi))
76
  self.alpha = nn.Parameter(torch.tensor(1.5))
77
 
 
85
  self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4]))
86
  self.xor_weight = nn.Parameter(torch.tensor(0.7))
87
 
88
+ self.gate_norm = nn.LayerNorm(dim)
89
+
90
  self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi))
91
  self.twist_out_proj = nn.Linear(dim, dim, bias=False)
92
  nn.init.orthogonal_(self.twist_out_proj.weight)
 
121
 
122
  gate = w[0] * L + w[1] * M + w[2] * R
123
  gate = gate * (0.5 + 0.5 * lr)
124
+ gate = torch.sigmoid(self.gate_norm(gate))
125
 
126
+ return x * gate
127
 
128
  def _twist_out(self, x: Tensor) -> Tensor:
129
  cos_t = torch.cos(self.twist_out_angle)
 
132
 
133
  def forward(self, x: Tensor) -> Tensor:
134
  return self._twist_out(self._center_lens(self._twist_in(x)))
135
+
136
+ def get_lens_stats(self) -> Dict[str, float]:
137
+ """Return lens parameters for logging."""
138
+ return {
139
+ 'omega': self.omega.item(),
140
+ 'alpha': self.alpha.item(),
141
+ 'twist_in_angle': self.twist_in_angle.item(),
142
+ 'twist_out_angle': self.twist_out_angle.item(),
143
+ 'xor_weight': torch.sigmoid(self.xor_weight).item(),
144
+ 'accum_weights_l': torch.softmax(self.accum_weights, dim=0)[0].item(),
145
+ 'accum_weights_m': torch.softmax(self.accum_weights, dim=0)[1].item(),
146
+ 'accum_weights_r': torch.softmax(self.accum_weights, dim=0)[2].item(),
147
+ }
148
 
149
 
150
  # ============================================================================
 
192
 
193
  rw = torch.sigmoid(self.residual_weight)
194
  return rw * identity + (1 - rw) * h
195
+
196
+ def get_residual_weight(self) -> float:
197
+ return torch.sigmoid(self.residual_weight).item()
198
 
199
 
200
  # ============================================================================
201
+ # MÖBIUS NET
202
  # ============================================================================
203
 
204
  class MobiusNet(nn.Module):
 
 
 
 
 
205
  def __init__(
206
  self,
207
  in_chans: int = 3,
208
+ num_classes: int = 200,
209
+ channels: Tuple[int, ...] = (64, 128, 256, 512),
210
+ depths: Tuple[int, ...] = (2, 2, 2, 2),
211
  scale_range: Tuple[float, float] = (0.5, 2.5),
212
+ use_integrator: bool = True,
213
  ):
214
  super().__init__()
215
 
 
218
 
219
  self.total_layers = total_layers
220
  self.scale_range = scale_range
221
+ self.channels = tuple(channels)
222
+ self.depths = tuple(depths)
223
  self.num_stages = num_stages
224
+ self.use_integrator = use_integrator
225
+ self.num_classes = num_classes
226
+ self.in_chans = in_chans
227
 
 
228
  channels = list(channels)
229
  while len(channels) < num_stages:
230
  channels.append(channels[-1])
231
 
 
232
  self.stem = nn.Sequential(
233
+ nn.Conv2d(in_chans, channels[0], 3, stride=1, padding=1, bias=False),
234
  nn.BatchNorm2d(channels[0]),
235
  )
236
 
 
237
  layer_idx = 0
238
  self.stages = nn.ModuleList()
239
  self.downsamples = nn.ModuleList()
 
241
  for stage_idx in range(num_stages):
242
  ch = channels[stage_idx]
243
 
 
244
  stage = nn.ModuleList()
245
  for _ in range(depths[stage_idx]):
246
+ stage.append(MobiusConvBlock(ch, layer_idx, total_layers, scale_range))
 
 
247
  layer_idx += 1
248
  self.stages.append(stage)
249
 
 
250
  if stage_idx < num_stages - 1:
251
  ch_next = channels[stage_idx + 1]
252
  self.downsamples.append(nn.Sequential(
 
254
  nn.BatchNorm2d(ch_next),
255
  ))
256
 
257
+ final_ch = channels[num_stages - 1]
258
+ if use_integrator:
259
+ self.integrator = nn.Sequential(
260
+ nn.Conv2d(final_ch, final_ch, 3, padding=1, bias=False),
261
+ nn.BatchNorm2d(final_ch),
262
+ nn.GELU(),
263
+ )
264
+ else:
265
+ self.integrator = nn.Identity()
266
+
267
  self.pool = nn.AdaptiveAvgPool2d(1)
268
+ self.head = nn.Linear(final_ch, num_classes)
269
 
270
  def forward(self, x: Tensor) -> Tensor:
271
  x = self.stem(x)
 
276
  if i < len(self.downsamples):
277
  x = self.downsamples[i](x)
278
 
279
+ x = self.integrator(x)
280
  return self.head(self.pool(x).flatten(1))
281
 
282
+ def get_config(self) -> Dict[str, Any]:
283
+ """Return model configuration for saving."""
284
+ return {
285
+ 'in_chans': self.in_chans,
286
+ 'num_classes': self.num_classes,
287
+ 'channels': self.channels,
288
+ 'depths': self.depths,
289
+ 'scale_range': self.scale_range,
290
+ 'use_integrator': self.use_integrator,
291
+ 'total_layers': self.total_layers,
292
+ 'num_stages': self.num_stages,
293
+ }
294
 
295
+ def get_all_lens_stats(self) -> Dict[str, Dict[str, float]]:
296
+ """Return stats from all lenses for logging."""
297
+ stats = {}
 
 
298
  layer_idx = 0
299
+ for stage_idx, stage in enumerate(self.stages):
300
+ for block_idx, block in enumerate(stage):
301
+ key = f"stage{stage_idx}_block{block_idx}"
302
+ stats[key] = block.lens.get_lens_stats()
303
+ stats[key]['residual_weight'] = block.get_residual_weight()
 
 
 
 
 
 
 
304
  layer_idx += 1
305
+ return stats
306
+
307
+
308
+ # ============================================================================
309
+ # TINY IMAGENET DATASET
310
+ # ============================================================================
311
+
312
+ def get_tiny_imagenet_loaders(data_dir='./data/tiny-imagenet-200', batch_size=128):
313
+ train_dir = os.path.join(data_dir, 'train')
314
+ val_dir = os.path.join(data_dir, 'val')
315
+
316
+ val_images_dir = os.path.join(val_dir, 'images')
317
+ if os.path.exists(val_images_dir):
318
+ print("Reorganizing validation folder...")
319
+ reorganize_val_folder(val_dir)
320
+
321
+ train_transform = transforms.Compose([
322
+ transforms.RandomCrop(64, padding=8),
323
+ transforms.RandomHorizontalFlip(),
324
+ transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
325
+ transforms.ToTensor(),
326
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
327
+ ])
328
+
329
+ val_transform = transforms.Compose([
330
+ transforms.ToTensor(),
331
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
332
+ ])
333
+
334
+ train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
335
+ val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)
336
+
337
+ train_loader = DataLoader(
338
+ train_dataset, batch_size=batch_size, shuffle=True,
339
+ num_workers=8, pin_memory=True, persistent_workers=True
340
+ )
341
+ val_loader = DataLoader(
342
+ val_dataset, batch_size=256, shuffle=False,
343
+ num_workers=4, pin_memory=True, persistent_workers=True
344
+ )
345
+
346
+ return train_loader, val_loader
347
+
348
+
349
+ def reorganize_val_folder(val_dir):
350
+ """Reorganize Tiny ImageNet val folder into class subfolders."""
351
+ val_images_dir = os.path.join(val_dir, 'images')
352
+ val_annotations = os.path.join(val_dir, 'val_annotations.txt')
353
+
354
+ if not os.path.exists(val_images_dir):
355
+ return
356
+
357
+ with open(val_annotations, 'r') as f:
358
+ for line in f:
359
+ parts = line.strip().split('\t')
360
+ img_name, class_id = parts[0], parts[1]
361
 
362
+ class_dir = os.path.join(val_dir, class_id)
363
+ os.makedirs(class_dir, exist_ok=True)
364
+
365
+ src = os.path.join(val_images_dir, img_name)
366
+ dst = os.path.join(class_dir, img_name)
367
+
368
+ if os.path.exists(src):
369
+ shutil.move(src, dst)
370
+
371
+ if os.path.exists(val_images_dir):
372
+ shutil.rmtree(val_images_dir)
373
+ if os.path.exists(val_annotations):
374
+ os.remove(val_annotations)
375
+
376
+ print("Validation folder reorganized.")
377
 
378
 
379
  # ============================================================================
 
381
  # ============================================================================
382
 
383
  PRESETS = {
384
+ 'mobius_tiny_s': {
385
+ 'channels': (64, 128, 256),
386
+ 'depths': (2, 2, 2),
387
  'scale_range': (0.5, 2.5),
388
  },
389
+ 'mobius_tiny_m': {
390
+ 'channels': (64, 128, 256, 512, 768),
391
+ 'depths': (2, 2, 4, 2, 2),
392
+ 'scale_range': (0.25, 2.75),
393
  },
394
+ 'mobius_tiny_l': {
395
+ 'channels': (96, 192, 384, 768),
396
+ 'depths': (3, 3, 3, 3),
 
 
 
 
 
397
  'scale_range': (0.5, 3.5),
398
  },
399
+ 'mobius_base': {
400
+ 'channels': (128, 256, 512, 768, 1024),
401
+ 'depths': (2, 2, 2, 2, 2),
402
+ 'scale_range': (0.25, 2.75),
403
  },
404
  }
405
 
406
 
407
+ # ============================================================================
408
+ # CHECKPOINT MANAGER
409
+ # ============================================================================
410
+
411
+ class CheckpointManager:
412
+ def __init__(
413
+ self,
414
+ base_dir: str,
415
+ variant_name: str,
416
+ dataset_name: str,
417
+ hf_repo: str = "AbstractPhil/mobiusnet",
418
+ upload_every_n_epochs: int = 10,
419
+ save_every_n_epochs: int = 10,
420
+ timestamp: Optional[str] = None,
421
+ ):
422
+ self.timestamp = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
423
+ self.variant_name = variant_name
424
+ self.dataset_name = dataset_name
425
+ self.hf_repo = hf_repo
426
+ self.upload_every_n_epochs = upload_every_n_epochs
427
+ self.save_every_n_epochs = save_every_n_epochs
428
+
429
+ # Directory structure
430
+ self.run_name = f"{variant_name}_{dataset_name}"
431
+ self.run_dir = Path(base_dir) / "checkpoints" / self.run_name / self.timestamp
432
+ self.checkpoints_dir = self.run_dir / "checkpoints"
433
+ self.tensorboard_dir = self.run_dir / "tensorboard"
434
+
435
+ # Create directories
436
+ self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
437
+ self.tensorboard_dir.mkdir(parents=True, exist_ok=True)
438
+
439
+ # TensorBoard writer
440
+ self.writer = SummaryWriter(log_dir=str(self.tensorboard_dir))
441
+
442
+ # HuggingFace API
443
+ self.hf_api = HfApi()
444
+ self.uploaded_files = set()
445
+
446
+ # Track best
447
+ self.best_acc = 0.0
448
+ self.best_epoch = 0
449
+ self.best_changed_since_upload = False
450
+
451
+ print(f"Checkpoint directory: {self.run_dir}")
452
+
453
+ @staticmethod
454
+ def extract_timestamp(checkpoint_path: str) -> Optional[str]:
455
+ """Extract timestamp from checkpoint path."""
456
+ # Match YYYYMMDD_HHMMSS pattern
457
+ match = re.search(r'(\d{8}_\d{6})', checkpoint_path)
458
+ if match:
459
+ return match.group(1)
460
+ return None
461
+
462
+ def save_config(self, config: Dict[str, Any], training_config: Dict[str, Any]):
463
+ """Save model and training configuration."""
464
+ full_config = {
465
+ 'model': config,
466
+ 'training': training_config,
467
+ 'timestamp': self.timestamp,
468
+ 'variant_name': self.variant_name,
469
+ 'dataset_name': self.dataset_name,
470
+ }
471
+
472
+ config_path = self.run_dir / "config.json"
473
+ with open(config_path, 'w') as f:
474
+ json.dump(full_config, f, indent=2)
475
+
476
+ return config_path
477
+
478
+ def save_checkpoint(
479
+ self,
480
+ model: nn.Module,
481
+ optimizer: torch.optim.Optimizer,
482
+ scheduler: Any,
483
+ epoch: int,
484
+ train_acc: float,
485
+ val_acc: float,
486
+ train_loss: float,
487
+ is_best: bool = False,
488
+ ):
489
+ """Save checkpoint every N epochs, always save best (overwriting)."""
490
+
491
+ # Unwrap compiled model if necessary
492
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
493
+
494
+ # Checkpoint data
495
+ checkpoint = {
496
+ 'epoch': epoch,
497
+ 'train_acc': train_acc,
498
+ 'val_acc': val_acc,
499
+ 'train_loss': train_loss,
500
+ 'best_acc': self.best_acc,
501
+ 'optimizer_state_dict': optimizer.state_dict(),
502
+ 'scheduler_state_dict': scheduler.state_dict(),
503
+ }
504
+
505
+ # Save epoch checkpoint every N epochs
506
+ if epoch % self.save_every_n_epochs == 0:
507
+ epoch_pt_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt"
508
+ torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, epoch_pt_path)
509
+
510
+ epoch_st_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors"
511
+ save_safetensors(raw_model.state_dict(), str(epoch_st_path))
512
+
513
+ # Save best model (overwrites previous best)
514
+ if is_best:
515
+ self.best_acc = val_acc
516
+ self.best_epoch = epoch
517
+ self.best_changed_since_upload = True
518
+
519
+ # PyTorch best
520
+ best_pt_path = self.checkpoints_dir / "best_model.pt"
521
+ torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, best_pt_path)
522
+
523
+ # SafeTensors best
524
+ best_st_path = self.checkpoints_dir / "best_model.safetensors"
525
+ save_safetensors(raw_model.state_dict(), str(best_st_path))
526
+
527
+ # Save accuracy info
528
+ acc_path = self.run_dir / "best_accuracy.json"
529
+ with open(acc_path, 'w') as f:
530
+ json.dump({
531
+ 'best_acc': val_acc,
532
+ 'best_epoch': epoch,
533
+ 'train_acc': train_acc,
534
+ 'train_loss': train_loss,
535
+ }, f, indent=2)
536
+
537
+ def save_final(self, model: nn.Module, final_acc: float, final_epoch: int):
538
+ """Save final model."""
539
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
540
+
541
+ # SafeTensors final
542
+ final_st_path = self.checkpoints_dir / "final_model.safetensors"
543
+ save_safetensors(raw_model.state_dict(), str(final_st_path))
544
+
545
+ # PyTorch final
546
+ final_pt_path = self.checkpoints_dir / "final_model.pt"
547
+ torch.save({
548
+ 'model_state_dict': raw_model.state_dict(),
549
+ 'final_acc': final_acc,
550
+ 'final_epoch': final_epoch,
551
+ 'best_acc': self.best_acc,
552
+ 'best_epoch': self.best_epoch,
553
+ }, final_pt_path)
554
+
555
+ # Final accuracy info
556
+ acc_path = self.run_dir / "final_accuracy.json"
557
+ with open(acc_path, 'w') as f:
558
+ json.dump({
559
+ 'final_acc': final_acc,
560
+ 'final_epoch': final_epoch,
561
+ 'best_acc': self.best_acc,
562
+ 'best_epoch': self.best_epoch,
563
+ }, f, indent=2)
564
+
565
+ return final_st_path, final_pt_path
566
+
567
+ def log_scalars(self, epoch: int, scalars: Dict[str, float], prefix: str = ""):
568
+ """Log scalars to TensorBoard."""
569
+ for name, value in scalars.items():
570
+ tag = f"{prefix}/{name}" if prefix else name
571
+ self.writer.add_scalar(tag, value, epoch)
572
+
573
+ def log_lens_stats(self, epoch: int, model: nn.Module):
574
+ """Log lens statistics to TensorBoard."""
575
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
576
+ stats = raw_model.get_all_lens_stats()
577
+
578
+ for block_name, block_stats in stats.items():
579
+ for stat_name, value in block_stats.items():
580
+ self.writer.add_scalar(f"lens/{block_name}/{stat_name}", value, epoch)
581
+
582
+ def log_histograms(self, epoch: int, model: nn.Module):
583
+ """Log weight histograms to TensorBoard."""
584
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
585
+
586
+ for name, param in raw_model.named_parameters():
587
+ if param.requires_grad:
588
+ self.writer.add_histogram(f"weights/{name}", param.data, epoch)
589
+ if param.grad is not None:
590
+ self.writer.add_histogram(f"gradients/{name}", param.grad, epoch)
591
+
592
+ def upload_to_hf(self, epoch: int, force: bool = False):
593
+ """Upload checkpoint every N epochs. Best uploads only on upload epochs if changed."""
594
+ if not force and epoch % self.upload_every_n_epochs != 0:
595
+ return
596
+
597
+ try:
598
+ hf_base_path = f"checkpoints/{self.run_name}/{self.timestamp}"
599
+
600
+ files_to_upload = []
601
+
602
+ # Always upload config
603
+ config_path = self.run_dir / "config.json"
604
+ if config_path.exists():
605
+ files_to_upload.append(config_path)
606
+
607
+ # Upload checkpoint if saved this epoch
608
+ if epoch % self.save_every_n_epochs == 0:
609
+ ckpt_st = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors"
610
+ ckpt_pt = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt"
611
+ if ckpt_st.exists():
612
+ files_to_upload.append(ckpt_st)
613
+ if ckpt_pt.exists():
614
+ files_to_upload.append(ckpt_pt)
615
+
616
+ # Upload best if it changed since last upload
617
+ if self.best_changed_since_upload:
618
+ best_files = [
619
+ self.checkpoints_dir / "best_model.safetensors",
620
+ self.checkpoints_dir / "best_model.pt",
621
+ self.run_dir / "best_accuracy.json",
622
+ ]
623
+ for f in best_files:
624
+ if f.exists():
625
+ files_to_upload.append(f)
626
+ self.best_changed_since_upload = False
627
+
628
+ # Upload files
629
+ for local_path in files_to_upload:
630
+ rel_path = local_path.relative_to(self.run_dir)
631
+ hf_path = f"{hf_base_path}/{rel_path}"
632
+
633
+ try:
634
+ self.hf_api.upload_file(
635
+ path_or_fileobj=str(local_path),
636
+ path_in_repo=hf_path,
637
+ repo_id=self.hf_repo,
638
+ repo_type="model",
639
+ )
640
+ print(f"Uploaded: {hf_path}")
641
+ except Exception as e:
642
+ print(f"Failed to upload {rel_path}: {e}")
643
+
644
+ except Exception as e:
645
+ print(f"HuggingFace upload error: {e}")
646
+
647
+ def close(self):
648
+ """Close TensorBoard writer."""
649
+ self.writer.close()
650
+
651
+ @staticmethod
652
+ def load_checkpoint(
653
+ checkpoint_path: str,
654
+ model: nn.Module,
655
+ optimizer: Optional[torch.optim.Optimizer] = None,
656
+ scheduler: Optional[Any] = None,
657
+ hf_repo: str = "AbstractPhil/mobiusnet",
658
+ device: torch.device = torch.device('cpu'),
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Load checkpoint from local path or HuggingFace repo.
662
+
663
+ Args:
664
+ checkpoint_path: Either:
665
+ - Local file path to .pt checkpoint
666
+ - Local directory containing checkpoints
667
+ - HuggingFace path like "checkpoints/variant_dataset/timestamp"
668
+ model: Model to load weights into
669
+ optimizer: Optional optimizer to restore state
670
+ scheduler: Optional scheduler to restore state
671
+ hf_repo: HuggingFace repo ID
672
+ device: Device to load tensors to
673
+
674
+ Returns:
675
+ Dict with checkpoint info (epoch, best_acc, etc.)
676
+ """
677
+ from huggingface_hub import hf_hub_download, list_repo_files
678
+
679
+ checkpoint_file = None
680
+
681
+ # Check if it's a local file
682
+ if os.path.isfile(checkpoint_path):
683
+ checkpoint_file = checkpoint_path
684
+
685
+ # Check if it's a local directory
686
+ elif os.path.isdir(checkpoint_path):
687
+ # Look for best_model.pt or latest checkpoint
688
+ best_path = os.path.join(checkpoint_path, "checkpoints", "best_model.pt")
689
+ if os.path.exists(best_path):
690
+ checkpoint_file = best_path
691
+ else:
692
+ # Find latest epoch checkpoint
693
+ ckpt_dir = os.path.join(checkpoint_path, "checkpoints")
694
+ if os.path.isdir(ckpt_dir):
695
+ pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.startswith("checkpoint_epoch_") and f.endswith(".pt")])
696
+ if pt_files:
697
+ checkpoint_file = os.path.join(ckpt_dir, pt_files[-1])
698
+
699
+ # Try HuggingFace download
700
+ if checkpoint_file is None:
701
+ print(f"Attempting to download from HuggingFace: {hf_repo}/{checkpoint_path}")
702
+ try:
703
+ # If checkpoint_path is a directory path in the repo
704
+ if not checkpoint_path.endswith(".pt"):
705
+ # Try to download best_model.pt
706
+ try:
707
+ checkpoint_file = hf_hub_download(
708
+ repo_id=hf_repo,
709
+ filename=f"{checkpoint_path}/checkpoints/best_model.pt",
710
+ repo_type="model",
711
+ )
712
+ print(f"Downloaded best_model.pt from {hf_repo}")
713
+ except:
714
+ # List files and find latest checkpoint
715
+ files = list_repo_files(repo_id=hf_repo, repo_type="model")
716
+ ckpt_files = sorted([f for f in files if checkpoint_path in f and f.endswith(".pt") and "checkpoint_epoch_" in f])
717
+ if ckpt_files:
718
+ checkpoint_file = hf_hub_download(
719
+ repo_id=hf_repo,
720
+ filename=ckpt_files[-1],
721
+ repo_type="model",
722
+ )
723
+ print(f"Downloaded {ckpt_files[-1]} from {hf_repo}")
724
+ else:
725
+ # Direct file path
726
+ checkpoint_file = hf_hub_download(
727
+ repo_id=hf_repo,
728
+ filename=checkpoint_path,
729
+ repo_type="model",
730
+ )
731
+ print(f"Downloaded {checkpoint_path} from {hf_repo}")
732
+ except Exception as e:
733
+ raise FileNotFoundError(f"Could not find or download checkpoint: {checkpoint_path}. Error: {e}")
734
+
735
+ if checkpoint_file is None:
736
+ raise FileNotFoundError(f"Could not find checkpoint: {checkpoint_path}")
737
+
738
+ print(f"Loading checkpoint from: {checkpoint_file}")
739
+ checkpoint = torch.load(checkpoint_file, map_location=device, weights_only=False)
740
+
741
+ # Load model weights
742
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
743
+ raw_model.load_state_dict(checkpoint['model_state_dict'])
744
+ print(f"Loaded model weights")
745
+
746
+ # Load optimizer state
747
+ if optimizer is not None and 'optimizer_state_dict' in checkpoint:
748
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
749
+ print(f"Loaded optimizer state")
750
+
751
+ # Load scheduler state
752
+ if scheduler is not None and 'scheduler_state_dict' in checkpoint:
753
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
754
+ print(f"Loaded scheduler state")
755
+
756
+ info = {
757
+ 'epoch': checkpoint.get('epoch', 0),
758
+ 'best_acc': checkpoint.get('best_acc', 0.0),
759
+ 'train_acc': checkpoint.get('train_acc', 0.0),
760
+ 'val_acc': checkpoint.get('val_acc', 0.0),
761
+ 'train_loss': checkpoint.get('train_loss', 0.0),
762
+ }
763
+
764
+ print(f"Resuming from epoch {info['epoch']} (best_acc: {info['best_acc']:.4f})")
765
+
766
+ return info
767
+
768
+
769
  # ============================================================================
770
  # TRAINING
771
  # ============================================================================
772
 
773
+ def train_tiny_imagenet(
774
+ preset: str = 'mobius_tiny_m',
775
  epochs: int = 100,
776
  lr: float = 1e-3,
777
  batch_size: int = 128,
778
+ use_integrator: bool = True,
779
+ data_dir: str = './data/tiny-imagenet-200',
780
+ output_dir: str = './outputs',
781
+ hf_repo: str = "AbstractPhil/mobiusnet",
782
+ save_every_n_epochs: int = 10,
783
+ upload_every_n_epochs: int = 10,
784
+ log_histograms_every: int = 10,
785
+ use_compile: bool = True,
786
+ continue_from: Optional[str] = None,
787
  ):
788
+ """
789
+ Train MobiusNet on Tiny ImageNet.
790
+
791
+ Args:
792
+ preset: Model preset name
793
+ epochs: Total epochs to train
794
+ lr: Learning rate
795
+ batch_size: Batch size
796
+ use_integrator: Whether to use integrator layer
797
+ data_dir: Path to Tiny ImageNet data
798
+ output_dir: Output directory for checkpoints
799
+ hf_repo: HuggingFace repo for uploads/downloads
800
+ save_every_n_epochs: Save checkpoint every N epochs
801
+ upload_every_n_epochs: Upload to HF every N epochs
802
+ log_histograms_every: Log weight histograms every N epochs
803
+ use_compile: Whether to use torch.compile
804
+ continue_from: Resume from checkpoint. Can be:
805
+ - Local .pt file path
806
+ - Local checkpoint directory
807
+ - HuggingFace path (e.g., "checkpoints/mobius_base_tiny_imagenet/20240101_120000")
808
+ """
809
  config = PRESETS[preset]
810
+ dataset_name = "tiny_imagenet"
811
 
812
  print("=" * 70)
813
+ print(f"MÖBIUS NET - {preset.upper()} - TINY IMAGENET")
814
  print("=" * 70)
815
  print(f"Device: {device}")
816
  print(f"Channels: {config['channels']}")
817
  print(f"Depths: {config['depths']}")
818
  print(f"Scale range: {config['scale_range']}")
819
+ print(f"Integrator: {use_integrator}")
820
+ if continue_from:
821
+ print(f"Continuing from: {continue_from}")
822
  print()
823
 
824
+ # Extract timestamp from checkpoint path if continuing
825
+ resume_timestamp = None
826
+ if continue_from:
827
+ resume_timestamp = CheckpointManager.extract_timestamp(continue_from)
828
+ if resume_timestamp:
829
+ print(f"Using original timestamp: {resume_timestamp}")
830
 
831
+ # Initialize checkpoint manager
832
+ ckpt_manager = CheckpointManager(
833
+ base_dir=output_dir,
834
+ variant_name=preset,
835
+ dataset_name=dataset_name,
836
+ hf_repo=hf_repo,
837
+ upload_every_n_epochs=upload_every_n_epochs,
838
+ save_every_n_epochs=save_every_n_epochs,
839
+ timestamp=resume_timestamp,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840
  )
841
 
842
+ # Data
843
+ train_loader, val_loader = get_tiny_imagenet_loaders(data_dir, batch_size)
844
+
845
+ # Model
846
  model = MobiusNet(
847
  in_chans=3,
848
+ num_classes=200,
849
+ use_integrator=use_integrator,
850
  **config
851
  ).to(device)
852
 
 
 
 
 
 
 
 
853
  total_params = sum(p.numel() for p in model.parameters())
854
  print(f"Total params: {total_params:,}")
855
  print()
856
 
857
+ # Save config
858
+ training_config = {
859
+ 'epochs': epochs,
860
+ 'lr': lr,
861
+ 'batch_size': batch_size,
862
+ 'optimizer': 'AdamW',
863
+ 'weight_decay': 0.05,
864
+ 'scheduler': 'CosineAnnealingLR',
865
+ 'total_params': total_params,
866
+ }
867
+ ckpt_manager.save_config(model.get_config(), training_config)
868
+
869
+ # Compile model
870
+ if use_compile:
871
+ model = torch.compile(model, mode='reduce-overhead')
872
+
873
+ # Optimizer and scheduler
874
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
875
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
876
 
877
+ # Load checkpoint if continuing
878
+ start_epoch = 1
879
  best_acc = 0.0
880
 
881
+ if continue_from:
882
+ ckpt_info = CheckpointManager.load_checkpoint(
883
+ checkpoint_path=continue_from,
884
+ model=model,
885
+ optimizer=optimizer,
886
+ scheduler=scheduler,
887
+ hf_repo=hf_repo,
888
+ device=device,
889
+ )
890
+ start_epoch = ckpt_info['epoch'] + 1
891
+ best_acc = ckpt_info['best_acc']
892
+ ckpt_manager.best_acc = best_acc
893
+ ckpt_manager.best_epoch = ckpt_info['epoch']
894
+ print(f"Resuming training from epoch {start_epoch}")
895
+
896
+ for epoch in range(start_epoch, epochs + 1):
897
+ # Training
898
  model.train()
899
  train_loss, train_correct, train_total = 0, 0, 0
900
 
 
917
 
918
  scheduler.step()
919
 
920
+ # Validation
921
  model.eval()
922
  val_correct, val_total = 0, 0
923
  with torch.no_grad():
924
+ for x, y in val_loader:
925
  x, y = x.to(device), y.to(device)
926
  logits = model(x)
927
  val_correct += (logits.argmax(1) == y).sum().item()
928
  val_total += x.size(0)
929
 
930
+ # Metrics
931
  train_acc = train_correct / train_total
932
  val_acc = val_correct / val_total
933
+ avg_loss = train_loss / train_total
934
+ current_lr = scheduler.get_last_lr()[0]
935
 
936
+ is_best = val_acc > best_acc
937
+ if is_best:
938
+ best_acc = val_acc
939
+
940
+ marker = " ★" if is_best else ""
941
+ print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
942
  f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}")
943
+
944
+ # TensorBoard logging
945
+ ckpt_manager.log_scalars(epoch, {
946
+ 'loss': avg_loss,
947
+ 'train_acc': train_acc,
948
+ 'val_acc': val_acc,
949
+ 'best_acc': best_acc,
950
+ 'learning_rate': current_lr,
951
+ }, prefix="train")
952
+
953
+ # Log lens stats
954
+ ckpt_manager.log_lens_stats(epoch, model)
955
+
956
+ # Log histograms periodically
957
+ if epoch % log_histograms_every == 0:
958
+ ckpt_manager.log_histograms(epoch, model)
959
+
960
+ # Save checkpoint
961
+ ckpt_manager.save_checkpoint(
962
+ model=model,
963
+ optimizer=optimizer,
964
+ scheduler=scheduler,
965
+ epoch=epoch,
966
+ train_acc=train_acc,
967
+ val_acc=val_acc,
968
+ train_loss=avg_loss,
969
+ is_best=is_best,
970
+ )
971
+
972
+ # Upload to HuggingFace (handles both checkpoint and best)
973
+ ckpt_manager.upload_to_hf(epoch)
974
+
975
+ # Save final model
976
+ ckpt_manager.save_final(model, val_acc, epochs)
977
+
978
+ # Final upload
979
+ ckpt_manager.upload_to_hf(epochs, force=True)
980
+ ckpt_manager.close()
981
 
982
  print()
983
  print("=" * 70)
984
  print("FINAL RESULTS")
985
  print("=" * 70)
986
+ print(f"Preset: {preset}")
987
  print(f"Best accuracy: {best_acc:.4f}")
988
  print(f"Total params: {total_params:,}")
989
+ print(f"Checkpoints: {ckpt_manager.run_dir}")
990
  print("=" * 70)
991
 
992
  return model, best_acc
 
997
  # ============================================================================
998
 
999
  if __name__ == '__main__':
1000
+ model, best_acc = train_tiny_imagenet(
1001
+ preset='mobius_base',
1002
+ epochs=200,
1003
+ lr=3e-4,
1004
+ batch_size=128,
1005
+ use_integrator=True,
1006
+ data_dir='./data/tiny-imagenet-200',
1007
+ output_dir='./outputs',
1008
+ hf_repo='AbstractPhil/mobiusnet',
1009
+ save_every_n_epochs=10,
1010
+ upload_every_n_epochs=10,
1011
+ log_histograms_every=10,
1012
+ use_compile=True,
1013
+ continue_from='/content/outputs/checkpoints/mobius_base_tiny_imagenet/20260110_132436/checkpoints/best_model.pt', # Set to path or HF checkpoint to resume
1014
+ # Examples:
1015
+ # continue_from="./outputs/checkpoints/mobius_base_tiny_imagenet/20240101_120000"
1016
+ # continue_from="./outputs/checkpoints/mobius_base_tiny_imagenet/20240101_120000/checkpoints/best_model.pt"
1017
+ # continue_from="checkpoints/mobius_base_tiny_imagenet/20240101_120000" # downloads from HF
1018
  )