oguzakif's picture
init repo
d4b77ac
raw
history blame
5.57 kB
# coding=utf-8
import os
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..')))
import argparse
import os
import cv2
import glob
import copy
import numpy as np
import torch
from PIL import Image
import scipy.ndimage
import torchvision.transforms.functional as F
import torch.nn.functional as F2
from RAFT import utils
from RAFT import RAFT
import utils.region_fill as rf
from torchvision.transforms import ToTensor
import time
def to_tensor(img):
img = Image.fromarray(img)
img_t = F.to_tensor(img).float()
return img_t
def gradient_mask(mask): # 产生梯度的mask
gradient_mask = np.logical_or.reduce((mask,
np.concatenate((mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)),
axis=0),
np.concatenate((mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)),
axis=1)))
return gradient_mask
def create_dir(dir):
"""Creates a directory if not exist.
"""
if not os.path.exists(dir):
os.makedirs(dir)
def initialize_RAFT(args):
"""Initializes the RAFT model.
"""
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to('cuda')
model.eval()
return model
def calculate_flow(args, model, vid, video, mode):
"""Calculates optical flow.
"""
if mode not in ['forward', 'backward']:
raise NotImplementedError
nFrame, _, imgH, imgW = video.shape
Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
create_dir(os.path.join(args.outroot, vid, mode + '_flo'))
# create_dir(os.path.join(args.outroot, vid, 'flow', mode + '_png'))
with torch.no_grad():
for i in range(video.shape[0] - 1):
print("Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='')
if mode == 'forward':
# Flow i -> i + 1
image1 = video[i, None]
image2 = video[i + 1, None]
elif mode == 'backward':
# Flow i + 1 -> i
image1 = video[i + 1, None]
image2 = video[i, None]
else:
raise NotImplementedError
_, flow = model(image1, image2, iters=20, test_mode=True)
flow = flow[0].permute(1, 2, 0).cpu().numpy()
# Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
# Flow visualization.
# flow_img = utils.flow_viz.flow_to_image(flow)
# flow_img = Image.fromarray(flow_img)
# Saves the flow and flow_img.
# flow_img.save(os.path.join(args.outroot, vid, 'flow', mode + '_png', '%05d.png'%i))
utils.frame_utils.writeFlow(os.path.join(args.outroot, vid, mode + '_flo', '%05d.flo' % i), flow)
def main(args):
# Flow model.
RAFT_model = initialize_RAFT(args)
videos = os.listdir(args.path)
videoLen = len(videos)
try:
exceptList = os.listdir(args.expdir)
except:
exceptList = []
v = 0
for vid in videos:
v += 1
print('[{}]/[{}] Video {} is being processed'.format(v, len(videos), vid))
if vid in exceptList:
print('Video: {} skipped'.format(vid))
continue
# Loads frames.
filename_list = glob.glob(os.path.join(args.path, vid, '*.png')) + \
glob.glob(os.path.join(args.path, vid, '*.jpg'))
# Obtains imgH, imgW and nFrame.
imgH, imgW = np.array(Image.open(filename_list[0])).shape[:2]
nFrame = len(filename_list)
print('images are loaded')
# Loads video.
video = []
for filename in sorted(filename_list):
print(filename)
img = np.array(Image.open(filename))
if args.width != 0 and args.height != 0:
img = cv2.resize(img, (args.width, args.height), cv2.INTER_LINEAR)
video.append(torch.from_numpy(img.astype(np.uint8)).permute(2, 0, 1).float())
video = torch.stack(video, dim=0)
video = video.to('cuda')
# Calcutes the corrupted flow.
start = time.time()
calculate_flow(args, RAFT_model, vid, video, 'forward')
calculate_flow(args, RAFT_model, vid, video, 'backward')
end = time.time()
sumTime = end - start
print('{}/{}, video {} is finished. {} frames takes {}s, {}s/frame.'.format(v, videoLen, vid, nFrame, sumTime,
sumTime / (2 * nFrame)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# flow basic setting
parser.add_argument('--path', required=True, type=str)
parser.add_argument('--expdir', type=str)
parser.add_argument('--outroot', required=True, type=str)
parser.add_argument('--width', type=int, default=432)
parser.add_argument('--height', type=int, default=256)
# RAFT
parser.add_argument('--model', default='../weight/raft-things.pth', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
args = parser.parse_args()
main(args)