File size: 3,407 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import numpy as np
from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from modules.postprocess.realesrgan_model_arch import SRVGGNetCompact
from modules.upscaler import Upscaler
from modules.shared import opts, device, log
from modules import devices

class UpscalerRealESRGAN(Upscaler):
    def __init__(self, dirname):
        self.name = "RealESRGAN"
        self.user_path = dirname
        super().__init__()
        self.scalers = self.find_scalers()
        self.models = {}
        for scaler in self.scalers:
            if scaler.name == 'RealESRGAN 2x+':
                scaler.model = lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
                scaler.scale = 2
            elif scaler.name == 'RealESRGAN 4x+ Anime6B':
                scaler.model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
            elif scaler.name == 'RealESRGAN 4x General V3':
                scaler.model = lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
            elif scaler.name == 'RealESRGAN 4x General WDN V3':
                scaler.model = lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
            elif scaler.name == 'RealESRGAN AnimeVideo V3':
                scaler.model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
            elif scaler.name == 'RealESRGAN 4x+':
                scaler.model = lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
            else:
                log.error(f"Upscaler unrecognized model: type={self.name} model={scaler.name}")

    def load_model(self, path): # pylint: disable=unused-argument
        pass

    def do_upscale(self, img, selected_model):
        if not self.enable:
            return img
        try:
            from modules.postprocess.realesrgan_model_arch import RealESRGANer
        except Exception:
            log.error("Error importing Real-ESRGAN:")
            return img
        info = self.find_model(selected_model)
        if info is None or not os.path.exists(info.local_data_path):
            return img
        if self.models.get(info.local_data_path, None) is not None:
            log.debug(f"Upscaler cached: type={self.name} model={info.local_data_path}")
            upsampler=self.models[info.local_data_path]
        else:
            upsampler = RealESRGANer(
                name=info.name,
                scale=info.scale,
                model_path=info.local_data_path,
                model=info.model(),
                half=not opts.no_half and not opts.upcast_sampling,
                tile=opts.upscaler_tile_size,
                tile_pad=opts.upscaler_tile_overlap,
                device=device,
            )
            self.models[info.local_data_path] = upsampler
        upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
        if opts.upscaler_unload and info.local_data_path in self.models:
            del self.models[info.local_data_path]
            log.debug(f"Upscaler unloaded: type={self.name} model={selected_model}")
            devices.torch_gc(force=True)

        image = Image.fromarray(upsampled)
        return image