CephVIT / app.py
farrell236's picture
Upload 4 files
325d063 verified
import argparse
import gradio as gr
import numpy as np
import cv2
import torch
from model import SimpleHRNet, ViTHeatmap
from heatmap_utils import heatmaps_to_coords_dark
from secure_torch_load import secure_torch_load
def parse_args():
parser = argparse.ArgumentParser(description="Cephalogram landmark inference app")
parser.add_argument("--checkpoint", type=str, default="best.pt.enc", help="Path to model checkpoint")
parser.add_argument("--device", type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), help="Torch device, e.g. cuda or cpu")
parser.add_argument("--server-port", type=int, default=44065, help="Port for Gradio app")
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Host for Gradio app")
parser.add_argument("--share", action="store_true", help="Enable public Gradio share link")
parser.add_argument("--inbrowser", action="store_true", help="Open app in browser on launch")
return parser.parse_args()
def load_model(checkpoint_path, device):
ckpt = secure_torch_load(checkpoint_path, map_location="cpu")
# ckpt = torch.load(checkpoint_path, map_location="cpu")
args = ckpt["args"]
landmark_symbols = ckpt.get("landmark_symbols", None)
if args["model"] == "hrnet":
model = SimpleHRNet(num_landmarks=args["num_landmarks"])
else:
model = ViTHeatmap(
num_landmarks=args["num_landmarks"],
model_name=args["vit_name"],
pretrained=False,
img_size=(args["input_height"], args["input_width"]),
)
model.load_state_dict(ckpt["model_state_dict"])
model.to(device)
model.eval()
return model, args, landmark_symbols
def get_symbols(n, checkpoint_symbols):
if checkpoint_symbols is not None and len(checkpoint_symbols) == n:
return checkpoint_symbols
return [f"LM_{i}" for i in range(n)]
def preprocess(image, model_args, device):
h_orig, w_orig = image.shape[:2]
h_in = model_args["input_height"]
w_in = model_args["input_width"]
resized = cv2.resize(image, (w_in, h_in))
tensor = torch.from_numpy(resized).permute(2, 0, 1).float() / 255.0
tensor = tensor.unsqueeze(0).to(device)
return tensor, (h_orig, w_orig), (h_in, w_in)
def decode(pred_heatmaps, orig_size, input_size):
h_orig, w_orig = orig_size
h_in, w_in = input_size
h_hm, w_hm = pred_heatmaps.shape[2], pred_heatmaps.shape[3]
coords_hm = heatmaps_to_coords_dark(pred_heatmaps)[0]
coords_in = coords_hm.clone()
coords_in[:, 0] *= (w_in / w_hm)
coords_in[:, 1] *= (h_in / h_hm)
coords_orig = coords_in.clone()
coords_orig[:, 0] *= (w_orig / w_in)
coords_orig[:, 1] *= (h_orig / h_in)
return coords_orig.cpu().numpy()
def compute_confidence(heatmaps):
hm = heatmaps[0].detach().cpu().numpy()
return hm.reshape(hm.shape[0], -1).max(axis=1)
def draw_points(image, coords, symbols, color=(255, 0, 0)):
out = image.copy()
h, w = out.shape[:2]
for i, (x, y) in enumerate(coords):
x, y = int(round(float(x))), int(round(float(y)))
if 0 <= x < w and 0 <= y < h:
cv2.circle(out, (x, y), 4, color, -1, lineType=cv2.LINE_AA)
cv2.putText(
out,
symbols[i],
(x + 5, y - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.4,
color,
1,
cv2.LINE_AA,
)
return out
def heatmap_overlay(image, heatmap):
h, w = image.shape[:2]
hm = cv2.resize(heatmap, (w, h), interpolation=cv2.INTER_LINEAR)
hm = (hm - hm.min()) / (hm.max() - hm.min() + 1e-6)
hm_color = cv2.applyColorMap((hm * 255).astype(np.uint8), cv2.COLORMAP_JET)
hm_color = cv2.cvtColor(hm_color, cv2.COLOR_BGR2RGB)
return cv2.addWeighted(image, 0.6, hm_color, 0.4, 0)
def make_single_landmark_view(orig, coords, symbols, hm_np, idx):
out = heatmap_overlay(orig, hm_np[idx])
out = draw_points(
out,
np.array([coords[idx]], dtype=np.float32),
[symbols[idx]],
color=(255, 255, 255),
)
return out
def build_demo(model, model_args, checkpoint_symbols, device):
default_symbols = get_symbols(model_args["num_landmarks"], checkpoint_symbols)
def run_inference(image):
if image is None:
return None, None, None, None, None, None, gr.Dropdown()
orig = image.copy()
tensor, orig_size, input_size = preprocess(orig, model_args, device)
with torch.no_grad():
heatmaps = model(tensor)
coords = decode(heatmaps, orig_size, input_size)
hm_np = heatmaps[0].detach().cpu().numpy()
conf = compute_confidence(heatmaps)
symbols = get_symbols(len(coords), checkpoint_symbols)
pred_overlay = draw_points(orig, coords, symbols)
summed_overlay = heatmap_overlay(orig, hm_np.sum(axis=0))
single_overlay = make_single_landmark_view(orig, coords, symbols, hm_np, 0)
table = [
[symbols[i], float(coords[i, 0]), float(coords[i, 1]), float(conf[i])]
for i in range(len(symbols))
]
cache = {
"orig": orig,
"coords": coords,
"symbols": symbols,
"heatmaps": hm_np,
"pred_overlay": pred_overlay,
"summed_overlay": summed_overlay,
"table": table,
}
dropdown_update = gr.Dropdown(choices=symbols, value=symbols[0])
return orig, pred_overlay, summed_overlay, single_overlay, table, cache, dropdown_update
def update_selected_landmark(selected_landmark, cache):
if cache is None:
return None
symbols = cache["symbols"]
idx = symbols.index(selected_landmark) if selected_landmark in symbols else 0
return make_single_landmark_view(
cache["orig"],
cache["coords"],
cache["symbols"],
cache["heatmaps"],
idx,
)
with gr.Blocks() as demo:
gr.Markdown("## Cephalogram Landmark Inference")
cache_state = gr.State()
with gr.Row():
with gr.Column(scale=1, min_width=320):
input_image = gr.Image(type="numpy", label="Input Image", height=420)
run_button = gr.Button("Run Inference", variant="primary")
selected_landmark = gr.Dropdown(
choices=default_symbols,
value=default_symbols[0],
label="Landmark Heatmap Selector",
)
with gr.Column(scale=2):
with gr.Row():
out_orig = gr.Image(label="Original", height=284)
out_pred = gr.Image(label="Predictions", height=284)
with gr.Row():
out_sum = gr.Image(label="All-Landmark Heatmap Overlay", height=284)
out_single = gr.Image(label="Selected Landmark Heatmap Overlay", height=284)
out_table = gr.Dataframe(
headers=["Landmark", "X", "Y", "Confidence"],
label="Predictions",
interactive=False,
wrap=True,
)
run_button.click(
fn=run_inference,
inputs=[input_image],
outputs=[
out_orig,
out_pred,
out_sum,
out_single,
out_table,
cache_state,
selected_landmark,
],
)
selected_landmark.change(
fn=update_selected_landmark,
inputs=[selected_landmark, cache_state],
outputs=[out_single],
)
return demo
if __name__ == "__main__":
cli_args = parse_args()
model, model_args, checkpoint_symbols = load_model(cli_args.checkpoint, cli_args.device)
checkpoint_symbols = [
"A", "ANS", "B", "Me", "N", "Or", "Pog", "PNS", "Pn", "R",
"S", "Ar", "Co", "Gn", "Go", "Po", "LPM", "LIT", "LMT", "UPM",
"UIA", "UIT", "UMT", "LIA", "Li", "Ls", "N`", "Pog`", "Sn"
] # TEMPORARY HARD CODE
demo = build_demo(model, model_args, checkpoint_symbols, cli_args.device)
demo.launch(
# server_name=cli_args.server_name,
# server_port=cli_args.server_port,
# share=cli_args.share,
# inbrowser=cli_args.inbrowser,
)