import json import os import random import re import sys import tempfile import time from datetime import datetime from glob import glob from pathlib import Path from typing import List, Optional from uuid import uuid4 from zipfile import ZipFile import gradio as gr import numpy as np import pandas as pd import requests from datasets import load_dataset from decord import VideoReader, cpu from huggingface_hub import ( CommitScheduler, HfApi, InferenceClient, login, snapshot_download, ) from PIL import Image import concurrent.futures cached_latest_posts_df = None cached_top_posts = None last_fetched = None last_fetched_top = None def resize_image(image): width, height = image.size new_width = width * 0.35 new_height = height * 0.35 return image.resize((int(new_width), int(new_height)), Image.BILINEAR) def get_reddit_id(url): # Regular expression pattern for r/GamePhysics URLs and IDs pattern = r"https://www\.reddit\.com/r/GamePhysics/comments/([0-9a-zA-Z]+).*|([0-9a-zA-Z]+)" # Match the URL or ID against the pattern match = re.match(pattern, url) if match: # Extract the post ID from the URL post_id = match.group(1) or match.group(2) print(f"Valid GamePhysics post ID: {post_id}") else: post_id = url return post_id def download_samples(url, video_url, num_frames): frames = extract_frames_decord(video_url, num_frames) # Create a temporary directory to store the images with tempfile.TemporaryDirectory() as temp_dir: # Save each frame as a JPEG image in the temporary directory for i, frame in enumerate(frames): frame_path = os.path.join(temp_dir, f"frame_{i}.jpg") frame.save( frame_path, format="JPEG", quality=85 ) # Adjust quality as needed # Create a zip file in a persistent location post_id = get_reddit_id(url) print(f"Creating zip file for post {post_id}") zip_path = f"frames-{post_id}.zip" with ZipFile(zip_path, "w") as zipf: for i in range(num_frames): frame_path = os.path.join(temp_dir, f"frame_{i}.jpg") zipf.write(frame_path, os.path.basename(frame_path)) # Return the path of the zip file return zip_path def extract_frames_decord(video_path, num_frames=10): try: start_time = time.time() print(f"Extracting {num_frames} frames from {video_path}") # Load the video vr = VideoReader(video_path, ctx=cpu(0)) # Calculate the indices for the frames to be extracted total_frames = len(vr) frame_indices = np.linspace( 0, total_frames - 1, num_frames, dtype=int, endpoint=False ) # Extract frames batch_frames = vr.get_batch(frame_indices).asnumpy() # Convert frames to PIL Images frame_images = [ Image.fromarray(batch_frames[i]) for i in range(batch_frames.shape[0]) ] end_time = time.time() print(f"Decord extraction took {end_time - start_time} seconds") return frame_images except Exception as e: raise Exception(f"Error extracting frames from video: {e}") def extract_frames_decord_preview(video_path, num_frames=10): try: start_time = time.time() print(f"Extracting {num_frames} frames from {video_path}") # Load the video vr = VideoReader(video_path, ctx=cpu(0)) # Calculate the indices for the frames to be extracted total_frames = len(vr) frame_indices = np.linspace( 0, total_frames - 1, num_frames, dtype=int, endpoint=False ) # Extract frames batch_frames = vr.get_batch(frame_indices).asnumpy() # Convert frames to PIL Images frame_images = [ Image.fromarray(batch_frames[i]) for i in range(batch_frames.shape[0]) ] end_time = time.time() print(f"Decord extraction took {end_time - start_time} seconds") # # resize images to save bandwidth, keep aspect ratio # for i, image in enumerate(frame_images): # width, height = image.size # new_width = int(width * 0.35) # new_height = int(height * 0.35) # frame_images[i] = image.resize((new_width, new_height), Image.ANTIALIAS) with concurrent.futures.ThreadPoolExecutor() as executor: frame_images = list(executor.map(resize_image, frame_images)) return frame_images except Exception as e: raise Exception(f"Error extracting frames from video: {e}") def get_top_posts(): global cached_top_posts global last_fetched_top # make sure we don't fetch data too often, limit to 1 request per 10 minutes now_time = datetime.now() if last_fetched_top is not None and (now_time - last_fetched_top).seconds < 600: print("Using cached data") return cached_top_posts last_fetched_top = now_time url = "https://www.reddit.com/r/GamePhysics/top/.json?t=month" headers = {"User-Agent": "Mozilla/5.0"} response = requests.get(url, headers=headers) if response.status_code != 200: return [] data = response.json() # Extract posts from the data posts = data["data"]["children"] for post in posts: title = post["data"]["title"] post_id = post["data"]["id"] # print(f"ID: {post_id}, Title: {title}") # create [post_id, title] list examples = [[post["data"]["id"], post["data"]["title"]] for post in posts] # make a dataframe examples = pd.DataFrame(examples, columns=["post_id", "title"]) cached_top_posts = examples return examples def get_latest_posts(): global cached_latest_posts_df global last_fetched # make sure we don't fetch data too often, limit to 1 request per 10 minutes now_time = datetime.now() if last_fetched is not None and (now_time - last_fetched).seconds < 600: print("Using cached data") return cached_latest_posts_df last_fetched = now_time url = "https://www.reddit.com/r/GamePhysics/.json" headers = {"User-Agent": "Mozilla/5.0"} response = requests.get(url, headers=headers) if response.status_code != 200: return [] data = response.json() # Extract posts from the data posts = data["data"]["children"] for post in posts: title = post["data"]["title"] post_id = post["data"]["id"] # print(f"ID: {post_id}, Title: {title}") # create [post_id, title] list examples = [[post["data"]["id"], post["data"]["title"]] for post in posts] # make a dataframe examples = pd.DataFrame(examples, columns=["post_id", "title"]) cached_latest_posts_df = examples return examples def row_selected_top(evt: gr.SelectData): global cached_top_posts string_value = evt.value row = evt.index[0] post_id = cached_top_posts.iloc[row]["post_id"] return post_id def row_selected_latest(evt: gr.SelectData): global cached_latest_posts_df string_value = evt.value row = evt.index[0] post_id = cached_latest_posts_df.iloc[row]["post_id"] return post_id def load_video(url): post_id = get_reddit_id(url) video_url = f"https://huggingface.co/datasets/asgaardlab/GamePhysicsDailyDump/resolve/main/data/videos/{post_id}.mp4?download=true" video_url2 = f"https://huggingface.co/datasets/asgaardlab/GamePhysics-FullResolution/resolve/main/videos/{post_id}/{post_id}.mp4?download=true" # make sure file exists before returning, make a request without downloading the file r1 = requests.head(video_url) r2 = requests.head(video_url2) if ( r2.status_code != 200 and r2.status_code != 302 and r1.status_code != 200 and r1.status_code != 302 ): raise gr.Error( f"Video is not in the repo, please try another post. - {r1.status_code }" ) if r1.status_code == 200 or r1.status_code == 302: return video_url else: return video_url2 with gr.Blocks() as demo: gr.Markdown("## Preview GamePhysics") dummt_title = gr.Textbox(visible=False) with gr.Row(): with gr.Column(): reddit_id = gr.Textbox( lines=1, placeholder="Post url or id here", label="URL or Post ID" ) load_btn = gr.Button("Load") video_player = gr.Video(interactive=False) with gr.Column(): gr.Markdown("## Sample frames") num_frames = gr.Slider(minimum=1, maximum=60, step=1, value=10) sample_decord_btn = gr.Button("Sample frames") sampled_frames = gr.Gallery(label="Sampled frames preview") download_samples_btn = gr.Button("Download Samples") output_files = gr.File() download_samples_btn.click( download_samples, inputs=[reddit_id, video_player, num_frames], outputs=[output_files], ) with gr.Column(): gr.Markdown("## Reddits Posts") with gr.Tab("Latest Posts"): latest_post_dataframe = gr.Dataframe() latest_posts_btn = gr.Button("Refresh Latest Posts") with gr.Tab("Top Monthly Posts"): top_posts_dataframe = gr.Dataframe() top_posts_btn = gr.Button("Refresh Top Posts") sample_decord_btn.click( extract_frames_decord_preview, inputs=[video_player, num_frames], outputs=[sampled_frames], ) load_btn.click(load_video, inputs=[reddit_id], outputs=[video_player]) latest_posts_btn.click(get_latest_posts, outputs=[latest_post_dataframe]) top_posts_btn.click(get_top_posts, outputs=[top_posts_dataframe]) demo.load(get_latest_posts, outputs=[latest_post_dataframe]) demo.load(get_top_posts, outputs=[top_posts_dataframe]) latest_post_dataframe.select(fn=row_selected_latest, outputs=[reddit_id]).then( load_video, inputs=[reddit_id], outputs=[video_player] ) top_posts_dataframe.select(fn=row_selected_top, outputs=[reddit_id]).then( load_video, inputs=[reddit_id], outputs=[video_player] ) demo.launch()