Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	File size: 4,637 Bytes
			
			| 29b0bbc 62e3f73 29b0bbc 62e3f73 29b0bbc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | import os
import torch 
from gfpgan import GFPGANer
from tqdm import tqdm
import cv2
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
import warnings
from enum import Enum
class EnhancementMethod(str, Enum):
    gfpgan = "gfpgan"
    RestoreFormer = "RestoreFormer"
    codeformer = "codeformer"
    realesrgan = "realesrgan"
class Enhancer:
    def __init__(self, method: EnhancementMethod, background_enhancement=True, upscale=2):
        self.method = method
        self.background_enhancement = background_enhancement
        self.upscale = upscale
        self.bg_upsampler = None
        self.realesrgan_enhancer = None
        if self.method != EnhancementMethod.realesrgan:
            self.setup_face_enhancer()
            if self.background_enhancement:
                self.setup_background_enhancer()
        else:
            self.setup_realesrgan_enhancer()
    def setup_background_enhancer(self):
        if not torch.cuda.is_available():
            warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it.')
            return
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
        model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth'
        self.bg_upsampler = RealESRGANer(
            scale=self.upscale,
            model_path=model_path,
            model=model,
            tile=400,
            tile_pad=10,
            pre_pad=0,
            half=True)
    def setup_realesrgan_enhancer(self):
        if not torch.cuda.is_available():
            raise ValueError('CUDA is not available for RealESRGAN')
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
        model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth'
        self.realesrgan_enhancer = RealESRGANer(
            scale=self.upscale,
            model_path=model_path,
            model=model,
            tile=400,
            tile_pad=10,
            pre_pad=0,
            half=True)
    def setup_face_enhancer(self):
        model_configs = {
            EnhancementMethod.gfpgan: {
                'arch': 'clean',
                'channel_multiplier': 2,
                'model_name': 'GFPGANv1.4',
                'url': 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth'
            },
            EnhancementMethod.RestoreFormer: {
                'arch': 'RestoreFormer',
                'channel_multiplier': 2,
                'model_name': 'RestoreFormer',
                'url': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
            },
            EnhancementMethod.codeformer: {
                'arch': 'CodeFormer',
                'channel_multiplier': 2,
                'model_name': 'CodeFormer',
                'url': 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth'
            }
        }
        config = model_configs.get(self.method)
        if not config:
            raise ValueError(f'Wrong model version {self.method}')
        model_path = os.path.join('gfpgan/weights', config['model_name'] + '.pth')
        if not os.path.isfile(model_path):
            model_path = os.path.join('checkpoints', config['model_name'] + '.pth')
        if not os.path.isfile(model_path):
            model_path = config['url']
        self.face_enhancer = GFPGANer(
            model_path=model_path,
            upscale=self.upscale,
            arch=config['arch'],
            channel_multiplier=config['channel_multiplier'],
            bg_upsampler=self.bg_upsampler)
    def check_image_resolution(self, image):
        height, width, _ = image.shape
        return width, height
    async def enhance(self, image):
        img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        width, height = self.check_image_resolution(img)
        
        if self.method == EnhancementMethod.realesrgan:
            enhanced_img, _ = await asyncio.to_thread(self.realesrgan_enhancer.enhance, img, outscale=self.upscale)
        else:
            _, _, enhanced_img = await asyncio.to_thread(self.face_enhancer.enhance,
                img,
                has_aligned=False,
                only_center_face=False,
                paste_back=True)
        
        enhanced_img = cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB)
        enhanced_width, enhanced_height = self.check_image_resolution(enhanced_img)
        return enhanced_img, (width, height), (enhanced_width, enhanced_height) |