Spaces:
Running
on
Zero
Running
on
Zero
# https://github.com/xinntao/facexlib/blob/master/inference/inference_matting.py | |
from tqdm import tqdm, trange | |
import argparse | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import torch.nn.functional as F | |
from torchvision.transforms.functional import normalize | |
from facexlib.matting import init_matting_model | |
from facexlib.utils import img2tensor | |
def matt_single(args): | |
modnet = init_matting_model() | |
# read image | |
img = cv2.imread(args.img_path) / 255. | |
# unify image channels to 3 | |
if len(img.shape) == 2: | |
img = img[:, :, None] | |
if img.shape[2] == 1: | |
img = np.repeat(img, 3, axis=2) | |
elif img.shape[2] == 4: | |
img = img[:, :, 0:3] | |
img_t = img2tensor(img, bgr2rgb=True, float32=True) | |
normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) | |
img_t = img_t.unsqueeze(0).cuda() | |
# resize image for input | |
_, _, im_h, im_w = img_t.shape | |
ref_size = 512 | |
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: | |
if im_w >= im_h: | |
im_rh = ref_size | |
im_rw = int(im_w / im_h * ref_size) | |
elif im_w < im_h: | |
im_rw = ref_size | |
im_rh = int(im_h / im_w * ref_size) | |
else: | |
im_rh = im_h | |
im_rw = im_w | |
im_rw = im_rw - im_rw % 32 | |
im_rh = im_rh - im_rh % 32 | |
img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area') | |
# inference | |
_, _, matte = modnet(img_t, True) | |
# resize and save matte | |
matte = F.interpolate(matte, size=(im_h, im_w), mode='area') | |
matte = matte[0][0].data.cpu().numpy() | |
cv2.imwrite(args.save_path, (matte * 255).astype('uint8')) | |
# get foreground | |
matte = matte[:, :, None] | |
foreground = img * matte + np.full(img.shape, 1) * (1 - matte) | |
cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255) | |
def matt_directory(args): # for extracting ffhq imgs foreground | |
modnet = init_matting_model() | |
all_imgs = list(Path(args.img_dir_path).rglob('*.png')) | |
print('all imgs: ', len(all_imgs)) | |
tgt_dir_path = '/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_matte/' | |
# tgt_img_path = '/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_matting/' | |
for img_path in tqdm(all_imgs): | |
# read image | |
# img = cv2.imread(args.img_path) / 255. | |
img = cv2.imread(str(img_path)) / 255. | |
relative_img_path = Path(img_path).relative_to('/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_512/') | |
tgt_save_path = tgt_dir_path / relative_img_path | |
(tgt_save_path.parent).mkdir(parents=True, exist_ok=True) | |
# unify image channels to 3 | |
if len(img.shape) == 2: | |
img = img[:, :, None] | |
if img.shape[2] == 1: | |
img = np.repeat(img, 3, axis=2) | |
elif img.shape[2] == 4: | |
img = img[:, :, 0:3] | |
img_t = img2tensor(img, bgr2rgb=True, float32=True) | |
normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) | |
img_t = img_t.unsqueeze(0).cuda() | |
# resize image for input | |
_, _, im_h, im_w = img_t.shape | |
ref_size = 512 | |
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: | |
if im_w >= im_h: | |
im_rh = ref_size | |
im_rw = int(im_w / im_h * ref_size) | |
elif im_w < im_h: | |
im_rw = ref_size | |
im_rh = int(im_h / im_w * ref_size) | |
else: | |
im_rh = im_h | |
im_rw = im_w | |
im_rw = im_rw - im_rw % 32 | |
im_rh = im_rh - im_rh % 32 | |
img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area') | |
# inference | |
_, _, matte = modnet(img_t, True) | |
# resize and save matte | |
matte = F.interpolate(matte, size=(im_h, im_w), mode='area') | |
matte = matte[0][0].data.cpu().numpy() | |
# cv2.imwrite(args.save_path, (matte * 255).astype('uint8')) | |
cv2.imwrite(str(tgt_save_path), (matte * 255).astype('uint8')) | |
assert tgt_save_path.exists() | |
# get foreground | |
# matte = matte[:, :, None] | |
# foreground = img * matte + np.full(img.shape, 1) * (1 - matte) | |
# cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255) | |
pass | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--img_path', type=str, default='assets/test.jpg') | |
parser.add_argument('--save_path', type=str, default='test_matting.png') | |
parser.add_argument('--img_dir_path', type=str, default='assets', required=False) | |
args = parser.parse_args() | |
# matt_single(args) | |
matt_directory(args) |