|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import functools | 
					
						
						|  | import math | 
					
						
						|  | import re | 
					
						
						|  | from collections import OrderedDict | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  | from . import block as B | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RRDBNet(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | state_dict, | 
					
						
						|  | norm=None, | 
					
						
						|  | act: str = "leakyrelu", | 
					
						
						|  | upsampler: str = "upconv", | 
					
						
						|  | mode: B.ConvMode = "CNA", | 
					
						
						|  | ) -> None: | 
					
						
						|  | """ | 
					
						
						|  | ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks. | 
					
						
						|  | By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, | 
					
						
						|  | and Chen Change Loy. | 
					
						
						|  | This is old-arch Residual in Residual Dense Block Network and is not | 
					
						
						|  | the newest revision that's available at github.com/xinntao/ESRGAN. | 
					
						
						|  | This is on purpose, the newest Network has severely limited the | 
					
						
						|  | potential use of the Network with no benefits. | 
					
						
						|  | This network supports model files from both new and old-arch. | 
					
						
						|  | Args: | 
					
						
						|  | norm: Normalization layer | 
					
						
						|  | act: Activation layer | 
					
						
						|  | upsampler: Upsample layer. upconv, pixel_shuffle | 
					
						
						|  | mode: Convolution mode | 
					
						
						|  | """ | 
					
						
						|  | super(RRDBNet, self).__init__() | 
					
						
						|  | self.model_arch = "ESRGAN" | 
					
						
						|  | self.sub_type = "SR" | 
					
						
						|  |  | 
					
						
						|  | self.state = state_dict | 
					
						
						|  | self.norm = norm | 
					
						
						|  | self.act = act | 
					
						
						|  | self.upsampler = upsampler | 
					
						
						|  | self.mode = mode | 
					
						
						|  |  | 
					
						
						|  | self.state_map = { | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | "model.0.weight": ("conv_first.weight",), | 
					
						
						|  | "model.0.bias": ("conv_first.bias",), | 
					
						
						|  | "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), | 
					
						
						|  | "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), | 
					
						
						|  | r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( | 
					
						
						|  | r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", | 
					
						
						|  | r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", | 
					
						
						|  | ), | 
					
						
						|  | } | 
					
						
						|  | if "params_ema" in self.state: | 
					
						
						|  | self.state = self.state["params_ema"] | 
					
						
						|  |  | 
					
						
						|  | self.num_blocks = self.get_num_blocks() | 
					
						
						|  | self.plus = any("conv1x1" in k for k in self.state.keys()) | 
					
						
						|  | if self.plus: | 
					
						
						|  | self.model_arch = "ESRGAN+" | 
					
						
						|  |  | 
					
						
						|  | self.state = self.new_to_old_arch(self.state) | 
					
						
						|  |  | 
					
						
						|  | self.key_arr = list(self.state.keys()) | 
					
						
						|  |  | 
					
						
						|  | self.in_nc: int = self.state[self.key_arr[0]].shape[1] | 
					
						
						|  | self.out_nc: int = self.state[self.key_arr[-1]].shape[0] | 
					
						
						|  |  | 
					
						
						|  | self.scale: int = self.get_scale() | 
					
						
						|  | self.num_filters: int = self.state[self.key_arr[0]].shape[0] | 
					
						
						|  |  | 
					
						
						|  | c2x2 = False | 
					
						
						|  | if self.state["model.0.weight"].shape[-2] == 2: | 
					
						
						|  | c2x2 = True | 
					
						
						|  | self.scale = round(math.sqrt(self.scale / 4)) | 
					
						
						|  | self.model_arch = "ESRGAN-2c2" | 
					
						
						|  |  | 
					
						
						|  | self.supports_fp16 = True | 
					
						
						|  | self.supports_bfp16 = True | 
					
						
						|  | self.min_size_restriction = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in ( | 
					
						
						|  | self.in_nc / 4, | 
					
						
						|  | self.in_nc / 16, | 
					
						
						|  | ): | 
					
						
						|  | self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc)) | 
					
						
						|  | else: | 
					
						
						|  | self.shuffle_factor = None | 
					
						
						|  |  | 
					
						
						|  | upsample_block = { | 
					
						
						|  | "upconv": B.upconv_block, | 
					
						
						|  | "pixel_shuffle": B.pixelshuffle_block, | 
					
						
						|  | }.get(self.upsampler) | 
					
						
						|  | if upsample_block is None: | 
					
						
						|  | raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found") | 
					
						
						|  |  | 
					
						
						|  | if self.scale == 3: | 
					
						
						|  | upsample_blocks = upsample_block( | 
					
						
						|  | in_nc=self.num_filters, | 
					
						
						|  | out_nc=self.num_filters, | 
					
						
						|  | upscale_factor=3, | 
					
						
						|  | act_type=self.act, | 
					
						
						|  | c2x2=c2x2, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | upsample_blocks = [ | 
					
						
						|  | upsample_block( | 
					
						
						|  | in_nc=self.num_filters, | 
					
						
						|  | out_nc=self.num_filters, | 
					
						
						|  | act_type=self.act, | 
					
						
						|  | c2x2=c2x2, | 
					
						
						|  | ) | 
					
						
						|  | for _ in range(int(math.log(self.scale, 2))) | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | self.model = B.sequential( | 
					
						
						|  |  | 
					
						
						|  | B.conv_block( | 
					
						
						|  | in_nc=self.in_nc, | 
					
						
						|  | out_nc=self.num_filters, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | norm_type=None, | 
					
						
						|  | act_type=None, | 
					
						
						|  | c2x2=c2x2, | 
					
						
						|  | ), | 
					
						
						|  | B.ShortcutBlock( | 
					
						
						|  | B.sequential( | 
					
						
						|  |  | 
					
						
						|  | *[ | 
					
						
						|  | B.RRDB( | 
					
						
						|  | nf=self.num_filters, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | gc=32, | 
					
						
						|  | stride=1, | 
					
						
						|  | bias=True, | 
					
						
						|  | pad_type="zero", | 
					
						
						|  | norm_type=self.norm, | 
					
						
						|  | act_type=self.act, | 
					
						
						|  | mode="CNA", | 
					
						
						|  | plus=self.plus, | 
					
						
						|  | c2x2=c2x2, | 
					
						
						|  | ) | 
					
						
						|  | for _ in range(self.num_blocks) | 
					
						
						|  | ], | 
					
						
						|  |  | 
					
						
						|  | B.conv_block( | 
					
						
						|  | in_nc=self.num_filters, | 
					
						
						|  | out_nc=self.num_filters, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | norm_type=self.norm, | 
					
						
						|  | act_type=None, | 
					
						
						|  | mode=self.mode, | 
					
						
						|  | c2x2=c2x2, | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | ), | 
					
						
						|  | *upsample_blocks, | 
					
						
						|  |  | 
					
						
						|  | B.conv_block( | 
					
						
						|  | in_nc=self.num_filters, | 
					
						
						|  | out_nc=self.num_filters, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | norm_type=None, | 
					
						
						|  | act_type=self.act, | 
					
						
						|  | c2x2=c2x2, | 
					
						
						|  | ), | 
					
						
						|  |  | 
					
						
						|  | B.conv_block( | 
					
						
						|  | in_nc=self.num_filters, | 
					
						
						|  | out_nc=self.out_nc, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | norm_type=None, | 
					
						
						|  | act_type=None, | 
					
						
						|  | c2x2=c2x2, | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.shuffle_factor: | 
					
						
						|  | self.in_nc //= self.shuffle_factor**2 | 
					
						
						|  | self.scale //= self.shuffle_factor | 
					
						
						|  |  | 
					
						
						|  | self.load_state_dict(self.state, strict=False) | 
					
						
						|  |  | 
					
						
						|  | def new_to_old_arch(self, state): | 
					
						
						|  | """Convert a new-arch model state dictionary to an old-arch dictionary.""" | 
					
						
						|  | if "params_ema" in state: | 
					
						
						|  | state = state["params_ema"] | 
					
						
						|  |  | 
					
						
						|  | if "conv_first.weight" not in state: | 
					
						
						|  |  | 
					
						
						|  | return state | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for kind in ("weight", "bias"): | 
					
						
						|  | self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ | 
					
						
						|  | f"model.1.sub./NB/.{kind}" | 
					
						
						|  | ] | 
					
						
						|  | del self.state_map[f"model.1.sub./NB/.{kind}"] | 
					
						
						|  |  | 
					
						
						|  | old_state = OrderedDict() | 
					
						
						|  | for old_key, new_keys in self.state_map.items(): | 
					
						
						|  | for new_key in new_keys: | 
					
						
						|  | if r"\1" in old_key: | 
					
						
						|  | for k, v in state.items(): | 
					
						
						|  | sub = re.sub(new_key, old_key, k) | 
					
						
						|  | if sub != k: | 
					
						
						|  | old_state[sub] = v | 
					
						
						|  | else: | 
					
						
						|  | if new_key in state: | 
					
						
						|  | old_state[old_key] = state[new_key] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | max_upconv = 0 | 
					
						
						|  | for key in state.keys(): | 
					
						
						|  | match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key) | 
					
						
						|  | if match is not None: | 
					
						
						|  | _, key_num, key_type = match.groups() | 
					
						
						|  | old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key] | 
					
						
						|  | max_upconv = max(max_upconv, int(key_num) * 3) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for key in state.keys(): | 
					
						
						|  | if key in ("HRconv.weight", "conv_hr.weight"): | 
					
						
						|  | old_state[f"model.{max_upconv + 2}.weight"] = state[key] | 
					
						
						|  | elif key in ("HRconv.bias", "conv_hr.bias"): | 
					
						
						|  | old_state[f"model.{max_upconv + 2}.bias"] = state[key] | 
					
						
						|  | elif key in ("conv_last.weight",): | 
					
						
						|  | old_state[f"model.{max_upconv + 4}.weight"] = state[key] | 
					
						
						|  | elif key in ("conv_last.bias",): | 
					
						
						|  | old_state[f"model.{max_upconv + 4}.bias"] = state[key] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def compare(item1, item2): | 
					
						
						|  | parts1 = item1.split(".") | 
					
						
						|  | parts2 = item2.split(".") | 
					
						
						|  | int1 = int(parts1[1]) | 
					
						
						|  | int2 = int(parts2[1]) | 
					
						
						|  | return int1 - int2 | 
					
						
						|  |  | 
					
						
						|  | sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) | 
					
						
						|  |  | 
					
						
						|  | return out_dict | 
					
						
						|  |  | 
					
						
						|  | def get_scale(self, min_part: int = 6) -> int: | 
					
						
						|  | n = 0 | 
					
						
						|  | for part in list(self.state): | 
					
						
						|  | parts = part.split(".")[1:] | 
					
						
						|  | if len(parts) == 2: | 
					
						
						|  | part_num = int(parts[0]) | 
					
						
						|  | if part_num > min_part and parts[1] == "weight": | 
					
						
						|  | n += 1 | 
					
						
						|  | return 2**n | 
					
						
						|  |  | 
					
						
						|  | def get_num_blocks(self) -> int: | 
					
						
						|  | nbs = [] | 
					
						
						|  | state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( | 
					
						
						|  | r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", | 
					
						
						|  | ) | 
					
						
						|  | for state_key in state_keys: | 
					
						
						|  | for k in self.state: | 
					
						
						|  | m = re.search(state_key, k) | 
					
						
						|  | if m: | 
					
						
						|  | nbs.append(int(m.group(1))) | 
					
						
						|  | if nbs: | 
					
						
						|  | break | 
					
						
						|  | return max(*nbs) + 1 | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | if self.shuffle_factor: | 
					
						
						|  | _, _, h, w = x.size() | 
					
						
						|  | mod_pad_h = ( | 
					
						
						|  | self.shuffle_factor - h % self.shuffle_factor | 
					
						
						|  | ) % self.shuffle_factor | 
					
						
						|  | mod_pad_w = ( | 
					
						
						|  | self.shuffle_factor - w % self.shuffle_factor | 
					
						
						|  | ) % self.shuffle_factor | 
					
						
						|  | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") | 
					
						
						|  | x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor) | 
					
						
						|  | x = self.model(x) | 
					
						
						|  | return x[:, :, : h * self.scale, : w * self.scale] | 
					
						
						|  | return self.model(x) | 
					
						
						|  |  |