dreamlessx commited on
Commit
57d9540
·
verified ·
1 Parent(s): 59c75b7

Upload landmarkdiff/__main__.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/__main__.py +136 -0
landmarkdiff/__main__.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI entry point for python -m landmarkdiff."""
2
+
3
+ import argparse
4
+ import sys
5
+
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser(
9
+ prog="landmarkdiff",
10
+ description="Facial surgery outcome prediction from clinical photography",
11
+ )
12
+ parser.add_argument("--version", action="store_true", help="Print version and exit")
13
+
14
+ subparsers = parser.add_subparsers(dest="command")
15
+
16
+ # inference
17
+ infer = subparsers.add_parser("infer", help="Run inference on an image")
18
+ infer.add_argument("image", type=str, help="Path to input face image")
19
+ infer.add_argument("--procedure", type=str, default="rhinoplasty",
20
+ choices=["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic", "brow_lift", "mentoplasty"])
21
+ infer.add_argument("--intensity", type=float, default=60.0,
22
+ help="Deformation intensity (0-100)")
23
+ infer.add_argument("--mode", type=str, default="tps",
24
+ choices=["tps", "controlnet", "img2img", "controlnet_ip"])
25
+ infer.add_argument("--output", type=str, default="output/")
26
+ infer.add_argument("--steps", type=int, default=30)
27
+ infer.add_argument("--seed", type=int, default=None)
28
+
29
+ # landmarks
30
+ lm = subparsers.add_parser("landmarks", help="Extract and visualize landmarks")
31
+ lm.add_argument("image", type=str, help="Path to input face image")
32
+ lm.add_argument("--output", type=str, default="output/landmarks.png")
33
+
34
+ # demo
35
+ subparsers.add_parser("demo", help="Launch Gradio web demo")
36
+
37
+ args = parser.parse_args()
38
+
39
+ if args.version:
40
+ from landmarkdiff import __version__
41
+ print(f"landmarkdiff {__version__}")
42
+ return
43
+
44
+ if args.command is None:
45
+ parser.print_help()
46
+ return
47
+
48
+ if args.command == "infer":
49
+ _run_inference(args)
50
+ elif args.command == "landmarks":
51
+ _run_landmarks(args)
52
+ elif args.command == "demo":
53
+ _run_demo()
54
+
55
+
56
+ def _run_inference(args):
57
+ from pathlib import Path
58
+ import numpy as np
59
+ from PIL import Image
60
+ from landmarkdiff.landmarks import extract_landmarks
61
+ from landmarkdiff.manipulation import apply_procedure_preset
62
+
63
+ output_dir = Path(args.output)
64
+ output_dir.mkdir(parents=True, exist_ok=True)
65
+
66
+ img = Image.open(args.image).convert("RGB").resize((512, 512))
67
+ img_array = np.array(img)
68
+
69
+ landmarks = extract_landmarks(img_array)
70
+ if landmarks is None:
71
+ print("no face detected")
72
+ sys.exit(1)
73
+
74
+ deformed = apply_procedure_preset(landmarks, args.procedure, intensity=args.intensity)
75
+
76
+ if args.mode == "tps":
77
+ from landmarkdiff.synthetic.tps_warp import warp_image_tps
78
+ src = landmarks.pixel_coords[:, :2].copy()
79
+ dst = deformed.pixel_coords[:, :2].copy()
80
+ src[:, 0] *= 512 / landmarks.image_width
81
+ src[:, 1] *= 512 / landmarks.image_height
82
+ dst[:, 0] *= 512 / deformed.image_width
83
+ dst[:, 1] *= 512 / deformed.image_height
84
+ warped = warp_image_tps(img_array, src, dst)
85
+ Image.fromarray(warped).save(str(output_dir / "prediction.png"))
86
+ print(f"saved tps result to {output_dir / 'prediction.png'}")
87
+ else:
88
+ from landmarkdiff.inference import LandmarkDiffPipeline
89
+ pipeline = LandmarkDiffPipeline(mode=args.mode, device="cuda")
90
+ pipeline.load()
91
+ result = pipeline.generate(
92
+ img_array,
93
+ procedure=args.procedure,
94
+ intensity=args.intensity,
95
+ num_inference_steps=args.steps,
96
+ seed=args.seed,
97
+ )
98
+ result["output"].save(str(output_dir / "prediction.png"))
99
+ print(f"saved result to {output_dir / 'prediction.png'}")
100
+
101
+
102
+ def _run_landmarks(args):
103
+ from pathlib import Path
104
+ import numpy as np
105
+ from PIL import Image
106
+ from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
107
+
108
+ img = np.array(Image.open(args.image).convert("RGB").resize((512, 512)))
109
+ landmarks = extract_landmarks(img)
110
+ if landmarks is None:
111
+ print("no face detected")
112
+ sys.exit(1)
113
+
114
+ mesh = render_landmark_image(landmarks, 512, 512)
115
+
116
+ output_path = Path(args.output)
117
+ output_path.parent.mkdir(parents=True, exist_ok=True)
118
+
119
+ from PIL import Image
120
+ Image.fromarray(mesh).save(str(output_path))
121
+ print(f"saved landmark mesh to {output_path}")
122
+ print(f"detected {len(landmarks.landmarks)} landmarks, confidence {landmarks.confidence:.2f}")
123
+
124
+
125
+ def _run_demo():
126
+ try:
127
+ from scripts.app import build_app
128
+ demo = build_app()
129
+ demo.launch()
130
+ except ImportError:
131
+ print("gradio not installed - run: pip install landmarkdiff[app]")
132
+ sys.exit(1)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()