Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- predict.py +9 -7
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
predict.py
CHANGED
|
@@ -106,11 +106,11 @@ def load_model(model_dir: str):
|
|
| 106 |
return ens
|
| 107 |
|
| 108 |
|
| 109 |
-
def _predict_8frames_direct(model, context_tensor, last_tensor
|
| 110 |
output = model(context_tensor)
|
| 111 |
residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
|
| 112 |
last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
|
| 113 |
-
return torch.clamp(last_expanded +
|
| 114 |
|
| 115 |
|
| 116 |
def _predict_ar_frame(model, context_tensor, last_tensor, residual_scale=1.0):
|
|
@@ -155,7 +155,8 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 155 |
ctx = context_tensor.clone()
|
| 156 |
last_t = last_tensor.clone()
|
| 157 |
for step in range(PRED_FRAMES):
|
| 158 |
-
|
|
|
|
| 159 |
ar_preds.append(predicted)
|
| 160 |
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 161 |
ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
|
|
@@ -213,8 +214,9 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 213 |
for step in range(PRED_FRAMES):
|
| 214 |
ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
|
| 215 |
ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
|
| 216 |
-
|
| 217 |
-
|
|
|
|
| 218 |
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 219 |
ar_frame = (ar_orig + ar_flip_back) / 2.0
|
| 220 |
ar_preds_run.append(ar_frame)
|
|
@@ -261,10 +263,10 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 261 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 262 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 263 |
|
| 264 |
-
predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor
|
| 265 |
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 266 |
last_flipped = torch.flip(last_tensor, dims=[3])
|
| 267 |
-
predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped
|
| 268 |
predicted_flipped = torch.flip(predicted_flipped, dims=[4])
|
| 269 |
predicted = (predicted_orig + predicted_flipped) / 2.0
|
| 270 |
|
|
|
|
| 106 |
return ens
|
| 107 |
|
| 108 |
|
| 109 |
+
def _predict_8frames_direct(model, context_tensor, last_tensor):
|
| 110 |
output = model(context_tensor)
|
| 111 |
residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
|
| 112 |
last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
|
| 113 |
+
return torch.clamp(last_expanded + residuals, 0, 1)
|
| 114 |
|
| 115 |
|
| 116 |
def _predict_ar_frame(model, context_tensor, last_tensor, residual_scale=1.0):
|
|
|
|
| 155 |
ctx = context_tensor.clone()
|
| 156 |
last_t = last_tensor.clone()
|
| 157 |
for step in range(PRED_FRAMES):
|
| 158 |
+
pong_scale = 1.06 if step >= 4 else 1.0
|
| 159 |
+
predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t, residual_scale=pong_scale)
|
| 160 |
ar_preds.append(predicted)
|
| 161 |
ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
|
| 162 |
ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
|
|
|
|
| 214 |
for step in range(PRED_FRAMES):
|
| 215 |
ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
|
| 216 |
ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
|
| 217 |
+
sonic_scale = 1.12 if step >= 3 else 1.0
|
| 218 |
+
ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t, residual_scale=sonic_scale)
|
| 219 |
+
ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f, residual_scale=sonic_scale)
|
| 220 |
ar_flip_back = torch.flip(ar_flip, dims=[3])
|
| 221 |
ar_frame = (ar_orig + ar_flip_back) / 2.0
|
| 222 |
ar_preds_run.append(ar_frame)
|
|
|
|
| 263 |
context_tensor = torch.from_numpy(context).to(DEVICE)
|
| 264 |
last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
|
| 265 |
|
| 266 |
+
predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor)
|
| 267 |
context_flipped = torch.flip(context_tensor, dims=[3])
|
| 268 |
last_flipped = torch.flip(last_tensor, dims=[3])
|
| 269 |
+
predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped)
|
| 270 |
predicted_flipped = torch.flip(predicted_flipped, dims=[4])
|
| 271 |
predicted = (predicted_orig + predicted_flipped) / 2.0
|
| 272 |
|