Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- predict.py +7 -3
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
predict.py
CHANGED
|
@@ -164,19 +164,23 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
|
|
| 164 |
|
| 165 |
ar_pred = torch.stack(ar_preds, dim=1)
|
| 166 |
|
| 167 |
-
# Pure AR for steps 1-4, then blend for steps 5-8
|
| 168 |
-
pong_ar_weights = [1.0, 1.0, 1.0, 1.0, 0.70, 0.65, 0.60, 0.55]
|
| 169 |
predicted = torch.zeros_like(direct_pred)
|
| 170 |
for step in range(PRED_FRAMES):
|
| 171 |
-
ar_weight =
|
| 172 |
direct_weight = 1.0 - ar_weight
|
| 173 |
predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
|
| 174 |
|
| 175 |
predicted_np = predicted[0].cpu().numpy()
|
|
|
|
| 176 |
ens.direct_cache = []
|
| 177 |
for i in range(PRED_FRAMES):
|
| 178 |
frame = np.transpose(predicted_np[i], (1, 2, 0))
|
| 179 |
frame = (frame * 255).clip(0, 255).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
ens.direct_cache.append(frame)
|
| 181 |
|
| 182 |
result = ens.direct_cache[ens.cache_step]
|
|
|
|
| 164 |
|
| 165 |
ar_pred = torch.stack(ar_preds, dim=1)
|
| 166 |
|
|
|
|
|
|
|
| 167 |
predicted = torch.zeros_like(direct_pred)
|
| 168 |
for step in range(PRED_FRAMES):
|
| 169 |
+
ar_weight = 0.85 - (step / (PRED_FRAMES - 1)) * 0.3
|
| 170 |
direct_weight = 1.0 - ar_weight
|
| 171 |
predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
|
| 172 |
|
| 173 |
predicted_np = predicted[0].cpu().numpy()
|
| 174 |
+
last_ctx_uint8 = (last_frame * 255).clip(0, 255).astype(np.uint8) # [64,64,3]
|
| 175 |
ens.direct_cache = []
|
| 176 |
for i in range(PRED_FRAMES):
|
| 177 |
frame = np.transpose(predicted_np[i], (1, 2, 0))
|
| 178 |
frame = (frame * 255).clip(0, 255).astype(np.uint8)
|
| 179 |
+
# Fallback: if prediction is very different from context, blend with context
|
| 180 |
+
diff = np.abs(frame.astype(np.float32) - last_ctx_uint8.astype(np.float32))
|
| 181 |
+
mean_diff = diff.mean()
|
| 182 |
+
if mean_diff > 30: # very different prediction
|
| 183 |
+
frame = ((0.5 * frame.astype(np.float32) + 0.5 * last_ctx_uint8.astype(np.float32))).clip(0, 255).astype(np.uint8)
|
| 184 |
ens.direct_cache.append(frame)
|
| 185 |
|
| 186 |
result = ens.direct_cache[ens.cache_step]
|