|
import os |
|
import torch |
|
|
|
from gfpgan import GFPGANer |
|
|
|
from tqdm import tqdm |
|
|
|
from src.utils.videoio import load_video_to_cv2 |
|
|
|
import cv2 |
|
|
|
|
|
class GeneratorWithLen(object): |
|
""" From https://stackoverflow.com/a/7460929 """ |
|
|
|
def __init__(self, gen, length): |
|
self.gen = gen |
|
self.length = length |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __iter__(self): |
|
return self.gen |
|
|
|
def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'): |
|
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) |
|
return list(gen) |
|
|
|
def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'): |
|
""" Provide a generator with a __len__ method so that it can passed to functions that |
|
call len()""" |
|
|
|
if os.path.isfile(images): |
|
|
|
images = load_video_to_cv2(images) |
|
|
|
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) |
|
gen_with_len = GeneratorWithLen(gen, len(images)) |
|
return gen_with_len |
|
|
|
def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'): |
|
""" Provide a generator function so that all of the enhanced images don't need |
|
to be stored in memory at the same time. This can save tons of RAM compared to |
|
the enhancer function. """ |
|
|
|
print('face enhancer....') |
|
if not isinstance(images, list) and os.path.isfile(images): |
|
images = load_video_to_cv2(images) |
|
|
|
|
|
if method == 'gfpgan': |
|
arch = 'clean' |
|
channel_multiplier = 2 |
|
model_name = 'GFPGANv1.4' |
|
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' |
|
elif method == 'RestoreFormer': |
|
arch = 'RestoreFormer' |
|
channel_multiplier = 2 |
|
model_name = 'RestoreFormer' |
|
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' |
|
elif method == 'codeformer': |
|
arch = 'CodeFormer' |
|
channel_multiplier = 2 |
|
model_name = 'CodeFormer' |
|
url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' |
|
else: |
|
raise ValueError(f'Wrong model version {method}.') |
|
|
|
|
|
|
|
if bg_upsampler == 'realesrgan': |
|
if not torch.cuda.is_available(): |
|
import warnings |
|
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' |
|
'If you really want to use it, please modify the corresponding codes.') |
|
bg_upsampler = None |
|
else: |
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
from realesrgan import RealESRGANer |
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) |
|
bg_upsampler = RealESRGANer( |
|
scale=2, |
|
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', |
|
model=model, |
|
tile=400, |
|
tile_pad=10, |
|
pre_pad=0, |
|
half=True) |
|
else: |
|
bg_upsampler = None |
|
|
|
|
|
model_path = os.path.join('gfpgan/weights', model_name + '.pth') |
|
|
|
if not os.path.isfile(model_path): |
|
model_path = os.path.join('checkpoints', model_name + '.pth') |
|
|
|
if not os.path.isfile(model_path): |
|
|
|
model_path = url |
|
|
|
restorer = GFPGANer( |
|
model_path=model_path, |
|
upscale=2, |
|
arch=arch, |
|
channel_multiplier=channel_multiplier, |
|
bg_upsampler=bg_upsampler) |
|
|
|
|
|
for idx in tqdm(range(len(images)), 'Face Enhancer:'): |
|
|
|
img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) |
|
|
|
|
|
cropped_faces, restored_faces, r_img = restorer.enhance( |
|
img, |
|
has_aligned=False, |
|
only_center_face=False, |
|
paste_back=True) |
|
|
|
r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) |
|
yield r_img |
|
|