Spaces:
Running
Running
| import argparse | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from model import SimpleHRNet, ViTHeatmap | |
| from heatmap_utils import heatmaps_to_coords_dark | |
| from secure_torch_load import secure_torch_load | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Cephalogram landmark inference app") | |
| parser.add_argument("--checkpoint", type=str, default="best.pt.enc", help="Path to model checkpoint") | |
| parser.add_argument("--device", type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), help="Torch device, e.g. cuda or cpu") | |
| parser.add_argument("--server-port", type=int, default=44065, help="Port for Gradio app") | |
| parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Host for Gradio app") | |
| parser.add_argument("--share", action="store_true", help="Enable public Gradio share link") | |
| parser.add_argument("--inbrowser", action="store_true", help="Open app in browser on launch") | |
| return parser.parse_args() | |
| def load_model(checkpoint_path, device): | |
| ckpt = secure_torch_load(checkpoint_path, map_location="cpu") | |
| # ckpt = torch.load(checkpoint_path, map_location="cpu") | |
| args = ckpt["args"] | |
| landmark_symbols = ckpt.get("landmark_symbols", None) | |
| if args["model"] == "hrnet": | |
| model = SimpleHRNet(num_landmarks=args["num_landmarks"]) | |
| else: | |
| model = ViTHeatmap( | |
| num_landmarks=args["num_landmarks"], | |
| model_name=args["vit_name"], | |
| pretrained=False, | |
| img_size=(args["input_height"], args["input_width"]), | |
| ) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| model.to(device) | |
| model.eval() | |
| return model, args, landmark_symbols | |
| def get_symbols(n, checkpoint_symbols): | |
| if checkpoint_symbols is not None and len(checkpoint_symbols) == n: | |
| return checkpoint_symbols | |
| return [f"LM_{i}" for i in range(n)] | |
| def preprocess(image, model_args, device): | |
| h_orig, w_orig = image.shape[:2] | |
| h_in = model_args["input_height"] | |
| w_in = model_args["input_width"] | |
| resized = cv2.resize(image, (w_in, h_in)) | |
| tensor = torch.from_numpy(resized).permute(2, 0, 1).float() / 255.0 | |
| tensor = tensor.unsqueeze(0).to(device) | |
| return tensor, (h_orig, w_orig), (h_in, w_in) | |
| def decode(pred_heatmaps, orig_size, input_size): | |
| h_orig, w_orig = orig_size | |
| h_in, w_in = input_size | |
| h_hm, w_hm = pred_heatmaps.shape[2], pred_heatmaps.shape[3] | |
| coords_hm = heatmaps_to_coords_dark(pred_heatmaps)[0] | |
| coords_in = coords_hm.clone() | |
| coords_in[:, 0] *= (w_in / w_hm) | |
| coords_in[:, 1] *= (h_in / h_hm) | |
| coords_orig = coords_in.clone() | |
| coords_orig[:, 0] *= (w_orig / w_in) | |
| coords_orig[:, 1] *= (h_orig / h_in) | |
| return coords_orig.cpu().numpy() | |
| def compute_confidence(heatmaps): | |
| hm = heatmaps[0].detach().cpu().numpy() | |
| return hm.reshape(hm.shape[0], -1).max(axis=1) | |
| def draw_points(image, coords, symbols, color=(255, 0, 0)): | |
| out = image.copy() | |
| h, w = out.shape[:2] | |
| for i, (x, y) in enumerate(coords): | |
| x, y = int(round(float(x))), int(round(float(y))) | |
| if 0 <= x < w and 0 <= y < h: | |
| cv2.circle(out, (x, y), 4, color, -1, lineType=cv2.LINE_AA) | |
| cv2.putText( | |
| out, | |
| symbols[i], | |
| (x + 5, y - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.4, | |
| color, | |
| 1, | |
| cv2.LINE_AA, | |
| ) | |
| return out | |
| def heatmap_overlay(image, heatmap): | |
| h, w = image.shape[:2] | |
| hm = cv2.resize(heatmap, (w, h), interpolation=cv2.INTER_LINEAR) | |
| hm = (hm - hm.min()) / (hm.max() - hm.min() + 1e-6) | |
| hm_color = cv2.applyColorMap((hm * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
| hm_color = cv2.cvtColor(hm_color, cv2.COLOR_BGR2RGB) | |
| return cv2.addWeighted(image, 0.6, hm_color, 0.4, 0) | |
| def make_single_landmark_view(orig, coords, symbols, hm_np, idx): | |
| out = heatmap_overlay(orig, hm_np[idx]) | |
| out = draw_points( | |
| out, | |
| np.array([coords[idx]], dtype=np.float32), | |
| [symbols[idx]], | |
| color=(255, 255, 255), | |
| ) | |
| return out | |
| def build_demo(model, model_args, checkpoint_symbols, device): | |
| default_symbols = get_symbols(model_args["num_landmarks"], checkpoint_symbols) | |
| def run_inference(image): | |
| if image is None: | |
| return None, None, None, None, None, None, gr.Dropdown() | |
| orig = image.copy() | |
| tensor, orig_size, input_size = preprocess(orig, model_args, device) | |
| with torch.no_grad(): | |
| heatmaps = model(tensor) | |
| coords = decode(heatmaps, orig_size, input_size) | |
| hm_np = heatmaps[0].detach().cpu().numpy() | |
| conf = compute_confidence(heatmaps) | |
| symbols = get_symbols(len(coords), checkpoint_symbols) | |
| pred_overlay = draw_points(orig, coords, symbols) | |
| summed_overlay = heatmap_overlay(orig, hm_np.sum(axis=0)) | |
| single_overlay = make_single_landmark_view(orig, coords, symbols, hm_np, 0) | |
| table = [ | |
| [symbols[i], float(coords[i, 0]), float(coords[i, 1]), float(conf[i])] | |
| for i in range(len(symbols)) | |
| ] | |
| cache = { | |
| "orig": orig, | |
| "coords": coords, | |
| "symbols": symbols, | |
| "heatmaps": hm_np, | |
| "pred_overlay": pred_overlay, | |
| "summed_overlay": summed_overlay, | |
| "table": table, | |
| } | |
| dropdown_update = gr.Dropdown(choices=symbols, value=symbols[0]) | |
| return orig, pred_overlay, summed_overlay, single_overlay, table, cache, dropdown_update | |
| def update_selected_landmark(selected_landmark, cache): | |
| if cache is None: | |
| return None | |
| symbols = cache["symbols"] | |
| idx = symbols.index(selected_landmark) if selected_landmark in symbols else 0 | |
| return make_single_landmark_view( | |
| cache["orig"], | |
| cache["coords"], | |
| cache["symbols"], | |
| cache["heatmaps"], | |
| idx, | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Cephalogram Landmark Inference") | |
| cache_state = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=320): | |
| input_image = gr.Image(type="numpy", label="Input Image", height=420) | |
| run_button = gr.Button("Run Inference", variant="primary") | |
| selected_landmark = gr.Dropdown( | |
| choices=default_symbols, | |
| value=default_symbols[0], | |
| label="Landmark Heatmap Selector", | |
| ) | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| out_orig = gr.Image(label="Original", height=284) | |
| out_pred = gr.Image(label="Predictions", height=284) | |
| with gr.Row(): | |
| out_sum = gr.Image(label="All-Landmark Heatmap Overlay", height=284) | |
| out_single = gr.Image(label="Selected Landmark Heatmap Overlay", height=284) | |
| out_table = gr.Dataframe( | |
| headers=["Landmark", "X", "Y", "Confidence"], | |
| label="Predictions", | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| run_button.click( | |
| fn=run_inference, | |
| inputs=[input_image], | |
| outputs=[ | |
| out_orig, | |
| out_pred, | |
| out_sum, | |
| out_single, | |
| out_table, | |
| cache_state, | |
| selected_landmark, | |
| ], | |
| ) | |
| selected_landmark.change( | |
| fn=update_selected_landmark, | |
| inputs=[selected_landmark, cache_state], | |
| outputs=[out_single], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| cli_args = parse_args() | |
| model, model_args, checkpoint_symbols = load_model(cli_args.checkpoint, cli_args.device) | |
| checkpoint_symbols = [ | |
| "A", "ANS", "B", "Me", "N", "Or", "Pog", "PNS", "Pn", "R", | |
| "S", "Ar", "Co", "Gn", "Go", "Po", "LPM", "LIT", "LMT", "UPM", | |
| "UIA", "UIT", "UMT", "LIA", "Li", "Ls", "N`", "Pog`", "Sn" | |
| ] # TEMPORARY HARD CODE | |
| demo = build_demo(model, model_args, checkpoint_symbols, cli_args.device) | |
| demo.launch( | |
| # server_name=cli_args.server_name, | |
| # server_port=cli_args.server_port, | |
| # share=cli_args.share, | |
| # inbrowser=cli_args.inbrowser, | |
| ) | |