Spaces:
Sleeping
Sleeping
import os | |
import random | |
import time | |
import datetime | |
import cv2 | |
import numpy as np | |
import torch | |
import imageio | |
from pathlib import Path | |
from tqdm import tqdm | |
import gradio as gr | |
# Import your custom modules | |
import utils.loss | |
import utils.samp | |
import utils.data | |
import utils.improc | |
import utils.misc | |
import utils.saveload | |
from nets.blocks import InputPadder | |
from nets.net34 import Net | |
import imageio | |
from demo_dense_visualize import Tracker | |
import spaces | |
# Set torch matmul precision (as in your original code) | |
torch.set_float32_matmul_precision('medium') | |
# -------------------- Utility Functions -------------------- # | |
def count_parameters(model): | |
total_params = 0 | |
for name, parameter in model.named_parameters(): | |
if not parameter.requires_grad: | |
continue | |
total_params += parameter.numel() | |
print('Total params: %.2f M' % (total_params/1e6)) | |
return total_params | |
def seed_everything(seed: int): | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
# torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
seed_everything(42) | |
torch.set_grad_enabled(False) | |
# -------------------- Model Loading -------------------- # | |
url = "https://huggingface.co/aharley/alltracker/resolve/main/alltracker.pth" | |
state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu') | |
model = Net(16) | |
count_parameters(model) | |
model.load_state_dict(state_dict['model'], strict=True) | |
print('loaded ckpt') | |
device = 'cpu:0' | |
model.to(device) | |
for n, p in model.named_parameters(): | |
p.requires_grad = False | |
model.eval() | |
tracker = Tracker( | |
model=model, | |
mean=torch.tensor([0.485, 0.456, 0.406]).to(device).reshape(1, 3, 1, 1), | |
std=torch.tensor([0.229, 0.224, 0.225]).to(device).reshape(1, 3, 1, 1), | |
S=16, | |
stride=8, | |
inference_iters=4, | |
target_res=1024, | |
device=device, | |
) | |
# -------------------- Step 1: Extract the First Frame -------------------- # | |
def extract_first_frame(video_path, _): | |
""" | |
Opens the video, extracts the first frame, resizes it (largest dimension 1024), | |
and returns: | |
- the frame for display (to be annotated), | |
- the video file path (to store in state), | |
- a copy of the original first frame (to be used when adding points) | |
""" | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return None, None, None | |
ret, frame = cap.read() | |
cap.release() | |
if not ret: | |
return None, video_path, None | |
# Convert BGR to RGB | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
scale = min(tracker.target_res / frame_rgb.shape[0], tracker.target_res / frame_rgb.shape[1]) | |
frame_resized = cv2.resize(frame_rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) | |
# Return the displayed frame, the video file path, and a copy of the original frame for point drawing. | |
return frame_resized, video_path, frame_resized.copy(), [] | |
# -------------------- Callback to Add a Clicked Point -------------------- # | |
def add_point(orig_frame, points, evt: gr.SelectData): | |
""" | |
Called when the user clicks on the displayed first frame. | |
- orig_frame: The original first frame image (numpy array). | |
- points: The current list of point coordinates. | |
- evt: Event data from the image click (expects evt.index as (x, y)). | |
Returns the updated image (with circles drawn at all points) | |
and the updated list of points. | |
""" | |
if points is None: | |
points = [] | |
# evt.index contains the (x, y) coordinates of the click. | |
x, y = evt.index | |
new_points = points.copy() | |
new_points.append([x, y]) | |
# Draw circles on a copy of the original image. | |
updated_frame = orig_frame.copy() | |
for (px, py) in new_points: | |
cv2.circle(updated_frame, (int(round(px)), int(round(py))), radius=5, color=(0,255,0), thickness=-1) | |
# print(updated_frame.shape) | |
return updated_frame, new_points | |
# -------------------- Step 2: Process Video & Track Points -------------------- # | |
def process_video_with_points(video_path, click_points): | |
""" | |
Runs the dense flow prediction over the entire video, tracking the user-selected points. | |
Args: | |
video_path: Path to the uploaded video. | |
click_points: List of [x, y] coordinates selected on the first frame. | |
(Coordinates are in the same (resized) space as the displayed first frame.) | |
Returns: | |
A path to the output video with tracked points overlaid. | |
""" | |
if len(click_points) == 0: | |
print("No points selected for tracking.") | |
return "Error: No points selected for tracking." | |
# Open the video. | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return "Error: Could not open video." | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
# List to store frames with overlaid points. | |
output_frames = [] | |
# Initialize the points with those selected on the first frame. | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
pbar = tqdm(total=total_frames, desc="Processing video") | |
tracker.reset() | |
frame_disps = [] | |
try: | |
while True: | |
if 'cuda' in device: | |
torch.cuda.empty_cache() | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convert frame from BGR to RGB and resize as in your original code. | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
scale = min(tracker.target_res / frame_rgb.shape[0], tracker.target_res / frame_rgb.shape[1]) | |
frame_disp = cv2.resize(frame_rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) | |
frame_disps.append(frame_disp) | |
flows = tracker.track(frame_rgb) | |
if flows is not None: | |
flows_np = flows[0].cpu().numpy() | |
for i, flow_np in enumerate(flows_np): | |
# --- Update tracked points using the flow --- | |
current_points = [] | |
for (x, y) in click_points: | |
xi = int(round(x)) | |
yi = int(round(y)) | |
# print('xi, yi', xi, yi) | |
if 0 <= yi < flow_np.shape[1] and 0 <= xi < flow_np.shape[2]: | |
dx = flow_np[0, yi, xi] | |
dy = flow_np[1, yi, xi] | |
# print('dx, dy', dx, dy) | |
else: | |
dx, dy = 0.0, 0.0 | |
current_points.append([x + dx, y + dy]) | |
# Draw the updated points on the frame. | |
for (x, y) in current_points: | |
cv2.circle(frame_disps[i], (int(round(x)), int(round(y))), radius=5, color=(0,255,0), thickness=-1) | |
output_frames.append(frame_disps[i]) | |
frame_disps = [] | |
pbar.update(1) | |
except RuntimeError as e: | |
# Check if the error message indicates an OOM error. | |
if "out of memory" in str(e).lower(): | |
if 'cuda' in device: | |
torch.cuda.empty_cache() | |
pbar.close() | |
cap.release() | |
print("Error: Out of Memory during video processing.") | |
return "Error: Out of Memory during video processing. Please try a smaller video or lower resolution." | |
else: | |
# Re-raise if it's another type of error. | |
raise e | |
pbar.close() | |
cap.release() | |
# -------------------- Save the Output Video -------------------- # | |
output_path = "tracked_output.mp4" | |
print(len(output_frames), output_frames[0].shape) | |
imageio.mimwrite(output_path, output_frames, fps=fps) | |
return output_path | |
# -------------------- Wrappers to Update Tracker Based on UI Settings -------------------- # | |
def extract_with_config(video_path, points, resolution, window_index): | |
""" | |
Update the tracker configuration using the slider values, then extract the first frame. | |
- resolution: Target resolution from slider (e.g., 512, 768, 1024). | |
- window_index: An index (0–3) to be mapped to sliding window lengths {0:2, 1:4, 2:8, 3:16}. | |
""" | |
tracker.target_res = resolution | |
mapping = {0: 2, 1: 4, 2: 8, 3: 16} | |
tracker.S = mapping.get(int(window_index), 16) | |
return extract_first_frame(video_path, points) | |
def process_with_config(video_path, click_points, resolution, window_index): | |
""" | |
Update the tracker configuration using the slider values, then process the video. | |
""" | |
tracker.target_res = resolution | |
mapping = {0: 2, 1: 4, 2: 8, 3: 16} | |
tracker.S = mapping.get(int(window_index), 16) | |
return process_video_with_points(video_path, click_points) | |
if __name__ == '__main__': | |
# -------------------- Gradio Interface -------------------- # | |
# The interface is built in two steps: | |
# 1. Upload a video and extract the first frame. | |
# 2. Annotate the first frame with multiple points (using gr.Points), | |
# then run tracking on the video. | |
with gr.Blocks() as demo: | |
gr.Markdown("## Dense Flow Tracking with Clickable Points") | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="Upload Video", value="172620-847860540_small.mp4") | |
extract_btn = gr.Button("Extract First Frame") | |
# Add sliders for resolution and sliding window length. | |
resolution_slider = gr.Slider(minimum=512, maximum=1024, step=256, value=1024, label="Target Resolution") | |
# The slider below outputs an index 0-3; we'll map it to {0:2, 1:4, 2:8, 3:16} | |
window_slider = gr.Slider(minimum=0, maximum=3, step=1, value=3, label="Sliding Window Length (Index: 0->2, 1->4, 2->8, 3->16)") | |
with gr.Column(): | |
# This image will display the first frame and be interactive. | |
first_frame_display = gr.Image(label="First Frame (Click to add points)", interactive=True) | |
clear_pts_btn = gr.Button("Clear Points") | |
# Hidden states: video file path, original first frame, and accumulated click points. | |
video_state = gr.State(None) | |
orig_frame_state = gr.State(None) | |
points_state = gr.State([]) | |
track_btn = gr.Button("Track Points") | |
output_video = gr.Video(label="Tracked Video") | |
# When "Extract First Frame" is clicked, extract and display the first frame. | |
extract_btn.click( | |
fn=extract_with_config, | |
inputs=[video_input, points_state, resolution_slider, window_slider], | |
outputs=[first_frame_display, video_state, orig_frame_state, points_state] | |
) | |
clear_pts_btn.click( | |
fn=lambda _, __: (orig_frame_state, []), | |
inputs=[orig_frame_state, points_state], | |
outputs=[first_frame_display, points_state] | |
) | |
# When the user clicks on the image, add a point. | |
first_frame_display.select( | |
fn=add_point, | |
inputs=[orig_frame_state, points_state], | |
outputs=[first_frame_display, points_state] | |
) | |
# When "Track Points" is clicked, process the video using the accumulated points. | |
track_btn.click( | |
fn=process_with_config, | |
inputs=[video_state, points_state, resolution_slider, window_slider], | |
outputs=output_video | |
) | |
demo.launch() | |