import cvbase from torchvision.transforms import ToTensor from get_flowNN_gradient import get_flowNN_gradient from utils.Poisson_blend_img import Poisson_blend_img from utils.region_fill import regionfill from importlib import import_module import yaml from RAFT import RAFT from RAFT import utils import torch.nn.functional as F2 import torchvision.transforms.functional as F from skimage.feature import canny import scipy.ndimage from PIL import Image import imageio import torch import numpy as np import copy import glob import cv2 import argparse import warnings import os import sys sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..'))) sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'tool'))) sys.path.append(os.path.abspath(os.path.join( __file__, '..', '..', 'tool', 'utils'))) sys.path.append(os.path.abspath(os.path.join( __file__, '..', '..', 'tool', 'utils', 'region_fill.py'))) sys.path.append(os.path.abspath(os.path.join( __file__, '..', '..', 'tool', 'utils', 'Poisson_blend_img.py'))) sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'FGT'))) sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'LAFC'))) sys.path.append(os.path.abspath( os.path.join(os.path.dirname("__file__"), '..'))) warnings.filterwarnings("ignore") def to_tensor(img): img = Image.fromarray(img) img_t = F.to_tensor(img).float() return img_t def diffusion(flows, masks): flows_filled = [] for i in range(flows.shape[0]): flow, mask = flows[i], masks[i] flow_filled = np.zeros(flow.shape) flow_filled[:, :, 0] = regionfill(flow[:, :, 0], mask[:, :, 0]) flow_filled[:, :, 1] = regionfill(flow[:, :, 1], mask[:, :, 0]) flows_filled.append(flow_filled) return flows_filled def np2tensor(array, near='c'): if isinstance(array, list): array = np.stack(array, axis=0) # [t, h, w, c] if near == 'c': array = torch.from_numpy(np.transpose(array, (3, 0, 1, 2))).unsqueeze( 0).float() # [1, c, t, h, w] elif near == 't': array = torch.from_numpy(np.transpose( array, (0, 3, 1, 2))).unsqueeze(0).float() else: raise ValueError(f'Unknown near type: {near}') return array def tensor2np(array): array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy() return array def gradient_mask(mask): gradient_mask = np.logical_or.reduce((mask, np.concatenate((mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)), axis=0), np.concatenate((mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)), axis=1))) return gradient_mask def indicesGen(pivot, interval, frames, t): singleSide = frames // 2 results = [] for i in range(-singleSide, singleSide + 1): index = pivot + interval * i if index < 0: index = abs(index) if index > t - 1: index = 2 * (t - 1) - index results.append(index) return results def get_ref_index(f, neighbor_ids, length, ref_length, num_ref): ref_index = [] if num_ref == -1: for i in range(0, length, ref_length): if i not in neighbor_ids: ref_index.append(i) else: start_idx = max(0, f - ref_length * (num_ref // 2)) end_idx = min(length, f + ref_length * (num_ref // 2)) for i in range(start_idx, end_idx + 1, ref_length): if i not in neighbor_ids: if len(ref_index) > num_ref: break ref_index.append(i) return ref_index def save_flows(output, videoFlowF, videoFlowB): create_dir(os.path.join(output, 'completed_flow', 'forward_flo')) create_dir(os.path.join(output, 'completed_flow', 'backward_flo')) create_dir(os.path.join(output, 'completed_flow', 'forward_png')) create_dir(os.path.join(output, 'completed_flow', 'backward_png')) N = videoFlowF.shape[-1] for i in range(N): forward_flow = videoFlowF[..., i] backward_flow = videoFlowB[..., i] forward_flow_vis = cvbase.flow2rgb(forward_flow) backward_flow_vis = cvbase.flow2rgb(backward_flow) cvbase.write_flow(forward_flow, os.path.join( output, 'completed_flow', 'forward_flo', '{:05d}.flo'.format(i))) cvbase.write_flow(backward_flow, os.path.join( output, 'completed_flow', 'backward_flo', '{:05d}.flo'.format(i))) imageio.imwrite(os.path.join(output, 'completed_flow', 'forward_png', '{:05d}.png'.format(i)), forward_flow_vis) imageio.imwrite(os.path.join(output, 'completed_flow', 'backward_png', '{:05d}.png'.format(i)), backward_flow_vis) def save_fgcp(output, frames, masks): create_dir(os.path.join(output, 'prop_frames')) create_dir(os.path.join(output, 'masks_left')) create_dir(os.path.join(output, 'prop_frames_npy')) create_dir(os.path.join(output, 'masks_left_npy')) assert len(frames) == masks.shape[2] for i in range(len(frames)): cv2.imwrite(os.path.join(output, 'prop_frames', '%05d.png' % i), frames[i] * 255.) cv2.imwrite(os.path.join(output, 'masks_left', '%05d.png' % i), masks[:, :, i] * 255.) np.save(os.path.join(output, 'prop_frames_npy', '%05d.npy' % i), frames[i] * 255.) np.save(os.path.join(output, 'masks_left_npy', '%05d.npy' % i), masks[:, :, i] * 255.) def create_dir(dir): """Creates a directory if not exist. """ if not os.path.exists(dir): os.makedirs(dir) def initialize_RAFT(args, device): """Initializes the RAFT model. """ model = torch.nn.DataParallel(RAFT(args)) model.load_state_dict(torch.load(args.raft_model)) model = model.module model.to(device) model.eval() return model def initialize_LAFC(args, device): print(args.lafc_ckpts) assert len(os.listdir(args.lafc_ckpts)) == 2 checkpoint, config_file = glob.glob(os.path.join(args.lafc_ckpts, '*.tar'))[0], \ glob.glob(os.path.join(args.lafc_ckpts, '*.yaml'))[0] with open(config_file, 'r') as f: configs = yaml.full_load(f) model = configs['model'] pkg = import_module('LAFC.models.{}'.format(model)) model = pkg.Model(configs) state = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device)) model.load_state_dict(state['model_state_dict']) model = model.to(device) return model, configs def initialize_FGT(args, device): assert len(os.listdir(args.fgt_ckpts)) == 2 checkpoint, config_file = glob.glob(os.path.join(args.fgt_ckpts, '*.tar'))[0], \ glob.glob(os.path.join(args.fgt_ckpts, '*.yaml'))[0] with open(config_file, 'r') as f: configs = yaml.full_load(f) model = configs['model'] net = import_module('FGT.models.{}'.format(model)) model = net.Model(configs).to(device) state = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device)) model.load_state_dict(state['model_state_dict']) return model, configs def calculate_flow(args, model, video, mode): """Calculates optical flow. """ if mode not in ['forward', 'backward']: raise NotImplementedError imgH, imgW = args.imgH, args.imgW Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) if args.vis_flows: create_dir(os.path.join(args.outroot, 'flow', mode + '_flo')) create_dir(os.path.join(args.outroot, 'flow', mode + '_png')) with torch.no_grad(): for i in range(video.shape[0] - 1): print( "Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='') if mode == 'forward': # Flow i -> i + 1 image1 = video[i, None] image2 = video[i + 1, None] elif mode == 'backward': # Flow i + 1 -> i image1 = video[i + 1, None] image2 = video[i, None] else: raise NotImplementedError _, flow = model(image1, image2, iters=20, test_mode=True) flow = flow[0].permute(1, 2, 0).cpu().numpy() # resize optical flows h, w = flow.shape[:2] if h != imgH or w != imgW: flow = cv2.resize(flow, (imgW, imgH), cv2.INTER_LINEAR) flow[:, :, 0] *= (float(imgW) / float(w)) flow[:, :, 1] *= (float(imgH) / float(h)) Flow = np.concatenate((Flow, flow[..., None]), axis=-1) if args.vis_flows: # Flow visualization. flow_img = utils.flow_viz.flow_to_image(flow) flow_img = Image.fromarray(flow_img) # Saves the flow and flow_img. flow_img.save(os.path.join(args.outroot, 'flow', mode + '_png', '%05d.png' % i)) utils.frame_utils.writeFlow(os.path.join( args.outroot, 'flow', mode + '_flo', '%05d.flo' % i), flow) return Flow def extrapolation(args, video_ori, corrFlowF_ori, corrFlowB_ori): """Prepares the data for video extrapolation. """ imgH, imgW, _, nFrame = video_ori.shape # Defines new FOV. imgH_extr = int(args.H_scale * imgH) imgW_extr = int(args.W_scale * imgW) imgH_extr = imgH_extr - imgH_extr % 4 imgW_extr = imgW_extr - imgW_extr % 4 H_start = int((imgH_extr - imgH) / 2) W_start = int((imgW_extr - imgW) / 2) # Generates the mask for missing region. flow_mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.bool) flow_mask[H_start: H_start + imgH, W_start: W_start + imgW] = 0 mask_dilated = gradient_mask(flow_mask) # Extrapolates the FOV for video. video = np.zeros(((imgH_extr, imgW_extr, 3, nFrame)), dtype=np.float32) video[H_start: H_start + imgH, W_start: W_start + imgW, :, :] = video_ori for i in range(nFrame): print("Preparing frame {0}".format(i), '\r', end='') video[:, :, :, i] = cv2.inpaint((video[:, :, :, i] * 255).astype(np.uint8), flow_mask.astype(np.uint8), 3, cv2.INPAINT_TELEA).astype(np.float32) / 255. # Extrapolates the FOV for flow. corrFlowF = np.zeros( ((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32) corrFlowB = np.zeros( ((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32) corrFlowF[H_start: H_start + imgH, W_start: W_start + imgW, :] = corrFlowF_ori corrFlowB[H_start: H_start + imgH, W_start: W_start + imgW, :] = corrFlowB_ori return video, corrFlowF, corrFlowB, flow_mask, mask_dilated, (W_start, H_start), (W_start + imgW, H_start + imgH) def complete_flow(config, flow_model, flows, flow_masks, mode, device): if mode not in ['forward', 'backward']: raise NotImplementedError(f'Error flow mode {mode}') flow_masks = np.moveaxis(flow_masks, -1, 0) # [N, H, W] flows = np.moveaxis(flows, -1, 0) # [N, H, W, 2] if len(flow_masks.shape) == 3: flow_masks = flow_masks[:, :, :, np.newaxis] if mode == 'forward': flow_masks = flow_masks[0:-1] else: flow_masks = flow_masks[1:] num_flows, flow_interval = config['num_flows'], config['flow_interval'] diffused_flows = diffusion(flows, flow_masks) flows = np2tensor(flows) flow_masks = np2tensor(flow_masks) diffused_flows = np2tensor(diffused_flows) flows = flows.to(device) flow_masks = flow_masks.to(device) diffused_flows = diffused_flows.to(device) t = diffused_flows.shape[2] filled_flows = [None] * t pivot = num_flows // 2 for i in range(t): indices = indicesGen(i, flow_interval, num_flows, t) print('Indices: ', indices, '\r', end='') cand_flows = flows[:, :, indices] cand_masks = flow_masks[:, :, indices] inputs = diffused_flows[:, :, indices] pivot_mask = cand_masks[:, :, pivot] pivot_flow = cand_flows[:, :, pivot] with torch.no_grad(): output_flow = flow_model(inputs, cand_masks) if isinstance(output_flow, tuple) or isinstance(output_flow, list): output_flow = output_flow[0] comp = output_flow * pivot_mask + pivot_flow * (1 - pivot_mask) if filled_flows[i] is None: filled_flows[i] = comp assert None not in filled_flows return filled_flows def read_flow(flow_dir, video): nFrame, _, imgH, imgW = video.shape Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) flows = sorted(glob.glob(os.path.join(flow_dir, '*.flo'))) for flow in flows: flow_data = cvbase.read_flow(flow) h, w = flow_data.shape[:2] flow_data = cv2.resize(flow_data, (imgW, imgH), cv2.INTER_LINEAR) flow_data[:, :, 0] *= (float(imgW) / float(w)) flow_data[:, :, 1] *= (float(imgH) / float(h)) Flow = np.concatenate((Flow, flow_data[..., None]), axis=-1) return Flow def norm_flows(flows): assert len(flows.shape) == 5, 'FLow shape: {}'.format(flows.shape) flattened_flows = flows.flatten(3) flow_max = torch.max(flattened_flows, dim=-1, keepdim=True)[0] flows = flows / flow_max.unsqueeze(-1) return flows def save_results(outdir, comp_frames): out_dir = os.path.join(outdir, 'frames') if not os.path.exists(out_dir): os.makedirs(out_dir) for i in range(len(comp_frames)): out_path = os.path.join(out_dir, '{:05d}.png'.format(i)) cv2.imwrite(out_path, comp_frames[i][:, :, ::-1]) def video_inpainting(args, imgArr, imgMaskArr): device = torch.device('cuda:{}'.format(args.gpu)) print(args) if args.opt is not None: with open(args.opt, 'r') as f: opts = yaml.full_load(f) for k in opts.keys(): if k in args: setattr(args, k, opts[k]) print(args) # Flow model. RAFT_model = initialize_RAFT(args, device) # LAFC (flow completion) LAFC_model, LAFC_config = initialize_LAFC(args, device) # FGT FGT_model, FGT_config = initialize_FGT(args, device) # Loads frames. # filename_list = glob.glob(os.path.join(args.path, '*.png')) + \ # glob.glob(os.path.join(args.path, '*.jpg')) # Obtains imgH, imgW and nFrame. imgH, imgW = args.imgH, args.imgW # nFrame = len(filename_list) nFrame = len(imgArr) if imgH < 350: flowH, flowW = imgH * 2, imgW * 2 else: flowH, flowW = imgH, imgW # Load video. video, video_flow = [], [] if args.mode == 'watermark_removal': maskname_list = glob.glob(os.path.join(args.path_mask, '*.png')) + glob.glob( os.path.join(args.path_mask, '*.jpg')) assert len(filename_list) == len(maskname_list) for filename, maskname in zip(sorted(filename_list), sorted(maskname_list)): frame = torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)).permute(2, 0, 1).float().unsqueeze(0) mask = torch.from_numpy(np.array(Image.open(maskname)).astype(np.uint8)).permute(2, 0, 1).float().unsqueeze(0) mask[mask > 0] = 1 frame = frame * (1 - mask) frame = F2.upsample(frame, size=(imgH, imgW), mode='bilinear', align_corners=False) frame_flow = F2.upsample(frame, size=( flowH, flowW), mode='bilinear', align_corners=False) video.append(frame) video_flow.append(frame_flow) else: '''for filename in sorted(filename_list): frame = torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)).permute(2, 0, 1).float().unsqueeze(0) frame = F2.upsample(frame, size=(imgH, imgW), mode='bilinear', align_corners=False) frame_flow = F2.upsample(frame, size=(flowH, flowW), mode='bilinear', align_corners=False) video.append(frame) video_flow.append(frame_flow)''' for im in imgArr: frame = torch.from_numpy(np.array(im).astype( np.uint8)).permute(2, 0, 1).float().unsqueeze(0) frame = F2.upsample(frame, size=(imgH, imgW), mode='bilinear', align_corners=False) frame_flow = F2.upsample(frame, size=( flowH, flowW), mode='bilinear', align_corners=False) video.append(frame) video_flow.append(frame_flow) video = torch.cat(video, dim=0) # [n, c, h, w] video_flow = torch.cat(video_flow, dim=0) gts = video.clone() video = video.to(device) video_flow = video_flow.to(device) # Calcutes the corrupted flow. forward_flows = calculate_flow( args, RAFT_model, video_flow, 'forward') # [B, C, 2, N] backward_flows = calculate_flow(args, RAFT_model, video_flow, 'backward') # Makes sure video is in BGR (opencv) format. video = video.permute(2, 3, 1, 0).cpu().numpy()[ :, :, ::-1, :] / 255. # np array -> [h, w, c, N] (0~1) if args.mode == 'video_extrapolation': # Creates video and flow where the extrapolated region are missing. video, forward_flows, backward_flows, flow_mask, mask_dilated, start_point, end_point = extrapolation(args, video, forward_flows, backward_flows) imgH, imgW = video.shape[:2] # mask indicating the missing region in the video. mask = np.tile(flow_mask[..., None], (1, 1, nFrame)) flow_mask = np.tile(flow_mask[..., None], (1, 1, nFrame)) mask_dilated = np.tile(mask_dilated[..., None], (1, 1, nFrame)) else: # Loads masks. filename_list = glob.glob(os.path.join(args.path_mask, '*.png')) + \ glob.glob(os.path.join(args.path_mask, '*.jpg')) mask = [] mask_dilated = [] flow_mask = [] '''for filename in sorted(filename_list): mask_img = np.array(Image.open(filename).convert('L')) mask_img = cv2.resize(mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST) if args.flow_mask_dilates > 0: flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=args.flow_mask_dilates) else: flow_mask_img = mask_img flow_mask.append(flow_mask_img) if args.frame_dilates > 0: mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=args.frame_dilates) mask.append(mask_img) mask_dilated.append(gradient_mask(mask_img))''' for f_mask in imgMaskArr: mask_img = np.array(f_mask) mask_img = cv2.resize(mask_img, dsize=( imgW, imgH), interpolation=cv2.INTER_NEAREST) if args.flow_mask_dilates > 0: flow_mask_img = scipy.ndimage.binary_dilation( mask_img, iterations=args.flow_mask_dilates) else: flow_mask_img = mask_img flow_mask.append(flow_mask_img) if args.frame_dilates > 0: mask_img = scipy.ndimage.binary_dilation( mask_img, iterations=args.frame_dilates) mask.append(mask_img) mask_dilated.append(gradient_mask(mask_img)) # mask indicating the missing region in the video. mask = np.stack(mask, -1).astype(np.bool) # [H, W, C, N] mask_dilated = np.stack(mask_dilated, -1).astype(np.bool) flow_mask = np.stack(flow_mask, -1).astype(np.bool) # Completes the flow. videoFlowF = complete_flow( LAFC_config, LAFC_model, forward_flows, flow_mask, 'forward', device) videoFlowB = complete_flow( LAFC_config, LAFC_model, backward_flows, flow_mask, 'backward', device) videoFlowF = tensor2np(videoFlowF) videoFlowB = tensor2np(videoFlowB) print('\nFinish flow completion.') if args.vis_completed_flows: save_flows(args.outroot, videoFlowF, videoFlowB) # Prepare gradients gradient_x = np.empty(((imgH, imgW, 3, 0)), dtype=np.float32) gradient_y = np.empty(((imgH, imgW, 3, 0)), dtype=np.float32) for indFrame in range(nFrame): img = video[:, :, :, indFrame] img[mask[:, :, indFrame], :] = 0 img = cv2.inpaint((img * 255).astype(np.uint8), mask[:, :, indFrame].astype(np.uint8), 3, cv2.INPAINT_TELEA).astype(np.float32) / 255. gradient_x_ = np.concatenate((np.diff(img, axis=1), np.zeros((imgH, 1, 3), dtype=np.float32)), axis=1) gradient_y_ = np.concatenate( (np.diff(img, axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0) gradient_x = np.concatenate( (gradient_x, gradient_x_.reshape(imgH, imgW, 3, 1)), axis=-1) gradient_y = np.concatenate( (gradient_y, gradient_y_.reshape(imgH, imgW, 3, 1)), axis=-1) gradient_x[mask_dilated[:, :, indFrame], :, indFrame] = 0 gradient_y[mask_dilated[:, :, indFrame], :, indFrame] = 0 gradient_x_filled = gradient_x gradient_y_filled = gradient_y mask_gradient = mask_dilated video_comp = video # Gradient propagation. gradient_x_filled, gradient_y_filled, mask_gradient = \ get_flowNN_gradient(args, gradient_x_filled, gradient_y_filled, mask, mask_gradient, videoFlowF, videoFlowB, None, None) # if there exist holes in mask, Poisson blending will fail. So I did this trick. I sacrifice some value. Another solution is to modify Poisson blending. for indFrame in range(nFrame): mask_gradient[:, :, indFrame] = scipy.ndimage.binary_fill_holes(mask_gradient[:, :, indFrame]).astype( np.bool) # After one gradient propagation iteration # gradient --> RGB frameBlends = [] for indFrame in range(nFrame): print("Poisson blending frame {0:3d}".format(indFrame)) if mask[:, :, indFrame].sum() > 0: try: frameBlend, UnfilledMask = Poisson_blend_img(video_comp[:, :, :, indFrame], gradient_x_filled[:, 0: imgW - 1, :, indFrame], gradient_y_filled[0: imgH - 1, :, :, indFrame], mask[:, :, indFrame], mask_gradient[:, :, indFrame]) except: frameBlend, UnfilledMask = video_comp[:, :, :, indFrame], mask[:, :, indFrame] frameBlend = np.clip(frameBlend, 0, 1.0) tmp = cv2.inpaint((frameBlend * 255).astype(np.uint8), UnfilledMask.astype(np.uint8), 3, cv2.INPAINT_TELEA).astype(np.float32) / 255. frameBlend[UnfilledMask, :] = tmp[UnfilledMask, :] video_comp[:, :, :, indFrame] = frameBlend mask[:, :, indFrame] = UnfilledMask frameBlend_ = copy.deepcopy(frameBlend) # Green indicates the regions that are not filled yet. frameBlend_[mask[:, :, indFrame], :] = [0, 1., 0] else: frameBlend_ = video_comp[:, :, :, indFrame] frameBlends.append(frameBlend_) if args.vis_prop: save_fgcp(args.outroot, frameBlends, mask) video_length = len(frameBlends) for i in range(len(frameBlends)): frameBlends[i] = frameBlends[i][:, :, ::-1] frames_first = np2tensor(frameBlends, near='t').to(device) mask = np.moveaxis(mask, -1, 0) mask = mask[:, :, :, np.newaxis] masks = np2tensor(mask, near='t').to(device) normed_frames = frames_first * 2 - 1 comp_frames = [None] * video_length ref_length = args.step num_ref = args.num_ref neighbor_stride = args.neighbor_stride videoFlowF = np.moveaxis(videoFlowF, -1, 0) videoFlowF = np.concatenate([videoFlowF, videoFlowF[-1:, ...]], axis=0) flows = np2tensor(videoFlowF, near='t') flows = norm_flows(flows).to(device) for f in range(0, video_length, neighbor_stride): neighbor_ids = [i for i in range( max(0, f - neighbor_stride), min(video_length, f + neighbor_stride + 1))] ref_ids = get_ref_index( f, neighbor_ids, video_length, ref_length, num_ref) print(f, len(neighbor_ids), len(ref_ids)) selected_frames = normed_frames[:, neighbor_ids + ref_ids] selected_masks = masks[:, neighbor_ids + ref_ids] masked_frames = selected_frames * (1 - selected_masks) selected_flows = flows[:, neighbor_ids + ref_ids] with torch.no_grad(): filled_frames = FGT_model( masked_frames, selected_flows, selected_masks) filled_frames = (filled_frames + 1) / 2 filled_frames = filled_frames.cpu().permute(0, 2, 3, 1).numpy() * 255 for i in range(len(neighbor_ids)): idx = neighbor_ids[i] valid_frame = frames_first[0, idx].cpu().permute( 1, 2, 0).numpy() * 255. valid_mask = masks[0, idx].cpu().permute(1, 2, 0).numpy() comp = np.array(filled_frames[i]).astype(np.uint8) * valid_mask + \ np.array(valid_frame).astype(np.uint8) * (1 - valid_mask) if comp_frames[idx] is None: comp_frames[idx] = comp else: comp_frames[idx] = comp_frames[idx].astype( np.float32) * 0.5 + comp.astype(np.float32) * 0.5 if args.vis_frame: save_results(args.outroot, comp_frames) create_dir(args.outroot) for i in range(len(comp_frames)): comp_frames[i] = comp_frames[i].astype(np.uint8) imageio.mimwrite(os.path.join(args.outroot, 'result.mp4'), comp_frames, fps=30, quality=8) print(f'Done, please check your result in {args.outroot} ') def main(args): assert args.mode in ('object_removal', 'video_extrapolation', 'watermark_removal'), ( "Accepted modes: 'object_removal', 'video_extrapolation', and 'watermark_removal', but input is %s" ) % args.mode video_inpainting(args) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--opt', default='configs/object_removal.yaml', help='Please select your config file for inference') # video completion parser.add_argument('--mode', default='object_removal', choices=[ 'object_removal', 'watermark_removal', 'video_extrapolation'], help="modes: object_removal / video_extrapolation") parser.add_argument( '--path', default='/myData/davis_resized/walking', help="dataset for evaluation") parser.add_argument( '--path_mask', default='/myData/dilateAnnotations_4/walking', help="mask for object removal") parser.add_argument( '--outroot', default='quick_start/walking3', help="output directory") parser.add_argument('--consistencyThres', dest='consistencyThres', default=5, type=float, help='flow consistency error threshold') parser.add_argument('--alpha', dest='alpha', default=0.1, type=float) parser.add_argument('--Nonlocal', dest='Nonlocal', default=False, type=bool) # RAFT parser.add_argument( '--raft_model', default='../LAFC/flowCheckPoint/raft-things.pth', help="restore checkpoint") parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') # LAFC parser.add_argument('--lafc_ckpts', type=str, default='../LAFC/checkpoint') # FGT parser.add_argument('--fgt_ckpts', type=str, default='../FGT/checkpoint') # extrapolation parser.add_argument('--H_scale', dest='H_scale', default=2, type=float, help='H extrapolation scale') parser.add_argument('--W_scale', dest='W_scale', default=2, type=float, help='W extrapolation scale') # Image basic information parser.add_argument('--imgH', type=int, default=256) parser.add_argument('--imgW', type=int, default=432) parser.add_argument('--flow_mask_dilates', type=int, default=8) parser.add_argument('--frame_dilates', type=int, default=0) parser.add_argument('--gpu', type=int, default=0) # FGT inference parameters parser.add_argument('--step', type=int, default=10) parser.add_argument('--num_ref', type=int, default=-1) parser.add_argument('--neighbor_stride', type=int, default=5) # visualization parser.add_argument('--vis_flows', action='store_true', help='Visualize the initialized flows') parser.add_argument('--vis_completed_flows', action='store_true', help='Visualize the completed flows') parser.add_argument('--vis_prop', action='store_true', help='Visualize the frames after stage-I filling (flow guided content propagation)') parser.add_argument('--vis_frame', action='store_true', help='Visualize frames') args = parser.parse_args() main(args)