dreamlessx commited on
Commit
5b7e166
·
verified ·
1 Parent(s): a76de72

Update landmarkdiff/config.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/config.py +51 -15
landmarkdiff/config.py CHANGED
@@ -18,7 +18,6 @@ Usage:
18
 
19
  from __future__ import annotations
20
 
21
- import os
22
  from dataclasses import asdict, dataclass, field
23
  from pathlib import Path
24
  from typing import Any
@@ -48,8 +47,9 @@ class TrainingConfig:
48
  gradient_accumulation_steps: int = 4
49
  max_train_steps: int = 50000
50
  warmup_steps: int = 500
51
- mixed_precision: str = "fp16"
52
  seed: int = 42
 
53
 
54
  # Optimizer
55
  optimizer: str = "adamw" # "adamw", "adam8bit", "prodigy"
@@ -62,15 +62,23 @@ class TrainingConfig:
62
  lr_scheduler: str = "cosine"
63
  lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
64
 
 
 
 
 
65
  # Phase B specific
66
  identity_loss_weight: float = 0.1
67
  perceptual_loss_weight: float = 0.05
68
  use_differentiable_arcface: bool = False
69
  arcface_weights_path: str | None = None
70
 
 
 
 
71
  # Checkpointing
72
  save_every_n_steps: int = 5000
73
  resume_from_checkpoint: str | None = None
 
74
 
75
  # Validation
76
  validate_every_n_steps: int = 2500
@@ -81,9 +89,9 @@ class TrainingConfig:
81
  class DataConfig:
82
  """Dataset configuration."""
83
 
84
- train_dir: str = "data/training"
85
- val_dir: str = "data/validation"
86
- test_dir: str = "data/test"
87
  image_size: int = 512
88
  num_workers: int = 4
89
  pin_memory: bool = True
@@ -92,6 +100,8 @@ class DataConfig:
92
  random_flip: bool = True
93
  random_rotation: float = 5.0 # degrees
94
  color_jitter: float = 0.1
 
 
95
 
96
  # Procedure filtering
97
  procedures: list[str] = field(
@@ -100,6 +110,8 @@ class DataConfig:
100
  "blepharoplasty",
101
  "rhytidectomy",
102
  "orthognathic",
 
 
103
  ]
104
  )
105
  intensity_range: tuple[float, float] = (30.0, 100.0)
@@ -128,7 +140,7 @@ class InferenceConfig:
128
 
129
  # Identity verification
130
  verify_identity: bool = True
131
- identity_threshold: float = 0.6
132
 
133
 
134
  @dataclass
@@ -154,6 +166,7 @@ class WandbConfig:
154
  entity: str | None = None
155
  run_name: str | None = None
156
  tags: list[str] = field(default_factory=list)
 
157
 
158
 
159
  @dataclass
@@ -161,7 +174,7 @@ class SlurmConfig:
161
  """SLURM job submission parameters."""
162
 
163
  partition: str = "batch_gpu"
164
- account: str = os.environ.get("SLURM_ACCOUNT", "default_gpu")
165
  gpu_type: str = "nvidia_rtx_a6000"
166
  num_gpus: int = 1
167
  mem: str = "48G"
@@ -174,7 +187,7 @@ class SlurmConfig:
174
  class SafetyConfig:
175
  """Clinical safety and responsible AI parameters."""
176
 
177
- identity_threshold: float = 0.6
178
  max_displacement_fraction: float = 0.05
179
  watermark_enabled: bool = True
180
  watermark_text: str = "AI-GENERATED PREDICTION"
@@ -190,7 +203,7 @@ class ExperimentConfig:
190
 
191
  experiment_name: str = "default"
192
  description: str = ""
193
- version: str = "0.3.0"
194
 
195
  model: ModelConfig = field(default_factory=ModelConfig)
196
  training: TrainingConfig = field(default_factory=TrainingConfig)
@@ -217,7 +230,7 @@ class ExperimentConfig:
217
  return cls(
218
  experiment_name=raw.get("experiment_name", "default"),
219
  description=raw.get("description", ""),
220
- version=raw.get("version", "0.3.0"),
221
  model=_from_dict(ModelConfig, raw.get("model", {})),
222
  training=_from_dict(TrainingConfig, raw.get("training", {})),
223
  data=_from_dict(DataConfig, raw.get("data", {})),
@@ -242,20 +255,41 @@ class ExperimentConfig:
242
  return asdict(self)
243
 
244
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  def _from_dict(cls: type, d: dict) -> Any:
246
- """Create a dataclass from a dict, ignoring unknown keys."""
 
 
 
 
247
  import dataclasses
248
 
249
  field_map = {f.name: f for f in dataclasses.fields(cls)}
250
  filtered = {}
251
  for k, v in d.items():
252
- if k not in field_map:
 
 
 
 
 
253
  continue
254
  # Convert lists back to tuples where the field type is tuple
255
- f = field_map[k]
256
  if isinstance(v, list) and "tuple" in str(f.type):
257
  v = tuple(v)
258
- filtered[k] = v
259
  return cls(**filtered)
260
 
261
 
@@ -288,12 +322,14 @@ def load_config(
288
  for key, value in overrides.items():
289
  parts = key.split(".")
290
  obj = config
 
291
  for part in parts[:-1]:
292
  if hasattr(obj, part):
293
  obj = getattr(obj, part)
294
  else:
 
295
  break
296
- if hasattr(obj, parts[-1]):
297
  setattr(obj, parts[-1], value)
298
 
299
  return config
 
18
 
19
  from __future__ import annotations
20
 
 
21
  from dataclasses import asdict, dataclass, field
22
  from pathlib import Path
23
  from typing import Any
 
47
  gradient_accumulation_steps: int = 4
48
  max_train_steps: int = 50000
49
  warmup_steps: int = 500
50
+ mixed_precision: str = "bf16"
51
  seed: int = 42
52
+ ema_decay: float = 0.9999
53
 
54
  # Optimizer
55
  optimizer: str = "adamw" # "adamw", "adam8bit", "prodigy"
 
62
  lr_scheduler: str = "cosine"
63
  lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
64
 
65
+ # Logging intervals
66
+ log_every: int = 100
67
+ sample_every: int = 1000
68
+
69
  # Phase B specific
70
  identity_loss_weight: float = 0.1
71
  perceptual_loss_weight: float = 0.05
72
  use_differentiable_arcface: bool = False
73
  arcface_weights_path: str | None = None
74
 
75
+ # Loss weights (alternative to individual weights)
76
+ loss_weights: dict[str, float] = field(default_factory=dict)
77
+
78
  # Checkpointing
79
  save_every_n_steps: int = 5000
80
  resume_from_checkpoint: str | None = None
81
+ resume_phase_a: str | None = None
82
 
83
  # Validation
84
  validate_every_n_steps: int = 2500
 
89
  class DataConfig:
90
  """Dataset configuration."""
91
 
92
+ train_dir: str = "data/training_combined"
93
+ val_dir: str = "data/splits/val"
94
+ test_dir: str = "data/splits/test"
95
  image_size: int = 512
96
  num_workers: int = 4
97
  pin_memory: bool = True
 
100
  random_flip: bool = True
101
  random_rotation: float = 5.0 # degrees
102
  color_jitter: float = 0.1
103
+ clinical_augment: bool = False
104
+ geometric_augment: bool = True
105
 
106
  # Procedure filtering
107
  procedures: list[str] = field(
 
110
  "blepharoplasty",
111
  "rhytidectomy",
112
  "orthognathic",
113
+ "brow_lift",
114
+ "mentoplasty",
115
  ]
116
  )
117
  intensity_range: tuple[float, float] = (30.0, 100.0)
 
140
 
141
  # Identity verification
142
  verify_identity: bool = True
143
+ identity_threshold: float = 0.5
144
 
145
 
146
  @dataclass
 
166
  entity: str | None = None
167
  run_name: str | None = None
168
  tags: list[str] = field(default_factory=list)
169
+ mode: str = "online" # "online", "offline", "disabled"
170
 
171
 
172
  @dataclass
 
174
  """SLURM job submission parameters."""
175
 
176
  partition: str = "batch_gpu"
177
+ account: str = "" # Set via YAML or SLURM_ACCOUNT env var
178
  gpu_type: str = "nvidia_rtx_a6000"
179
  num_gpus: int = 1
180
  mem: str = "48G"
 
187
  class SafetyConfig:
188
  """Clinical safety and responsible AI parameters."""
189
 
190
+ identity_threshold: float = 0.5
191
  max_displacement_fraction: float = 0.05
192
  watermark_enabled: bool = True
193
  watermark_text: str = "AI-GENERATED PREDICTION"
 
203
 
204
  experiment_name: str = "default"
205
  description: str = ""
206
+ version: str = "0.3.2"
207
 
208
  model: ModelConfig = field(default_factory=ModelConfig)
209
  training: TrainingConfig = field(default_factory=TrainingConfig)
 
230
  return cls(
231
  experiment_name=raw.get("experiment_name", "default"),
232
  description=raw.get("description", ""),
233
+ version=raw.get("version", "0.3.2"),
234
  model=_from_dict(ModelConfig, raw.get("model", {})),
235
  training=_from_dict(TrainingConfig, raw.get("training", {})),
236
  data=_from_dict(DataConfig, raw.get("data", {})),
 
255
  return asdict(self)
256
 
257
 
258
+ _FIELD_ALIASES: dict[str, str] = {
259
+ # YAML name -> dataclass field name
260
+ "max_steps": "max_train_steps",
261
+ "save_interval": "save_every_n_steps",
262
+ "sample_interval": "sample_every",
263
+ "log_interval": "log_every",
264
+ "adam_weight_decay": "weight_decay",
265
+ "lr_warmup_steps": "warmup_steps",
266
+ "resume_from": "resume_from_checkpoint",
267
+ }
268
+
269
+
270
  def _from_dict(cls: type, d: dict) -> Any:
271
+ """Create a dataclass from a dict, ignoring unknown keys.
272
+
273
+ Supports field aliases so YAML configs using train_controlnet.py-style
274
+ names (e.g. max_steps) map to dataclass fields (max_train_steps).
275
+ """
276
  import dataclasses
277
 
278
  field_map = {f.name: f for f in dataclasses.fields(cls)}
279
  filtered = {}
280
  for k, v in d.items():
281
+ # Resolve aliases
282
+ canonical = _FIELD_ALIASES.get(k, k)
283
+ if canonical not in field_map:
284
+ continue
285
+ # Don't overwrite if the canonical name was already set explicitly
286
+ if canonical in filtered:
287
  continue
288
  # Convert lists back to tuples where the field type is tuple
289
+ f = field_map[canonical]
290
  if isinstance(v, list) and "tuple" in str(f.type):
291
  v = tuple(v)
292
+ filtered[canonical] = v
293
  return cls(**filtered)
294
 
295
 
 
322
  for key, value in overrides.items():
323
  parts = key.split(".")
324
  obj = config
325
+ resolved = True
326
  for part in parts[:-1]:
327
  if hasattr(obj, part):
328
  obj = getattr(obj, part)
329
  else:
330
+ resolved = False
331
  break
332
+ if resolved and hasattr(obj, parts[-1]):
333
  setattr(obj, parts[-1], value)
334
 
335
  return config