Spaces:
Sleeping
Sleeping
import json | |
import os | |
import random | |
import re | |
import sys | |
import time | |
from datetime import datetime | |
from glob import glob | |
from pathlib import Path | |
from typing import List, Optional | |
from uuid import uuid4 | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import requests | |
from datasets import load_dataset | |
from huggingface_hub import ( | |
CommitScheduler, | |
HfApi, | |
InferenceClient, | |
login, | |
snapshot_download, | |
) | |
from PIL import Image | |
cached_latest_posts_df = None | |
cached_top_posts = None | |
last_fetched = None | |
last_fetched_top = None | |
import os | |
import tempfile | |
from zipfile import ZipFile | |
import numpy as np | |
from PIL import Image | |
from decord import VideoReader | |
from decord import cpu | |
def get_reddit_id(url): | |
# Regular expression pattern for r/GamePhysics URLs and IDs | |
pattern = r"https://www\.reddit\.com/r/GamePhysics/comments/([0-9a-zA-Z]+).*|([0-9a-zA-Z]+)" | |
# Match the URL or ID against the pattern | |
match = re.match(pattern, url) | |
if match: | |
# Extract the post ID from the URL | |
post_id = match.group(1) or match.group(2) | |
print(f"Valid GamePhysics post ID: {post_id}") | |
else: | |
post_id = url | |
return post_id | |
def download_samples(url, video_url, num_frames): | |
frames = extract_frames_decord(video_url, num_frames) | |
# Create a temporary directory to store the images | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Save each frame as a JPEG image in the temporary directory | |
for i, frame in enumerate(frames): | |
frame_path = os.path.join(temp_dir, f"frame_{i}.jpg") | |
frame.save( | |
frame_path, format="JPEG", quality=85 | |
) # Adjust quality as needed | |
# Create a zip file in a persistent location | |
post_id = get_reddit_id(url) | |
print(f"Creating zip file for post {post_id}") | |
zip_path = f"frames-{post_id}.zip" | |
with ZipFile(zip_path, "w") as zipf: | |
for i in range(num_frames): | |
frame_path = os.path.join(temp_dir, f"frame_{i}.jpg") | |
zipf.write(frame_path, os.path.basename(frame_path)) | |
# Return the path of the zip file | |
return zip_path | |
def extract_frames_decord(video_path, num_frames=10): | |
try: | |
start_time = time.time() | |
print(f"Extracting {num_frames} frames from {video_path}") | |
# Load the video | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
# Calculate the indices for the frames to be extracted | |
total_frames = len(vr) | |
frame_indices = np.linspace( | |
0, total_frames - 1, num_frames, dtype=int, endpoint=False | |
) | |
# Extract frames | |
batch_frames = vr.get_batch(frame_indices).asnumpy() | |
# Convert frames to PIL Images | |
frame_images = [ | |
Image.fromarray(batch_frames[i]) for i in range(batch_frames.shape[0]) | |
] | |
end_time = time.time() | |
print(f"Decord extraction took {end_time - start_time} seconds") | |
return frame_images | |
except Exception as e: | |
raise Exception(f"Error extracting frames from video: {e}") | |
def get_top_posts(): | |
global cached_top_posts | |
global last_fetched_top | |
# make sure we don't fetch data too often, limit to 1 request per 10 minutes | |
now_time = datetime.now() | |
if last_fetched_top is not None and (now_time - last_fetched_top).seconds < 600: | |
print("Using cached data") | |
return cached_top_posts | |
last_fetched_top = now_time | |
url = "https://www.reddit.com/r/GamePhysics/top/.json?t=month" | |
headers = {"User-Agent": "Mozilla/5.0"} | |
response = requests.get(url, headers=headers) | |
if response.status_code != 200: | |
return [] | |
data = response.json() | |
# Extract posts from the data | |
posts = data["data"]["children"] | |
for post in posts: | |
title = post["data"]["title"] | |
post_id = post["data"]["id"] | |
# print(f"ID: {post_id}, Title: {title}") | |
# create [post_id, title] list | |
examples = [[post["data"]["id"], post["data"]["title"]] for post in posts] | |
# make a dataframe | |
examples = pd.DataFrame(examples, columns=["post_id", "title"]) | |
cached_top_posts = examples | |
return examples | |
def get_latest_posts(): | |
global cached_latest_posts_df | |
global last_fetched | |
# make sure we don't fetch data too often, limit to 1 request per 10 minutes | |
now_time = datetime.now() | |
if last_fetched is not None and (now_time - last_fetched).seconds < 600: | |
print("Using cached data") | |
return cached_latest_posts_df | |
last_fetched = now_time | |
url = "https://www.reddit.com/r/GamePhysics/.json" | |
headers = {"User-Agent": "Mozilla/5.0"} | |
response = requests.get(url, headers=headers) | |
if response.status_code != 200: | |
return [] | |
data = response.json() | |
# Extract posts from the data | |
posts = data["data"]["children"] | |
for post in posts: | |
title = post["data"]["title"] | |
post_id = post["data"]["id"] | |
# print(f"ID: {post_id}, Title: {title}") | |
# create [post_id, title] list | |
examples = [[post["data"]["id"], post["data"]["title"]] for post in posts] | |
# make a dataframe | |
examples = pd.DataFrame(examples, columns=["post_id", "title"]) | |
cached_latest_posts_df = examples | |
return examples | |
def row_selected(evt: gr.SelectData): | |
global cached_latest_posts_df | |
global cached_top_posts | |
# find which dataframe was selected | |
string_value = evt.value | |
row = evt.index[0] | |
target_df = None | |
if cached_latest_posts_df.isin([string_value]).any().any(): | |
target_df = cached_latest_posts_df | |
elif cached_top_posts.isin([string_value]).any().any(): | |
target_df = cached_top_posts | |
else: | |
raise gr.Error("Could not find selected post in any dataframe") | |
post_id = target_df.iloc[row]["post_id"] | |
return post_id | |
def load_video(url): | |
post_id = get_reddit_id(url) | |
video_url = f"https://huggingface.co/datasets/asgaardlab/GamePhysicsDailyDump/resolve/main/data/videos/{post_id}.mp4?download=true" | |
# make sure file exists before returning, make a request without downloading the file | |
r = requests.head(video_url) | |
if r.status_code != 200 and r.status_code != 302: | |
raise gr.Error( | |
f"Video is not in the repo, please try another post. - {r.status_code }" | |
) | |
return video_url | |
with gr.Blocks() as demo: | |
gr.Markdown("## Preview GamePhysics") | |
dummt_title = gr.Textbox(visible=False) | |
with gr.Row(): | |
with gr.Column(): | |
reddit_id = gr.Textbox( | |
lines=1, placeholder="Post url or id here", label="URL or Post ID" | |
) | |
load_btn = gr.Button("Load") | |
video_player = gr.Video(interactive=False) | |
with gr.Column(): | |
gr.Markdown("## Latest Posts") | |
latest_post_dataframe = gr.Dataframe() | |
latest_posts_btn = gr.Button("Refresh Latest Posts") | |
top_posts_btn = gr.Button("Refresh Top Posts") | |
with gr.Column(): | |
gr.Markdown("## Sampled Frames from Video") | |
with gr.Row(): | |
num_frames = gr.Slider(minimum=1, maximum=60, step=1, value=10) | |
sample_decord_btn = gr.Button("Sample decord") | |
sampled_frames = gr.Gallery() | |
download_samples_btn = gr.Button("Download Samples") | |
output_files = gr.File() | |
download_samples_btn.click( | |
download_samples, | |
inputs=[reddit_id, video_player, num_frames], | |
outputs=[output_files], | |
) | |
sample_decord_btn.click( | |
extract_frames_decord, | |
inputs=[video_player, num_frames], | |
outputs=[sampled_frames], | |
) | |
load_btn.click(load_video, inputs=[reddit_id], outputs=[video_player]) | |
latest_posts_btn.click(get_latest_posts, outputs=[latest_post_dataframe]) | |
top_posts_btn.click(get_top_posts, outputs=[latest_post_dataframe]) | |
demo.load(get_latest_posts, outputs=[latest_post_dataframe]) | |
latest_post_dataframe.select(fn=row_selected, outputs=[reddit_id]).then( | |
load_video, inputs=[reddit_id], outputs=[video_player] | |
) | |
demo.launch() | |