Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
gmflow_dir = os.path.join(parent_dir, 'gmflow_module') | |
sys.path.insert(0, gmflow_dir) | |
from gmflow.gmflow import GMFlow # noqa: E702 E402 F401 | |
from utils.utils import InputPadder # noqa: E702 E402 | |
import huggingface_hub | |
repo_name = 'Anonymous-sub/Rerender' | |
global_device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
gmflow_path = huggingface_hub.hf_hub_download( | |
repo_name, 'models/gmflow_sintel-0c07dcb3.pth', local_dir='./') | |
def coords_grid(b, h, w, homogeneous=False, device=None): | |
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] | |
stacks = [x, y] | |
if homogeneous: | |
ones = torch.ones_like(x) # [H, W] | |
stacks.append(ones) | |
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] | |
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] | |
if device is not None: | |
grid = grid.to(global_device) | |
return grid | |
def bilinear_sample(img, | |
sample_coords, | |
mode='bilinear', | |
padding_mode='zeros', | |
return_mask=False): | |
# img: [B, C, H, W] | |
# sample_coords: [B, 2, H, W] in image scale | |
if sample_coords.size(1) != 2: # [B, H, W, 2] | |
sample_coords = sample_coords.permute(0, 3, 1, 2) | |
b, _, h, w = sample_coords.shape | |
# Normalize to [-1, 1] | |
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 | |
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 | |
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] | |
img = F.grid_sample(img, | |
grid, | |
mode=mode, | |
padding_mode=padding_mode, | |
align_corners=True) | |
if return_mask: | |
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & ( | |
y_grid <= 1) # [B, H, W] | |
return img, mask | |
return img | |
def flow_warp(feature, | |
flow, | |
mask=False, | |
mode='bilinear', | |
padding_mode='zeros'): | |
b, c, h, w = feature.size() | |
assert flow.size(1) == 2 | |
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] | |
return bilinear_sample(feature, | |
grid, | |
mode=mode, | |
padding_mode=padding_mode, | |
return_mask=mask) | |
def forward_backward_consistency_check(fwd_flow, | |
bwd_flow, | |
alpha=0.01, | |
beta=0.5): | |
# fwd_flow, bwd_flow: [B, 2, H, W] | |
# alpha and beta values are following UnFlow | |
# (https://arxiv.org/abs/1711.07837) | |
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 | |
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 | |
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, | |
dim=1) # [B, H, W] | |
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] | |
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] | |
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] | |
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) | |
threshold = alpha * flow_mag + beta | |
fwd_occ = (diff_fwd > threshold).float() # [B, H, W] | |
bwd_occ = (diff_bwd > threshold).float() | |
return fwd_occ, bwd_occ | |
def get_warped_and_mask(flow_model, | |
image1, | |
image2, | |
image3=None, | |
pixel_consistency=False): | |
if image3 is None: | |
image3 = image1 | |
padder = InputPadder(image1.shape, padding_factor=8) | |
image1, image2 = padder.pad(image1[None].to(global_device), | |
image2[None].to(global_device)) | |
results_dict = flow_model(image1, | |
image2, | |
attn_splits_list=[2], | |
corr_radius_list=[-1], | |
prop_radius_list=[-1], | |
pred_bidir_flow=True) | |
flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] | |
fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] | |
bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] | |
fwd_occ, bwd_occ = forward_backward_consistency_check( | |
fwd_flow, bwd_flow) # [1, H, W] float | |
if pixel_consistency: | |
warped_image1 = flow_warp(image1, bwd_flow) | |
bwd_occ = torch.clamp( | |
bwd_occ + | |
(abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0, | |
1).unsqueeze(0) | |
warped_results = flow_warp(image3, bwd_flow) | |
return warped_results, bwd_occ, bwd_flow | |
class FlowCalc(): | |
def __init__(self, model_path='./models/gmflow_sintel-0c07dcb3.pth'): | |
flow_model = GMFlow( | |
feature_channels=128, | |
num_scales=1, | |
upsample_factor=8, | |
num_head=1, | |
attention_type='swin', | |
ffn_dim_expansion=4, | |
num_transformer_layers=6, | |
).to(global_device) | |
checkpoint = torch.load(model_path, | |
map_location=lambda storage, loc: storage) | |
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint | |
flow_model.load_state_dict(weights, strict=False) | |
flow_model.eval() | |
self.model = flow_model | |
def get_flow(self, image1, image2, save_path=None): | |
if save_path is not None and os.path.exists(save_path): | |
bwd_flow = read_flow(save_path) | |
return bwd_flow | |
image1 = torch.from_numpy(image1).permute(2, 0, 1).float() | |
image2 = torch.from_numpy(image2).permute(2, 0, 1).float() | |
padder = InputPadder(image1.shape, padding_factor=8) | |
image1, image2 = padder.pad(image1[None].to(global_device), | |
image2[None].to(global_device)) | |
results_dict = self.model(image1, | |
image2, | |
attn_splits_list=[2], | |
corr_radius_list=[-1], | |
prop_radius_list=[-1], | |
pred_bidir_flow=True) | |
flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] | |
bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] | |
if save_path is not None: | |
flow_np = bwd_flow.cpu().numpy() | |
np.save(save_path, flow_np) | |
return bwd_flow | |
def warp(self, img, flow, mode='bilinear'): | |
expand = False | |
if len(img.shape) == 2: | |
expand = True | |
img = np.expand_dims(img, 2) | |
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) | |
dtype = img.dtype | |
img = img.to(torch.float) | |
res = flow_warp(img, flow, mode=mode) | |
res = res.to(dtype) | |
res = res[0].cpu().permute(1, 2, 0).numpy() | |
if expand: | |
res = res[:, :, 0] | |
return res | |
def read_flow(save_path): | |
flow_np = np.load(save_path) | |
bwd_flow = torch.from_numpy(flow_np) | |
return bwd_flow | |
flow_calc = FlowCalc() | |