echotracker / app.py
riponazad
Update README with Taming paper, citations, and like request
077fa26
import gradio as gr
import os
import torch
import cv2
import numpy as np
import random
from PIL import Image
from utils import points_to_tensor
from utils import visualize_tracking
import mediapy as media
# ── Colormap (matches your viz_utils.get_colors logic) ───────────────────────
def get_colors(n):
"""Generate n random but unique colors in RGB 0-255."""
random.seed(42) # remove this line if you want different colors each run
# Spread hues evenly across 0-179 (HSV in OpenCV), then shuffle
hues = list(range(0, 180, max(1, 180 // n)))[:n]
random.shuffle(hues)
colors = []
for hue in hues:
# Randomize saturation and value slightly for more visual variety
sat = random.randint(180, 255)
val = random.randint(180, 255)
hsv = np.uint8([[[hue, sat, val]]])
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)[0][0]
colors.append(tuple(int(c) for c in rgb))
return colors
N_POINTS = 100
COLORMAP = get_colors(N_POINTS)
select_points = [] # will hold np.array([x, y]) entries
# ── Video helpers ─────────────────────────────────────────────────────────────
def get_frame(video_path: str, frame_idx: int) -> np.ndarray:
"""Extract a single frame from video by index."""
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
cap.release()
if not ret:
raise ValueError(f"Could not read frame {frame_idx}")
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
def get_total_frames(video_path: str) -> int:
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return total
# ── Draw points on frame ──────────────────────────────────────────────────────
def draw_points(frame: np.ndarray, points: list) -> np.ndarray:
"""Draw colored circle markers on frame for each selected point."""
out = frame.copy()
for i, pt in enumerate(points):
color = COLORMAP[i % N_POINTS] # RGB tuple
bgr = (color[2], color[1], color[0]) # cv2 uses BGR
cv2.circle(out, (pt[0], pt[1]), radius=6,
color=bgr, thickness=-1)
cv2.circle(out, (pt[0], pt[1]), radius=6,
color=(255, 255, 255), thickness=2) # white border
cv2.putText(out, str(i + 1), (pt[0] + 10, pt[1] - 6),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return out
_SAMPLES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_samples")
# JS injected into gr.Blocks β€” controls download availability on video players
_DOWNLOAD_CTRL_JS = """
(function () {
const EXAMPLE_IDS = ['video_upload_player', 'out_video_player'];
const USER_IDS = ['out_video_player'];
function applyNoDownload(ids) {
ids.forEach(function (id) {
var el = document.getElementById(id);
if (!el) return;
el.querySelectorAll('video').forEach(function (v) {
v.setAttribute('controlsList', 'nodownload');
v.oncontextmenu = function (e) { e.preventDefault(); };
});
el.querySelectorAll('a').forEach(function (a) {
a.style.cssText = 'display:none!important;pointer-events:none!important';
});
el.querySelectorAll('button').forEach(function (btn) {
var lbl = (btn.getAttribute('aria-label') || btn.getAttribute('title') || '').toLowerCase();
if (lbl.includes('download') || lbl.includes('save')) {
btn.style.cssText = 'display:none!important;pointer-events:none!important';
}
});
});
}
function clearNoDownload(ids) {
ids.forEach(function (id) {
var el = document.getElementById(id);
if (!el) return;
el.querySelectorAll('video').forEach(function (v) {
v.removeAttribute('controlsList');
v.oncontextmenu = null;
});
el.querySelectorAll('a').forEach(function (a) { a.style.cssText = ''; });
el.querySelectorAll('button').forEach(function (btn) { btn.style.cssText = ''; });
});
}
window._isExampleMode = false;
function applyCurrentMode() {
if (window._isExampleMode) applyNoDownload(EXAMPLE_IDS);
else clearNoDownload(USER_IDS);
}
/* Watch both containers for DOM changes (e.g. when video src updates) */
EXAMPLE_IDS.concat(['out_video_player']).forEach(function (id) {
(function tryObserve() {
var el = document.getElementById(id);
if (!el) { setTimeout(tryObserve, 400); return; }
new MutationObserver(applyCurrentMode)
.observe(el, { childList: true, subtree: true });
})();
});
/* Intercept value setter on hidden textbox to receive mode signal from Python */
function hookTrigger() {
var container = document.querySelector('#download_ctrl textarea');
if (!container) { setTimeout(hookTrigger, 300); return; }
var desc = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
Object.defineProperty(container, 'value', {
get: function () { return desc.get.call(this); },
set: function (v) {
desc.set.call(this, v);
window._isExampleMode = (v === '1');
applyCurrentMode();
},
configurable: true,
});
}
setTimeout(hookTrigger, 500);
})();
"""
# label β†’ (path, is_ood)
EXAMPLE_VIDEOS = {
"A4C": (os.path.join(_SAMPLES_DIR, "input1.mp4"), False),
"A4C (OOD)": (os.path.join(_SAMPLES_DIR, "input2.mp4"), True),
"RV (OOD)": (os.path.join(_SAMPLES_DIR, "input3_RV.mp4"), True),
"PSAX (OOD)": (os.path.join(_SAMPLES_DIR, "psax_video_crop.mp4"), True),
}
def _get_thumbnail(video_path: str) -> np.ndarray | None:
"""Extract a single frame near the middle of the video for use as a thumbnail."""
try:
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, int(total * 0.4)))
ret, frame = cap.read()
cap.release()
if ret:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
except Exception:
pass
return None
THUMBNAILS = {label: _get_thumbnail(path) for label, (path, _) in EXAMPLE_VIDEOS.items()}
# ── Gradio event handlers ─────────────────────────────────────────────────────
def on_video_upload(video_path):
"""Called when video is uploaded β€” jump to 72% frame."""
if video_path is None:
# return None, gr.update(value=0, maximum=0, interactive=False), "No video loaded.", []
return None
total = get_total_frames(video_path)
idx_72 = int(total * 0.72)
frame = get_frame(video_path, idx_72)
#drawn = draw_points(frame, select_points)
frame_display_update = gr.update(
value=frame,
interactive=True, # enables click events via gr.SelectData
)
slider_update = gr.update(
value=idx_72,
minimum=0,
maximum=total - 1,
step=1,
interactive=True,
label=f"Frame selector (total: {total} frames)"
)
select_points.clear() # clear any existing points when new video is loaded
status = f"πŸ“Ή Loaded β€” {total} frames | 🎞️ Showing frame {idx_72} (72%)"
# last value resets the download-control style (user upload β†’ downloads allowed)
return frame_display_update, slider_update, status, video_path, ""
def load_example(video_path):
"""Load an example video, reset all output/selection fields, and hide downloads."""
frame_upd, slider_upd, status, state, _ = on_video_upload(video_path)
return (
gr.update(value=video_path), # video_upload
frame_upd, # frame_display
slider_upd, # frame_slider
status, # status_text
state, # video_state
gr.update(value=None), # out_video β€” clear previous result
gr.update(value="No points selected yet."), # points_display
"1", # download_ctrl β€” disable downloads
)
def on_slider_release(frame_idx, video_path, points_display):
"""Called when slider is released β€” show new frame, keep existing points."""
if video_path is None:
return None, "No video loaded.", points_display
frame = get_frame(video_path, int(frame_idx))
select_points.clear() # clear any existing points when new video is loaded
#print(f"Selected point: {select_points}")
points_display = gr.update(
value="No points selected yet.",
label="πŸ“‹ Selected Points",
lines=5,
interactive=False,
)
#drawn = draw_points(frame, select_points)
status = f"🎞️ Showing Frame {int(frame_idx)} ({int(frame_idx) / get_total_frames(video_path) * 100:.1f}%) | {len(select_points)} point(s) selected"
return frame, status, points_display
def on_point_select(frame_idx, video_path, evt: gr.SelectData):
"""Called when user clicks on the image β€” add point, redraw."""
if video_path is None:
return None, "Upload a video first.", format_points()
if len(select_points) >= N_POINTS:
status = f"⚠️ Max {N_POINTS} points reached."
frame = get_frame(video_path, int(frame_idx))
return draw_points(frame, select_points), status, format_points()
x, y = int(evt.index[0]), int(evt.index[1])
select_points.append(np.array([x, y]))
#print(f"Selected point: {select_points}")
frame = get_frame(video_path, int(frame_idx))
drawn = draw_points(frame, select_points)
status = f"βœ… Point {len(select_points)} added at ({x}, {y}) | Frame {int(frame_idx)}"
return drawn, status, format_points()
def on_clear_points(frame_idx, video_path):
"""Clear all selected points."""
select_points.clear()
if video_path is None:
return None, "Points cleared.", format_points()
frame = get_frame(video_path, int(frame_idx))
return draw_points(frame, select_points), "πŸ—‘οΈ All points cleared.", format_points()
def on_undo_point(frame_idx, video_path):
"""Remove last selected point."""
if select_points:
removed = select_points.pop()
msg = f"↩️ Removed point at ({removed[0]}, {removed[1]})"
else:
msg = "No points to undo."
if video_path is None:
return None, msg, format_points()
frame = get_frame(video_path, int(frame_idx))
return draw_points(frame, select_points), msg, format_points()
def format_points():
"""Format select_points for display in the textbox."""
if not select_points:
return "No points selected yet."
lines = [f" [{i+1}] x={p[0]}, y={p[1]}" for i, p in enumerate(select_points)]
return "select_points:\n" + "\n".join(lines)
def track(video_path, frame_idx, out_video, target_size=(256, 256)):
"""Placeholder for tracking function β€” replace with your actual tracking logic."""
if video_path is None:
status = f"⚠️ No video loaded. Cannot run the tracker."
return status
if len(select_points) < 1:
status = f"⚠️ No points selected. Please select at least one point to track."
return status
tracker, device = load_model("echotracker_cvamd_ts.pt")
cap = cv2.VideoCapture(video_path)
W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frames = []
paint_frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
paint_frames.append(frame)
frame = cv2.resize(frame, target_size)
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)))
cap.release()
paint_frames = np.array(paint_frames)
frames = torch.from_numpy(np.array(frames)).unsqueeze(0).unsqueeze(2).float().to(device) # shape: [B, T, H, W]
q_points = points_to_tensor(select_points, frame_idx, H, W, 256).to(device) # shape: [1, N, 3]
#print(f"βœ… Loaded video frames: {frames.shape} {paint_frames.shape}")
# print(f"Selected points: {q_points.shape}")
with torch.no_grad():
output = tracker(frames, q_points)
trajs_e = output[-1].cpu().permute(0, 2, 1, 3)
q_points[...,1] /= 256 - 1
q_points[...,2] /= 256 - 1
trajs_e[...,0] /= 256 - 1
trajs_e[...,1] /= 256 - 1
#print(f"Tracker output trajectories: {trajs_e.shape}")
paint_frames = visualize_tracking(
frames=paint_frames, points=trajs_e.squeeze().cpu().numpy(),
vis_color='random',
thickness=5,
track_length=30,
)
# Save or display paint_frames as needed (e.g., save as video or show in Gradio)
out_vid = "outputs/output.mp4"
os.makedirs("outputs", exist_ok=True)
media.write_video(out_vid, paint_frames, fps=25)
status = f"βœ… Tracking completed! The output is visualized below."
out_video = gr.update(value=out_vid, autoplay=True, loop=True)
return out_video, status
def load_model(model_path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
"""Load a torchscript model
Args:
model_path (str): path to the torchscript weights
device (str, optional): Defaults to "cuda" if torch.cuda.is_available() else "cpu".
Returns:
model: the loaded torchscript model
"""
model = torch.jit.load(model_path, map_location=device).eval()
#print(f"βœ… TorchScript model loaded on {device}")
return model, device
# ── Gradio UI ─────────────────────────────────────────────────────────────────
HEADER = """
<div style="text-align:center; padding: 20px 0 8px;">
<h1 style="font-size:2.2rem; font-weight:700; margin-bottom:4px;">πŸ«€ EchoTracker</h1>
<p style="font-size:1.05rem; color:var(--echo-muted); margin:4px 0 0;">
Advancing Myocardial Point Tracking in Echocardiography
</p>
<p style="font-size:0.9rem; color:var(--echo-subtle); margin:2px 0 0;">
MICCAI 2024 &nbsp;Β·&nbsp;
Azad, Chernyshov, Nyberg, Tveten, Lovstakken, Dalen, Grenne, Østvik
</p>
<p style="font-size:0.9rem; color:var(--echo-subtle); margin:4px 0 0;">
Model weights from: <em>Taming Modern Point Tracking for Speckle Tracking Echocardiography via Impartial Motion</em>
&nbsp;Β·&nbsp; ICCV 2025 Workshop &nbsp;Β·&nbsp;
Azad, Nyberg, Dalen, Grenne, Lovstakken, Østvik
</p>
<div style="margin-top:12px; display:flex; justify-content:center; gap:10px; flex-wrap:wrap;">
<a href="https://link.springer.com/chapter/10.1007/978-3-031-72083-3_60"
target="_blank"
style="display:inline-flex;align-items:center;gap:5px;padding:5px 14px;border-radius:6px;
background:#2563eb;color:white;font-size:0.85rem;text-decoration:none;font-weight:500;">
πŸ“„ Paper (MICCAI 2024)
</a>
<a href="https://openaccess.thecvf.com/content/ICCV2025W/CVAMD/papers/Azad_Taming_Modern_Point_Tracking_for_Speckle_Tracking_Echocardiography_via_Impartial_CVAMD_2025_paper.pdf"
target="_blank"
style="display:inline-flex;align-items:center;gap:5px;padding:5px 14px;border-radius:6px;
background:#2563eb;color:white;font-size:0.85rem;text-decoration:none;font-weight:500;">
πŸ“„ Paper (ICCV 2025 Workshop)
</a>
<a href="https://arxiv.org/abs/2405.08587" target="_blank"
style="display:inline-flex;align-items:center;gap:5px;padding:5px 14px;border-radius:6px;
background:#dc2626;color:white;font-size:0.85rem;text-decoration:none;font-weight:500;">
πŸ“ ArXiv (EchoTracker)
</a>
<a href="https://arxiv.org/abs/2507.10127" target="_blank"
style="display:inline-flex;align-items:center;gap:5px;padding:5px 14px;border-radius:6px;
background:#dc2626;color:white;font-size:0.85rem;text-decoration:none;font-weight:500;">
πŸ“ ArXiv (Taming)
</a>
<a href="https://github.com/riponazad/echotracker" target="_blank"
style="display:inline-flex;align-items:center;gap:5px;padding:5px 14px;border-radius:6px;
background:#1f2937;color:white;font-size:0.85rem;text-decoration:none;font-weight:500;">
πŸ’» GitHub
</a>
<a href="https://riponazad.github.io/echotracker/" target="_blank"
style="display:inline-flex;align-items:center;gap:5px;padding:5px 14px;border-radius:6px;
background:#7c3aed;color:white;font-size:0.85rem;text-decoration:none;font-weight:500;">
🌐 Project Page
</a>
</div>
</div>
"""
CITATION_MD = """
If you use EchoTracker or the model weights in this demo, please cite both papers:
```bibtex
@InProceedings{azad2024echotracker,
author = {Azad, Md Abulkalam and Chernyshov, Artem and Nyberg, John
and Tveten, Ingrid and Lovstakken, Lasse and Dalen, H{\\aa}vard
and Grenne, Bj{\\o}rnar and {\\O}stvik, Andreas},
title = {EchoTracker: Advancing Myocardial Point Tracking in Echocardiography},
booktitle = {Medical Image Computing and Computer Assisted Intervention -- MICCAI 2024},
year = {2024},
publisher = {Springer Nature Switzerland},
doi = {10.1007/978-3-031-72083-3_60}
}
@InProceedings{Azad_2025_ICCV,
author = {Azad, Md Abulkalam and Nyberg, John and Dalen, H{\\aa}vard
and Grenne, Bj{\\o}rnar and Lovstakken, Lasse and {\\O}stvik, Andreas},
title = {Taming Modern Point Tracking for Speckle Tracking Echocardiography via Impartial Motion},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops},
month = {October},
year = {2025},
pages = {1115--1124}
}
```
"""
with gr.Blocks(title="EchoTracker", theme=gr.themes.Soft(),
css="""
.gr-button { font-weight: 600; }
:root { --echo-muted: #444; --echo-subtle: #666; }
.dark { --echo-muted: #c0c0c0; --echo-subtle: #a8a8a8; }
""",
js=_DOWNLOAD_CTRL_JS) as demo:
gr.HTML(HEADER)
gr.Markdown("---")
# ── Instructions ──────────────────────────────────────────────────────────
with gr.Accordion("ℹ️ How to use", open=False):
gr.Markdown("""
1. **Load a video** β€” upload your own echocardiography clip, or click one of the provided example videos below the panel.
2. **Navigate** to the desired query frame using the frame slider.
3. **Click** on the frame image to place tracking points on cardiac tissue surfaces (e.g. LV/RV walls, myocardium).
4. Use **Undo** or **Clear All** to adjust your selection.
5. Press **β–Ά Run EchoTracker** to generate tracked trajectories for all selected points.
> **Tip:** Select points at the *end-diastolic* frame for best results. Up to 100 points are supported.
> Example clips cover apical 4-chamber (A4C), right-ventricle (RV), and parasternal short-axis (PSAX) views.
> Clips marked **OOD** (πŸ”Ά) are out-of-distribution β€” different scanner or view not seen during training, showcasing EchoTracker's generalisation ability.
""")
# hidden state
video_state = gr.State(value=None)
# injects/removes CSS that hides download buttons on example videos
download_ctrl = gr.Textbox(value="0", visible=False, elem_id="download_ctrl")
gr.Markdown("### Step 1 β€” Upload & Select Query Points")
gr.Markdown(
"Upload your own echocardiography video, or click one of the **example clips** below to get started."
)
with gr.Row(equal_height=False):
# ── Left column: input + points ───────────────────────────────────────
with gr.Column(scale=1, min_width=300):
video_upload = gr.Video(
label="Echocardiography Video β€” upload yours or use an example below",
sources="upload",
include_audio=False,
autoplay=True,
loop=True,
elem_id="video_upload_player",
)
points_display = gr.Textbox(
value="No points selected yet.",
label="πŸ“‹ Selected Query Points",
lines=5,
max_lines=5,
interactive=False,
)
gr.Markdown(
"<small style='color:var(--echo-subtle)'>Coordinates are stored as "
"<code>np.array([x, y])</code> and passed to the tracker.</small>"
)
# ── Right column: frame viewer + controls ─────────────────────────────
with gr.Column(scale=2, min_width=400):
frame_display = gr.Image(
label="Query Frame β€” click to place tracking points",
interactive=True,
type="numpy",
sources=[],
)
frame_slider = gr.Slider(
minimum=0, maximum=100, value=0, step=1,
label="Frame",
interactive=False,
)
status_text = gr.Textbox(
label="Status", lines=1, interactive=False, show_label=False,
placeholder="Status messages will appear here…",
)
with gr.Row():
undo_btn = gr.Button("↩ Undo", scale=1)
clear_btn = gr.Button("πŸ—‘ Clear All", variant="stop", scale=1)
gr.Markdown("---")
gr.Markdown("### Step 2 β€” Run Tracker & View Output")
with gr.Row():
with gr.Column(scale=1):
run_btn = gr.Button("β–Ά Run EchoTracker", variant="primary", size="lg")
with gr.Column(scale=2):
out_video = gr.Video(
label="Tracking Output",
sources=[],
include_audio=False,
interactive=False,
autoplay=True,
loop=True,
elem_id="out_video_player",
)
gr.Markdown("---")
gr.Markdown(
"**Or try an example clip** "
"<small style='color:var(--echo-subtle)'>β€” OOD = out-of-distribution (different scanner / view not seen during training)</small>"
)
gr.Markdown(
"> ⚠️ **Example videos are provided for demonstration purposes only. "
"They should not be downloaded, reproduced, or used for any purpose outside this demo.**"
)
ex_btns = []
with gr.Row(equal_height=True):
for label, (path, is_ood) in EXAMPLE_VIDEOS.items():
with gr.Column(min_width=120):
gr.Image(
value=THUMBNAILS[label],
show_label=False,
interactive=False,
height=110,
container=False,
)
btn_label = f"{label} πŸ”Ά" if is_ood else label
ex_btns.append(gr.Button(btn_label, size="sm"))
# ── Like request ──────────────────────────────────────────────────────────
gr.Markdown(
"<div style='text-align:center; padding: 8px 0;'>"
"If you find this demo useful, please click the ❀️ <b>Like</b> button at the top of this Space β€” "
"it helps others discover this work and supports open research in cardiac image analysis."
"</div>"
)
# ── Citation ──────────────────────────────────────────────────────────────
with gr.Accordion("πŸ“ Citation", open=False):
gr.Markdown(CITATION_MD)
# ── Wire events ───────────────────────────────────────────────────────────
video_upload.upload(
fn=on_video_upload,
inputs=[video_upload],
outputs=[frame_display, frame_slider, status_text, video_state, download_ctrl]
)
frame_slider.release(
fn=on_slider_release,
inputs=[frame_slider, video_state, points_display],
outputs=[frame_display, status_text, points_display]
)
frame_display.select(
fn=on_point_select,
inputs=[frame_slider, video_state],
outputs=[frame_display, status_text, points_display]
)
undo_btn.click(
fn=on_undo_point,
inputs=[frame_slider, video_state],
outputs=[frame_display, status_text, points_display]
)
clear_btn.click(
fn=on_clear_points,
inputs=[frame_slider, video_state],
outputs=[frame_display, status_text, points_display]
)
for btn, (path, _) in zip(ex_btns, EXAMPLE_VIDEOS.values()):
btn.click(
fn=load_example,
inputs=gr.State(path),
outputs=[video_upload, frame_display, frame_slider, status_text, video_state,
out_video, points_display, download_ctrl]
)
run_btn.click(
fn=track,
inputs=[video_state, frame_slider, out_video],
outputs=[out_video, status_text]
)
demo.launch(share=False)