| """Inference for AR curriculum model + TTA.""" |
| import json |
| import numpy as np |
| import torch |
| import sys |
| sys.path.insert(0, "/home/coder/code") |
| from flow_warp_attn_model import FlowWarpAttnUNet |
|
|
|
|
| def load_model(model_dir: str): |
| with open(f"{model_dir}/config.json") as f: |
| config = json.load(f) |
| model = FlowWarpAttnUNet(in_channels=config["in_channels"], channels=config["channels"]) |
| sd = torch.load(f"{model_dir}/model.pt", map_location="cpu", weights_only=True) |
| sd = {k: v.float() for k, v in sd.items()} |
| model.load_state_dict(sd) |
| model.eval() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| return {"model": model, "device": device, "context_len": config["context_len"]} |
|
|
|
|
| def _prepare_input(context_frames, context_len): |
| N = len(context_frames) |
| if N >= context_len: |
| frames = context_frames[-context_len:] |
| else: |
| pad = np.repeat(context_frames[:1], context_len - N, axis=0) |
| frames = np.concatenate([pad, context_frames], axis=0) |
| frames_f = frames.astype(np.float32) / 255.0 |
| frames_f = np.transpose(frames_f, (0, 3, 1, 2)) |
| context = frames_f.reshape(1, -1, 64, 64) |
| last_frame = frames_f[-1:] |
| return context, last_frame |
|
|
|
|
| def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray: |
| model = model_dict["model"] |
| device = model_dict["device"] |
| context_len = model_dict["context_len"] |
|
|
| ctx, last = _prepare_input(context_frames, context_len) |
| with torch.no_grad(): |
| ctx_t = torch.from_numpy(ctx).to(device) |
| last_t = torch.from_numpy(last).to(device) |
| pred1, _ = model(ctx_t, last_t) |
|
|
| flipped_frames = context_frames[:, :, ::-1, :].copy() |
| ctx_f, last_f = _prepare_input(flipped_frames, context_len) |
| with torch.no_grad(): |
| ctx_ft = torch.from_numpy(ctx_f).to(device) |
| last_ft = torch.from_numpy(last_f).to(device) |
| pred2, _ = model(ctx_ft, last_ft) |
| pred2 = pred2.flip(-1) |
|
|
| pred = (pred1 + pred2) / 2.0 |
| pred_np = pred[0].cpu().numpy() |
| pred_np = np.transpose(pred_np, (1, 2, 0)) |
| return (pred_np * 255.0).clip(0, 255).astype(np.uint8) |
|
|