import json import os import random import re import sys import time from datetime import datetime from glob import glob from pathlib import Path from typing import List, Optional from uuid import uuid4 import gradio as gr import numpy as np import pandas as pd import requests from datasets import load_dataset from huggingface_hub import ( CommitScheduler, HfApi, InferenceClient, login, snapshot_download, ) from PIL import Image cached_latest_posts_df = None last_fetched = None import os import tempfile from zipfile import ZipFile import numpy as np from PIL import Image from decord import VideoReader from decord import cpu def download_samples(video_url, num_frames): frames = extract_frames_decord(video_url, num_frames) # Create a temporary directory to store the images with tempfile.TemporaryDirectory() as temp_dir: # Save each frame as a JPEG image in the temporary directory for i, frame in enumerate(frames): frame_path = os.path.join(temp_dir, f"frame_{i}.jpg") frame.save( frame_path, format="JPEG", quality=85 ) # Adjust quality as needed # Create a zip file in a persistent location zip_path = "frames.zip" with ZipFile(zip_path, "w") as zipf: for i in range(num_frames): frame_path = os.path.join(temp_dir, f"frame_{i}.jpg") zipf.write(frame_path, os.path.basename(frame_path)) # Return the path of the zip file return zip_path def extract_frames_decord(video_path, num_frames=10): try: start_time = time.time() print(f"Extracting {num_frames} frames from {video_path}") # Load the video vr = VideoReader(video_path, ctx=cpu(0)) # Calculate the indices for the frames to be extracted total_frames = len(vr) frame_indices = np.linspace( 0, total_frames - 1, num_frames, dtype=int, endpoint=False ) # Extract frames batch_frames = vr.get_batch(frame_indices).asnumpy() # Convert frames to PIL Images frame_images = [ Image.fromarray(batch_frames[i]) for i in range(batch_frames.shape[0]) ] end_time = time.time() print(f"Decord extraction took {end_time - start_time} seconds") return frame_images except Exception as e: raise Exception(f"Error extracting frames from video: {e}") def get_latest_pots(): global cached_latest_posts_df global last_fetched # make sure we don't fetch data too often, limit to 1 request per 10 minutes now_time = datetime.now() if last_fetched is not None and (now_time - last_fetched).seconds < 600: print("Using cached data") return cached_latest_posts_df last_fetched = now_time url = "https://www.reddit.com/r/GamePhysics/.json" headers = {"User-Agent": "Mozilla/5.0"} response = requests.get(url, headers=headers) if response.status_code != 200: return [] data = response.json() # Extract posts from the data posts = data["data"]["children"] for post in posts: title = post["data"]["title"] post_id = post["data"]["id"] # print(f"ID: {post_id}, Title: {title}") # create [post_id, title] list examples = [[post["data"]["id"], post["data"]["title"]] for post in posts] # make a dataframe examples = pd.DataFrame(examples, columns=["post_id", "title"]) cached_latest_posts_df = examples return examples def row_selected(evt: gr.SelectData): global cached_latest_posts_df row = evt.index[0] post_id = cached_latest_posts_df.iloc[row]["post_id"] return post_id def load_video(url): # Regular expression pattern for r/GamePhysics URLs and IDs pattern = r"https://www\.reddit\.com/r/GamePhysics/comments/([0-9a-zA-Z]+).*|([0-9a-zA-Z]+)" # Match the URL or ID against the pattern match = re.match(pattern, url) if match: # Extract the post ID from the URL post_id = match.group(1) or match.group(2) print(f"Valid GamePhysics post ID: {post_id}") else: post_id = url video_url = f"https://huggingface.co/datasets/asgaardlab/GamePhysicsDailyDump/resolve/main/data/videos/{post_id}.mp4?download=true" # make sure file exists before returning, make a request without downloading the file r = requests.head(video_url) if r.status_code != 200 and r.status_code != 302: raise gr.Error( f"Video is not in the repo, please try another post. - {r.status_code }" ) return video_url with gr.Blocks() as demo: gr.Markdown("## Preview GamePhysics") dummt_title = gr.Textbox(visible=False) with gr.Row(): with gr.Column(): reddit_id = gr.Textbox( lines=1, placeholder="Post url or id here", label="URL or Post ID" ) load_btn = gr.Button("Load") video_player = gr.Video(interactive=False) with gr.Column(): gr.Markdown("## Latest Posts") latest_post_dataframe = gr.Dataframe() get_latest_pots_btn = gr.Button("Refresh Latest Posts") with gr.Column(): gr.Markdown("## Sampled Frames from Video") with gr.Row(): num_frames = gr.Slider(minimum=1, maximum=20, step=1, value=10) sample_decord_btn = gr.Button("Sample decord") sampled_frames = gr.Gallery() download_samples_btn = gr.Button("Download Samples") output_files = gr.File() download_samples_btn.click( download_samples, inputs=[video_player, num_frames], outputs=[output_files] ) sample_decord_btn.click( extract_frames_decord, inputs=[video_player, num_frames], outputs=[sampled_frames], ) load_btn.click(load_video, inputs=[reddit_id], outputs=[video_player]) get_latest_pots_btn.click(get_latest_pots, outputs=[latest_post_dataframe]) demo.load(get_latest_pots, outputs=[latest_post_dataframe]) latest_post_dataframe.select(fn=row_selected, outputs=[reddit_id]).then( load_video, inputs=[reddit_id], outputs=[video_player] ) demo.launch()