| """Prediction interface for per-game Flow-Warp-Mask models v12 with motion encoding + TTA.""" |
| import sys |
| import os |
| import numpy as np |
| import torch |
|
|
| sys.path.insert(0, "/home/coder/code") |
| from flowmask_model import FlowWarpMaskUNet |
| from flownet_model import differentiable_warp |
|
|
| CONTEXT_LEN = 4 |
| GAME_CONFIGS = { |
| "pong": {"channels": [32, 64, 128], "file": "pong_model.pt"}, |
| "sonic": {"channels": [40, 80, 160], "file": "sonic_model.pt"}, |
| "pole_position": {"channels": [24, 48, 96], "file": "pole_model.pt"}, |
| } |
|
|
|
|
| def detect_game(context_frames): |
| mean_val = context_frames.mean() |
| if mean_val < 10: |
| return "pong" |
| elif mean_val < 80: |
| return "sonic" |
| else: |
| return "pole_position" |
|
|
|
|
| def load_model(model_dir: str): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| models = {} |
| for game, cfg in GAME_CONFIGS.items(): |
| model = FlowWarpMaskUNet(in_channels=12, channels=cfg["channels"]) |
| model_path = os.path.join(model_dir, cfg["file"]) |
| state_dict = torch.load(model_path, map_location=device, weights_only=True) |
| state_dict = {k: v.float() for k, v in state_dict.items()} |
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
| models[game] = model |
| return {"models": models, "device": device} |
|
|
|
|
| def _make_motion_input(frames): |
| """Create motion encoding: last frame (3ch) + 3 pairwise diffs (9ch) = 12ch. |
| |
| Args: |
| frames: (4, 3, H, W) tensor in [0,1] |
| Returns: |
| (12, H, W) tensor |
| """ |
| last = frames[-1] |
| diff1 = frames[-1] - frames[-2] |
| diff2 = frames[-2] - frames[-3] |
| diff3 = frames[-3] - frames[-4] |
| return torch.cat([last, diff1, diff2, diff3], dim=0) |
|
|
|
|
| def _prepare_context(context_frames): |
| """Prepare 4-frame context from numpy frames.""" |
| if len(context_frames) >= CONTEXT_LEN: |
| frames = context_frames[-CONTEXT_LEN:] |
| else: |
| pad_count = CONTEXT_LEN - len(context_frames) |
| padding = np.stack([context_frames[0]] * pad_count, axis=0) |
| frames = np.concatenate([padding, context_frames], axis=0) |
|
|
| frames_t = torch.from_numpy(frames.astype(np.float32) / 255.0) |
| frames_t = frames_t.permute(0, 3, 1, 2) |
| return frames_t |
|
|
|
|
| def _run_model(model, frames_t, device): |
| """Run model with motion encoding input.""" |
| last_frame = frames_t[-1].unsqueeze(0) |
| inp = _make_motion_input(frames_t).unsqueeze(0) |
|
|
| inp = inp.to(device) |
| last_frame = last_frame.to(device) |
|
|
| flow, mask, gen_frame = model(inp) |
| warped = differentiable_warp(last_frame, flow) |
| pred = mask * warped + (1 - mask) * gen_frame |
| pred = torch.clamp(pred, 0, 1) |
| return pred |
|
|
|
|
| def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray: |
| models = model_dict["models"] |
| device = model_dict["device"] |
|
|
| game = detect_game(context_frames) |
| model = models[game] |
|
|
| frames_t = _prepare_context(context_frames) |
|
|
| with torch.no_grad(): |
| |
| pred1 = _run_model(model, frames_t, device) |
|
|
| |
| frames_flipped = frames_t.flip(-1) |
| pred2_flipped = _run_model(model, frames_flipped, device) |
| pred2 = pred2_flipped.flip(-1) |
|
|
| |
| pred = (pred1 + pred2) / 2.0 |
|
|
| pred = pred[0].cpu().permute(1, 2, 0).numpy() |
| pred = (pred * 255).clip(0, 255).astype(np.uint8) |
|
|
| |
| if game == "pong": |
| pred[pred < 5] = 0 |
|
|
| return pred |
|
|