| |
| try: |
| import spaces |
|
|
| ZEROGPU_AVAILABLE = True |
| except ImportError: |
| ZEROGPU_AVAILABLE = False |
| print("Warning: spaces module not available. Running without ZeroGPU support.") |
|
|
| import gradio as gr |
| import tempfile |
| import os |
| import torch |
| import gc |
| from demo_utils import load_model, process_video, save_video, image_to_video |
| import av |
| from PIL import Image |
| import numpy as np |
|
|
| model_cache = {} |
|
|
|
|
| def get_model(device): |
| if device not in model_cache: |
| model_cache[device] = load_model(device=device) |
| return model_cache[device] |
|
|
|
|
| |
| if ZEROGPU_AVAILABLE: |
| device = "cuda" |
| print("Using ZeroGPU (CUDA device will be allocated on demand)") |
| elif torch.cuda.is_available(): |
| device = "cuda" |
| print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}") |
| else: |
| device = "cpu" |
| print("No GPU available, using CPU") |
|
|
|
|
| def cleanup_gpu(): |
| """Clean up GPU memory.""" |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
|
|
| def extract_metadata(file): |
| if file is None: |
| return "", None, None, None, None, None |
|
|
| |
| file_path = file.name if hasattr(file, "name") else str(file) |
| file_extension = os.path.splitext(file_path)[1].lower() |
| is_image = file_extension in [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"] |
|
|
| if is_image: |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video: |
| tmp_path = tmp_video.name |
|
|
| metadata = image_to_video(file_path, tmp_path, fps=1.0) |
|
|
| total_frames = metadata["frames"] |
| fps = metadata["fps"] |
| original_height = metadata["height"] |
| original_width = metadata["width"] |
| info_text = f"{original_width}×{original_height} | Image (1 frame)" |
| else: |
| tmp_path = file_path |
|
|
| container = av.open(tmp_path) |
| video_stream = container.streams.video[0] |
| total_frames = video_stream.frames |
| fps = float(video_stream.average_rate) |
| original_height = video_stream.height |
| original_width = video_stream.width |
| container.close() |
| info_text = f"{original_width}×{original_height} | {total_frames} frames @ {fps:.1f} FPS" |
|
|
| return info_text, tmp_path, total_frames, fps, original_width, original_height |
|
|
|
|
| def handle_file_upload(file): |
| metadata = extract_metadata(file) |
|
|
| if metadata[1] is None: |
| return "", None, None |
|
|
| info_text, tmp_path, total_frames, fps, original_width, original_height = metadata |
| return info_text, metadata, fps |
|
|
|
|
| def _process_video_impl( |
| file_info, gazing_ratio, task_loss_requirement, output_fps, progress=None |
| ): |
| if file_info is None: |
| return None, None, None, None, None, None, None, "No file uploaded" |
|
|
| _, tmp_path, total_frames, fps, _, _ = file_info |
|
|
| if tmp_path is None: |
| return None, None, None, None, None, None, None, "Invalid file" |
|
|
| |
| yield None, None, None, None, None, None, None, "Loading model..." |
|
|
| if progress: |
| progress(0.0, desc="Loading model...") |
| setup = get_model(device) |
|
|
| yield None, None, None, None, None, None, None, "Processing video..." |
|
|
| if progress: |
| progress(0.1, desc="Processing video...") |
|
|
| status_messages = [] |
|
|
| def update_progress(pct, msg): |
| if progress: |
| progress(pct, desc=msg) |
| status_messages.append(msg) |
|
|
| |
| |
| |
| model_gazing_ratio = gazing_ratio * (196 / 265) |
|
|
| for results in process_video( |
| tmp_path, |
| setup, |
| gazing_ratio=model_gazing_ratio, |
| task_loss_requirement=task_loss_requirement, |
| progress_callback=update_progress, |
| spatial_batch_size=2, |
| ): |
| if status_messages: |
| yield None, None, None, None, None, None, None, status_messages[-1] |
|
|
| yield None, None, None, None, None, None, None, "Saving output videos..." |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| original_path = os.path.join(tmpdir, "original.mp4") |
| gazing_path = os.path.join(tmpdir, "gazing.mp4") |
| recon_path = os.path.join(tmpdir, "reconstruction.mp4") |
| scales_stitch_path = os.path.join(tmpdir, "scales_stitch.mp4") |
|
|
| |
| fps_to_use = output_fps if output_fps is not None else results["fps"] |
|
|
| save_video(results["original_frames"], original_path, fps_to_use) |
| save_video(results["gazing_frames"], gazing_path, fps_to_use) |
| save_video(results["reconstruction_frames"], recon_path, fps_to_use) |
| save_video(results["scales_stitch_frames"], scales_stitch_path, fps_to_use) |
|
|
| with open(original_path, "rb") as f: |
| original_data = f.read() |
| with open(gazing_path, "rb") as f: |
| gazing_data = f.read() |
| with open(recon_path, "rb") as f: |
| recon_data = f.read() |
| with open(scales_stitch_path, "rb") as f: |
| scales_stitch_data = f.read() |
|
|
| original_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") |
| original_file.write(original_data) |
| original_file.close() |
|
|
| gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") |
| gazing_file.write(gazing_data) |
| gazing_file.close() |
|
|
| recon_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") |
| recon_file.write(recon_data) |
| recon_file.close() |
|
|
| scales_stitch_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") |
| scales_stitch_file.write(scales_stitch_data) |
| scales_stitch_file.close() |
|
|
| gazing_pct_text = f"{results['gazing_pct']:.2%}" |
| gazing_tokens_text = f"{results['total_gazing_tokens']:,}" |
| total_tokens_text = f"{results['total_possible_tokens']:,}" |
|
|
| yield ( |
| gazing_pct_text, |
| gazing_tokens_text, |
| total_tokens_text, |
| original_file.name, |
| gazing_file.name, |
| recon_file.name, |
| scales_stitch_file.name, |
| "Processing complete!", |
| ) |
|
|
|
|
| if ZEROGPU_AVAILABLE: |
| process_video_ui = spaces.GPU(duration=120)(_process_video_impl) |
| else: |
| process_video_ui = _process_video_impl |
|
|
|
|
| def _process_video_api_impl( |
| file_path, gazing_ratio, task_loss_requirement, output_fps, progress=None |
| ): |
| """API-friendly endpoint that takes a file path string instead of gr.File. |
| Returns the gazing video as a gr.File output for proper file serving.""" |
| if not file_path or not os.path.exists(file_path): |
| raise gr.Error("file not found") |
|
|
| metadata = extract_metadata(file_path) |
| if metadata[1] is None: |
| raise gr.Error("could not read file") |
|
|
| _, tmp_path, total_frames, fps, _, _ = metadata |
|
|
| yield gr.update() |
|
|
| if progress: |
| progress(0.0, desc="Loading model...") |
| setup = get_model(device) |
|
|
| yield gr.update() |
|
|
| if progress: |
| progress(0.1, desc="Processing video...") |
|
|
| def update_progress(pct, msg): |
| if progress: |
| progress(pct, desc=msg) |
|
|
| model_gazing_ratio = gazing_ratio * (196 / 265) |
|
|
| for results in process_video( |
| tmp_path, |
| setup, |
| gazing_ratio=model_gazing_ratio, |
| task_loss_requirement=task_loss_requirement, |
| progress_callback=update_progress, |
| spatial_batch_size=2, |
| ): |
| yield gr.update() |
|
|
| fps_to_use = output_fps if output_fps is not None else results["fps"] |
|
|
| gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") |
| save_video(results["gazing_frames"], gazing_file.name, fps_to_use) |
| gazing_file.close() |
|
|
| yield gr.File(value=gazing_file.name) |
|
|
|
|
| if ZEROGPU_AVAILABLE: |
| process_video_api = spaces.GPU(duration=120)(_process_video_api_impl) |
| else: |
| process_video_api = _process_video_api_impl |
|
|
|
|
| def extract_first_frame_thumbnail( |
| video_path, output_path, size=(200, 200), force=False |
| ): |
| """Extract first frame from video and save as thumbnail with fixed aspect ratio.""" |
| if os.path.exists(output_path) and not force: |
| return |
| container = av.open(video_path) |
| for frame in container.decode(video=0): |
| img = frame.to_image() |
| |
| width, height = img.size |
| min_dim = min(width, height) |
| left = (width - min_dim) // 2 |
| top = (height - min_dim) // 2 |
| img_cropped = img.crop((left, top, left + min_dim, top + min_dim)) |
| img_resized = img_cropped.resize(size, Image.LANCZOS) |
| img_resized.save(output_path) |
| break |
| container.close() |
|
|
|
|
| |
| example_videos = [ |
| "example_inputs/doorbell.mp4", |
| "example_inputs/tomjerry.mp4", |
| "example_inputs/security.mp4", |
| ] |
|
|
| for video_path in example_videos: |
| if os.path.exists(video_path): |
| thumb_path = video_path.replace(".mp4", "_thumb.png") |
| |
| extract_first_frame_thumbnail( |
| video_path, thumb_path, size=(100, 100), force=True |
| ) |
|
|
| |
| doorbell_thumb_img = np.array(Image.open("example_inputs/doorbell_thumb.png")) |
| tomjerry_thumb_img = np.array(Image.open("example_inputs/tomjerry_thumb.png")) |
| security_thumb_img = np.array(Image.open("example_inputs/security_thumb.png")) |
|
|
| with gr.Blocks(title="AutoGaze Demo", delete_cache=(86400, 86400)) as demo: |
| gr.Markdown("# AutoGaze Official Demo") |
| gr.Markdown( |
| "## **Attend Before Attention: Efficient and Scalable Video Understanding via Autoregressive Gazing**" |
| ) |
| gr.Markdown(""" |
| <div style="text-align: left; margin: 10px 0; font-size: 1.2em; font-weight: 600;"> |
| 📄 <a href="https://arxiv.org/abs/2603.12254" target="_blank" style="text-decoration: none; color: inherit;">Paper</a> 🌐 <a href="https://autogaze.github.io" target="_blank" style="text-decoration: none; color: inherit;">Project Website</a> |
| </div> |
| """) |
|
|
| file_metadata = gr.State() |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| uploaded_file = gr.File( |
| label="Upload Video or Image", file_types=["video", "image"] |
| ) |
| with gr.Column(scale=1): |
| file_info = gr.Textbox(label="File Info", interactive=False) |
| process_button = gr.Button("Process Video", variant="primary") |
|
|
| def load_example_video(evt: gr.SelectData): |
| video_map = { |
| 0: "example_inputs/doorbell.mp4", |
| 1: "example_inputs/tomjerry.mp4", |
| 2: "example_inputs/security.mp4", |
| } |
| return video_map[evt.index] |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Example Videos - Click Thumbnail to Load") |
| example_gallery = gr.Gallery( |
| value=[ |
| (doorbell_thumb_img, "doorbell.mp4"), |
| (tomjerry_thumb_img, "tomjerry.mp4"), |
| (security_thumb_img, "security.mp4"), |
| ], |
| label="", |
| show_label=False, |
| columns=3, |
| rows=1, |
| height=200, |
| object_fit="contain", |
| allow_preview=False, |
| ) |
| gr.Markdown("### Settings") |
|
|
| with gr.Accordion("Output Settings", open=True): |
| fps_slider = gr.Number( |
| label="Output FPS", |
| value=None, |
| minimum=1, |
| maximum=120, |
| info="Frames per second for displaying output videos (only affects playback speed)", |
| ) |
|
|
| with gr.Accordion("Model Parameters", open=True): |
| gazing_ratio_slider = gr.Slider( |
| label="Gazing Ratio", |
| minimum=round(1 / 196, 2), |
| maximum=round(265 / 196, 2), |
| step=0.01, |
| value=0.75, |
| info="Max fraction of patches to gaze at per frame", |
| ) |
| task_loss_slider = gr.Slider( |
| label="Task Loss Requirement", |
| minimum=0.0, |
| maximum=1.5, |
| step=0.05, |
| value=0.7, |
| info="Reconstruction loss threshold", |
| ) |
|
|
| with gr.Accordion("FAQ", open=False): |
| gr.Markdown(""" |
| **What file formats are supported?** |
| |
| The app supports common video formats (MP4, AVI, MOV, etc.) and image formats (JPG, PNG, etc.). |
| |
| **What is the Gazing Ratio?** |
| |
| The gazing ratio explicitly controls how many patches the model looks at per frame. Higher values mean more patches are selected. The range extends to past 1.0 because of multi-scale gazing; if all patches at all scales are selected, the ratio can reach up to 1.35. |
| |
| **What is Task Loss Requirement?** |
| |
| This threshold determines when the model stops gazing at a frame, based on the predicted reconstruction loss from the current gazed patches. Lower = more gazing, higher = less gazing. |
| |
| **How do Gazing Ratio and Task Loss interact?** |
| |
| These two parameters separately control the number of gazed patches in an image/video. This demo will take the stricter of the two requirements when determining how many patches to gaze at. For example, if the gazing ratio suggests gazing at 15% of patches, but the task loss requirement is met after only 7% patches, then only 7% patches will be gazed at. To only use one of the two parameters, set the other to its maximum value. |
| """) |
|
|
| with gr.Column(scale=2): |
| gr.Markdown("### Results") |
|
|
| status_text = gr.Markdown("Ready") |
|
|
| with gr.Row(): |
| gazing_pct = gr.Textbox(label="Gazing %", interactive=False) |
| gazing_tokens = gr.Textbox(label="# Gazed Patches", interactive=False) |
| total_tokens = gr.Textbox(label="Total Patches", interactive=False) |
|
|
| with gr.Row(): |
| original_video = gr.Video(label="Original", autoplay=False, loop=True) |
| gazing_video = gr.Video( |
| label="Gazing Pattern (all scales)", autoplay=False, loop=True |
| ) |
| reconstruction_video = gr.Video( |
| label="Reconstruction", autoplay=False, loop=True |
| ) |
|
|
| with gr.Row(): |
| scales_stitch_video = gr.Video( |
| label="Gazing Pattern (individual scales)", |
| autoplay=False, |
| loop=True, |
| ) |
|
|
| example_gallery.select(load_example_video, outputs=uploaded_file) |
| uploaded_file.change( |
| fn=handle_file_upload, |
| inputs=[uploaded_file], |
| outputs=[file_info, file_metadata, fps_slider], |
| ) |
|
|
| process_button.click( |
| fn=process_video_ui, |
| inputs=[file_metadata, gazing_ratio_slider, task_loss_slider, fps_slider], |
| outputs=[ |
| gazing_pct, |
| gazing_tokens, |
| total_tokens, |
| original_video, |
| gazing_video, |
| reconstruction_video, |
| scales_stitch_video, |
| status_text, |
| ], |
| ).then(fn=cleanup_gpu, inputs=None, outputs=None) |
|
|
| |
| with gr.Tab("API", visible=False): |
| api_file_path = gr.Textbox(label="File Path") |
| api_gazing_ratio = gr.Slider( |
| minimum=round(1 / 196, 2), |
| maximum=round(265 / 196, 2), |
| step=0.01, |
| value=0.75, |
| label="Gazing Ratio", |
| ) |
| api_task_loss = gr.Slider( |
| minimum=0.0, |
| maximum=1.5, |
| step=0.05, |
| value=0.7, |
| label="Task Loss Requirement", |
| ) |
| api_output_fps = gr.Number(label="Output FPS", value=None) |
| api_button = gr.Button("Process (API)") |
| api_result = gr.File(label="Result") |
| api_button.click( |
| fn=process_video_api, |
| inputs=[api_file_path, api_gazing_ratio, api_task_loss, api_output_fps], |
| outputs=[api_result], |
| api_name="process_video_api", |
| ) |
|
|
| |
| demo.unload(cleanup_gpu) |
|
|
| |
| print("Clearing model cache and GPU memory at startup...") |
| model_cache.clear() |
| cleanup_gpu() |
| print("Startup cleanup complete.") |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True, show_error=True) |
|
|