#!/usr/bin/env python3 # -*- coding: utf-8 -*- import functools import math import re from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F from . import block as B # Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/esrgan.py # Which enhanced stuff that was already here class RRDBNet(nn.Module): def __init__( self, state_dict, norm=None, act: str = "leakyrelu", upsampler: str = "upconv", mode: B.ConvMode = "CNA", ) -> None: """ ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks. By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, and Chen Change Loy. This is old-arch Residual in Residual Dense Block Network and is not the newest revision that's available at github.com/xinntao/ESRGAN. This is on purpose, the newest Network has severely limited the potential use of the Network with no benefits. This network supports model files from both new and old-arch. Args: norm: Normalization layer act: Activation layer upsampler: Upsample layer. upconv, pixel_shuffle mode: Convolution mode """ super(RRDBNet, self).__init__() self.model_arch = "ESRGAN" self.sub_type = "SR" self.state = state_dict self.norm = norm self.act = act self.upsampler = upsampler self.mode = mode self.state_map = { # currently supports old, new, and newer RRDBNet arch models # ESRGAN, BSRGAN/RealSR, Real-ESRGAN "model.0.weight": ("conv_first.weight",), "model.0.bias": ("conv_first.bias",), "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", ), } if "params_ema" in self.state: self.state = self.state["params_ema"] # self.model_arch = "RealESRGAN" self.num_blocks = self.get_num_blocks() self.plus = any("conv1x1" in k for k in self.state.keys()) if self.plus: self.model_arch = "ESRGAN+" self.state = self.new_to_old_arch(self.state) self.key_arr = list(self.state.keys()) self.in_nc: int = self.state[self.key_arr[0]].shape[1] self.out_nc: int = self.state[self.key_arr[-1]].shape[0] self.scale: int = self.get_scale() self.num_filters: int = self.state[self.key_arr[0]].shape[0] c2x2 = False if self.state["model.0.weight"].shape[-2] == 2: c2x2 = True self.scale = round(math.sqrt(self.scale / 4)) self.model_arch = "ESRGAN-2c2" self.supports_fp16 = True self.supports_bfp16 = True self.min_size_restriction = None # Detect if pixelunshuffle was used (Real-ESRGAN) if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in ( self.in_nc / 4, self.in_nc / 16, ): self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc)) else: self.shuffle_factor = None upsample_block = { "upconv": B.upconv_block, "pixel_shuffle": B.pixelshuffle_block, }.get(self.upsampler) if upsample_block is None: raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found") if self.scale == 3: upsample_blocks = upsample_block( in_nc=self.num_filters, out_nc=self.num_filters, upscale_factor=3, act_type=self.act, c2x2=c2x2, ) else: upsample_blocks = [ upsample_block( in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act, c2x2=c2x2, ) for _ in range(int(math.log(self.scale, 2))) ] self.model = B.sequential( # fea conv B.conv_block( in_nc=self.in_nc, out_nc=self.num_filters, kernel_size=3, norm_type=None, act_type=None, c2x2=c2x2, ), B.ShortcutBlock( B.sequential( # rrdb blocks *[ B.RRDB( nf=self.num_filters, kernel_size=3, gc=32, stride=1, bias=True, pad_type="zero", norm_type=self.norm, act_type=self.act, mode="CNA", plus=self.plus, c2x2=c2x2, ) for _ in range(self.num_blocks) ], # lr conv B.conv_block( in_nc=self.num_filters, out_nc=self.num_filters, kernel_size=3, norm_type=self.norm, act_type=None, mode=self.mode, c2x2=c2x2, ), ) ), *upsample_blocks, # hr_conv0 B.conv_block( in_nc=self.num_filters, out_nc=self.num_filters, kernel_size=3, norm_type=None, act_type=self.act, c2x2=c2x2, ), # hr_conv1 B.conv_block( in_nc=self.num_filters, out_nc=self.out_nc, kernel_size=3, norm_type=None, act_type=None, c2x2=c2x2, ), ) # Adjust these properties for calculations outside of the model if self.shuffle_factor: self.in_nc //= self.shuffle_factor**2 self.scale //= self.shuffle_factor self.load_state_dict(self.state, strict=False) def new_to_old_arch(self, state): """Convert a new-arch model state dictionary to an old-arch dictionary.""" if "params_ema" in state: state = state["params_ema"] if "conv_first.weight" not in state: # model is already old arch, this is a loose check, but should be sufficient return state # add nb to state keys for kind in ("weight", "bias"): self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ f"model.1.sub./NB/.{kind}" ] del self.state_map[f"model.1.sub./NB/.{kind}"] old_state = OrderedDict() for old_key, new_keys in self.state_map.items(): for new_key in new_keys: if r"\1" in old_key: for k, v in state.items(): sub = re.sub(new_key, old_key, k) if sub != k: old_state[sub] = v else: if new_key in state: old_state[old_key] = state[new_key] # upconv layers max_upconv = 0 for key in state.keys(): match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key) if match is not None: _, key_num, key_type = match.groups() old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key] max_upconv = max(max_upconv, int(key_num) * 3) # final layers for key in state.keys(): if key in ("HRconv.weight", "conv_hr.weight"): old_state[f"model.{max_upconv + 2}.weight"] = state[key] elif key in ("HRconv.bias", "conv_hr.bias"): old_state[f"model.{max_upconv + 2}.bias"] = state[key] elif key in ("conv_last.weight",): old_state[f"model.{max_upconv + 4}.weight"] = state[key] elif key in ("conv_last.bias",): old_state[f"model.{max_upconv + 4}.bias"] = state[key] # Sort by first numeric value of each layer def compare(item1, item2): parts1 = item1.split(".") parts2 = item2.split(".") int1 = int(parts1[1]) int2 = int(parts2[1]) return int1 - int2 sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) # Rebuild the output dict in the right order out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) return out_dict def get_scale(self, min_part: int = 6) -> int: n = 0 for part in list(self.state): parts = part.split(".")[1:] if len(parts) == 2: part_num = int(parts[0]) if part_num > min_part and parts[1] == "weight": n += 1 return 2**n def get_num_blocks(self) -> int: nbs = [] state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", ) for state_key in state_keys: for k in self.state: m = re.search(state_key, k) if m: nbs.append(int(m.group(1))) if nbs: break return max(*nbs) + 1 def forward(self, x): if self.shuffle_factor: _, _, h, w = x.size() mod_pad_h = ( self.shuffle_factor - h % self.shuffle_factor ) % self.shuffle_factor mod_pad_w = ( self.shuffle_factor - w % self.shuffle_factor ) % self.shuffle_factor x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor) x = self.model(x) return x[:, :, : h * self.scale, : w * self.scale] return self.model(x)