ojaffe commited on
Commit
7d8154b
·
verified ·
1 Parent(s): ae8a361

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. __pycache__/predict.cpython-311.pyc +0 -0
  2. predict.py +44 -38
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
predict.py CHANGED
@@ -171,16 +171,10 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
171
  predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
172
 
173
  predicted_np = predicted[0].cpu().numpy()
174
- last_ctx_uint8 = (last_frame * 255).clip(0, 255).astype(np.uint8) # [64,64,3]
175
  ens.direct_cache = []
176
  for i in range(PRED_FRAMES):
177
  frame = np.transpose(predicted_np[i], (1, 2, 0))
178
  frame = (frame * 255).clip(0, 255).astype(np.uint8)
179
- # Fallback: if prediction is very different from context, blend with context
180
- diff = np.abs(frame.astype(np.float32) - last_ctx_uint8.astype(np.float32))
181
- mean_diff = diff.mean()
182
- if mean_diff > 30: # very different prediction
183
- frame = ((0.5 * frame.astype(np.float32) + 0.5 * last_ctx_uint8.astype(np.float32))).clip(0, 255).astype(np.uint8)
184
  ens.direct_cache.append(frame)
185
 
186
  result = ens.direct_cache[ens.cache_step]
@@ -196,6 +190,14 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
196
  ens.reset_cache()
197
  return result
198
 
 
 
 
 
 
 
 
 
199
  ens.reset_cache()
200
  with torch.no_grad():
201
  context_tensor = torch.from_numpy(context).to(DEVICE)
@@ -208,39 +210,43 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
208
  direct_flipped = torch.flip(direct_flipped, dims=[4])
209
  direct_pred = (direct_orig + direct_flipped) / 2.0
210
 
211
- # Multi-run AR with noise diversity
212
- all_ar_runs = []
213
- for noise_std in [0.0, 1.0/255.0, 2.0/255.0]:
214
- ar_preds_run = []
215
- ctx = context_tensor.clone()
216
- ctx_flip = context_flipped.clone()
217
- last_t = last_tensor.clone()
218
- last_f = last_flipped.clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  for step in range(PRED_FRAMES):
220
- ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
221
- ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
222
- ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t)
223
- ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f)
224
- ar_flip_back = torch.flip(ar_flip, dims=[3])
225
- ar_frame = (ar_orig + ar_flip_back) / 2.0
226
- ar_preds_run.append(ar_frame)
227
- ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
228
- ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
229
- ctx = ctx_frames.reshape(1, -1, 64, 64)
230
- last_t = ar_orig
231
- ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
232
- ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
233
- ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
234
- last_f = ar_flip
235
- all_ar_runs.append(torch.stack(ar_preds_run, dim=1))
236
-
237
- ar_pred = sum(all_ar_runs) / len(all_ar_runs)
238
-
239
- predicted = torch.zeros_like(direct_pred)
240
- for step in range(PRED_FRAMES):
241
- ar_weight = 0.65 - (step / (PRED_FRAMES - 1)) * 0.3
242
- direct_weight = 1.0 - ar_weight
243
- predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
244
 
245
  predicted_np = predicted[0].cpu().numpy()
246
  ens.direct_cache = []
 
171
  predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
172
 
173
  predicted_np = predicted[0].cpu().numpy()
 
174
  ens.direct_cache = []
175
  for i in range(PRED_FRAMES):
176
  frame = np.transpose(predicted_np[i], (1, 2, 0))
177
  frame = (frame * 255).clip(0, 255).astype(np.uint8)
 
 
 
 
 
178
  ens.direct_cache.append(frame)
179
 
180
  result = ens.direct_cache[ens.cache_step]
 
190
  ens.reset_cache()
191
  return result
192
 
193
+ # Detect scene transitions in context frames
194
+ scene_transition = False
195
+ for i in range(len(frames) - 1):
196
+ diff = np.abs(frames[i].astype(np.float32) - frames[i + 1].astype(np.float32)).mean()
197
+ if diff > 30.0 / 255.0: # frames are normalized to 0-1
198
+ scene_transition = True
199
+ break
200
+
201
  ens.reset_cache()
202
  with torch.no_grad():
203
  context_tensor = torch.from_numpy(context).to(DEVICE)
 
210
  direct_flipped = torch.flip(direct_flipped, dims=[4])
211
  direct_pred = (direct_orig + direct_flipped) / 2.0
212
 
213
+ if scene_transition:
214
+ # Scene transition: use direct-only (AR produces garbage after scene changes)
215
+ predicted = direct_pred
216
+ else:
217
+ # Normal scene: full AR+direct blend with noise diversity
218
+ all_ar_runs = []
219
+ for noise_std in [0.0, 1.0/255.0, 2.0/255.0]:
220
+ ar_preds_run = []
221
+ ctx = context_tensor.clone()
222
+ ctx_flip = context_flipped.clone()
223
+ last_t = last_tensor.clone()
224
+ last_f = last_flipped.clone()
225
+ for step in range(PRED_FRAMES):
226
+ ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
227
+ ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
228
+ ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t)
229
+ ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f)
230
+ ar_flip_back = torch.flip(ar_flip, dims=[3])
231
+ ar_frame = (ar_orig + ar_flip_back) / 2.0
232
+ ar_preds_run.append(ar_frame)
233
+ ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
234
+ ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
235
+ ctx = ctx_frames.reshape(1, -1, 64, 64)
236
+ last_t = ar_orig
237
+ ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
238
+ ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
239
+ ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
240
+ last_f = ar_flip
241
+ all_ar_runs.append(torch.stack(ar_preds_run, dim=1))
242
+
243
+ ar_pred = sum(all_ar_runs) / len(all_ar_runs)
244
+
245
+ predicted = torch.zeros_like(direct_pred)
246
  for step in range(PRED_FRAMES):
247
+ ar_weight = 0.65 - (step / (PRED_FRAMES - 1)) * 0.3
248
+ direct_weight = 1.0 - ar_weight
249
+ predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  predicted_np = predicted[0].cpu().numpy()
252
  ens.direct_cache = []