#!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2022-06-09 14:59:55 import torch import random import numpy as np from einops import rearrange def batch_inpainging_from_grad(im_in, mask, gradx, grady): ''' Recovering from gradient for batch data (torch tensro). Input: im_in: N x c x h x w, torch tensor, masked image mask: N x 1 x h x w, torch tensor gradx, grady: N x c x h x w, torch tensor, image gradient ''' im_out = torch.zeros_like(im_in.data) for ii in range(im_in.shape[0]): im_current, gradx_current, grady_current = [rearrange(x[ii,].cpu().numpy(), 'c h w -> h w c') for x in [im_in, gradx, grady]] mask_current = mask[ii, 0,].cpu().numpy() out_current = inpainting_from_grad(im_current, mask_current, gradx_current, grady_current) im_out[ii,] = torch.from_numpy(rearrange(out_current, 'h w c -> c h w')).to( device=im_in.device, dtype=im_in.dtype ) return im_out def inpainting_from_grad(im_in, mask, gradx, grady): ''' Input: im_in: h x w x c, masked image, numpy array mask: h x w, image mask, 1 represents missing value gradx: h x w x c, gradient along x-axis, numpy array grady: h x w x c, gradient along y-axis, numpy array Output: im_out: recoverd image ''' h, w = im_in.shape[:2] counts_h = np.sum(1-mask, axis=0, keepdims=False) counts_w = np.sum(1-mask, axis=1, keepdims=False) if np.any(counts_h[1:-1,] == h): idx = find_first_index(counts_h[1:-1,], h) + 1 im_out = fill_image_from_gradx(im_in, mask, gradx, idx) elif np.any(counts_w[1:-1,] == w): idx = find_first_index(counts_w[1:-1,], w) + 1 im_out = inpainting_from_grad(im_in.T, mask.T, gradx.T, idx) else: idx = random.choices(list(range(1,w-1)), k=1, weights=counts_h[1:-1])[0] line = fill_line(im_in[:, idx, ], mask[:, idx,], grady[:, idx,]) im_in[:, idx,] = line im_out = fill_image_from_gradx(im_in, mask, gradx, idx) if im_in.ndim > mask.ndim: mask = mask[:, :, None] im_out = im_in + im_out * mask return im_out def fill_image_from_gradx(im_in, mask, gradx, idx): init = np.zeros_like(im_in) init[:, idx,] = im_in[:, idx,] right = np.cumsum(init[:, idx:-1, ] + gradx[:, idx+1:, ], axis=1) left = np.cumsum( init[:, idx:0:-1, ] - gradx[:, idx:0:-1, ], axis=1 )[:, ::-1] center = im_in[:, idx, ][:, None] # h x 1 x 3 im_out = np.concatenate((left, center, right), axis=1) return im_out def fill_line(xx, mm, grad): ''' Fill one line from grad. Input: xx: n x c array, masked vector mm: (n,) array, mask, 1 represent missing value grad: (n,) array ''' n = xx.shape[0] assert mm.sum() < n if mm.sum() == 0: return xx else: idx1 = find_first_index(mm, 1) if idx1 == 0: idx2 = find_first_index(mm, 0) subx = xx[idx2::-1,].copy() subgrad = grad[idx2::-1, ].copy() subx -= subgrad xx[:idx2,] = np.cumsum(subx, axis=0)[idx2-1::-1,] mm[idx1:idx2,] = 0 else: idx2 = find_first_index(mm[idx1:,], 0) + idx1 subx = xx[idx1-1:idx2-1,].copy() subgrad = grad[idx1:idx2,].copy() subx += subgrad xx[idx1:idx2,] = np.cumsum(subx, axis=0) mm[idx1:idx2,] = 0 return fill_line(xx, mm, grad) def find_first_index(mm, value): ''' Input: mm: (n, ) array value: scalar ''' try: out = next((idx for idx, val in np.ndenumerate(mm) if val == value))[0] except StopIteration: out = mm.shape[0] return out if __name__ == '__main__': import sys from pathlib import Path sys.path.append(str(Path(__file__).resolve().parents[1])) from utils import util_image from datapipe.masks.train import process_mask # mask_file_names = [x for x in Path('../lama/LaMa_test_images').glob('*mask*.png')] mask_file_names = [x for x in Path('./testdata/inpainting/val/places/').glob('*mask*.png')] file_names = [x.parents[0]/(x.stem.rsplit('_mask',1)[0]+'.png') for x in mask_file_names] for im_path, mask_path in zip(file_names, mask_file_names): im = util_image.imread(im_path, chn='rgb', dtype='float32') mask = process_mask(util_image.imread(mask_path, chn='rgb', dtype='float32')[:, :, 0]) grad_dict = util_image.imgrad(im) im_masked = im * (1 - mask[:, :, None]) im_recover = inpainting_from_grad(im_masked, mask, grad_dict['gradx'], grad_dict['grady']) error_max = np.abs(im_recover -im).max() print('Error Max: {:.2e}'.format(error_max))