video-object-remover / FGT_codes /tool /video_inpainting.py
oguzakif's picture
add datetime to output files to prevent conflict
2838f54
raw
history blame
No virus
30.9 kB
import sys
import os
import warnings
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..", "FGT")))
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..", "LAFC")))
warnings.filterwarnings("ignore")
import cvbase
from torchvision.transforms import ToTensor
from tool.get_flowNN_gradient import get_flowNN_gradient
from tool.utils.Poisson_blend_img import Poisson_blend_img
from tool.utils.region_fill import regionfill
from importlib import import_module
import yaml
from RAFT import RAFT
from RAFT import utils
import torch.nn.functional as F2
import torchvision.transforms.functional as F
from skimage.feature import canny
import scipy.ndimage
from PIL import Image
import imageio
import torch
import numpy as np
import copy
import glob
import cv2
import argparse
def to_tensor(img):
img = Image.fromarray(img)
img_t = F.to_tensor(img).float()
return img_t
def diffusion(flows, masks):
flows_filled = []
for i in range(flows.shape[0]):
flow, mask = flows[i], masks[i]
flow_filled = np.zeros(flow.shape)
flow_filled[:, :, 0] = regionfill(flow[:, :, 0], mask[:, :, 0])
flow_filled[:, :, 1] = regionfill(flow[:, :, 1], mask[:, :, 0])
flows_filled.append(flow_filled)
return flows_filled
def np2tensor(array, near="c"):
if isinstance(array, list):
array = np.stack(array, axis=0) # [t, h, w, c]
if near == "c":
array = (
torch.from_numpy(np.transpose(array, (3, 0, 1, 2))).unsqueeze(0).float()
) # [1, c, t, h, w]
elif near == "t":
array = torch.from_numpy(np.transpose(array, (0, 3, 1, 2))).unsqueeze(0).float()
else:
raise ValueError(f"Unknown near type: {near}")
return array
def tensor2np(array):
array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy()
return array
def gradient_mask(mask):
gradient_mask = np.logical_or.reduce(
(
mask,
np.concatenate(
(mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)), axis=0
),
np.concatenate(
(mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)), axis=1
),
)
)
return gradient_mask
def indicesGen(pivot, interval, frames, t):
singleSide = frames // 2
results = []
for i in range(-singleSide, singleSide + 1):
index = pivot + interval * i
if index < 0:
index = abs(index)
if index > t - 1:
index = 2 * (t - 1) - index
results.append(index)
return results
def get_ref_index(f, neighbor_ids, length, ref_length, num_ref):
ref_index = []
if num_ref == -1:
for i in range(0, length, ref_length):
if i not in neighbor_ids:
ref_index.append(i)
else:
start_idx = max(0, f - ref_length * (num_ref // 2))
end_idx = min(length, f + ref_length * (num_ref // 2))
for i in range(start_idx, end_idx + 1, ref_length):
if i not in neighbor_ids:
if len(ref_index) > num_ref:
break
ref_index.append(i)
return ref_index
def save_flows(output, videoFlowF, videoFlowB):
create_dir(os.path.join(output, "completed_flow", "forward_flo"))
create_dir(os.path.join(output, "completed_flow", "backward_flo"))
create_dir(os.path.join(output, "completed_flow", "forward_png"))
create_dir(os.path.join(output, "completed_flow", "backward_png"))
N = videoFlowF.shape[-1]
for i in range(N):
forward_flow = videoFlowF[..., i]
backward_flow = videoFlowB[..., i]
forward_flow_vis = cvbase.flow2rgb(forward_flow)
backward_flow_vis = cvbase.flow2rgb(backward_flow)
cvbase.write_flow(
forward_flow,
os.path.join(
output, "completed_flow", "forward_flo", "{:05d}.flo".format(i)
),
)
cvbase.write_flow(
backward_flow,
os.path.join(
output, "completed_flow", "backward_flo", "{:05d}.flo".format(i)
),
)
imageio.imwrite(
os.path.join(
output, "completed_flow", "forward_png", "{:05d}.png".format(i)
),
forward_flow_vis,
)
imageio.imwrite(
os.path.join(
output, "completed_flow", "backward_png", "{:05d}.png".format(i)
),
backward_flow_vis,
)
def save_fgcp(output, frames, masks):
create_dir(os.path.join(output, "prop_frames"))
create_dir(os.path.join(output, "masks_left"))
create_dir(os.path.join(output, "prop_frames_npy"))
create_dir(os.path.join(output, "masks_left_npy"))
assert len(frames) == masks.shape[2]
for i in range(len(frames)):
cv2.imwrite(
os.path.join(output, "prop_frames", "%05d.png" % i), frames[i] * 255.0
)
cv2.imwrite(
os.path.join(output, "masks_left", "%05d.png" % i), masks[:, :, i] * 255.0
)
np.save(
os.path.join(output, "prop_frames_npy", "%05d.npy" % i), frames[i] * 255.0
)
np.save(
os.path.join(output, "masks_left_npy", "%05d.npy" % i),
masks[:, :, i] * 255.0,
)
def create_dir(dir):
"""Creates a directory if not exist."""
if not os.path.exists(dir):
os.makedirs(dir)
def initialize_RAFT(args, device):
"""Initializes the RAFT model."""
model = torch.nn.DataParallel(RAFT(args))
if not torch.cuda.is_available():
model.load_state_dict(torch.load(args.raft_model, map_location=lambda storage, loc: storage))
else:
model.load_state_dict(torch.load(args.raft_model))
model = model.module
model.to(device)
model.eval()
return model
def initialize_LAFC(args, device):
print(args.lafc_ckpts)
assert len(os.listdir(args.lafc_ckpts)) == 2
checkpoint, config_file = (
glob.glob(os.path.join(args.lafc_ckpts, "*.tar"))[0],
glob.glob(os.path.join(args.lafc_ckpts, "*.yaml"))[0],
)
with open(config_file, "r") as f:
configs = yaml.full_load(f)
model = configs["model"]
pkg = import_module("LAFC.models.{}".format(model))
model = pkg.Model(configs)
if not torch.cuda.is_available():
state = torch.load(
checkpoint, map_location=lambda storage, loc: storage
)
else:
state = torch.load(
checkpoint, map_location=lambda storage, loc: storage.cuda(device)
)
model.load_state_dict(state["model_state_dict"])
model = model.to(device)
return model, configs
def initialize_FGT(args, device):
assert len(os.listdir(args.fgt_ckpts)) == 2
checkpoint, config_file = (
glob.glob(os.path.join(args.fgt_ckpts, "*.tar"))[0],
glob.glob(os.path.join(args.fgt_ckpts, "*.yaml"))[0],
)
with open(config_file, "r") as f:
configs = yaml.full_load(f)
model = configs["model"]
net = import_module("FGT.models.{}".format(model))
model = net.Model(configs).to(device)
if not torch.cuda.is_available():
state = torch.load(
checkpoint, map_location=lambda storage, loc: storage
)
else:
state = torch.load(
checkpoint, map_location=lambda storage, loc: storage.cuda(device)
)
model.load_state_dict(state["model_state_dict"])
return model, configs
def calculate_flow(args, model, video, mode):
"""Calculates optical flow."""
if mode not in ["forward", "backward"]:
raise NotImplementedError
imgH, imgW = args.imgH, args.imgW
Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
if args.vis_flows:
create_dir(os.path.join(args.outroot, "flow", mode + "_flo"))
create_dir(os.path.join(args.outroot, "flow", mode + "_png"))
with torch.no_grad():
for i in range(video.shape[0] - 1):
print(
"Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1),
"\r",
end="",
)
if mode == "forward":
# Flow i -> i + 1
image1 = video[i, None]
image2 = video[i + 1, None]
elif mode == "backward":
# Flow i + 1 -> i
image1 = video[i + 1, None]
image2 = video[i, None]
else:
raise NotImplementedError
_, flow = model(image1, image2, iters=20, test_mode=True)
flow = flow[0].permute(1, 2, 0).cpu().numpy()
# resize optical flows
h, w = flow.shape[:2]
if h != imgH or w != imgW:
flow = cv2.resize(flow, (imgW, imgH), cv2.INTER_LINEAR)
flow[:, :, 0] *= float(imgW) / float(w)
flow[:, :, 1] *= float(imgH) / float(h)
Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
if args.vis_flows:
# Flow visualization.
flow_img = utils.flow_viz.flow_to_image(flow)
flow_img = Image.fromarray(flow_img)
# Saves the flow and flow_img.
flow_img.save(
os.path.join(args.outroot, "flow", mode + "_png", "%05d.png" % i)
)
utils.frame_utils.writeFlow(
os.path.join(args.outroot, "flow", mode + "_flo", "%05d.flo" % i),
flow,
)
return Flow
def extrapolation(args, video_ori, corrFlowF_ori, corrFlowB_ori):
"""Prepares the data for video extrapolation."""
imgH, imgW, _, nFrame = video_ori.shape
# Defines new FOV.
imgH_extr = int(args.H_scale * imgH)
imgW_extr = int(args.W_scale * imgW)
imgH_extr = imgH_extr - imgH_extr % 4
imgW_extr = imgW_extr - imgW_extr % 4
H_start = int((imgH_extr - imgH) / 2)
W_start = int((imgW_extr - imgW) / 2)
# Generates the mask for missing region.
flow_mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.bool)
flow_mask[H_start : H_start + imgH, W_start : W_start + imgW] = 0
mask_dilated = gradient_mask(flow_mask)
# Extrapolates the FOV for video.
video = np.zeros(((imgH_extr, imgW_extr, 3, nFrame)), dtype=np.float32)
video[H_start : H_start + imgH, W_start : W_start + imgW, :, :] = video_ori
for i in range(nFrame):
print("Preparing frame {0}".format(i), "\r", end="")
video[:, :, :, i] = (
cv2.inpaint(
(video[:, :, :, i] * 255).astype(np.uint8),
flow_mask.astype(np.uint8),
3,
cv2.INPAINT_TELEA,
).astype(np.float32)
/ 255.0
)
# Extrapolates the FOV for flow.
corrFlowF = np.zeros(((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32)
corrFlowB = np.zeros(((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32)
corrFlowF[H_start : H_start + imgH, W_start : W_start + imgW, :] = corrFlowF_ori
corrFlowB[H_start : H_start + imgH, W_start : W_start + imgW, :] = corrFlowB_ori
return (
video,
corrFlowF,
corrFlowB,
flow_mask,
mask_dilated,
(W_start, H_start),
(W_start + imgW, H_start + imgH),
)
def complete_flow(config, flow_model, flows, flow_masks, mode, device):
if mode not in ["forward", "backward"]:
raise NotImplementedError(f"Error flow mode {mode}")
flow_masks = np.moveaxis(flow_masks, -1, 0) # [N, H, W]
flows = np.moveaxis(flows, -1, 0) # [N, H, W, 2]
if len(flow_masks.shape) == 3:
flow_masks = flow_masks[:, :, :, np.newaxis]
if mode == "forward":
flow_masks = flow_masks[0:-1]
else:
flow_masks = flow_masks[1:]
num_flows, flow_interval = config["num_flows"], config["flow_interval"]
diffused_flows = diffusion(flows, flow_masks)
flows = np2tensor(flows)
flow_masks = np2tensor(flow_masks)
diffused_flows = np2tensor(diffused_flows)
flows = flows.to(device)
flow_masks = flow_masks.to(device)
diffused_flows = diffused_flows.to(device)
t = diffused_flows.shape[2]
filled_flows = [None] * t
pivot = num_flows // 2
for i in range(t):
indices = indicesGen(i, flow_interval, num_flows, t)
print("Indices: ", indices, "\r", end="")
cand_flows = flows[:, :, indices]
cand_masks = flow_masks[:, :, indices]
inputs = diffused_flows[:, :, indices]
pivot_mask = cand_masks[:, :, pivot]
pivot_flow = cand_flows[:, :, pivot]
with torch.no_grad():
output_flow = flow_model(inputs, cand_masks)
if isinstance(output_flow, tuple) or isinstance(output_flow, list):
output_flow = output_flow[0]
comp = output_flow * pivot_mask + pivot_flow * (1 - pivot_mask)
if filled_flows[i] is None:
filled_flows[i] = comp
assert None not in filled_flows
return filled_flows
def read_flow(flow_dir, video):
nFrame, _, imgH, imgW = video.shape
Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
flows = sorted(glob.glob(os.path.join(flow_dir, "*.flo")))
for flow in flows:
flow_data = cvbase.read_flow(flow)
h, w = flow_data.shape[:2]
flow_data = cv2.resize(flow_data, (imgW, imgH), cv2.INTER_LINEAR)
flow_data[:, :, 0] *= float(imgW) / float(w)
flow_data[:, :, 1] *= float(imgH) / float(h)
Flow = np.concatenate((Flow, flow_data[..., None]), axis=-1)
return Flow
def norm_flows(flows):
assert len(flows.shape) == 5, "FLow shape: {}".format(flows.shape)
flattened_flows = flows.flatten(3)
flow_max = torch.max(flattened_flows, dim=-1, keepdim=True)[0]
flows = flows / flow_max.unsqueeze(-1)
return flows
def save_results(outdir, comp_frames):
out_dir = os.path.join(outdir, "frames")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
for i in range(len(comp_frames)):
out_path = os.path.join(out_dir, "{:05d}.png".format(i))
cv2.imwrite(out_path, comp_frames[i][:, :, ::-1])
def video_inpainting(args, imgArr, imgMaskArr):
#device = torch.device("cuda:{}".format(args.gpu))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(args)
if args.opt is not None:
with open(args.opt, "r") as f:
opts = yaml.full_load(f)
for k in opts.keys():
if k in args:
setattr(args, k, opts[k])
print(args)
# Flow model.
RAFT_model = initialize_RAFT(args, device)
# LAFC (flow completion)
LAFC_model, LAFC_config = initialize_LAFC(args, device)
# FGT
FGT_model, FGT_config = initialize_FGT(args, device)
# Loads frames.
# filename_list = glob.glob(os.path.join(args.path, '*.png')) + \
# glob.glob(os.path.join(args.path, '*.jpg'))
# Obtains imgH, imgW and nFrame.
imgH, imgW = args.imgH, args.imgW
# nFrame = len(filename_list)
nFrame = len(imgArr)
if imgH < 350:
flowH, flowW = imgH * 2, imgW * 2
else:
flowH, flowW = imgH, imgW
# Load video.
video, video_flow = [], []
if args.mode == "watermark_removal":
maskname_list = glob.glob(os.path.join(args.path_mask, "*.png")) + glob.glob(
os.path.join(args.path_mask, "*.jpg")
)
assert len(filename_list) == len(maskname_list)
for filename, maskname in zip(sorted(filename_list), sorted(maskname_list)):
frame = (
torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8))
.permute(2, 0, 1)
.float()
.unsqueeze(0)
)
mask = (
torch.from_numpy(np.array(Image.open(maskname)).astype(np.uint8))
.permute(2, 0, 1)
.float()
.unsqueeze(0)
)
mask[mask > 0] = 1
frame = frame * (1 - mask)
frame = F2.upsample(
frame, size=(imgH, imgW), mode="bilinear", align_corners=False
)
frame_flow = F2.upsample(
frame, size=(flowH, flowW), mode="bilinear", align_corners=False
)
video.append(frame)
video_flow.append(frame_flow)
else:
"""for filename in sorted(filename_list):
frame = torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)).permute(2, 0, 1).float().unsqueeze(0)
frame = F2.upsample(frame, size=(imgH, imgW), mode='bilinear', align_corners=False)
frame_flow = F2.upsample(frame, size=(flowH, flowW), mode='bilinear', align_corners=False)
video.append(frame)
video_flow.append(frame_flow)"""
for im in imgArr:
frame = (
torch.from_numpy(np.array(im).astype(np.uint8))
.permute(2, 0, 1)
.float()
.unsqueeze(0)
)
frame = F2.upsample(
frame, size=(imgH, imgW), mode="bilinear", align_corners=False
)
frame_flow = F2.upsample(
frame, size=(flowH, flowW), mode="bilinear", align_corners=False
)
video.append(frame)
video_flow.append(frame_flow)
video = torch.cat(video, dim=0) # [n, c, h, w]
video_flow = torch.cat(video_flow, dim=0)
gts = video.clone()
video = video.to(device)
video_flow = video_flow.to(device)
# Calcutes the corrupted flow.
forward_flows = calculate_flow(
args, RAFT_model, video_flow, "forward"
) # [B, C, 2, N]
backward_flows = calculate_flow(args, RAFT_model, video_flow, "backward")
# Makes sure video is in BGR (opencv) format.
video = (
video.permute(2, 3, 1, 0).cpu().numpy()[:, :, ::-1, :] / 255.0
) # np array -> [h, w, c, N] (0~1)
if args.mode == "video_extrapolation":
# Creates video and flow where the extrapolated region are missing.
(
video,
forward_flows,
backward_flows,
flow_mask,
mask_dilated,
start_point,
end_point,
) = extrapolation(args, video, forward_flows, backward_flows)
imgH, imgW = video.shape[:2]
# mask indicating the missing region in the video.
mask = np.tile(flow_mask[..., None], (1, 1, nFrame))
flow_mask = np.tile(flow_mask[..., None], (1, 1, nFrame))
mask_dilated = np.tile(mask_dilated[..., None], (1, 1, nFrame))
else:
# Loads masks.
filename_list = glob.glob(os.path.join(args.path_mask, "*.png")) + glob.glob(
os.path.join(args.path_mask, "*.jpg")
)
mask = []
mask_dilated = []
flow_mask = []
"""for filename in sorted(filename_list):
mask_img = np.array(Image.open(filename).convert('L'))
mask_img = cv2.resize(mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST)
if args.flow_mask_dilates > 0:
flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=args.flow_mask_dilates)
else:
flow_mask_img = mask_img
flow_mask.append(flow_mask_img)
if args.frame_dilates > 0:
mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=args.frame_dilates)
mask.append(mask_img)
mask_dilated.append(gradient_mask(mask_img))"""
for f_mask in imgMaskArr:
mask_img = np.array(f_mask)
mask_img = cv2.resize(
mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST
)
if args.flow_mask_dilates > 0:
flow_mask_img = scipy.ndimage.binary_dilation(
mask_img, iterations=args.flow_mask_dilates
)
else:
flow_mask_img = mask_img
flow_mask.append(flow_mask_img)
if args.frame_dilates > 0:
mask_img = scipy.ndimage.binary_dilation(
mask_img, iterations=args.frame_dilates
)
mask.append(mask_img)
mask_dilated.append(gradient_mask(mask_img))
# mask indicating the missing region in the video.
mask = np.stack(mask, -1).astype(np.bool) # [H, W, C, N]
mask_dilated = np.stack(mask_dilated, -1).astype(np.bool)
flow_mask = np.stack(flow_mask, -1).astype(np.bool)
# Completes the flow.
videoFlowF = complete_flow(
LAFC_config, LAFC_model, forward_flows, flow_mask, "forward", device
)
videoFlowB = complete_flow(
LAFC_config, LAFC_model, backward_flows, flow_mask, "backward", device
)
videoFlowF = tensor2np(videoFlowF)
videoFlowB = tensor2np(videoFlowB)
print("\nFinish flow completion.")
if args.vis_completed_flows:
save_flows(args.outroot, videoFlowF, videoFlowB)
# Prepare gradients
gradient_x = np.empty(((imgH, imgW, 3, 0)), dtype=np.float32)
gradient_y = np.empty(((imgH, imgW, 3, 0)), dtype=np.float32)
for indFrame in range(nFrame):
img = video[:, :, :, indFrame]
img[mask[:, :, indFrame], :] = 0
img = (
cv2.inpaint(
(img * 255).astype(np.uint8),
mask[:, :, indFrame].astype(np.uint8),
3,
cv2.INPAINT_TELEA,
).astype(np.float32)
/ 255.0
)
gradient_x_ = np.concatenate(
(np.diff(img, axis=1), np.zeros((imgH, 1, 3), dtype=np.float32)), axis=1
)
gradient_y_ = np.concatenate(
(np.diff(img, axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0
)
gradient_x = np.concatenate(
(gradient_x, gradient_x_.reshape(imgH, imgW, 3, 1)), axis=-1
)
gradient_y = np.concatenate(
(gradient_y, gradient_y_.reshape(imgH, imgW, 3, 1)), axis=-1
)
gradient_x[mask_dilated[:, :, indFrame], :, indFrame] = 0
gradient_y[mask_dilated[:, :, indFrame], :, indFrame] = 0
gradient_x_filled = gradient_x
gradient_y_filled = gradient_y
mask_gradient = mask_dilated
video_comp = video
# Gradient propagation.
gradient_x_filled, gradient_y_filled, mask_gradient = get_flowNN_gradient(
args,
gradient_x_filled,
gradient_y_filled,
mask,
mask_gradient,
videoFlowF,
videoFlowB,
None,
None,
)
# if there exist holes in mask, Poisson blending will fail. So I did this trick. I sacrifice some value. Another solution is to modify Poisson blending.
for indFrame in range(nFrame):
mask_gradient[:, :, indFrame] = scipy.ndimage.binary_fill_holes(
mask_gradient[:, :, indFrame]
).astype(np.bool)
# After one gradient propagation iteration
# gradient --> RGB
frameBlends = []
for indFrame in range(nFrame):
print("Poisson blending frame {0:3d}".format(indFrame))
if mask[:, :, indFrame].sum() > 0:
try:
frameBlend, UnfilledMask = Poisson_blend_img(
video_comp[:, :, :, indFrame],
gradient_x_filled[:, 0 : imgW - 1, :, indFrame],
gradient_y_filled[0 : imgH - 1, :, :, indFrame],
mask[:, :, indFrame],
mask_gradient[:, :, indFrame],
)
except:
frameBlend, UnfilledMask = (
video_comp[:, :, :, indFrame],
mask[:, :, indFrame],
)
frameBlend = np.clip(frameBlend, 0, 1.0)
tmp = (
cv2.inpaint(
(frameBlend * 255).astype(np.uint8),
UnfilledMask.astype(np.uint8),
3,
cv2.INPAINT_TELEA,
).astype(np.float32)
/ 255.0
)
frameBlend[UnfilledMask, :] = tmp[UnfilledMask, :]
video_comp[:, :, :, indFrame] = frameBlend
mask[:, :, indFrame] = UnfilledMask
frameBlend_ = copy.deepcopy(frameBlend)
# Green indicates the regions that are not filled yet.
frameBlend_[mask[:, :, indFrame], :] = [0, 1.0, 0]
else:
frameBlend_ = video_comp[:, :, :, indFrame]
frameBlends.append(frameBlend_)
if args.vis_prop:
save_fgcp(args.outroot, frameBlends, mask)
video_length = len(frameBlends)
for i in range(len(frameBlends)):
frameBlends[i] = frameBlends[i][:, :, ::-1]
frames_first = np2tensor(frameBlends, near="t").to(device)
mask = np.moveaxis(mask, -1, 0)
mask = mask[:, :, :, np.newaxis]
masks = np2tensor(mask, near="t").to(device)
normed_frames = frames_first * 2 - 1
comp_frames = [None] * video_length
ref_length = args.step
num_ref = args.num_ref
neighbor_stride = args.neighbor_stride
videoFlowF = np.moveaxis(videoFlowF, -1, 0)
videoFlowF = np.concatenate([videoFlowF, videoFlowF[-1:, ...]], axis=0)
flows = np2tensor(videoFlowF, near="t")
flows = norm_flows(flows).to(device)
for f in range(0, video_length, neighbor_stride):
neighbor_ids = [
i
for i in range(
max(0, f - neighbor_stride), min(video_length, f + neighbor_stride + 1)
)
]
ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_length, num_ref)
print(f, len(neighbor_ids), len(ref_ids))
selected_frames = normed_frames[:, neighbor_ids + ref_ids]
selected_masks = masks[:, neighbor_ids + ref_ids]
masked_frames = selected_frames * (1 - selected_masks)
selected_flows = flows[:, neighbor_ids + ref_ids]
with torch.no_grad():
filled_frames = FGT_model(masked_frames, selected_flows, selected_masks)
filled_frames = (filled_frames + 1) / 2
filled_frames = filled_frames.cpu().permute(0, 2, 3, 1).numpy() * 255
for i in range(len(neighbor_ids)):
idx = neighbor_ids[i]
valid_frame = frames_first[0, idx].cpu().permute(1, 2, 0).numpy() * 255.0
valid_mask = masks[0, idx].cpu().permute(1, 2, 0).numpy()
comp = np.array(filled_frames[i]).astype(np.uint8) * valid_mask + np.array(
valid_frame
).astype(np.uint8) * (1 - valid_mask)
if comp_frames[idx] is None:
comp_frames[idx] = comp
else:
comp_frames[idx] = (
comp_frames[idx].astype(np.float32) * 0.5
+ comp.astype(np.float32) * 0.5
)
if args.vis_frame:
save_results(args.outroot, comp_frames)
create_dir(args.outroot)
for i in range(len(comp_frames)):
comp_frames[i] = comp_frames[i].astype(np.uint8)
imageio.mimwrite(
os.path.join(args.outroot, args.outfilename), comp_frames, fps=args.out_fps, quality=8
)
print(f"Done, please check your result in {args.outroot} ")
def main(args):
assert args.mode in (
"object_removal",
"video_extrapolation",
"watermark_removal",
), (
"Accepted modes: 'object_removal', 'video_extrapolation', and 'watermark_removal', but input is %s"
) % args.mode
video_inpainting(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--opt",
default="configs/object_removal.yaml",
help="Please select your config file for inference",
)
# video completion
parser.add_argument(
"--mode",
default="object_removal",
choices=["object_removal", "watermark_removal", "video_extrapolation"],
help="modes: object_removal / video_extrapolation",
)
parser.add_argument(
"--path", default="/myData/davis_resized/walking", help="dataset for evaluation"
)
parser.add_argument(
"--path_mask",
default="/myData/dilateAnnotations_4/walking",
help="mask for object removal",
)
parser.add_argument(
"--outroot", default="quick_start/walking3", help="output directory"
)
parser.add_argument(
"--consistencyThres",
dest="consistencyThres",
default=5,
type=float,
help="flow consistency error threshold",
)
parser.add_argument("--alpha", dest="alpha", default=0.1, type=float)
parser.add_argument("--Nonlocal", dest="Nonlocal", default=False, type=bool)
# RAFT
parser.add_argument(
"--raft_model",
default="../LAFC/flowCheckPoint/raft-things.pth",
help="restore checkpoint",
)
parser.add_argument("--small", action="store_true", help="use small model")
parser.add_argument(
"--mixed_precision", action="store_true", help="use mixed precision"
)
parser.add_argument(
"--alternate_corr",
action="store_true",
help="use efficent correlation implementation",
)
# LAFC
parser.add_argument("--lafc_ckpts", type=str, default="../LAFC/checkpoint")
# FGT
parser.add_argument("--fgt_ckpts", type=str, default="../FGT/checkpoint")
# extrapolation
parser.add_argument(
"--H_scale", dest="H_scale", default=2, type=float, help="H extrapolation scale"
)
parser.add_argument(
"--W_scale", dest="W_scale", default=2, type=float, help="W extrapolation scale"
)
# Image basic information
parser.add_argument("--imgH", type=int, default=256)
parser.add_argument("--imgW", type=int, default=432)
parser.add_argument("--flow_mask_dilates", type=int, default=8)
parser.add_argument("--frame_dilates", type=int, default=0)
parser.add_argument("--gpu", type=int, default=0)
# FGT inference parameters
parser.add_argument("--step", type=int, default=10)
parser.add_argument("--num_ref", type=int, default=-1)
parser.add_argument("--neighbor_stride", type=int, default=5)
# visualization
parser.add_argument(
"--vis_flows", action="store_true", help="Visualize the initialized flows"
)
parser.add_argument(
"--vis_completed_flows",
action="store_true",
help="Visualize the completed flows",
)
parser.add_argument(
"--vis_prop",
action="store_true",
help="Visualize the frames after stage-I filling (flow guided content propagation)",
)
parser.add_argument("--vis_frame", action="store_true", help="Visualize frames")
args = parser.parse_args()
main(args)