asdf98 commited on
Commit
d0236fe
·
verified ·
1 Parent(s): 67a401e

Add ETA to every training log line + epoch summary

Browse files
Files changed (1) hide show
  1. train.py +28 -17
train.py CHANGED
@@ -6,6 +6,7 @@ Optimized for Colab free tier:
6
  - Auto-limits large datasets (WikiArt capped at 10K by default)
7
  - Latent pre-caching: train on pure tensors, no VAE during training
8
  - Gradient checkpointing + auto batch size = no OOM
 
9
  - All datasets pure parquet, open SDXL VAE (no login)
10
  """
11
 
@@ -28,7 +29,7 @@ DATASET_PRESETS = {
28
  "image_column": "image",
29
  "label_column": "",
30
  "num_classes": 0,
31
- "max_default": 0, # 0 = use all (~2.5K, small enough)
32
  "description": "~2.5K cartoon/anime, unconditional, 181MB — fast",
33
  },
34
  "flowers": {
@@ -46,7 +47,7 @@ DATASET_PRESETS = {
46
  "image_column": "image",
47
  "label_column": "style",
48
  "num_classes": 0,
49
- "max_default": 10000, # Auto-cap: 105K is too many for Colab encoding
50
  "description": "~105K paintings with styles (auto-capped to 10K for speed)",
51
  },
52
  "art_painting": {
@@ -62,7 +63,6 @@ DATASET_PRESETS = {
62
 
63
 
64
  def auto_batch_size(model_size, image_size, gpu_mem_gb):
65
- """Safe batch size for model + resolution + GPU."""
66
  param_mem = {"small": 0.66, "base": 1.68, "large": 3.35}
67
  base = param_mem.get(model_size, 1.0)
68
  act_per_sample = {"small": {256: 0.02, 512: 0.07},
@@ -78,6 +78,13 @@ def auto_batch_size(model_size, image_size, gpu_mem_gb):
78
  return max(1, bs)
79
 
80
 
 
 
 
 
 
 
 
81
  @dataclass
82
  class TrainConfig:
83
  model_size: str = "small"
@@ -85,11 +92,11 @@ class TrainConfig:
85
  class_drop_prob: float = 0.1
86
  dataset_preset: str = "cartoon"
87
  image_size: int = 256
88
- max_images: int = 0 # 0 = use dataset's default cap
89
  vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
90
  vae_scaling_factor: float = 0.13025
91
  latent_channels: int = 4
92
- batch_size: int = 0 # 0 = auto
93
  gradient_accumulation_steps: int = 1
94
  learning_rate: float = 1e-4
95
  weight_decay: float = 0.01
@@ -175,7 +182,6 @@ def precache_latents(config, cache_path=None):
175
  transforms.CenterCrop(config.image_size), transforms.ToTensor(),
176
  ])
177
 
178
- # Determine max images: user override > dataset default > all
179
  if config.max_images > 0:
180
  max_imgs = config.max_images
181
  elif preset.get("max_default", 0) > 0:
@@ -184,10 +190,9 @@ def precache_latents(config, cache_path=None):
184
  else:
185
  max_imgs = len(dataset)
186
  max_imgs = min(max_imgs, len(dataset))
187
- print(f" Encoding {max_imgs} of {len(dataset)} images")
188
 
189
- # VAE encode batch size: bigger = faster. 64 for 256px, 32 for 512px
190
  encode_bs = 64 if config.image_size <= 256 else 32
 
191
 
192
  img_col, lbl_col = preset["image_column"], preset["label_column"]
193
  style_to_id = {}
@@ -220,7 +225,7 @@ def precache_latents(config, cache_path=None):
220
  speed = count / elapsed
221
  eta = (max_imgs - count) / speed if speed > 0 else 0
222
  if count % (encode_bs * 4) == 0:
223
- print(f" {count}/{max_imgs} ({speed:.0f} img/s, ~{eta:.0f}s left)")
224
 
225
  if batch_px:
226
  with torch.no_grad():
@@ -237,8 +242,7 @@ def precache_latents(config, cache_path=None):
237
  print(f" {len(style_to_id)} style classes")
238
  torch.save(save_data, cache_path)
239
  mb = os.path.getsize(cache_path) / 1024**2
240
- elapsed = time.time() - t0
241
- print(f"Cached {count} latents -> {cache_path} ({mb:.0f}MB, {elapsed:.0f}s)")
242
  del vae
243
  if torch.cuda.is_available(): torch.cuda.empty_cache()
244
  return cache_path
@@ -330,7 +334,7 @@ def train(config):
330
  scaler = GradScaler("cuda", enabled=config.mixed_precision and torch.cuda.is_available())
331
  fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
332
  lat_size = config.image_size // 8
333
- print(f"Steps: {total_steps}, Batch: {config.batch_size}")
334
 
335
  gs = 0; la = 0.0; vae = None; vae_loaded = False
336
  print(f"\nTraining!\n")
@@ -355,10 +359,14 @@ def train(config):
355
  ema.update(model); gs += 1
356
  if gs % config.log_every_n_steps == 0:
357
  al = la / config.log_every_n_steps
 
 
 
358
  vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0
359
- print(f"step={gs:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | "
360
- f"lr={opt.param_groups[0]['lr']:.2e} | vram={vram:.1f}G | "
361
- f"{gs/max(time.time()-t_start,1):.1f} st/s")
 
362
  la = 0.0
363
  if math.isnan(al) or al > 50: print("Diverged!"); return
364
  if gs % config.sample_every_n_steps == 0:
@@ -380,8 +388,11 @@ def train(config):
380
  torch.save({"model": model.state_dict(), "ema": ema.shadow,
381
  "optimizer": opt.state_dict(), "step": gs, "model_config": mcfg},
382
  f"{config.output_dir}/checkpoints/step_{gs:07d}.pt")
383
- print(f"Epoch {epoch} | {time.time()-et:.0f}s\n")
 
 
384
 
385
  final = f"{config.output_dir}/checkpoints/final.pt"
386
  torch.save({"model": model.state_dict(), "ema": ema.shadow, "model_config": mcfg, "step": gs}, final)
387
- print(f"\nDone! {gs} steps, {(time.time()-t_start)/60:.1f}min -> {final}")
 
 
6
  - Auto-limits large datasets (WikiArt capped at 10K by default)
7
  - Latent pre-caching: train on pure tensors, no VAE during training
8
  - Gradient checkpointing + auto batch size = no OOM
9
+ - ETA shown on every log line
10
  - All datasets pure parquet, open SDXL VAE (no login)
11
  """
12
 
 
29
  "image_column": "image",
30
  "label_column": "",
31
  "num_classes": 0,
32
+ "max_default": 0,
33
  "description": "~2.5K cartoon/anime, unconditional, 181MB — fast",
34
  },
35
  "flowers": {
 
47
  "image_column": "image",
48
  "label_column": "style",
49
  "num_classes": 0,
50
+ "max_default": 10000,
51
  "description": "~105K paintings with styles (auto-capped to 10K for speed)",
52
  },
53
  "art_painting": {
 
63
 
64
 
65
  def auto_batch_size(model_size, image_size, gpu_mem_gb):
 
66
  param_mem = {"small": 0.66, "base": 1.68, "large": 3.35}
67
  base = param_mem.get(model_size, 1.0)
68
  act_per_sample = {"small": {256: 0.02, 512: 0.07},
 
78
  return max(1, bs)
79
 
80
 
81
+ def _fmt_time(seconds):
82
+ """Format seconds into human readable string."""
83
+ if seconds < 60: return f"{seconds:.0f}s"
84
+ if seconds < 3600: return f"{seconds/60:.1f}m"
85
+ return f"{seconds/3600:.1f}h"
86
+
87
+
88
  @dataclass
89
  class TrainConfig:
90
  model_size: str = "small"
 
92
  class_drop_prob: float = 0.1
93
  dataset_preset: str = "cartoon"
94
  image_size: int = 256
95
+ max_images: int = 0
96
  vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
97
  vae_scaling_factor: float = 0.13025
98
  latent_channels: int = 4
99
+ batch_size: int = 0
100
  gradient_accumulation_steps: int = 1
101
  learning_rate: float = 1e-4
102
  weight_decay: float = 0.01
 
182
  transforms.CenterCrop(config.image_size), transforms.ToTensor(),
183
  ])
184
 
 
185
  if config.max_images > 0:
186
  max_imgs = config.max_images
187
  elif preset.get("max_default", 0) > 0:
 
190
  else:
191
  max_imgs = len(dataset)
192
  max_imgs = min(max_imgs, len(dataset))
 
193
 
 
194
  encode_bs = 64 if config.image_size <= 256 else 32
195
+ print(f" Encoding {max_imgs} images (batch={encode_bs})...")
196
 
197
  img_col, lbl_col = preset["image_column"], preset["label_column"]
198
  style_to_id = {}
 
225
  speed = count / elapsed
226
  eta = (max_imgs - count) / speed if speed > 0 else 0
227
  if count % (encode_bs * 4) == 0:
228
+ print(f" {count}/{max_imgs} | {speed:.0f} img/s | ETA {_fmt_time(eta)}")
229
 
230
  if batch_px:
231
  with torch.no_grad():
 
242
  print(f" {len(style_to_id)} style classes")
243
  torch.save(save_data, cache_path)
244
  mb = os.path.getsize(cache_path) / 1024**2
245
+ print(f"Cached {count} latents -> {cache_path} ({mb:.0f}MB, {_fmt_time(time.time()-t0)})")
 
246
  del vae
247
  if torch.cuda.is_available(): torch.cuda.empty_cache()
248
  return cache_path
 
334
  scaler = GradScaler("cuda", enabled=config.mixed_precision and torch.cuda.is_available())
335
  fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
336
  lat_size = config.image_size // 8
337
+ print(f"Steps: {total_steps} | Batch: {config.batch_size} | Epochs: {config.num_epochs}")
338
 
339
  gs = 0; la = 0.0; vae = None; vae_loaded = False
340
  print(f"\nTraining!\n")
 
359
  ema.update(model); gs += 1
360
  if gs % config.log_every_n_steps == 0:
361
  al = la / config.log_every_n_steps
362
+ elapsed = time.time() - t_start
363
+ sps = gs / max(elapsed, 1)
364
+ remaining = (total_steps - gs) / sps if sps > 0 else 0
365
  vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0
366
+ pct = gs / total_steps * 100
367
+ print(f"step={gs:>6d}/{total_steps} ({pct:.0f}%) | ep={epoch} | "
368
+ f"loss={al:.4f} | gn={gn:.2f} | lr={opt.param_groups[0]['lr']:.2e} | "
369
+ f"vram={vram:.1f}G | {sps:.1f} st/s | ETA {_fmt_time(remaining)}")
370
  la = 0.0
371
  if math.isnan(al) or al > 50: print("Diverged!"); return
372
  if gs % config.sample_every_n_steps == 0:
 
388
  torch.save({"model": model.state_dict(), "ema": ema.shadow,
389
  "optimizer": opt.state_dict(), "step": gs, "model_config": mcfg},
390
  f"{config.output_dir}/checkpoints/step_{gs:07d}.pt")
391
+ ep_time = time.time() - et
392
+ ep_eta = ep_time * (config.num_epochs - epoch - 1)
393
+ print(f"Epoch {epoch}/{config.num_epochs} done | {_fmt_time(ep_time)} | ETA {_fmt_time(ep_eta)}\n")
394
 
395
  final = f"{config.output_dir}/checkpoints/final.pt"
396
  torch.save({"model": model.state_dict(), "ema": ema.shadow, "model_config": mcfg, "step": gs}, final)
397
+ total_time = time.time() - t_start
398
+ print(f"\nDone! {gs} steps in {_fmt_time(total_time)} -> {final}")