| import os |
| import torch |
|
|
| from gfpgan import GFPGANer |
|
|
| from tqdm import tqdm |
|
|
| from src.utils.videoio import load_video_to_cv2 |
|
|
| import cv2 |
|
|
|
|
| class GeneratorWithLen(object): |
| """ From https://stackoverflow.com/a/7460929 """ |
|
|
| def __init__(self, gen, length): |
| self.gen = gen |
| self.length = length |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __iter__(self): |
| return self.gen |
|
|
| def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'): |
| gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) |
| return list(gen) |
|
|
| def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'): |
| """ Provide a generator with a __len__ method so that it can passed to functions that |
| call len()""" |
|
|
| if os.path.isfile(images): |
| |
| images = load_video_to_cv2(images) |
|
|
| gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) |
| gen_with_len = GeneratorWithLen(gen, len(images)) |
| return gen_with_len |
|
|
| def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'): |
| """ Provide a generator function so that all of the enhanced images don't need |
| to be stored in memory at the same time. This can save tons of RAM compared to |
| the enhancer function. """ |
|
|
| print('face enhancer....') |
| if not isinstance(images, list) and os.path.isfile(images): |
| images = load_video_to_cv2(images) |
|
|
| |
| if method == 'gfpgan': |
| arch = 'clean' |
| channel_multiplier = 2 |
| model_name = 'GFPGANv1.4' |
| url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' |
| elif method == 'RestoreFormer': |
| arch = 'RestoreFormer' |
| channel_multiplier = 2 |
| model_name = 'RestoreFormer' |
| url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' |
| elif method == 'codeformer': |
| arch = 'CodeFormer' |
| channel_multiplier = 2 |
| model_name = 'CodeFormer' |
| url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' |
| else: |
| raise ValueError(f'Wrong model version {method}.') |
|
|
|
|
| |
| if bg_upsampler == 'realesrgan': |
| if not torch.cuda.is_available(): |
| import warnings |
| warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' |
| 'If you really want to use it, please modify the corresponding codes.') |
| bg_upsampler = None |
| else: |
| from basicsr.archs.rrdbnet_arch import RRDBNet |
| from realesrgan import RealESRGANer |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) |
| bg_upsampler = RealESRGANer( |
| scale=2, |
| model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', |
| model=model, |
| tile=400, |
| tile_pad=10, |
| pre_pad=0, |
| half=True) |
| else: |
| bg_upsampler = None |
|
|
| |
| model_path = os.path.join('gfpgan/weights', model_name + '.pth') |
| |
| if not os.path.isfile(model_path): |
| model_path = os.path.join('checkpoints', model_name + '.pth') |
| |
| if not os.path.isfile(model_path): |
| |
| model_path = url |
|
|
| restorer = GFPGANer( |
| model_path=model_path, |
| upscale=2, |
| arch=arch, |
| channel_multiplier=channel_multiplier, |
| bg_upsampler=bg_upsampler) |
|
|
| |
| for idx in tqdm(range(len(images)), 'Face Enhancer:'): |
| |
| img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) |
| |
| |
| cropped_faces, restored_faces, r_img = restorer.enhance( |
| img, |
| has_aligned=False, |
| only_center_face=False, |
| paste_back=True) |
| |
| r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) |
| yield r_img |
|
|