oguzakif's picture
init repo
d4b77ac
raw
history blame
1.86 kB
import sys
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image
from .raft import RAFT
from .utils import flow_viz
from .utils.utils import InputPadder
DEVICE = 'cuda'
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img
def load_image_list(image_files):
images = []
for imfile in sorted(image_files):
images.append(load_image(imfile))
images = torch.stack(images, dim=0)
images = images.to(DEVICE)
padder = InputPadder(images.shape)
return padder.pad(images)[0]
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
# img_flo = np.concatenate([img, flo], axis=0)
img_flo = flo
cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
# cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
# cv2.waitKey()
def demo(args):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))
images = load_image_list(images)
for i in range(images.shape[0]-1):
image1 = images[i,None]
image2 = images[i+1,None]
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
viz(image1, flow_up)
def RAFT_infer(args):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
return model