dreamlessx commited on
Commit
efcf612
·
verified ·
1 Parent(s): 41f1384

Update landmarkdiff/cli.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/cli.py +29 -30
landmarkdiff/cli.py CHANGED
@@ -51,7 +51,6 @@ def cmd_infer(args: argparse.Namespace) -> None:
51
 
52
  if args.watermark:
53
  from landmarkdiff.safety import SafetyValidator
54
-
55
  validator = SafetyValidator()
56
  watermarked = validator.apply_watermark(result["output"])
57
  wm_path = out_path.with_stem(out_path.stem + "_watermarked")
@@ -78,26 +77,36 @@ def cmd_ensemble(args: argparse.Namespace) -> None:
78
 
79
 
80
  def cmd_evaluate(args: argparse.Namespace) -> None:
81
- """Run evaluation on test set."""
82
- from pathlib import Path
83
-
84
- # Import evaluation functions
85
- sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
86
- from scripts.run_evaluation import run_evaluation
87
-
88
- run_evaluation(
89
- test_dir=args.test_dir,
90
- output_dir=args.output,
91
- checkpoint=args.checkpoint,
92
- max_samples=args.max_samples,
93
  )
 
 
 
 
 
 
 
94
 
95
 
96
  def cmd_config(args: argparse.Namespace) -> None:
97
  """Show or validate configuration."""
98
  from landmarkdiff.config import ExperimentConfig, load_config, validate_config
99
 
100
- config = load_config(args.file) if args.file else ExperimentConfig()
 
 
 
101
 
102
  if args.validate:
103
  warnings = validate_config(config)
@@ -111,7 +120,6 @@ def cmd_config(args: argparse.Namespace) -> None:
111
  from dataclasses import asdict
112
 
113
  import yaml
114
-
115
  print(yaml.dump(asdict(config), default_flow_style=False, sort_keys=False))
116
 
117
 
@@ -147,7 +155,6 @@ def cmd_validate(args: argparse.Namespace) -> None:
147
  def cmd_version(args: argparse.Namespace) -> None:
148
  """Print version info."""
149
  from landmarkdiff import __version__
150
-
151
  print(f"LandmarkDiff v{__version__}")
152
 
153
 
@@ -162,14 +169,11 @@ def main(argv: list[str] | None = None) -> None:
162
  # --- infer ---
163
  p_infer = subparsers.add_parser("infer", help="Run single-image inference")
164
  p_infer.add_argument("image", help="Input face image path")
165
- p_infer.add_argument(
166
- "--procedure",
167
- default="rhinoplasty",
168
- choices=["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic"],
169
- )
170
  p_infer.add_argument("--intensity", type=float, default=65.0)
171
  p_infer.add_argument("--output", default="output.png")
172
- p_infer.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
173
  p_infer.add_argument("--checkpoint", default=None)
174
  p_infer.add_argument("--displacement-model", default=None)
175
  p_infer.add_argument("--seed", type=int, default=42)
@@ -183,12 +187,9 @@ def main(argv: list[str] | None = None) -> None:
183
  p_ensemble.add_argument("--intensity", type=float, default=65.0)
184
  p_ensemble.add_argument("--output", default="ensemble_output")
185
  p_ensemble.add_argument("--n-samples", type=int, default=5)
186
- p_ensemble.add_argument(
187
- "--strategy",
188
- default="best_of_n",
189
- choices=["pixel_average", "weighted_average", "best_of_n", "median"],
190
- )
191
- p_ensemble.add_argument("--mode", default="tps", choices=["controlnet", "img2img", "tps"])
192
  p_ensemble.add_argument("--checkpoint", default=None)
193
  p_ensemble.add_argument("--displacement-model", default=None)
194
  p_ensemble.add_argument("--seed", type=int, default=42)
@@ -198,9 +199,7 @@ def main(argv: list[str] | None = None) -> None:
198
  p_eval = subparsers.add_parser("evaluate", help="Evaluate on test set")
199
  p_eval.add_argument("--test-dir", required=True)
200
  p_eval.add_argument("--output", default="eval_results")
201
- p_eval.add_argument("--mode", default="tps")
202
  p_eval.add_argument("--checkpoint", default=None)
203
- p_eval.add_argument("--displacement-model", default=None)
204
  p_eval.add_argument("--max-samples", type=int, default=0)
205
  p_eval.set_defaults(func=cmd_evaluate)
206
 
 
51
 
52
  if args.watermark:
53
  from landmarkdiff.safety import SafetyValidator
 
54
  validator = SafetyValidator()
55
  watermarked = validator.apply_watermark(result["output"])
56
  wm_path = out_path.with_stem(out_path.stem + "_watermarked")
 
77
 
78
 
79
  def cmd_evaluate(args: argparse.Namespace) -> None:
80
+ """Run evaluation on test set.
81
+
82
+ Delegates to scripts/run_evaluation.py via subprocess to avoid
83
+ a circular dependency (landmarkdiff package should not import
84
+ from scripts/).
85
+ """
86
+ import subprocess
87
+
88
+ script = str(
89
+ __import__("pathlib").Path(__file__).resolve().parent.parent
90
+ / "scripts"
91
+ / "run_evaluation.py"
92
  )
93
+ cmd = [sys.executable, script, "--test_dir", args.test_dir, "--output", args.output]
94
+ if args.checkpoint:
95
+ cmd += ["--checkpoint", args.checkpoint]
96
+ if args.max_samples:
97
+ cmd += ["--max_samples", str(args.max_samples)]
98
+
99
+ subprocess.run(cmd, check=True)
100
 
101
 
102
  def cmd_config(args: argparse.Namespace) -> None:
103
  """Show or validate configuration."""
104
  from landmarkdiff.config import ExperimentConfig, load_config, validate_config
105
 
106
+ if args.file:
107
+ config = load_config(args.file)
108
+ else:
109
+ config = ExperimentConfig()
110
 
111
  if args.validate:
112
  warnings = validate_config(config)
 
120
  from dataclasses import asdict
121
 
122
  import yaml
 
123
  print(yaml.dump(asdict(config), default_flow_style=False, sort_keys=False))
124
 
125
 
 
155
  def cmd_version(args: argparse.Namespace) -> None:
156
  """Print version info."""
157
  from landmarkdiff import __version__
 
158
  print(f"LandmarkDiff v{__version__}")
159
 
160
 
 
169
  # --- infer ---
170
  p_infer = subparsers.add_parser("infer", help="Run single-image inference")
171
  p_infer.add_argument("image", help="Input face image path")
172
+ p_infer.add_argument("--procedure", default="rhinoplasty",
173
+ choices=["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic", "brow_lift", "mentoplasty"])
 
 
 
174
  p_infer.add_argument("--intensity", type=float, default=65.0)
175
  p_infer.add_argument("--output", default="output.png")
176
+ p_infer.add_argument("--mode", default="tps", choices=["controlnet", "controlnet_ip", "controlnet_fast", "img2img", "tps"])
177
  p_infer.add_argument("--checkpoint", default=None)
178
  p_infer.add_argument("--displacement-model", default=None)
179
  p_infer.add_argument("--seed", type=int, default=42)
 
187
  p_ensemble.add_argument("--intensity", type=float, default=65.0)
188
  p_ensemble.add_argument("--output", default="ensemble_output")
189
  p_ensemble.add_argument("--n-samples", type=int, default=5)
190
+ p_ensemble.add_argument("--strategy", default="best_of_n",
191
+ choices=["pixel_average", "weighted_average", "best_of_n", "median"])
192
+ p_ensemble.add_argument("--mode", default="tps", choices=["controlnet", "controlnet_ip", "controlnet_fast", "img2img", "tps"])
 
 
 
193
  p_ensemble.add_argument("--checkpoint", default=None)
194
  p_ensemble.add_argument("--displacement-model", default=None)
195
  p_ensemble.add_argument("--seed", type=int, default=42)
 
199
  p_eval = subparsers.add_parser("evaluate", help="Evaluate on test set")
200
  p_eval.add_argument("--test-dir", required=True)
201
  p_eval.add_argument("--output", default="eval_results")
 
202
  p_eval.add_argument("--checkpoint", default=None)
 
203
  p_eval.add_argument("--max-samples", type=int, default=0)
204
  p_eval.set_defaults(func=cmd_evaluate)
205