Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import cv2 | |
| import math | |
| import time | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from collections import deque | |
| from argparse import Namespace | |
| from torchvision import transforms | |
| # === RAFT Setup === | |
| sys.path.append("/app/preprocess/RAFT/core") | |
| from raft import RAFT | |
| from utils.utils import InputPadder | |
| # === CONFIG === | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| MODEL_PATH = "/app/RAFT/raft-things.pth" | |
| OUTPUT_VIDEO = "/app/full_tracked_output.mp4" | |
| OUTPUT_MASK_VIDEO = "/app/mask_output.mp4" | |
| STABILIZED_MASK = "/app/stabilized_mask_output.mp4" | |
| REVERSED_INPUT = "/app/reversed_input.mp4" | |
| # ========================================================== | |
| # === VIDEO UTILITIES ===================================== | |
| # ========================================================== | |
| def reverse_video(input_path, output_path): | |
| """ | |
| Reverse frames robustly β preserves all readable frames | |
| even if OpenCV metadata is off by one. | |
| """ | |
| cap = cv2.VideoCapture(input_path) | |
| if not cap.isOpened(): | |
| raise FileNotFoundError(f"β Could not open video: {input_path}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frames.append(frame) | |
| cap.release() | |
| if len(frames) == 0: | |
| raise ValueError("No frames read from video!") | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| for frame in reversed(frames): | |
| out.write(frame) | |
| out.release() | |
| cv2.destroyAllWindows() | |
| print(f"β Reversed {len(frames)} frames β {output_path}") | |
| return output_path | |
| def reverse_video_file_inplace(path_in): | |
| """ | |
| Reverse a video in-place without losing frames. | |
| """ | |
| tmp_path = path_in.replace(".mp4", "_tmp.mp4") | |
| reverse_video(path_in, tmp_path) | |
| os.replace(tmp_path, path_in) | |
| print(f"π Overwrote {path_in} with reversed version (same frame count).") | |
| # ========================================================== | |
| # === RAFT LOADING ========================================= | |
| # ========================================================== | |
| def load_raft_model(model_path): | |
| args = Namespace( | |
| small=False, | |
| mixed_precision=False, | |
| alternate_corr=False, | |
| dropout=0.0, | |
| max_depth=16, | |
| depth_network=False, | |
| depth_residual=False, | |
| depth_scale=1.0 | |
| ) | |
| model = torch.nn.DataParallel(RAFT(args)) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| return model.module.to(DEVICE).eval() | |
| def to_tensor(image): | |
| return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE) | |
| def compute_flow(model, img1, img2): | |
| t1, t2 = to_tensor(img1), to_tensor(img2) | |
| padder = InputPadder(t1.shape) | |
| t1, t2 = padder.pad(t1, t2) | |
| _, flow = model(t1, t2, iters=30, test_mode=True) | |
| flow = padder.unpad(flow)[0] | |
| return flow.permute(1, 2, 0).cpu().numpy() | |
| # ========================================================== | |
| # === FRAME / MASK HELPERS ================================ | |
| # ========================================================== | |
| def extract_frame(video_path, frame_number): | |
| cap = cv2.VideoCapture(video_path) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if not ret: | |
| return None | |
| return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| def save_mask(data): | |
| if data is None: | |
| return None, "β οΈ No mask data received!" | |
| if isinstance(data, dict): | |
| mask = data.get("mask") | |
| else: | |
| mask = data | |
| if mask is None: | |
| return None, "β οΈ Mask missing!" | |
| if mask.ndim == 3: | |
| mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY) | |
| else: | |
| mask_gray = mask | |
| _, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY) | |
| mask_path = "user_mask.png" | |
| cv2.imwrite(mask_path, bin_mask) | |
| return mask_path, f"β Saved mask ({np.count_nonzero(bin_mask)} painted pixels)" | |
| # ========================================================== | |
| # === UPDATED DYNAMIC CROP LOGIC =========================== | |
| # ========================================================== | |
| def compute_crop_box_from_mask_dynamic(first_frame_bgr, mask_path, pad=200): | |
| """ | |
| Compute a square crop region based on mask region + padding. | |
| Ensures equal width & height. | |
| """ | |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| if mask is None: | |
| raise FileNotFoundError(f"Mask not found: {mask_path}") | |
| H, W = mask.shape[:2] | |
| ys, xs = np.where(mask > 0) | |
| # Fallback to center crop if no mask | |
| if len(xs) == 0: | |
| cx, cy = W // 2, H // 2 | |
| size = min(W, H) // 2 | |
| return cx - size // 2, cy - size // 2, cx + size // 2, cy + size // 2 | |
| x_min, x_max = np.min(xs), np.max(xs) | |
| y_min, y_max = np.min(ys), np.max(ys) | |
| # Add padding | |
| x_min = max(0, x_min - pad) | |
| y_min = max(0, y_min - pad) | |
| x_max = min(W, x_max + pad) | |
| y_max = min(H, y_max + pad) | |
| # Make it square | |
| width = x_max - x_min | |
| height = y_max - y_min | |
| side = max(width, height) | |
| cx = (x_min + x_max) // 2 | |
| cy = (y_min + y_max) // 2 | |
| x_min = max(0, cx - side // 2) | |
| y_min = max(0, cy - side // 2) | |
| x_max = min(W, x_min + side) | |
| y_max = min(H, y_min + side) | |
| return int(x_min), int(y_min), int(x_max), int(y_max) | |
| def draw_crop_preview_on_frame(frame_rgb, crop_box, color=(0,255,0), thickness=2): | |
| x0, y0, x1, y1 = crop_box | |
| frame = frame_rgb.copy() | |
| cv2.rectangle(frame, (x0, y0), (x1, y1), color, thickness) | |
| return frame | |
| # ========================================================== | |
| # === STABILIZATION ======================================== | |
| # ========================================================== | |
| def stabilize_black_regions(input_video, output_path=STABILIZED_MASK, blend=0.3, sample_frames=10): | |
| """ | |
| Visually consistent black region stabilizer: | |
| - Repairs broken, thick edges and fills missing gaps. | |
| - Maintains consistent thickness and stable edges across frames. | |
| - Smooth temporal blending removes flicker and breathing effects. | |
| Args: | |
| input_video (str): Path to input mask video (black/white). | |
| output_path (str): Path to save stabilized video. | |
| blend (float): Temporal smoothing factor (0.0β1.0). | |
| sample_frames (int): Number of initial frames to sample for parameter estimation. | |
| """ | |
| cap = cv2.VideoCapture(input_video) | |
| if not cap.isOpened(): | |
| raise FileNotFoundError(f"Could not open video: {input_video}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| # === Step 1: Estimate global morphology parameters from first N frames === | |
| thickness_samples = [] | |
| count = 0 | |
| while count < sample_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV) | |
| dist = cv2.distanceTransform(mask, cv2.DIST_L2, 3) | |
| if np.any(mask > 0): | |
| thickness_samples.append(np.mean(dist[mask > 0])) | |
| count += 1 | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # rewind | |
| avg_thickness = np.median(thickness_samples) if thickness_samples else 5 | |
| k = int(np.clip(avg_thickness / 2.0, 3, 9)) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) | |
| min_area = (width * height) * 0.0005 | |
| print(f"π§ Fixed morphology parameters β kernel={k} | min_area={min_area:.1f}") | |
| prev_mask = None | |
| # === Step 2: Process all frames === | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV) | |
| # --- (A) Connectivity repair: bridge gaps & fill --- | |
| bridge_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) | |
| repaired = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, bridge_kernel, iterations=2) | |
| filled = cv2.morphologyEx(repaired, cv2.MORPH_CLOSE, | |
| cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)), iterations=2) | |
| filled = cv2.morphologyEx(filled, cv2.MORPH_OPEN, | |
| cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1) | |
| # --- (B) Edge thickness normalization --- | |
| dist = cv2.distanceTransform(cv2.bitwise_not(filled), cv2.DIST_L2, 3) | |
| normalized = (dist < avg_thickness * 1.2).astype(np.uint8) * 255 | |
| base_clean = cv2.bitwise_not(normalized) | |
| # --- (C) Morphological cleanup (fixed parameters) --- | |
| base_clean = cv2.morphologyEx(base_clean, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(base_clean, connectivity=8) | |
| filtered_mask = np.zeros_like(base_clean) | |
| for i in range(1, num_labels): | |
| area = stats[i, cv2.CC_STAT_AREA] | |
| component_mask = (labels == i).astype(np.uint8) * 255 | |
| if area >= min_area: | |
| filtered_mask = cv2.bitwise_or(filtered_mask, component_mask) | |
| else: | |
| # Merge small blobs softly | |
| merge_mask = cv2.dilate(component_mask, kernel, iterations=2) | |
| filtered_mask = cv2.bitwise_or(filtered_mask, merge_mask) | |
| # --- (D) Edge reinforcement --- | |
| edges = cv2.morphologyEx(filtered_mask, cv2.MORPH_GRADIENT, | |
| cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))) | |
| reinforced = cv2.bitwise_or(filtered_mask, edges) | |
| reinforced = cv2.morphologyEx(reinforced, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| reinforced = cv2.medianBlur(reinforced, 3) | |
| # --- (E) Temporal stabilization --- | |
| if prev_mask is not None: | |
| reinforced = cv2.addWeighted(reinforced, 1 - blend, prev_mask, blend, 0) | |
| reinforced = (reinforced > 127).astype(np.uint8) * 255 # re-binarize | |
| prev_mask = reinforced.copy() | |
| # Invert back to black region mask | |
| # clean = cv2.bitwise_not(reinforced) | |
| out.write(cv2.cvtColor(reinforced, cv2.COLOR_GRAY2BGR)) | |
| cap.release() | |
| out.release() | |
| print(f"β Visually stable and connected mask saved: {output_path}") | |
| return output_path | |
| # ========================================================== | |
| # === TRACKING ============================================= | |
| # ========================================================== | |
| def run_tracking(video_path, mask_path, selection_mode="All Pixels"): | |
| BLACK_THRESH = 1 | |
| HISTORY_LEN = 5 | |
| # --- Reverse input for backward tracking --- | |
| reversed_path = reverse_video(video_path, REVERSED_INPUT) | |
| cap = cv2.VideoCapture(reversed_path) | |
| model = load_raft_model(MODEL_PATH) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| print(f"ποΈ Input video: {total_frames} frames at {fps:.2f} FPS") | |
| ret, first_frame = cap.read() | |
| if not ret: | |
| return "β Could not read first frame.", None, None, None | |
| H, W = first_frame.shape[:2] | |
| # --- Compute dynamic square crop from mask --- | |
| x0, y0, x1, y1 = compute_crop_box_from_mask_dynamic(first_frame, mask_path, pad=200) | |
| cw, ch = x1 - x0, y1 - y0 | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H)) | |
| out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False) | |
| full_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| full_mask = cv2.resize(full_mask, (W, H), interpolation=cv2.INTER_NEAREST) | |
| crop_mask = full_mask[y0:y1, x0:x1] | |
| if selection_mode == "All Pixels": | |
| ys, xs = np.where(crop_mask > 0) | |
| else: | |
| gray_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) | |
| black_pixels = (gray_first[y0:y1, x0:x1] < BLACK_THRESH) | |
| combined = (crop_mask > 0) & black_pixels | |
| ys, xs = np.where(combined) | |
| tracked_points = np.vstack((xs, ys)).T.astype(np.float32) | |
| prev_full_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) | |
| prev_crop_rgb = prev_full_rgb[y0:y1, x0:x1] | |
| history = deque([True]*HISTORY_LEN, maxlen=HISTORY_LEN) | |
| stopped = False | |
| frame_idx = 0 | |
| curr_full_rgb = None | |
| # === Main tracking loop === | |
| while True: | |
| ret, curr_frame = cap.read() | |
| if not ret: | |
| break | |
| frame_idx += 1 | |
| curr_full_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) | |
| curr_crop_rgb = curr_full_rgb[y0:y1, x0:x1] | |
| gray_crop = cv2.cvtColor(curr_crop_rgb, cv2.COLOR_RGB2GRAY) | |
| # --- Optical flow between prev and curr --- | |
| flow_crop = compute_flow(model, prev_crop_rgb, curr_crop_rgb) | |
| vis_full = curr_full_rgb.copy() | |
| mask_full = np.full((H, W), 255, dtype=np.uint8) | |
| # --- Move tracked points --- | |
| new_points = [] | |
| for pt in tracked_points: | |
| px, py = int(pt[0]), int(pt[1]) | |
| if 0 <= px < cw and 0 <= py < ch: | |
| dx, dy = flow_crop[py, px] | |
| nx, ny = pt[0] + dx, pt[1] + dy | |
| nx = np.clip(nx, 0, cw-1) | |
| ny = np.clip(ny, 0, ch-1) | |
| new_points.append([nx, ny]) | |
| tracked_points = np.array(new_points, dtype=np.float32) | |
| # --- Detect black pixels --- | |
| black_mask = gray_crop < BLACK_THRESH | |
| black_indices = tracked_points.astype(int) | |
| has_black = any( | |
| 0 <= px < cw and 0 <= py < ch and black_mask[py, px] | |
| for px, py in black_indices | |
| ) | |
| history.append(has_black) | |
| # --- Painting logic --- | |
| if stopped: | |
| paint = False | |
| elif has_black: | |
| paint = True | |
| elif not any(history): # last N all False | |
| stopped = True | |
| paint = False | |
| else: | |
| paint = True | |
| # --- Paint or skip --- | |
| if paint: | |
| for pt in tracked_points: | |
| fx, fy = int(pt[0] + x0), int(pt[1] + y0) | |
| if 0 <= fx < W and 0 <= fy < H: | |
| cv2.circle(vis_full, (fx, fy), 1, (0,255,0), -1) | |
| mask_full[fy, fx] = 0 | |
| out_vis.write(cv2.cvtColor(vis_full, cv2.COLOR_RGB2BGR)) | |
| out_mask.write(mask_full) | |
| prev_crop_rgb = curr_crop_rgb | |
| if frame_idx % 10 == 0: | |
| print(f"Frame {frame_idx}: {'PAINT' if paint else 'NO-PAINT'} | has_black={has_black} | stopped={stopped}") | |
| # === Add final static frame to preserve frame count === | |
| try: | |
| if curr_full_rgb is not None: | |
| out_vis.write(cv2.cvtColor(curr_full_rgb, cv2.COLOR_RGB2BGR)) | |
| out_mask.write(mask_full) | |
| print("π§© Added final frame to preserve total frame count.") | |
| except Exception as e: | |
| print(f"β οΈ Could not add final frame: {e}") | |
| cap.release() | |
| out_vis.release() | |
| out_mask.release() | |
| # === Post-process: stabilization + reversal === | |
| stabilize_black_regions(OUTPUT_MASK_VIDEO) | |
| reverse_video_file_inplace(OUTPUT_VIDEO) | |
| reverse_video_file_inplace(OUTPUT_MASK_VIDEO) | |
| reverse_video_file_inplace(STABILIZED_MASK) | |
| # === Verify output frame counts === | |
| for path in [OUTPUT_VIDEO, OUTPUT_MASK_VIDEO, STABILIZED_MASK]: | |
| cap_test = cv2.VideoCapture(path) | |
| n = int(cap_test.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap_test.release() | |
| print(f"β Verified {os.path.basename(path)} β {n} frames") | |
| return ( | |
| f"β Tracking complete ({selection_mode}).\n" | |
| f"Square Crop {cw}x{ch} @ ({x0},{y0}) with padding=200\n" | |
| f"Painting stopped={'Yes' if stopped else 'No'} after {frame_idx} processed frames.\n" | |
| f"All outputs now match input frame count ({total_frames}).", | |
| OUTPUT_VIDEO, | |
| OUTPUT_MASK_VIDEO, | |
| STABILIZED_MASK | |
| ) | |
| # ========================================================== | |
| # === GRADIO APP =========================================== | |
| # ========================================================== | |
| def build_app(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π― Pixel Tracker (Dynamic Square Crop)") | |
| with gr.Row(): | |
| video_in = gr.Video(label="ποΈ Upload Video") | |
| frame_num = gr.Number(value=0, visible=False) | |
| load_btn = gr.Button("πΈ Load Frame for Annotation") | |
| annot = gr.Image(label="ποΈ Paint ROI Mask", tool="sketch", type="numpy", image_mode="RGBA", height=1000) | |
| save_btn = gr.Button("πΎ Save Mask") | |
| log = gr.Textbox(label="Logs", lines=8) | |
| preview_btn = gr.Button("π Preview Crop", visible=True) | |
| with gr.Row(): | |
| preview_frame = gr.Image(label="Preview Frame", visible=False) | |
| preview_crop = gr.Image(label="Cropped Region", visible=True) | |
| run_btn = gr.Button("π Run Tracking") | |
| with gr.Row(): | |
| result_video = gr.Video(label="π¬ Result (Forward)") | |
| mask_video = gr.Video(label="β¬ Mask (Forward)") | |
| stabilized_video = gr.Video(label="π§± Stabilized (Forward)") | |
| def load_reversed_frame(v, f): | |
| reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT) | |
| return extract_frame(reversed_path, int(f)) | |
| load_btn.click(load_reversed_frame, [video_in, frame_num], annot) | |
| save_btn.click(save_mask, annot, [gr.State(), log]) | |
| def preview_crop_fn(v): | |
| reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT) | |
| frame0 = extract_frame(reversed_path, 0) | |
| if frame0 is None or not os.path.exists("user_mask.png"): | |
| return None, None, "β οΈ Paint and Save Mask first." | |
| x0,y0,x1,y1 = compute_crop_box_from_mask_dynamic(cv2.cvtColor(frame0, cv2.COLOR_RGB2BGR), "user_mask.png", pad=200) | |
| frame_box = draw_crop_preview_on_frame(frame0, (x0,y0,x1,y1)) | |
| return frame_box, frame0[y0:y1, x0:x1], f"Square crop {x1-x0}x{y1-y0} at ({x0},{y0})" | |
| preview_btn.click(preview_crop_fn, video_in, [preview_frame, preview_crop, log]) | |
| def run_btn_fn(v, m): | |
| if not os.path.exists("user_mask.png"): | |
| return "β οΈ Save Mask first.", None, None, None | |
| return run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m) | |
| run_btn.click(run_btn_fn, [video_in, gr.Dropdown(["All Pixels", "Only Black Pixels"], value="All Pixels")], | |
| [log, result_video, mask_video, stabilized_video]) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch(server_name="0.0.0.0", server_port=7860, debug=True) |