ojaffe commited on
Commit
65aa516
·
verified ·
1 Parent(s): e586e5c

Upload folder using huggingface_hub

Browse files
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
predict.py CHANGED
@@ -106,11 +106,11 @@ def load_model(model_dir: str):
106
  return ens
107
 
108
 
109
- def _predict_8frames_direct(model, context_tensor, last_tensor, residual_scale=1.0):
110
  output = model(context_tensor)
111
  residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
112
  last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
113
- return torch.clamp(last_expanded + residual_scale * residuals, 0, 1)
114
 
115
 
116
  def _predict_ar_frame(model, context_tensor, last_tensor, residual_scale=1.0):
@@ -155,7 +155,8 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
155
  ctx = context_tensor.clone()
156
  last_t = last_tensor.clone()
157
  for step in range(PRED_FRAMES):
158
- predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t, residual_scale=1.03)
 
159
  ar_preds.append(predicted)
160
  ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
161
  ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
@@ -213,8 +214,9 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
213
  for step in range(PRED_FRAMES):
214
  ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
215
  ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
216
- ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t, residual_scale=1.08)
217
- ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f, residual_scale=1.08)
 
218
  ar_flip_back = torch.flip(ar_flip, dims=[3])
219
  ar_frame = (ar_orig + ar_flip_back) / 2.0
220
  ar_preds_run.append(ar_frame)
@@ -261,10 +263,10 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
261
  context_tensor = torch.from_numpy(context).to(DEVICE)
262
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
263
 
264
- predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor, residual_scale=1.03)
265
  context_flipped = torch.flip(context_tensor, dims=[3])
266
  last_flipped = torch.flip(last_tensor, dims=[3])
267
- predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped, residual_scale=1.03)
268
  predicted_flipped = torch.flip(predicted_flipped, dims=[4])
269
  predicted = (predicted_orig + predicted_flipped) / 2.0
270
 
 
106
  return ens
107
 
108
 
109
+ def _predict_8frames_direct(model, context_tensor, last_tensor):
110
  output = model(context_tensor)
111
  residuals = output.reshape(1, PRED_FRAMES, 3, 64, 64)
112
  last_expanded = last_tensor.unsqueeze(1).expand_as(residuals)
113
+ return torch.clamp(last_expanded + residuals, 0, 1)
114
 
115
 
116
  def _predict_ar_frame(model, context_tensor, last_tensor, residual_scale=1.0):
 
155
  ctx = context_tensor.clone()
156
  last_t = last_tensor.clone()
157
  for step in range(PRED_FRAMES):
158
+ pong_scale = 1.06 if step >= 4 else 1.0
159
+ predicted = _predict_ar_frame(ens.models["pong"], ctx, last_t, residual_scale=pong_scale)
160
  ar_preds.append(predicted)
161
  ctx_frames = ctx.reshape(1, CONTEXT_FRAMES, 3, 64, 64)
162
  ctx_frames = torch.cat([ctx_frames[:, 1:], predicted.unsqueeze(1)], dim=1)
 
214
  for step in range(PRED_FRAMES):
215
  ctx_in = ctx if noise_std == 0 else torch.clamp(ctx + torch.randn_like(ctx) * noise_std, 0, 1)
216
  ctx_flip_in = ctx_flip if noise_std == 0 else torch.clamp(ctx_flip + torch.randn_like(ctx_flip) * noise_std, 0, 1)
217
+ sonic_scale = 1.12 if step >= 3 else 1.0
218
+ ar_orig = _predict_ar_frame(ens.sonic_ar, ctx_in, last_t, residual_scale=sonic_scale)
219
+ ar_flip = _predict_ar_frame(ens.sonic_ar, ctx_flip_in, last_f, residual_scale=sonic_scale)
220
  ar_flip_back = torch.flip(ar_flip, dims=[3])
221
  ar_frame = (ar_orig + ar_flip_back) / 2.0
222
  ar_preds_run.append(ar_frame)
 
263
  context_tensor = torch.from_numpy(context).to(DEVICE)
264
  last_tensor = torch.from_numpy(last_frame_t).to(DEVICE)
265
 
266
+ predicted_orig = _predict_8frames_direct(ens.models["pole_position"], context_tensor, last_tensor)
267
  context_flipped = torch.flip(context_tensor, dims=[3])
268
  last_flipped = torch.flip(last_tensor, dims=[3])
269
+ predicted_flipped = _predict_8frames_direct(ens.models["pole_position"], context_flipped, last_flipped)
270
  predicted_flipped = torch.flip(predicted_flipped, dims=[4])
271
  predicted = (predicted_orig + predicted_flipped) / 2.0
272