Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
import os | |
import sys | |
sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..'))) | |
import argparse | |
import os | |
import cv2 | |
import glob | |
import copy | |
import numpy as np | |
import torch | |
from PIL import Image | |
import scipy.ndimage | |
import torchvision.transforms.functional as F | |
import torch.nn.functional as F2 | |
from RAFT import utils | |
from RAFT import RAFT | |
import utils.region_fill as rf | |
from torchvision.transforms import ToTensor | |
import time | |
def to_tensor(img): | |
img = Image.fromarray(img) | |
img_t = F.to_tensor(img).float() | |
return img_t | |
def gradient_mask(mask): # 产生梯度的mask | |
gradient_mask = np.logical_or.reduce((mask, | |
np.concatenate((mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)), | |
axis=0), | |
np.concatenate((mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)), | |
axis=1))) | |
return gradient_mask | |
def create_dir(dir): | |
"""Creates a directory if not exist. | |
""" | |
if not os.path.exists(dir): | |
os.makedirs(dir) | |
def initialize_RAFT(args): | |
"""Initializes the RAFT model. | |
""" | |
model = torch.nn.DataParallel(RAFT(args)) | |
model.load_state_dict(torch.load(args.model)) | |
model = model.module | |
model.to('cuda') | |
model.eval() | |
return model | |
def calculate_flow(args, model, vid, video, mode): | |
"""Calculates optical flow. | |
""" | |
if mode not in ['forward', 'backward']: | |
raise NotImplementedError | |
nFrame, _, imgH, imgW = video.shape | |
Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) | |
create_dir(os.path.join(args.outroot, vid, mode + '_flo')) | |
# create_dir(os.path.join(args.outroot, vid, 'flow', mode + '_png')) | |
with torch.no_grad(): | |
for i in range(video.shape[0] - 1): | |
print("Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='') | |
if mode == 'forward': | |
# Flow i -> i + 1 | |
image1 = video[i, None] | |
image2 = video[i + 1, None] | |
elif mode == 'backward': | |
# Flow i + 1 -> i | |
image1 = video[i + 1, None] | |
image2 = video[i, None] | |
else: | |
raise NotImplementedError | |
_, flow = model(image1, image2, iters=20, test_mode=True) | |
flow = flow[0].permute(1, 2, 0).cpu().numpy() | |
# Flow = np.concatenate((Flow, flow[..., None]), axis=-1) | |
# Flow visualization. | |
# flow_img = utils.flow_viz.flow_to_image(flow) | |
# flow_img = Image.fromarray(flow_img) | |
# Saves the flow and flow_img. | |
# flow_img.save(os.path.join(args.outroot, vid, 'flow', mode + '_png', '%05d.png'%i)) | |
utils.frame_utils.writeFlow(os.path.join(args.outroot, vid, mode + '_flo', '%05d.flo' % i), flow) | |
def main(args): | |
# Flow model. | |
RAFT_model = initialize_RAFT(args) | |
videos = os.listdir(args.path) | |
videoLen = len(videos) | |
try: | |
exceptList = os.listdir(args.expdir) | |
except: | |
exceptList = [] | |
v = 0 | |
for vid in videos: | |
v += 1 | |
print('[{}]/[{}] Video {} is being processed'.format(v, len(videos), vid)) | |
if vid in exceptList: | |
print('Video: {} skipped'.format(vid)) | |
continue | |
# Loads frames. | |
filename_list = glob.glob(os.path.join(args.path, vid, '*.png')) + \ | |
glob.glob(os.path.join(args.path, vid, '*.jpg')) | |
# Obtains imgH, imgW and nFrame. | |
imgH, imgW = np.array(Image.open(filename_list[0])).shape[:2] | |
nFrame = len(filename_list) | |
print('images are loaded') | |
# Loads video. | |
video = [] | |
for filename in sorted(filename_list): | |
print(filename) | |
img = np.array(Image.open(filename)) | |
if args.width != 0 and args.height != 0: | |
img = cv2.resize(img, (args.width, args.height), cv2.INTER_LINEAR) | |
video.append(torch.from_numpy(img.astype(np.uint8)).permute(2, 0, 1).float()) | |
video = torch.stack(video, dim=0) | |
video = video.to('cuda') | |
# Calcutes the corrupted flow. | |
start = time.time() | |
calculate_flow(args, RAFT_model, vid, video, 'forward') | |
calculate_flow(args, RAFT_model, vid, video, 'backward') | |
end = time.time() | |
sumTime = end - start | |
print('{}/{}, video {} is finished. {} frames takes {}s, {}s/frame.'.format(v, videoLen, vid, nFrame, sumTime, | |
sumTime / (2 * nFrame))) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
# flow basic setting | |
parser.add_argument('--path', required=True, type=str) | |
parser.add_argument('--expdir', type=str) | |
parser.add_argument('--outroot', required=True, type=str) | |
parser.add_argument('--width', type=int, default=432) | |
parser.add_argument('--height', type=int, default=256) | |
# RAFT | |
parser.add_argument('--model', default='../weight/raft-things.pth', help="restore checkpoint") | |
parser.add_argument('--small', action='store_true', help='use small model') | |
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') | |
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') | |
args = parser.parse_args() | |
main(args) | |