Spaces:
Running
Running
import os | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import glob | |
import torch | |
import tqdm | |
import shutil | |
import argparse | |
from third_party.GPEN.face_enhancement import FaceEnhancement | |
make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn)) | |
class GPENImageInfer(object): | |
def __init__(self, device): | |
super(GPENImageInfer, self).__init__() | |
model = { | |
"name": "GPEN-BFR-512", | |
"in_size": 512, | |
"out_size": 512, | |
"channel_multiplier": 2, | |
"narrow": 1, | |
} | |
faceenhancer = FaceEnhancement( | |
base_dir=make_abs_path('./'), | |
use_sr=True, | |
in_size=model["in_size"], | |
out_size=model["out_size"], | |
model=model["name"], | |
channel_multiplier=model["channel_multiplier"], | |
narrow=model["narrow"], | |
device=device, | |
) | |
self.faceenhancer = faceenhancer | |
def image_infer(self, in_img: np.ndarray): | |
""" | |
:param in_img: np.ndarray, (H,W,BGR), in [0,255] | |
:return: out_img: np.ndarray, (H,W,BGR), in [0,255] | |
""" | |
h, w, _ = in_img.shape | |
out_img, orig_faces, enhanced_faces = self.faceenhancer.process(in_img) | |
out_img = cv2.resize(out_img, (w, h)) | |
return out_img | |
def ndarray_infer(self, in_ndarray: np.ndarray, | |
save_folder: str = 'demo_images/out/', | |
save_name: str = 'reen.png', | |
): | |
""" | |
:param in_ndarray: np.ndarray, (N,H,W,BGR), in [0,255] | |
:param save_folder: not used | |
:param save_name: not used | |
:return: out_ndarray: np.ndarray, (N,H,W,BGR), in [0,255] | |
""" | |
B, H, W, C = in_ndarray.shape | |
out_ndarray = np.zeros_like(in_ndarray, dtype=np.uint8) # (N,H,W,BGR) | |
for b_idx in range(B): | |
single_img = in_ndarray[b_idx] | |
out_img = self.image_infer(single_img) # (H,W,BGR), in [0,255] | |
out_ndarray[b_idx] = out_img | |
return out_ndarray | |
def batch_infer(self, in_batch: torch.Tensor, | |
save_folder: str = 'demo_images/out/', | |
save_name: str = 'reen.png', | |
save_batch_idx: int = 0, | |
): | |
""" | |
:param in_batch: (N,RGB,H,W), in [-1,1] | |
:return: out_batch: (N,RGB,H,W), in [-1,1] | |
""" | |
B, C, H, W = in_batch.shape | |
device = in_batch.device | |
in_batch = ((in_batch + 1.) * 127.5).permute(0, 2, 3, 1) | |
in_batch = in_batch.cpu().numpy().astype(np.uint8) # (N,H,W,RGB), in [0,255] | |
in_batch = in_batch[:, :, :, ::-1] # (N,H,W,BGR) | |
out_batch = np.zeros_like(in_batch, dtype=np.uint8) # (N,H,W,BGR) | |
for b_idx in range(B): | |
single_img = in_batch[b_idx] | |
out_img = self.image_infer(single_img) # (H,W,BGR), in [0,255] | |
out_batch[b_idx] = out_img[:, :, ::-1] | |
if save_batch_idx is not None and b_idx == save_batch_idx: | |
cv2.imwrite(os.path.join(save_folder, save_name), out_img) | |
out_batch = torch.FloatTensor(out_batch).to(device) | |
out_batch = out_batch / 127.5 - 1. # (N,H,W,RGB) | |
out_batch = out_batch.permute(0, 3, 1, 2) # (N,RGB,H,W) | |
out_batch = out_batch.clamp(-1, 1) | |
return out_batch | |
if __name__ == '__main__': | |
gpen = GPENImageInfer() | |
in_folder = 'examples/imgs/' | |
img_list = os.listdir(in_folder) | |
for img_name in img_list: | |
if 'gpen' in img_name: | |
continue | |
in_path = os.path.join(in_folder, img_name) | |
out_path = in_path.replace('.png', '_gpen.png') | |
out_path = in_path.replace('.jpg', '_gpen.jpg') | |
im = cv2.imread(in_path, cv2.IMREAD_COLOR) # BGR | |
img = gpen.image_infer(im) | |
cv2.imwrite(out_path, img) | |