Track-Anything / inpainter /base_inpainter.py
watchtowerss's picture
operation prompt version
3c7c9f9
raw history blame
No virus
7.01 kB
import os
import glob
from PIL import Image
import torch
import yaml
import cv2
import importlib
import numpy as np
from tqdm import tqdm
from inpainter.util.tensor_util import resize_frames, resize_masks
class BaseInpainter:
def __init__(self, E2FGVI_checkpoint, device) -> None:
"""
E2FGVI_checkpoint: checkpoint of inpainter (version hq, with multi-resolution support)
"""
net = importlib.import_module('inpainter.model.e2fgvi_hq')
self.model = net.InpaintGenerator().to(device)
self.model.load_state_dict(torch.load(E2FGVI_checkpoint, map_location=device))
self.model.eval()
self.device = device
# load configurations
with open("inpainter/config/config.yaml", 'r') as stream:
config = yaml.safe_load(stream)
self.neighbor_stride = config['neighbor_stride']
self.num_ref = config['num_ref']
self.step = config['step']
# sample reference frames from the whole video
def get_ref_index(self, f, neighbor_ids, length):
ref_index = []
if self.num_ref == -1:
for i in range(0, length, self.step):
if i not in neighbor_ids:
ref_index.append(i)
else:
start_idx = max(0, f - self.step * (self.num_ref // 2))
end_idx = min(length, f + self.step * (self.num_ref // 2))
for i in range(start_idx, end_idx + 1, self.step):
if i not in neighbor_ids:
if len(ref_index) > self.num_ref:
break
ref_index.append(i)
return ref_index
def inpaint(self, frames, masks, dilate_radius=15, ratio=1):
"""
frames: numpy array, T, H, W, 3
masks: numpy array, T, H, W
dilate_radius: radius when applying dilation on masks
ratio: down-sample ratio
Output:
inpainted_frames: numpy array, T, H, W, 3
"""
assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
masks = masks.copy()
masks = np.clip(masks, 0, 1)
kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
T, H, W = masks.shape
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
# size: (w, h)
if ratio == 1:
size = None
binary_masks = masks
else:
size = [int(W*ratio), int(H*ratio)]
size = [si+1 if si%2>0 else si for si in size] # only consider even values
# shortest side should be larger than 50
if min(size) < 50:
ratio = 50. / min(H, W)
size = [int(W*ratio), int(H*ratio)]
binary_masks = resize_masks(masks, tuple(size))
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
# frames and binary_masks are numpy arrays
h, w = frames.shape[1:3]
video_length = T
# convert to tensor
imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
imgs, masks = imgs.to(self.device), masks.to(self.device)
comp_frames = [None] * video_length
for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
neighbor_ids = [
i for i in range(max(0, f - self.neighbor_stride),
min(video_length, f + self.neighbor_stride + 1))
]
ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
with torch.no_grad():
masked_imgs = selected_imgs * (1 - selected_masks)
mod_size_h = 60
mod_size_w = 108
h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
masked_imgs = torch.cat(
[masked_imgs, torch.flip(masked_imgs, [3])],
3)[:, :, :, :h + h_pad, :]
masked_imgs = torch.cat(
[masked_imgs, torch.flip(masked_imgs, [4])],
4)[:, :, :, :, :w + w_pad]
pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
pred_imgs = pred_imgs[:, :, :h, :w]
pred_imgs = (pred_imgs + 1) / 2
pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
for i in range(len(neighbor_ids)):
idx = neighbor_ids[i]
img = pred_imgs[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * (
1 - binary_masks[idx])
if comp_frames[idx] is None:
comp_frames[idx] = img
else:
comp_frames[idx] = comp_frames[idx].astype(
np.float32) * 0.5 + img.astype(np.float32) * 0.5
inpainted_frames = np.stack(comp_frames, 0)
return inpainted_frames.astype(np.uint8)
if __name__ == '__main__':
frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
frame_path.sort()
mask_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/Annotations/480p/parkour', "*.png"))
mask_path.sort()
save_path = '/ssd1/gaomingqi/results/inpainting/parkour'
if not os.path.exists(save_path):
os.mkdir(save_path)
frames = []
masks = []
for fid, mid in zip(frame_path, mask_path):
frames.append(Image.open(fid).convert('RGB'))
masks.append(Image.open(mid).convert('P'))
frames = np.stack(frames, 0)
masks = np.stack(masks, 0)
# ----------------------------------------------
# how to use
# ----------------------------------------------
# 1/3: set checkpoint and device
checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth'
device = 'cuda:6'
# 2/3: initialise inpainter
base_inpainter = BaseInpainter(checkpoint, device)
# 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
# ratio: (0, 1], ratio for down sample, default value is 1
inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.01) # numpy array, T, H, W, 3
# ----------------------------------------------
# end
# ----------------------------------------------
# save
for ti, inpainted_frame in enumerate(inpainted_frames):
frame = Image.fromarray(inpainted_frame).convert('RGB')
frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))