ojaffe commited on
Commit
ae8a361
·
verified ·
1 Parent(s): 515523e

Upload folder using huggingface_hub

Browse files
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
predict.py CHANGED
@@ -164,19 +164,23 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
164
 
165
  ar_pred = torch.stack(ar_preds, dim=1)
166
 
167
- # Pure AR for steps 1-4, then blend for steps 5-8
168
- pong_ar_weights = [1.0, 1.0, 1.0, 1.0, 0.70, 0.65, 0.60, 0.55]
169
  predicted = torch.zeros_like(direct_pred)
170
  for step in range(PRED_FRAMES):
171
- ar_weight = pong_ar_weights[step]
172
  direct_weight = 1.0 - ar_weight
173
  predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
174
 
175
  predicted_np = predicted[0].cpu().numpy()
 
176
  ens.direct_cache = []
177
  for i in range(PRED_FRAMES):
178
  frame = np.transpose(predicted_np[i], (1, 2, 0))
179
  frame = (frame * 255).clip(0, 255).astype(np.uint8)
 
 
 
 
 
180
  ens.direct_cache.append(frame)
181
 
182
  result = ens.direct_cache[ens.cache_step]
 
164
 
165
  ar_pred = torch.stack(ar_preds, dim=1)
166
 
 
 
167
  predicted = torch.zeros_like(direct_pred)
168
  for step in range(PRED_FRAMES):
169
+ ar_weight = 0.85 - (step / (PRED_FRAMES - 1)) * 0.3
170
  direct_weight = 1.0 - ar_weight
171
  predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
172
 
173
  predicted_np = predicted[0].cpu().numpy()
174
+ last_ctx_uint8 = (last_frame * 255).clip(0, 255).astype(np.uint8) # [64,64,3]
175
  ens.direct_cache = []
176
  for i in range(PRED_FRAMES):
177
  frame = np.transpose(predicted_np[i], (1, 2, 0))
178
  frame = (frame * 255).clip(0, 255).astype(np.uint8)
179
+ # Fallback: if prediction is very different from context, blend with context
180
+ diff = np.abs(frame.astype(np.float32) - last_ctx_uint8.astype(np.float32))
181
+ mean_diff = diff.mean()
182
+ if mean_diff > 30: # very different prediction
183
+ frame = ((0.5 * frame.astype(np.float32) + 0.5 * last_ctx_uint8.astype(np.float32))).clip(0, 255).astype(np.uint8)
184
  ens.direct_cache.append(frame)
185
 
186
  result = ens.direct_cache[ens.cache_step]