# --------------------------------------------------------------------------------
# StableSR for Automatic1111 WebUI
# Introducing state-of-the super-resolution method: StableSR!
# Techniques is originally proposed by my schoolmate Jianyi Wang et, al.
# Project Page:
# Official Repo:
# Paper:
# @original author: Jianyi Wang et, al.
# @migration: LI YI
# @organization: Nanyang Technological University - Singapore
# @date: 2023-05-20
# @license:
# S-Lab License 1.0 (see LICENSE file)
# CC BY-NC-SA 4.0 (required by NVIDIA SPADE module)
# @disclaimer:
# All code in this extension is for research purpose only.
# The commercial use of the code & checkpoint is strictly prohibited.
# --------------------------------------------------------------------------------
# - Please be aware that the CC BY-NC-SA 4.0 license in SPADE module
# also prohibits the commercial use of outcome images.
# - Jianyi Wang may change the SPADE module to a commercial-friendly one.
# If you want to use the outcome images for commercial purposes, please
# contact Jianyi Wang for more information.
# Please give me a star (and also Jianyi's repo) if you like this project!
# --------------------------------------------------------------------------------
import os
import torch
import gradio as gr
import numpy as np
import PIL.Image as Image
from pathlib import Path
from torch import Tensor
from tqdm import tqdm
from modules import scripts, processing, sd_samplers, devices, images, shared
from modules.processing import StableDiffusionProcessingImg2Img, Processed
from modules.shared import opts
from ldm.modules.diffusionmodules.openaimodel import UNetModel
from srmodule.spade import SPADELayers
from srmodule.struct_cond import EncoderUNetModelWT, build_unetwt
from srmodule.colorfix import adain_color_fix, wavelet_color_fix
SD_WEBUI_PATH = Path.cwd()
ME_PATH = SD_WEBUI_PATH / 'extensions' / 'sd-webui-stablesr'
MODEL_PATH = ME_PATH / 'models'
FORWARD_CACHE_NAME = 'org_forward_stablesr'
class StableSR:
def __init__(self, path, dtype, device):
state_dict = torch.load(path, map_location='cpu')
self.struct_cond_model: EncoderUNetModelWT = build_unetwt()
self.spade_layers: SPADELayers = SPADELayers()
del state_dict
self.struct_cond_model.apply(lambda x:, device=device))
self.spade_layers.apply(lambda x:, device=device))
self.latent_image: Tensor = None
self.set_image_hooks = {}
self.struct_cond: Tensor = None
def set_latent_image(self, latent_image):
self.latent_image = latent_image
for hook in self.set_image_hooks.values():
def hook(self, unet: UNetModel):
# hook unet to set the struct_cond
if not hasattr(unet, FORWARD_CACHE_NAME):
setattr(unet, FORWARD_CACHE_NAME, unet.forward)
def unet_forward(x, timesteps=None, context=None, y=None,**kwargs):
self.latent_image =
# Ensure the device of all modules layers is the same as the unet
# This will fix the issue when user use --medvram or --lowvram
timesteps =
self.struct_cond = None # mitigate vram peak
self.struct_cond = self.struct_cond_model(self.latent_image, timesteps[:self.latent_image.shape[0]])
return getattr(unet, FORWARD_CACHE_NAME)(x, timesteps, context, y, **kwargs)
unet.forward = unet_forward
self.spade_layers.hook(unet, lambda: self.struct_cond)
def unhook(self, unet: UNetModel):
# clean up cache
self.latent_image = None
self.struct_cond = None
self.set_image_hooks = {}
# unhook unet forward
if hasattr(unet, FORWARD_CACHE_NAME):
unet.forward = getattr(unet, FORWARD_CACHE_NAME)
delattr(unet, FORWARD_CACHE_NAME)
# unhook spade layers
class Script(scripts.Script):
def __init__(self) -> None:
self.model_list = {}
self.last_path = None
self.stablesr_model: StableSR = None
def load_model_list(self):
# traverse the CFG_PATH and add all files to the model list
self.model_list = {}
if not MODEL_PATH.exists():
for file in MODEL_PATH.iterdir():
if file.is_file():
# save tha absolute path
self.model_list[] = str(file.absolute())
self.model_list['None'] = None
def title(self):
return "StableSR"
def show(self, is_img2img):
return is_img2img
def ui(self, is_img2img):
with gr.Row():
model = gr.Dropdown(list(self.model_list.keys()), label="SR Model")
refresh = gr.Button(value='↻', variant='tool')
def refresh_fn(selected):
if selected not in self.model_list:
selected = 'None'
return gr.Dropdown.update(value=selected, choices=list(self.model_list.keys())),inputs=model, outputs=model)
with gr.Row():
scale_factor = gr.Slider(minimum=1, maximum=16, step=0.1, value=2, label='Scale Factor', elem_id=f'StableSR-scale')
with gr.Row():
color_fix = gr.Dropdown(['None', 'Wavelet', 'AdaIN'], label="Color Fix", value='Wavelet', elem_id=f'StableSR-color-fix')
save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None')
color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None'), inputs=color_fix, outputs=save_original, show_progress=False)
pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise')
unload_model= gr.Button(value='Unload Model', variant='tool')
def unload_model_fn():
if self.stablesr_model is not None:
self.stablesr_model = None
print('[StableSR] Model unloaded!')
print('[StableSR] No model loaded.')
return [model, scale_factor, pure_noise, color_fix, save_original]
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed:
if model == 'None':
# do clean up
self.stablesr_model = None
self.last_model_path = None
if model not in self.model_list:
raise gr.Error(f"Model {model} is not in the list! Please refresh your browser!")
if not os.path.exists(self.model_list[model]):
raise gr.Error(f"Model {model} is not on your disk! Please refresh the model list!")
if color_fix not in ['None', 'Wavelet', 'AdaIN']:
print(f'[StableSR] Invalid color fix method: {color_fix}')
color_fix = 'None'
# upscale the image, set the ouput size
init_img: Image = p.init_images[0]
target_width = int(init_img.width * scale_factor)
target_height = int(init_img.height * scale_factor)
# if the target width is not dividable by 8, then round it up
if target_width % 8 != 0:
target_width = target_width + 8 - target_width % 8
# if the target height is not dividable by 8, then round it up
if target_height % 8 != 0:
target_height = target_height + 8 - target_height % 8
init_img = init_img.resize((target_width, target_height), Image.LANCZOS)
p.init_images[0] = init_img
p.width = init_img.width
p.height = init_img.height
print('[StableSR] Target image size: {}x{}'.format(init_img.width, init_img.height))
first_param = shared.sd_model.parameters().__next__()
if self.last_path != self.model_list[model]:
# load the model
self.stablesr_model = None
if self.stablesr_model is None:
self.stablesr_model = StableSR(self.model_list[model], dtype=first_param.dtype, device=first_param.device)
self.last_path = self.model_list[model]
def sample_custom(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
unet: UNetModel = shared.sd_model.model.diffusion_model
x = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
if pure_noise:
# NOTE: use txt2img instead of img2img sampling
samples = sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
if p.initial_noise_multiplier != 1.0:
p.extra_generation_params["Noise multiplier"] =p.initial_noise_multiplier
x *= p.initial_noise_multiplier
samples = sampler.sample_img2img(p, p.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
if p.mask is not None:
samples = samples * p.nmask + p.init_latent * p.mask
del x
return samples
# in --medvram and --lowvram mode, we send the model back to the initial device
# replace the sample function
p.sample = sample_custom
if color_fix != 'None':
p.do_not_save_samples = True
result: Processed = processing.process_images(p)
if color_fix != 'None':
fixed_images = []
# fix the color
color_fix_func = wavelet_color_fix if color_fix == 'Wavelet' else adain_color_fix
for i in range(len(result.images)):
fixed_images.append(color_fix_func(result.images[i], init_img))
except Exception as e:
print(f'[StableSR] Error fixing color with default method: {e}')
# save the fixed color images
for i in range(len(fixed_images)):
images.save_image(fixed_images[i], p.outpath_samples, "", p.all_seeds[i], p.all_prompts[i], opts.samples_format, info=result.infotexts[i], p=p)
except Exception as e:
print(f'[StableSR] Error saving color fixed image: {e}')
if save_original:
for i in range(len(result.images)):
images.save_image(result.images[i], p.outpath_samples, "", p.all_seeds[i], p.all_prompts[i], opts.samples_format, info=result.infotexts[i], p=p, suffix="-before-color-fix")
except Exception as e:
print(f'[StableSR] Error saving original image: {e}')
result.images = result.images + fixed_images
return result