coder-model / predict.py
ojaffe's picture
Upload folder using huggingface_hub
46c40c4 verified
"""Inference for AR curriculum model + TTA."""
import json
import numpy as np
import torch
import sys
sys.path.insert(0, "/home/coder/code")
from flow_warp_attn_model import FlowWarpAttnUNet
def load_model(model_dir: str):
with open(f"{model_dir}/config.json") as f:
config = json.load(f)
model = FlowWarpAttnUNet(in_channels=config["in_channels"], channels=config["channels"])
sd = torch.load(f"{model_dir}/model.pt", map_location="cpu", weights_only=True)
sd = {k: v.float() for k, v in sd.items()}
model.load_state_dict(sd)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return {"model": model, "device": device, "context_len": config["context_len"]}
def _prepare_input(context_frames, context_len):
N = len(context_frames)
if N >= context_len:
frames = context_frames[-context_len:]
else:
pad = np.repeat(context_frames[:1], context_len - N, axis=0)
frames = np.concatenate([pad, context_frames], axis=0)
frames_f = frames.astype(np.float32) / 255.0
frames_f = np.transpose(frames_f, (0, 3, 1, 2))
context = frames_f.reshape(1, -1, 64, 64)
last_frame = frames_f[-1:]
return context, last_frame
def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
model = model_dict["model"]
device = model_dict["device"]
context_len = model_dict["context_len"]
ctx, last = _prepare_input(context_frames, context_len)
with torch.no_grad():
ctx_t = torch.from_numpy(ctx).to(device)
last_t = torch.from_numpy(last).to(device)
pred1, _ = model(ctx_t, last_t)
flipped_frames = context_frames[:, :, ::-1, :].copy()
ctx_f, last_f = _prepare_input(flipped_frames, context_len)
with torch.no_grad():
ctx_ft = torch.from_numpy(ctx_f).to(device)
last_ft = torch.from_numpy(last_f).to(device)
pred2, _ = model(ctx_ft, last_ft)
pred2 = pred2.flip(-1)
pred = (pred1 + pred2) / 2.0
pred_np = pred[0].cpu().numpy()
pred_np = np.transpose(pred_np, (1, 2, 0))
return (pred_np * 255.0).clip(0, 255).astype(np.uint8)