ojaffe commited on
Commit
515523e
·
verified ·
1 Parent(s): efae84f

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. __pycache__/predict.cpython-311.pyc +0 -0
  2. predict.py +14 -17
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
predict.py CHANGED
@@ -149,31 +149,28 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
149
  context_tensor = torch.from_numpy(context).to(DEVICE)
150
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
151
 
152
- # Pong AR-only with hflip TTA
153
- context_flipped = torch.flip(context_tensor, dims=[3])
154
- last_flipped = torch.flip(last_tensor, dims=[3])
155
 
156
  ar_preds = []
157
  ctx = context_tensor.clone()
158
- ctx_flip = context_flipped.clone()
159
  last_t = last_tensor.clone()
160
- last_f = last_flipped.clone()
161
  for step in range(PRED_FRAMES):
162
- ar_orig = _predict_ar_frame(ens.models["pong"], ctx, last_t)
163
- ar_flip = _predict_ar_frame(ens.models["pong"], ctx_flip, last_f)
164
- ar_flip_back = torch.flip(ar_flip, dims=[3])
165
- ar_frame = (ar_orig + ar_flip_back) / 2.0
166
- ar_preds.append(ar_frame)
167
  ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
168
- ctx_frames = torch.cat([ctx_frames[:, 1:], ar_orig.unsqueeze(1)], dim=1)
169
  ctx = ctx_frames.reshape(1, -1, 64, 64)
170
- last_t = ar_orig
171
- ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
172
- ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
173
- ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
174
- last_f = ar_flip
175
 
176
- predicted = torch.stack(ar_preds, dim=1)
 
 
 
 
 
 
 
 
177
 
178
  predicted_np = predicted[0].cpu().numpy()
179
  ens.direct_cache = []
 
149
  context_tensor = torch.from_numpy(context).to(DEVICE)
150
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
151
 
152
+ direct_pred = _predict_8frames_direct(ens.pong_direct, context_tensor, last_tensor)
 
 
153
 
154
  ar_preds = []
155
  ctx = context_tensor.clone()
 
156
  last_t = last_tensor.clone()
 
157
  for step in range(PRED_FRAMES):
158
+ predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t)
159
+ ar_preds.append(predicted)
 
 
 
160
  ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
161
+ ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
162
  ctx = ctx_frames.reshape(1, -1, 64, 64)
163
+ last_t = predicted
 
 
 
 
164
 
165
+ ar_pred = torch.stack(ar_preds, dim=1)
166
+
167
+ # Pure AR for steps 1-4, then blend for steps 5-8
168
+ pong_ar_weights = [1.0, 1.0, 1.0, 1.0, 0.70, 0.65, 0.60, 0.55]
169
+ predicted = torch.zeros_like(direct_pred)
170
+ for step in range(PRED_FRAMES):
171
+ ar_weight = pong_ar_weights[step]
172
+ direct_weight = 1.0 - ar_weight
173
+ predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
174
 
175
  predicted_np = predicted[0].cpu().numpy()
176
  ens.direct_cache = []