Spaces:
Runtime error
Runtime error
| from numpy.lib.npyio import load | |
| from torch._C import device | |
| import sys | |
| sys.path.append('/scratch/shared/beegfs/szwu/projects/video3d/RAFT') | |
| from core.raft import RAFT | |
| from .utils import InputPadder | |
| import torch | |
| class AttrDict(dict): | |
| def __init__(self, *args, **kwargs): | |
| super(AttrDict, self).__init__(*args, **kwargs) | |
| self.__dict__ = self | |
| class FlowModel(): | |
| def __init__(self, model, device): | |
| args = AttrDict({'model': model, 'small': False, 'mixed_precision': False, 'alternate_corr': False}) | |
| self.model = self.load_model(args, device) | |
| self.device = device | |
| def load_model(args, device): | |
| model = torch.nn.DataParallel(RAFT(args)) | |
| model.load_state_dict(torch.load(args.model)) | |
| model = model.module | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def preprocess_image(self, image): | |
| # image = image[:, :, ::-1].copy() | |
| image = torch.from_numpy(image).permute(2, 0, 1).float() | |
| image = image.to(self.device) | |
| image = image[None] | |
| # size = [540, 960] | |
| # image = torch.nn.functional.interpolate(image, size=size, mode='bilinear', align_corners=False) | |
| padder = InputPadder(image.shape) | |
| return padder.pad(image)[0], padder | |
| def compute_flow(self, frame, next_frame, iters=20): | |
| frame, padder = self.preprocess_image(frame) | |
| next_frame, padder = self.preprocess_image(next_frame) | |
| _, flow = self.model(frame, next_frame, iters=iters, test_mode=True) | |
| return padder.unpad(flow)[0].permute(1, 2, 0).cpu().numpy() | |