import sys import argparse import os import cv2 import glob import numpy as np import torch from PIL import Image from .raft import RAFT from .utils import flow_viz from .utils.utils import InputPadder DEVICE = 'cuda' def load_image(imfile): img = np.array(Image.open(imfile)).astype(np.uint8) img = torch.from_numpy(img).permute(2, 0, 1).float() return img def load_image_list(image_files): images = [] for imfile in sorted(image_files): images.append(load_image(imfile)) images = torch.stack(images, dim=0) images = images.to(DEVICE) padder = InputPadder(images.shape) return padder.pad(images)[0] def viz(img, flo): img = img[0].permute(1,2,0).cpu().numpy() flo = flo[0].permute(1,2,0).cpu().numpy() # map flow to rgb image flo = flow_viz.flow_to_image(flo) # img_flo = np.concatenate([img, flo], axis=0) img_flo = flo cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) # cv2.waitKey() def demo(args): model = torch.nn.DataParallel(RAFT(args)) model.load_state_dict(torch.load(args.model)) model = model.module model.to(DEVICE) model.eval() with torch.no_grad(): images = glob.glob(os.path.join(args.path, '*.png')) + \ glob.glob(os.path.join(args.path, '*.jpg')) images = load_image_list(images) for i in range(images.shape[0]-1): image1 = images[i,None] image2 = images[i+1,None] flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) viz(image1, flow_up) def RAFT_infer(args): model = torch.nn.DataParallel(RAFT(args)) model.load_state_dict(torch.load(args.model)) model = model.module model.to(DEVICE) model.eval() return model