Rerender / flow /flow_utils.py
Anonymous-sub's picture
Update flow/flow_utils.py
dc716b9
raw
history blame
7.37 kB
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
gmflow_dir = os.path.join(parent_dir, 'gmflow_module')
sys.path.insert(0, gmflow_dir)
from gmflow.gmflow import GMFlow # noqa: E702 E402 F401
from utils.utils import InputPadder # noqa: E702 E402
import huggingface_hub
repo_name = 'Anonymous-sub/Rerender'
global_device = 'cuda' if torch.cuda.is_available() else 'cpu'
gmflow_path = huggingface_hub.hf_hub_download(
repo_name, 'models/gmflow_sintel-0c07dcb3.pth', local_dir='./')
def coords_grid(b, h, w, homogeneous=False, device=None):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
stacks = [x, y]
if homogeneous:
ones = torch.ones_like(x) # [H, W]
stacks.append(ones)
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
if device is not None:
grid = grid.to(global_device)
return grid
def bilinear_sample(img,
sample_coords,
mode='bilinear',
padding_mode='zeros',
return_mask=False):
# img: [B, C, H, W]
# sample_coords: [B, 2, H, W] in image scale
if sample_coords.size(1) != 2: # [B, H, W, 2]
sample_coords = sample_coords.permute(0, 3, 1, 2)
b, _, h, w = sample_coords.shape
# Normalize to [-1, 1]
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
img = F.grid_sample(img,
grid,
mode=mode,
padding_mode=padding_mode,
align_corners=True)
if return_mask:
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (
y_grid <= 1) # [B, H, W]
return img, mask
return img
def flow_warp(feature,
flow,
mask=False,
mode='bilinear',
padding_mode='zeros'):
b, c, h, w = feature.size()
assert flow.size(1) == 2
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
return bilinear_sample(feature,
grid,
mode=mode,
padding_mode=padding_mode,
return_mask=mask)
def forward_backward_consistency_check(fwd_flow,
bwd_flow,
alpha=0.01,
beta=0.5):
# fwd_flow, bwd_flow: [B, 2, H, W]
# alpha and beta values are following UnFlow
# (https://arxiv.org/abs/1711.07837)
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow,
dim=1) # [B, H, W]
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
threshold = alpha * flow_mag + beta
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
bwd_occ = (diff_bwd > threshold).float()
return fwd_occ, bwd_occ
@torch.no_grad()
def get_warped_and_mask(flow_model,
image1,
image2,
image3=None,
pixel_consistency=False):
if image3 is None:
image3 = image1
padder = InputPadder(image1.shape, padding_factor=8)
image1, image2 = padder.pad(image1[None].to(global_device),
image2[None].to(global_device))
results_dict = flow_model(image1,
image2,
attn_splits_list=[2],
corr_radius_list=[-1],
prop_radius_list=[-1],
pred_bidir_flow=True)
flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W]
fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W]
bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W]
fwd_occ, bwd_occ = forward_backward_consistency_check(
fwd_flow, bwd_flow) # [1, H, W] float
if pixel_consistency:
warped_image1 = flow_warp(image1, bwd_flow)
bwd_occ = torch.clamp(
bwd_occ +
(abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0,
1).unsqueeze(0)
warped_results = flow_warp(image3, bwd_flow)
return warped_results, bwd_occ, bwd_flow
class FlowCalc():
def __init__(self, model_path='./models/gmflow_sintel-0c07dcb3.pth'):
flow_model = GMFlow(
feature_channels=128,
num_scales=1,
upsample_factor=8,
num_head=1,
attention_type='swin',
ffn_dim_expansion=4,
num_transformer_layers=6,
).to(global_device)
checkpoint = torch.load(model_path,
map_location=lambda storage, loc: storage)
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
flow_model.load_state_dict(weights, strict=False)
flow_model.eval()
self.model = flow_model
@torch.no_grad()
def get_flow(self, image1, image2, save_path=None):
if save_path is not None and os.path.exists(save_path):
bwd_flow = read_flow(save_path)
return bwd_flow
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
padder = InputPadder(image1.shape, padding_factor=8)
image1, image2 = padder.pad(image1[None].to(global_device),
image2[None].to(global_device))
results_dict = self.model(image1,
image2,
attn_splits_list=[2],
corr_radius_list=[-1],
prop_radius_list=[-1],
pred_bidir_flow=True)
flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W]
bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W]
if save_path is not None:
flow_np = bwd_flow.cpu().numpy()
np.save(save_path, flow_np)
return bwd_flow
def warp(self, img, flow, mode='bilinear'):
expand = False
if len(img.shape) == 2:
expand = True
img = np.expand_dims(img, 2)
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
dtype = img.dtype
img = img.to(torch.float)
res = flow_warp(img, flow, mode=mode)
res = res.to(dtype)
res = res[0].cpu().permute(1, 2, 0).numpy()
if expand:
res = res[:, :, 0]
return res
def read_flow(save_path):
flow_np = np.load(save_path)
bwd_flow = torch.from_numpy(flow_np)
return bwd_flow
flow_calc = FlowCalc()