diff --git "a/upscaler/RealESRGAN/model.py" "b/upscaler/RealESRGAN/model.py" --- "a/upscaler/RealESRGAN/model.py" +++ "b/upscaler/RealESRGAN/model.py" @@ -1,4049 +1,113 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Swapm/upscaler/RealESRGAN/model.py at main · G-force78/Swapm - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
- - - -
- Skip to content - - - - - - - - - - - - - - - - - - - -
-
- - - - - - - - - - -
- - -
-
- - - - - - -
-
-
-

- Global navigation -

-
-
-
-
- -
-
-
- -
- - -
- - - - -
-
-
- - -

© 2024 GitHub, Inc.

- -
-
-
- -
- -
-
- - - -
-
- - - - -
-
-
-

- Navigate back to -

-
-
- -
-
-
- - - - -
-
- -
- -
-
- -
-
- - -
- - - Create new... - - - - - - - -Issues - - -Pull requests - -
- - - - - - - Notifications - - - - -
- - - - - - - -
-
-
-

- Account menu -

-
-
- - - G-force78 - -
- - - Create new... - - - - - -
-
-
-
- -
-
-
- -
- - - -
-
-
- -
-
- -
-
- -
- -
- - - - - -
-
-
- - - -
- - - - -
- -
- - - - - - - - -
- - - - - - -
- - - - - - - - - - -
- - - - - - - - - - - - - - - - - -
- -
- - - - G-force78  /   - Swapm  /   - -
-
- - - -
- - -
-
- Clear Command Palette -
-
- - - -
-
- Tip: - Type # to search pull requests -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type # to search issues -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type # to search discussions -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type ! to search projects -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type @ to search teams -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type @ to search people and organizations -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type > to activate command mode -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Go to your accessibility settings to change your keyboard shortcuts -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type author:@me to search your content -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type is:pr to filter to pull requests -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type is:issue to filter to issues -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type is:project to filter to projects -
-
- Type ? for help and tips -
-
-
- -
-
- Tip: - Type is:open to filter to open content -
-
- Type ? for help and tips -
-
-
- -
- -
-
- We’ve encountered an error and some results aren't available at this time. Type a new search or try again later. -
-
- - No results matched your search - - - - - - - - - - -
- - - - - Search for issues and pull requests - - # - - - - Search for issues, pull requests, discussions, and projects - - # - - - - Search for organizations, repositories, and users - - @ - - - - Search for projects - - ! - - - - Search for files - - / - - - - Activate command mode - - > - - - - Search your issues, pull requests, and discussions - - # author:@me - - - - Search your issues, pull requests, and discussions - - # author:@me - - - - Filter to pull requests - - # is:pr - - - - Filter to issues - - # is:issue - - - - Filter to discussions - - # is:discussion - - - - Filter to projects - - # is:project - - - - Filter to open issues, pull requests, and discussions - - # is:open - - - - - - - - - - - - - - - - -
-
-
- -
- - - - - - - - - - -
- - -
-
-
- - - -
- This repository has been archived by the owner on Feb 17, 2024. It is now read-only. -
- - - - - - - - - - - - - -
- Open in github.dev - Open in a new github.dev tab - Open in codespace - - - - - - - - - - - - - - - -

Files

t

Latest commit

 

History

History
93 lines (80 loc) · 3.33 KB

model.py

File metadata and controls

93 lines (80 loc) · 3.33 KB

Symbols

Find definitions and references for functions and other symbols in this file by clicking a symbol below or in the code.
r
  • const
    HF_MODELS
  • class
    RealESRGAN
    • func
      __init__
    • func
      load_weights
    • func
      predict
-
- - - - -
- -
- -
-
- -
- - - - - - - - - - - - - - - - - - - - - - -
- -
-
- - - +import os +import torch +from torch.nn import functional as F +from PIL import Image +import numpy as np +import cv2 + +from .rrdbnet_arch import RRDBNet +from .utils import ( + pad_reflect, + split_image_into_overlapping_patches, + stich_together, + unpad_image, +) + + +HF_MODELS = { + 2: dict( + repo_id="sberbank-ai/Real-ESRGAN", + filename="RealESRGAN_x2.pth", + ), + 4: dict( + repo_id="sberbank-ai/Real-ESRGAN", + filename="RealESRGAN_x4.pth", + ), + 8: dict( + repo_id="sberbank-ai/Real-ESRGAN", + filename="RealESRGAN_x8.pth", + ), +} + + +class RealESRGAN: + def __init__(self, device, scale=4): + self.device = device + self.scale = scale + self.model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=scale, + ) + + def load_weights(self, model_path, download=True): + if not os.path.exists(model_path) and download: + from huggingface_hub import hf_hub_url, cached_download + + assert self.scale in [ + 2, + 4, + 8, + ], "You can download models only with scales: 2, 4, 8" + config = HF_MODELS[self.scale] + cache_dir = os.path.dirname(model_path) + local_filename = os.path.basename(model_path) + config_file_url = hf_hub_url( + repo_id=config["repo_id"], filename=config["filename"] + ) + cached_download( + config_file_url, cache_dir=cache_dir, force_filename=local_filename + ) + print("Weights downloaded to:", os.path.join(cache_dir, local_filename)) + + if self.device == "cpu": + loadnet = torch.load(model_path, map_location="cpu") + else: + loadnet = torch.load(model_path) + if "params" in loadnet: + self.model.load_state_dict(loadnet["params"], strict=True) + elif "params_ema" in loadnet: + self.model.load_state_dict(loadnet["params_ema"], strict=True) + else: + self.model.load_state_dict(loadnet, strict=True) + self.model.eval() + self.model.to(self.device) + + @torch.cuda.amp.autocast() + def predict( + self, lr_image, batch_size=4, patches_size=192, padding=24, pad_size=15 + ): + scale = self.scale + device = self.device + lr_image = np.array(lr_image) + lr_image = pad_reflect(lr_image, pad_size) + + patches, p_shape = split_image_into_overlapping_patches( + lr_image, patch_size=patches_size, padding_size=padding + ) + img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach() + + with torch.no_grad(): + res = self.model(img[0:batch_size]) + for i in range(batch_size, img.shape[0], batch_size): + res = torch.cat((res, self.model(img[i : i + batch_size])), 0) + + sr_image = res.permute((0, 2, 3, 1)).clamp_(0, 1).cpu() + np_sr_image = sr_image.numpy() + + padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,) + scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,) + np_sr_image = stich_together( + np_sr_image, + padded_image_shape=padded_size_scaled, + target_shape=scaled_image_shape, + padding_size=padding * scale, + ) + sr_img = (np_sr_image * 255).astype(np.uint8) + sr_img = unpad_image(sr_img, pad_size * scale) + # sr_img = Image.fromarray(sr_img) + + return sr_img