Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- predict.py +14 -17
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
predict.py
CHANGED
|
@@ -149,31 +149,28 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 149 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 150 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 151 |
|
| 152 |
-
|
| 153 |
-
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 154 |
-
last_flipped = torch.flip(last_tensor, dims=[3])
|
| 155 |
|
| 156 |
ar_preds = []
|
| 157 |
ctx = context_tensor.clone()
|
| 158 |
-
ctx_flip = context_flipped.clone()
|
| 159 |
last_t = last_tensor.clone()
|
| 160 |
-
last_f = last_flipped.clone()
|
| 161 |
for step in range(PRED_FRAMES):
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 165 |
-
ar_frame = (ar_orig + ar_flip_back) / 2.0
|
| 166 |
-
ar_preds.append(ar_frame)
|
| 167 |
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 168 |
-
ctx_frames = torch.cat([ctx_frames[:, 1:],
|
| 169 |
ctx = ctx_frames.reshape(1, -1, 64, 64)
|
| 170 |
-
last_t =
|
| 171 |
-
ctx_flip_frames = ctx_flip.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 172 |
-
ctx_flip_frames = torch.cat([ctx_flip_frames[:, 1:], ar_flip.unsqueeze(1)], dim=1)
|
| 173 |
-
ctx_flip = ctx_flip_frames.reshape(1, -1, 64, 64)
|
| 174 |
-
last_f = ar_flip
|
| 175 |
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
predicted_np = predicted[0].cpu().numpy()
|
| 179 |
ens.direct_cache = []
|
|
|
|
| 149 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 150 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 151 |
|
| 152 |
+
direct_pred = _predict_8frames_direct(ens.pong_direct, context_tensor, last_tensor)
|
|
|
|
|
|
|
| 153 |
|
| 154 |
ar_preds = []
|
| 155 |
ctx = context_tensor.clone()
|
|
|
|
| 156 |
last_t = last_tensor.clone()
|
|
|
|
| 157 |
for step in range(PRED_FRAMES):
|
| 158 |
+
predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t)
|
| 159 |
+
ar_preds.append(predicted)
|
|
|
|
|
|
|
|
|
|
| 160 |
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 161 |
+
ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
|
| 162 |
ctx = ctx_frames.reshape(1, -1, 64, 64)
|
| 163 |
+
last_t = predicted
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
ar_pred = torch.stack(ar_preds, dim=1)
|
| 166 |
+
|
| 167 |
+
# Pure AR for steps 1-4, then blend for steps 5-8
|
| 168 |
+
pong_ar_weights = [1.0, 1.0, 1.0, 1.0, 0.70, 0.65, 0.60, 0.55]
|
| 169 |
+
predicted = torch.zeros_like(direct_pred)
|
| 170 |
+
for step in range(PRED_FRAMES):
|
| 171 |
+
ar_weight = pong_ar_weights[step]
|
| 172 |
+
direct_weight = 1.0 - ar_weight
|
| 173 |
+
predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
|
| 174 |
|
| 175 |
predicted_np = predicted[0].cpu().numpy()
|
| 176 |
ens.direct_cache = []
|