dreamlessx commited on
Commit
82d5f3d
·
verified ·
1 Parent(s): 46ecbf8

Update landmarkdiff/synthetic/pair_generator.py to v0.3.2

Browse files
landmarkdiff/synthetic/pair_generator.py CHANGED
@@ -1,7 +1,10 @@
1
- """Synthetic pair generator for ControlNet fine-tuning.
2
 
3
- FFHQ -> landmarks -> random FFD -> conditioning + mask -> augment input.
4
- Augmentations on INPUT only, never target.
 
 
 
5
  """
6
 
7
  from __future__ import annotations
@@ -16,6 +19,7 @@ import numpy as np
16
  from landmarkdiff.conditioning import generate_conditioning
17
  from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
18
  from landmarkdiff.manipulation import (
 
19
  apply_procedure_preset,
20
  )
21
  from landmarkdiff.masking import generate_surgical_mask
@@ -27,16 +31,16 @@ from landmarkdiff.synthetic.tps_warp import warp_image_tps
27
  class TrainingPair:
28
  """A single training sample for ControlNet fine-tuning."""
29
 
30
- input_image: np.ndarray # augmented input (512x512 BGR)
31
- target_image: np.ndarray # clean target (512x512 BGR) - TPS-warped original
32
- conditioning: np.ndarray # landmark rendering (512x512 BGR)
33
- canny: np.ndarray # canny edge map (512x512 grayscale)
34
- mask: np.ndarray # feathered surgical mask (512x512 float32)
35
  procedure: str
36
  intensity: float
37
 
38
 
39
- PROCEDURES = ["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic"]
40
 
41
 
42
  def generate_pair(
@@ -46,7 +50,18 @@ def generate_pair(
46
  target_size: int = 512,
47
  rng: np.random.Generator | None = None,
48
  ) -> TrainingPair | None:
49
- """Generate a single training pair from a face image."""
 
 
 
 
 
 
 
 
 
 
 
50
  rng = rng or np.random.default_rng()
51
 
52
  # Resize to target
@@ -97,17 +112,38 @@ def generate_pairs_from_directory(
97
  num_pairs: int = 1000,
98
  target_size: int = 512,
99
  seed: int = 42,
 
 
100
  ) -> Iterator[TrainingPair]:
101
- """Generate training pairs from a directory of face images."""
 
 
 
 
 
 
 
 
 
 
 
 
102
  rng = np.random.default_rng(seed)
103
  image_dir = Path(image_dir)
104
 
105
  extensions = {".jpg", ".jpeg", ".png", ".webp"}
106
- image_files = sorted(f for f in image_dir.iterdir() if f.suffix.lower() in extensions)
 
 
 
107
 
108
  if not image_files:
109
  raise FileNotFoundError(f"No images found in {image_dir}")
110
 
 
 
 
 
111
  generated = 0
112
  consecutive_failures = 0
113
  idx = 0
@@ -123,6 +159,27 @@ def generate_pairs_from_directory(
123
  break
124
  continue
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  pair = generate_pair(image, target_size=target_size, rng=rng)
127
  if pair is not None:
128
  yield pair
@@ -134,6 +191,9 @@ def generate_pairs_from_directory(
134
  print(f"Warning: {consecutive_failures} consecutive failures, stopping early")
135
  break
136
 
 
 
 
137
 
138
  def save_pair(pair: TrainingPair, output_dir: Path, index: int) -> None:
139
  """Save a training pair to disk."""
 
1
+ """Synthetic training pair generator.
2
 
3
+ Creates (input, conditioning, mask, target) tuples for ControlNet fine-tuning.
4
+ Pipeline: FFHQ image -> extract landmarks -> random FFD manipulation ->
5
+ generate conditioning + mask -> apply clinical augmentation to input.
6
+
7
+ Augmentations are applied to INPUT only, never to target (ground truth).
8
  """
9
 
10
  from __future__ import annotations
 
19
  from landmarkdiff.conditioning import generate_conditioning
20
  from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
21
  from landmarkdiff.manipulation import (
22
+ PROCEDURE_LANDMARKS,
23
  apply_procedure_preset,
24
  )
25
  from landmarkdiff.masking import generate_surgical_mask
 
31
  class TrainingPair:
32
  """A single training sample for ControlNet fine-tuning."""
33
 
34
+ input_image: np.ndarray # augmented input (512x512 BGR)
35
+ target_image: np.ndarray # clean target (512x512 BGR) TPS-warped original
36
+ conditioning: np.ndarray # landmark rendering (512x512 BGR)
37
+ canny: np.ndarray # canny edge map (512x512 grayscale)
38
+ mask: np.ndarray # feathered surgical mask (512x512 float32)
39
  procedure: str
40
  intensity: float
41
 
42
 
43
+ PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
44
 
45
 
46
  def generate_pair(
 
50
  target_size: int = 512,
51
  rng: np.random.Generator | None = None,
52
  ) -> TrainingPair | None:
53
+ """Generate a single training pair from a face image.
54
+
55
+ Args:
56
+ image: BGR input image (any size).
57
+ procedure: Procedure type (random if None).
58
+ intensity: Manipulation intensity 0-100 (random 30-90 if None).
59
+ target_size: Output resolution.
60
+ rng: Random number generator.
61
+
62
+ Returns:
63
+ TrainingPair or None if face detection fails.
64
+ """
65
  rng = rng or np.random.default_rng()
66
 
67
  # Resize to target
 
112
  num_pairs: int = 1000,
113
  target_size: int = 512,
114
  seed: int = 42,
115
+ quality_check: bool = True,
116
+ min_quality: float = 45.0,
117
  ) -> Iterator[TrainingPair]:
118
+ """Generate training pairs from a directory of face images.
119
+
120
+ Args:
121
+ image_dir: Directory containing face images.
122
+ num_pairs: Total number of pairs to generate.
123
+ target_size: Output resolution.
124
+ seed: Random seed.
125
+ quality_check: Run face verifier quality check on source images.
126
+ min_quality: Minimum quality score to use image (0-100).
127
+
128
+ Yields:
129
+ TrainingPair instances.
130
+ """
131
  rng = np.random.default_rng(seed)
132
  image_dir = Path(image_dir)
133
 
134
  extensions = {".jpg", ".jpeg", ".png", ".webp"}
135
+ image_files = sorted(
136
+ f for f in image_dir.iterdir()
137
+ if f.suffix.lower() in extensions
138
+ )
139
 
140
  if not image_files:
141
  raise FileNotFoundError(f"No images found in {image_dir}")
142
 
143
+ # Optional quality pre-filter
144
+ _quality_cache: dict[str, float] = {}
145
+ quality_rejects = 0
146
+
147
  generated = 0
148
  consecutive_failures = 0
149
  idx = 0
 
159
  break
160
  continue
161
 
162
+ # Quality gate: reject low-quality source images before pair generation
163
+ if quality_check:
164
+ cache_key = str(img_path)
165
+ if cache_key not in _quality_cache:
166
+ try:
167
+ from landmarkdiff.face_verifier import analyze_distortions
168
+ resized = cv2.resize(image, (target_size, target_size))
169
+ report = analyze_distortions(resized)
170
+ _quality_cache[cache_key] = report.quality_score
171
+ except Exception:
172
+ _quality_cache[cache_key] = 100.0 # Can't check — allow through
173
+
174
+ if _quality_cache[cache_key] < min_quality:
175
+ quality_rejects += 1
176
+ if quality_rejects % 100 == 0:
177
+ print(f" Quality filter: {quality_rejects} images rejected so far")
178
+ consecutive_failures += 1
179
+ if consecutive_failures > len(image_files):
180
+ break
181
+ continue
182
+
183
  pair = generate_pair(image, target_size=target_size, rng=rng)
184
  if pair is not None:
185
  yield pair
 
191
  print(f"Warning: {consecutive_failures} consecutive failures, stopping early")
192
  break
193
 
194
+ if quality_rejects > 0:
195
+ print(f"Quality filter: rejected {quality_rejects} low-quality source images")
196
+
197
 
198
  def save_pair(pair: TrainingPair, output_dir: Path, index: int) -> None:
199
  """Save a training pair to disk."""