dreamlessx's picture
Update landmarkdiff/cli.py to v0.3.2
efcf612 verified
"""Unified CLI for LandmarkDiff.
Usage:
landmarkdiff infer IMAGE --procedure rhinoplasty --intensity 65
landmarkdiff evaluate --test-dir data/test --checkpoint checkpoints/latest
landmarkdiff train --config configs/phaseA.yaml
landmarkdiff demo IMAGE --output demo_report.png
landmarkdiff config --show
landmarkdiff validate IMAGE --output validated.png
"""
from __future__ import annotations
import argparse
import sys
def cmd_infer(args: argparse.Namespace) -> None:
"""Run single-image inference."""
from pathlib import Path
import cv2
from landmarkdiff.inference import LandmarkDiffPipeline
image = cv2.imread(args.image)
if image is None:
print(f"ERROR: Cannot read image: {args.image}")
sys.exit(1)
image = cv2.resize(image, (512, 512))
pipeline = LandmarkDiffPipeline(
mode=args.mode,
controlnet_checkpoint=args.checkpoint,
displacement_model_path=args.displacement_model,
)
pipeline.load()
result = pipeline.generate(
image,
procedure=args.procedure,
intensity=args.intensity,
seed=args.seed,
)
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(out_path), result["output"])
print(f"Output saved: {out_path}")
if args.watermark:
from landmarkdiff.safety import SafetyValidator
validator = SafetyValidator()
watermarked = validator.apply_watermark(result["output"])
wm_path = out_path.with_stem(out_path.stem + "_watermarked")
cv2.imwrite(str(wm_path), watermarked)
print(f"Watermarked: {wm_path}")
def cmd_ensemble(args: argparse.Namespace) -> None:
"""Run ensemble inference."""
from landmarkdiff.ensemble import ensemble_inference
ensemble_inference(
image_path=args.image,
procedure=args.procedure,
intensity=args.intensity,
output_dir=args.output,
n_samples=args.n_samples,
strategy=args.strategy,
mode=args.mode,
controlnet_checkpoint=args.checkpoint,
displacement_model_path=args.displacement_model,
seed=args.seed,
)
def cmd_evaluate(args: argparse.Namespace) -> None:
"""Run evaluation on test set.
Delegates to scripts/run_evaluation.py via subprocess to avoid
a circular dependency (landmarkdiff package should not import
from scripts/).
"""
import subprocess
script = str(
__import__("pathlib").Path(__file__).resolve().parent.parent
/ "scripts"
/ "run_evaluation.py"
)
cmd = [sys.executable, script, "--test_dir", args.test_dir, "--output", args.output]
if args.checkpoint:
cmd += ["--checkpoint", args.checkpoint]
if args.max_samples:
cmd += ["--max_samples", str(args.max_samples)]
subprocess.run(cmd, check=True)
def cmd_config(args: argparse.Namespace) -> None:
"""Show or validate configuration."""
from landmarkdiff.config import ExperimentConfig, load_config, validate_config
if args.file:
config = load_config(args.file)
else:
config = ExperimentConfig()
if args.validate:
warnings = validate_config(config)
if warnings:
print("Validation warnings:")
for w in warnings:
print(f" - {w}")
else:
print("Configuration valid (no warnings).")
else:
from dataclasses import asdict
import yaml
print(yaml.dump(asdict(config), default_flow_style=False, sort_keys=False))
def cmd_validate(args: argparse.Namespace) -> None:
"""Run safety validation on an output image."""
import cv2
from landmarkdiff.safety import SafetyValidator
input_img = cv2.imread(args.input)
output_img = cv2.imread(args.output_image)
if input_img is None or output_img is None:
print("ERROR: Cannot read input or output image.")
sys.exit(1)
validator = SafetyValidator(
watermark_enabled=args.watermark,
)
result = validator.validate(
input_image=input_img,
output_image=output_img,
face_confidence=args.face_confidence,
)
print(result.summary())
if not result.passed:
sys.exit(1)
def cmd_version(args: argparse.Namespace) -> None:
"""Print version info."""
from landmarkdiff import __version__
print(f"LandmarkDiff v{__version__}")
def main(argv: list[str] | None = None) -> None:
"""Main CLI entry point."""
parser = argparse.ArgumentParser(
prog="landmarkdiff",
description="LandmarkDiff: Facial surgery outcome prediction via latent diffusion",
)
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# --- infer ---
p_infer = subparsers.add_parser("infer", help="Run single-image inference")
p_infer.add_argument("image", help="Input face image path")
p_infer.add_argument("--procedure", default="rhinoplasty",
choices=["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic", "brow_lift", "mentoplasty"])
p_infer.add_argument("--intensity", type=float, default=65.0)
p_infer.add_argument("--output", default="output.png")
p_infer.add_argument("--mode", default="tps", choices=["controlnet", "controlnet_ip", "controlnet_fast", "img2img", "tps"])
p_infer.add_argument("--checkpoint", default=None)
p_infer.add_argument("--displacement-model", default=None)
p_infer.add_argument("--seed", type=int, default=42)
p_infer.add_argument("--watermark", action="store_true")
p_infer.set_defaults(func=cmd_infer)
# --- ensemble ---
p_ensemble = subparsers.add_parser("ensemble", help="Run ensemble inference")
p_ensemble.add_argument("image", help="Input face image path")
p_ensemble.add_argument("--procedure", default="rhinoplasty")
p_ensemble.add_argument("--intensity", type=float, default=65.0)
p_ensemble.add_argument("--output", default="ensemble_output")
p_ensemble.add_argument("--n-samples", type=int, default=5)
p_ensemble.add_argument("--strategy", default="best_of_n",
choices=["pixel_average", "weighted_average", "best_of_n", "median"])
p_ensemble.add_argument("--mode", default="tps", choices=["controlnet", "controlnet_ip", "controlnet_fast", "img2img", "tps"])
p_ensemble.add_argument("--checkpoint", default=None)
p_ensemble.add_argument("--displacement-model", default=None)
p_ensemble.add_argument("--seed", type=int, default=42)
p_ensemble.set_defaults(func=cmd_ensemble)
# --- evaluate ---
p_eval = subparsers.add_parser("evaluate", help="Evaluate on test set")
p_eval.add_argument("--test-dir", required=True)
p_eval.add_argument("--output", default="eval_results")
p_eval.add_argument("--checkpoint", default=None)
p_eval.add_argument("--max-samples", type=int, default=0)
p_eval.set_defaults(func=cmd_evaluate)
# --- config ---
p_config = subparsers.add_parser("config", help="Show or validate configuration")
p_config.add_argument("--file", default=None, help="YAML config file")
p_config.add_argument("--validate", action="store_true")
p_config.set_defaults(func=cmd_config)
# --- validate ---
p_validate = subparsers.add_parser("validate", help="Run safety validation")
p_validate.add_argument("input", help="Original input image")
p_validate.add_argument("output_image", help="Generated output image")
p_validate.add_argument("--watermark", action="store_true")
p_validate.add_argument("--face-confidence", type=float, default=1.0)
p_validate.set_defaults(func=cmd_validate)
# --- version ---
p_version = subparsers.add_parser("version", help="Print version")
p_version.set_defaults(func=cmd_version)
args = parser.parse_args(argv)
if not hasattr(args, "func"):
parser.print_help()
sys.exit(1)
args.func(args)
if __name__ == "__main__":
main()