taesiri commited on
Commit
0edcc33
1 Parent(s): 30bb4fd
Files changed (1) hide show
  1. app.py +86 -22
app.py CHANGED
@@ -25,7 +25,9 @@ from huggingface_hub import (
25
  from PIL import Image
26
 
27
  cached_latest_posts_df = None
 
28
  last_fetched = None
 
29
 
30
  import os
31
  import tempfile
@@ -37,7 +39,24 @@ from decord import VideoReader
37
  from decord import cpu
38
 
39
 
40
- def download_samples(video_url, num_frames):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  frames = extract_frames_decord(video_url, num_frames)
42
 
43
  # Create a temporary directory to store the images
@@ -50,7 +69,9 @@ def download_samples(video_url, num_frames):
50
  ) # Adjust quality as needed
51
 
52
  # Create a zip file in a persistent location
53
- zip_path = "frames.zip"
 
 
54
  with ZipFile(zip_path, "w") as zipf:
55
  for i in range(num_frames):
56
  frame_path = os.path.join(temp_dir, f"frame_{i}.jpg")
@@ -91,7 +112,43 @@ def extract_frames_decord(video_path, num_frames=10):
91
  raise Exception(f"Error extracting frames from video: {e}")
92
 
93
 
94
- def get_latest_pots():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  global cached_latest_posts_df
96
  global last_fetched
97
 
@@ -129,25 +186,26 @@ def get_latest_pots():
129
 
130
  def row_selected(evt: gr.SelectData):
131
  global cached_latest_posts_df
132
- row = evt.index[0]
133
- post_id = cached_latest_posts_df.iloc[row]["post_id"]
134
- return post_id
135
 
 
 
 
 
136
 
137
- def load_video(url):
138
- # Regular expression pattern for r/GamePhysics URLs and IDs
139
- pattern = r"https://www\.reddit\.com/r/GamePhysics/comments/([0-9a-zA-Z]+).*|([0-9a-zA-Z]+)"
 
 
 
140
 
141
- # Match the URL or ID against the pattern
142
- match = re.match(pattern, url)
143
 
144
- if match:
145
- # Extract the post ID from the URL
146
- post_id = match.group(1) or match.group(2)
147
- print(f"Valid GamePhysics post ID: {post_id}")
148
- else:
149
- post_id = url
150
 
 
 
151
  video_url = f"https://huggingface.co/datasets/asgaardlab/GamePhysicsDailyDump/resolve/main/data/videos/{post_id}.mp4?download=true"
152
 
153
  # make sure file exists before returning, make a request without downloading the file
@@ -175,12 +233,13 @@ with gr.Blocks() as demo:
175
  with gr.Column():
176
  gr.Markdown("## Latest Posts")
177
  latest_post_dataframe = gr.Dataframe()
178
- get_latest_pots_btn = gr.Button("Refresh Latest Posts")
 
179
 
180
  with gr.Column():
181
  gr.Markdown("## Sampled Frames from Video")
182
  with gr.Row():
183
- num_frames = gr.Slider(minimum=1, maximum=20, step=1, value=10)
184
  sample_decord_btn = gr.Button("Sample decord")
185
 
186
  sampled_frames = gr.Gallery()
@@ -189,7 +248,9 @@ with gr.Blocks() as demo:
189
  output_files = gr.File()
190
 
191
  download_samples_btn.click(
192
- download_samples, inputs=[video_player, num_frames], outputs=[output_files]
 
 
193
  )
194
 
195
  sample_decord_btn.click(
@@ -199,8 +260,11 @@ with gr.Blocks() as demo:
199
  )
200
 
201
  load_btn.click(load_video, inputs=[reddit_id], outputs=[video_player])
202
- get_latest_pots_btn.click(get_latest_pots, outputs=[latest_post_dataframe])
203
- demo.load(get_latest_pots, outputs=[latest_post_dataframe])
 
 
 
204
 
205
  latest_post_dataframe.select(fn=row_selected, outputs=[reddit_id]).then(
206
  load_video, inputs=[reddit_id], outputs=[video_player]
 
25
  from PIL import Image
26
 
27
  cached_latest_posts_df = None
28
+ cached_top_posts = None
29
  last_fetched = None
30
+ last_fetched_top = None
31
 
32
  import os
33
  import tempfile
 
39
  from decord import cpu
40
 
41
 
42
+ def get_reddit_id(url):
43
+ # Regular expression pattern for r/GamePhysics URLs and IDs
44
+ pattern = r"https://www\.reddit\.com/r/GamePhysics/comments/([0-9a-zA-Z]+).*|([0-9a-zA-Z]+)"
45
+
46
+ # Match the URL or ID against the pattern
47
+ match = re.match(pattern, url)
48
+
49
+ if match:
50
+ # Extract the post ID from the URL
51
+ post_id = match.group(1) or match.group(2)
52
+ print(f"Valid GamePhysics post ID: {post_id}")
53
+ else:
54
+ post_id = url
55
+
56
+ return post_id
57
+
58
+
59
+ def download_samples(url, video_url, num_frames):
60
  frames = extract_frames_decord(video_url, num_frames)
61
 
62
  # Create a temporary directory to store the images
 
69
  ) # Adjust quality as needed
70
 
71
  # Create a zip file in a persistent location
72
+ post_id = get_reddit_id(url)
73
+ print(f"Creating zip file for post {post_id}")
74
+ zip_path = f"frames-{post_id}.zip"
75
  with ZipFile(zip_path, "w") as zipf:
76
  for i in range(num_frames):
77
  frame_path = os.path.join(temp_dir, f"frame_{i}.jpg")
 
112
  raise Exception(f"Error extracting frames from video: {e}")
113
 
114
 
115
+ def get_top_posts():
116
+ global cached_top_posts
117
+ global last_fetched_top
118
+
119
+ # make sure we don't fetch data too often, limit to 1 request per 10 minutes
120
+ now_time = datetime.now()
121
+ if last_fetched_top is not None and (now_time - last_fetched_top).seconds < 600:
122
+ print("Using cached data")
123
+ return cached_top_posts
124
+
125
+ last_fetched_top = now_time
126
+ url = "https://www.reddit.com/r/GamePhysics/top/.json?t=month"
127
+ headers = {"User-Agent": "Mozilla/5.0"}
128
+
129
+ response = requests.get(url, headers=headers)
130
+ if response.status_code != 200:
131
+ return []
132
+
133
+ data = response.json()
134
+
135
+ # Extract posts from the data
136
+ posts = data["data"]["children"]
137
+
138
+ for post in posts:
139
+ title = post["data"]["title"]
140
+ post_id = post["data"]["id"]
141
+ # print(f"ID: {post_id}, Title: {title}")
142
+
143
+ # create [post_id, title] list
144
+ examples = [[post["data"]["id"], post["data"]["title"]] for post in posts]
145
+ # make a dataframe
146
+ examples = pd.DataFrame(examples, columns=["post_id", "title"])
147
+ cached_top_posts = examples
148
+ return examples
149
+
150
+
151
+ def get_latest_posts():
152
  global cached_latest_posts_df
153
  global last_fetched
154
 
 
186
 
187
  def row_selected(evt: gr.SelectData):
188
  global cached_latest_posts_df
189
+ global cached_top_posts
 
 
190
 
191
+ # find which dataframe was selected
192
+ string_value = evt.value
193
+ row = evt.index[0]
194
+ target_df = None
195
 
196
+ if cached_latest_posts_df.isin([string_value]).any().any():
197
+ target_df = cached_latest_posts_df
198
+ elif cached_top_posts.isin([string_value]).any().any():
199
+ target_df = cached_top_posts
200
+ else:
201
+ raise gr.Error("Could not find selected post in any dataframe")
202
 
203
+ post_id = target_df.iloc[row]["post_id"]
204
+ return post_id
205
 
 
 
 
 
 
 
206
 
207
+ def load_video(url):
208
+ post_id = get_reddit_id(url)
209
  video_url = f"https://huggingface.co/datasets/asgaardlab/GamePhysicsDailyDump/resolve/main/data/videos/{post_id}.mp4?download=true"
210
 
211
  # make sure file exists before returning, make a request without downloading the file
 
233
  with gr.Column():
234
  gr.Markdown("## Latest Posts")
235
  latest_post_dataframe = gr.Dataframe()
236
+ latest_posts_btn = gr.Button("Refresh Latest Posts")
237
+ top_posts_btn = gr.Button("Refresh Top Posts")
238
 
239
  with gr.Column():
240
  gr.Markdown("## Sampled Frames from Video")
241
  with gr.Row():
242
+ num_frames = gr.Slider(minimum=1, maximum=60, step=1, value=10)
243
  sample_decord_btn = gr.Button("Sample decord")
244
 
245
  sampled_frames = gr.Gallery()
 
248
  output_files = gr.File()
249
 
250
  download_samples_btn.click(
251
+ download_samples,
252
+ inputs=[reddit_id, video_player, num_frames],
253
+ outputs=[output_files],
254
  )
255
 
256
  sample_decord_btn.click(
 
260
  )
261
 
262
  load_btn.click(load_video, inputs=[reddit_id], outputs=[video_player])
263
+
264
+ latest_posts_btn.click(get_latest_posts, outputs=[latest_post_dataframe])
265
+ top_posts_btn.click(get_top_posts, outputs=[latest_post_dataframe])
266
+
267
+ demo.load(get_latest_posts, outputs=[latest_post_dataframe])
268
 
269
  latest_post_dataframe.select(fn=row_selected, outputs=[reddit_id]).then(
270
  load_video, inputs=[reddit_id], outputs=[video_player]