a76b4ccb8bb91c9e097ac910cbd005a5983176e8658571782cad9f042c9de31e
Browse files- extensions-builtin/ScuNET/__pycache__/scunet_model_arch.cpython-310.pyc +0 -0
- extensions-builtin/ScuNET/preload.py +6 -0
- extensions-builtin/ScuNET/scripts/__pycache__/scunet_model.cpython-310.pyc +0 -0
- extensions-builtin/ScuNET/scripts/scunet_model.py +144 -0
- extensions-builtin/ScuNET/scunet_model_arch.py +268 -0
- extensions-builtin/SwinIR/__pycache__/preload.cpython-310.pyc +0 -0
- extensions-builtin/SwinIR/__pycache__/swinir_model_arch.cpython-310.pyc +0 -0
- extensions-builtin/SwinIR/__pycache__/swinir_model_arch_v2.cpython-310.pyc +0 -0
- extensions-builtin/SwinIR/preload.py +6 -0
- extensions-builtin/SwinIR/scripts/__pycache__/swinir_model.cpython-310.pyc +0 -0
- extensions-builtin/SwinIR/scripts/swinir_model.py +192 -0
- extensions-builtin/SwinIR/swinir_model_arch.py +867 -0
- extensions-builtin/SwinIR/swinir_model_arch_v2.py +1017 -0
- extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +776 -0
- extensions-builtin/canvas-zoom-and-pan/scripts/__pycache__/hotkey_config.cpython-310.pyc +0 -0
- extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +14 -0
- extensions-builtin/canvas-zoom-and-pan/style.css +63 -0
- extensions-builtin/extra-options-section/scripts/__pycache__/extra_options_section.cpython-310.pyc +0 -0
- extensions-builtin/extra-options-section/scripts/extra_options_section.py +48 -0
- extensions-builtin/mobile/javascript/mobile.js +26 -0
- extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +42 -0
- extensions-builtin/sd_theme_editor/install.py +1 -0
- extensions-builtin/sd_theme_editor/javascript/ui_theme.js +435 -0
- extensions-builtin/sd_theme_editor/scripts/__pycache__/ui_theme.cpython-310.pyc +0 -0
- extensions-builtin/sd_theme_editor/scripts/ui_theme.py +177 -0
- extensions-builtin/sd_theme_editor/style.css +113 -0
- extensions-builtin/sd_theme_editor/themes/Golde.css +1 -0
- extensions-builtin/sd_theme_editor/themes/backup.css +1 -0
- extensions-builtin/sd_theme_editor/themes/d-230-52-94.css +1 -0
- extensions-builtin/sd_theme_editor/themes/default.css +1 -0
- extensions-builtin/sd_theme_editor/themes/default_cyan.css +1 -0
- extensions-builtin/sd_theme_editor/themes/default_orange.css +1 -0
- extensions-builtin/sd_theme_editor/themes/fun.css +1 -0
- extensions-builtin/sd_theme_editor/themes/minimal.css +1 -0
- extensions-builtin/sd_theme_editor/themes/minimal_orange.css +1 -0
- extensions-builtin/sd_theme_editor/themes/moonlight.css +1 -0
- extensions-builtin/sd_theme_editor/themes/ogxBGreen.css +1 -0
- extensions-builtin/sd_theme_editor/themes/ogxCyan.css +1 -0
- extensions-builtin/sd_theme_editor/themes/ogxCyanInvert.css +1 -0
- extensions-builtin/sd_theme_editor/themes/ogxGreen.css +1 -0
- extensions-builtin/sd_theme_editor/themes/ogxRed.css +1 -0
- extensions-builtin/sd_theme_editor/themes/retrog.css +1 -0
- extensions-builtin/sd_theme_editor/themes/tron.css +1 -0
- extensions-builtin/sd_theme_editor/themes/tron2.css +1 -0
- html/200w.webp +0 -0
- html/card-no-preview.png +0 -0
- html/extra-networks-card.html +18 -0
- html/extra-networks-no-cards.html +8 -0
- html/favicon.ico +0 -0
- html/footer.html +55 -0
extensions-builtin/ScuNET/__pycache__/scunet_model_arch.cpython-310.pyc
ADDED
Binary file (9.27 kB). View file
|
|
extensions-builtin/ScuNET/preload.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
|
extensions-builtin/ScuNET/scripts/__pycache__/scunet_model.cpython-310.pyc
ADDED
Binary file (5.02 kB). View file
|
|
extensions-builtin/ScuNET/scripts/scunet_model.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import PIL.Image
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
import modules.upscaler
|
9 |
+
from modules import devices, modelloader, script_callbacks, errors
|
10 |
+
from scunet_model_arch import SCUNet
|
11 |
+
|
12 |
+
from modules.modelloader import load_file_from_url
|
13 |
+
from modules.shared import opts
|
14 |
+
|
15 |
+
|
16 |
+
class UpscalerScuNET(modules.upscaler.Upscaler):
|
17 |
+
def __init__(self, dirname):
|
18 |
+
self.name = "ScuNET"
|
19 |
+
self.model_name = "ScuNET GAN"
|
20 |
+
self.model_name2 = "ScuNET PSNR"
|
21 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
|
22 |
+
self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
|
23 |
+
self.user_path = dirname
|
24 |
+
super().__init__()
|
25 |
+
model_paths = self.find_models(ext_filter=[".pth"])
|
26 |
+
scalers = []
|
27 |
+
add_model2 = True
|
28 |
+
for file in model_paths:
|
29 |
+
if file.startswith("http"):
|
30 |
+
name = self.model_name
|
31 |
+
else:
|
32 |
+
name = modelloader.friendly_name(file)
|
33 |
+
if name == self.model_name2 or file == self.model_url2:
|
34 |
+
add_model2 = False
|
35 |
+
try:
|
36 |
+
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
37 |
+
scalers.append(scaler_data)
|
38 |
+
except Exception:
|
39 |
+
errors.report(f"Error loading ScuNET model: {file}", exc_info=True)
|
40 |
+
if add_model2:
|
41 |
+
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
|
42 |
+
scalers.append(scaler_data2)
|
43 |
+
self.scalers = scalers
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
@torch.no_grad()
|
47 |
+
def tiled_inference(img, model):
|
48 |
+
# test the image tile by tile
|
49 |
+
h, w = img.shape[2:]
|
50 |
+
tile = opts.SCUNET_tile
|
51 |
+
tile_overlap = opts.SCUNET_tile_overlap
|
52 |
+
if tile == 0:
|
53 |
+
return model(img)
|
54 |
+
|
55 |
+
device = devices.get_device_for('scunet')
|
56 |
+
assert tile % 8 == 0, "tile size should be a multiple of window_size"
|
57 |
+
sf = 1
|
58 |
+
|
59 |
+
stride = tile - tile_overlap
|
60 |
+
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
61 |
+
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
62 |
+
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
|
63 |
+
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
|
64 |
+
|
65 |
+
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
|
66 |
+
for h_idx in h_idx_list:
|
67 |
+
|
68 |
+
for w_idx in w_idx_list:
|
69 |
+
|
70 |
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
71 |
+
|
72 |
+
out_patch = model(in_patch)
|
73 |
+
out_patch_mask = torch.ones_like(out_patch)
|
74 |
+
|
75 |
+
E[
|
76 |
+
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
77 |
+
].add_(out_patch)
|
78 |
+
W[
|
79 |
+
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
80 |
+
].add_(out_patch_mask)
|
81 |
+
pbar.update(1)
|
82 |
+
output = E.div_(W)
|
83 |
+
|
84 |
+
return output
|
85 |
+
|
86 |
+
def do_upscale(self, img: PIL.Image.Image, selected_file):
|
87 |
+
|
88 |
+
devices.torch_gc()
|
89 |
+
|
90 |
+
try:
|
91 |
+
model = self.load_model(selected_file)
|
92 |
+
except Exception as e:
|
93 |
+
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
94 |
+
return img
|
95 |
+
|
96 |
+
device = devices.get_device_for('scunet')
|
97 |
+
tile = opts.SCUNET_tile
|
98 |
+
h, w = img.height, img.width
|
99 |
+
np_img = np.array(img)
|
100 |
+
np_img = np_img[:, :, ::-1] # RGB to BGR
|
101 |
+
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
|
102 |
+
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
|
103 |
+
|
104 |
+
if tile > h or tile > w:
|
105 |
+
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
|
106 |
+
_img[:, :, :h, :w] = torch_img # pad image
|
107 |
+
torch_img = _img
|
108 |
+
|
109 |
+
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
|
110 |
+
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
|
111 |
+
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
|
112 |
+
del torch_img, torch_output
|
113 |
+
devices.torch_gc()
|
114 |
+
|
115 |
+
output = np_output.transpose((1, 2, 0)) # CHW to HWC
|
116 |
+
output = output[:, :, ::-1] # BGR to RGB
|
117 |
+
return PIL.Image.fromarray((output * 255).astype(np.uint8))
|
118 |
+
|
119 |
+
def load_model(self, path: str):
|
120 |
+
device = devices.get_device_for('scunet')
|
121 |
+
if path.startswith("http"):
|
122 |
+
# TODO: this doesn't use `path` at all?
|
123 |
+
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
124 |
+
else:
|
125 |
+
filename = path
|
126 |
+
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
127 |
+
model.load_state_dict(torch.load(filename), strict=True)
|
128 |
+
model.eval()
|
129 |
+
for _, v in model.named_parameters():
|
130 |
+
v.requires_grad = False
|
131 |
+
model = model.to(device)
|
132 |
+
|
133 |
+
return model
|
134 |
+
|
135 |
+
|
136 |
+
def on_ui_settings():
|
137 |
+
import gradio as gr
|
138 |
+
from modules import shared
|
139 |
+
|
140 |
+
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
|
141 |
+
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
|
142 |
+
|
143 |
+
|
144 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
extensions-builtin/ScuNET/scunet_model_arch.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from einops.layers.torch import Rearrange
|
7 |
+
from timm.models.layers import trunc_normal_, DropPath
|
8 |
+
|
9 |
+
|
10 |
+
class WMSA(nn.Module):
|
11 |
+
""" Self-attention module in Swin Transformer
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
15 |
+
super(WMSA, self).__init__()
|
16 |
+
self.input_dim = input_dim
|
17 |
+
self.output_dim = output_dim
|
18 |
+
self.head_dim = head_dim
|
19 |
+
self.scale = self.head_dim ** -0.5
|
20 |
+
self.n_heads = input_dim // head_dim
|
21 |
+
self.window_size = window_size
|
22 |
+
self.type = type
|
23 |
+
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
24 |
+
|
25 |
+
self.relative_position_params = nn.Parameter(
|
26 |
+
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
|
27 |
+
|
28 |
+
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
29 |
+
|
30 |
+
trunc_normal_(self.relative_position_params, std=.02)
|
31 |
+
self.relative_position_params = torch.nn.Parameter(
|
32 |
+
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
|
33 |
+
2).transpose(
|
34 |
+
0, 1))
|
35 |
+
|
36 |
+
def generate_mask(self, h, w, p, shift):
|
37 |
+
""" generating the mask of SW-MSA
|
38 |
+
Args:
|
39 |
+
shift: shift parameters in CyclicShift.
|
40 |
+
Returns:
|
41 |
+
attn_mask: should be (1 1 w p p),
|
42 |
+
"""
|
43 |
+
# supporting square.
|
44 |
+
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
45 |
+
if self.type == 'W':
|
46 |
+
return attn_mask
|
47 |
+
|
48 |
+
s = p - shift
|
49 |
+
attn_mask[-1, :, :s, :, s:, :] = True
|
50 |
+
attn_mask[-1, :, s:, :, :s, :] = True
|
51 |
+
attn_mask[:, -1, :, :s, :, s:] = True
|
52 |
+
attn_mask[:, -1, :, s:, :, :s] = True
|
53 |
+
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
54 |
+
return attn_mask
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
""" Forward pass of Window Multi-head Self-attention module.
|
58 |
+
Args:
|
59 |
+
x: input tensor with shape of [b h w c];
|
60 |
+
attn_mask: attention mask, fill -inf where the value is True;
|
61 |
+
Returns:
|
62 |
+
output: tensor shape [b h w c]
|
63 |
+
"""
|
64 |
+
if self.type != 'W':
|
65 |
+
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
66 |
+
|
67 |
+
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
68 |
+
h_windows = x.size(1)
|
69 |
+
w_windows = x.size(2)
|
70 |
+
# square validation
|
71 |
+
# assert h_windows == w_windows
|
72 |
+
|
73 |
+
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
74 |
+
qkv = self.embedding_layer(x)
|
75 |
+
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
76 |
+
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
77 |
+
# Adding learnable relative embedding
|
78 |
+
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
79 |
+
# Using Attn Mask to distinguish different subwindows.
|
80 |
+
if self.type != 'W':
|
81 |
+
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
|
82 |
+
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
83 |
+
|
84 |
+
probs = nn.functional.softmax(sim, dim=-1)
|
85 |
+
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
86 |
+
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
87 |
+
output = self.linear(output)
|
88 |
+
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
89 |
+
|
90 |
+
if self.type != 'W':
|
91 |
+
output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
|
92 |
+
|
93 |
+
return output
|
94 |
+
|
95 |
+
def relative_embedding(self):
|
96 |
+
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
97 |
+
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
98 |
+
# negative is allowed
|
99 |
+
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
|
100 |
+
|
101 |
+
|
102 |
+
class Block(nn.Module):
|
103 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
104 |
+
""" SwinTransformer Block
|
105 |
+
"""
|
106 |
+
super(Block, self).__init__()
|
107 |
+
self.input_dim = input_dim
|
108 |
+
self.output_dim = output_dim
|
109 |
+
assert type in ['W', 'SW']
|
110 |
+
self.type = type
|
111 |
+
if input_resolution <= window_size:
|
112 |
+
self.type = 'W'
|
113 |
+
|
114 |
+
self.ln1 = nn.LayerNorm(input_dim)
|
115 |
+
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
116 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
117 |
+
self.ln2 = nn.LayerNorm(input_dim)
|
118 |
+
self.mlp = nn.Sequential(
|
119 |
+
nn.Linear(input_dim, 4 * input_dim),
|
120 |
+
nn.GELU(),
|
121 |
+
nn.Linear(4 * input_dim, output_dim),
|
122 |
+
)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
x = x + self.drop_path(self.msa(self.ln1(x)))
|
126 |
+
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
class ConvTransBlock(nn.Module):
|
131 |
+
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
132 |
+
""" SwinTransformer and Conv Block
|
133 |
+
"""
|
134 |
+
super(ConvTransBlock, self).__init__()
|
135 |
+
self.conv_dim = conv_dim
|
136 |
+
self.trans_dim = trans_dim
|
137 |
+
self.head_dim = head_dim
|
138 |
+
self.window_size = window_size
|
139 |
+
self.drop_path = drop_path
|
140 |
+
self.type = type
|
141 |
+
self.input_resolution = input_resolution
|
142 |
+
|
143 |
+
assert self.type in ['W', 'SW']
|
144 |
+
if self.input_resolution <= self.window_size:
|
145 |
+
self.type = 'W'
|
146 |
+
|
147 |
+
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
|
148 |
+
self.type, self.input_resolution)
|
149 |
+
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
150 |
+
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
151 |
+
|
152 |
+
self.conv_block = nn.Sequential(
|
153 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
154 |
+
nn.ReLU(True),
|
155 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
156 |
+
)
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
160 |
+
conv_x = self.conv_block(conv_x) + conv_x
|
161 |
+
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
162 |
+
trans_x = self.trans_block(trans_x)
|
163 |
+
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
164 |
+
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
165 |
+
x = x + res
|
166 |
+
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class SCUNet(nn.Module):
|
171 |
+
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
172 |
+
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
|
173 |
+
super(SCUNet, self).__init__()
|
174 |
+
if config is None:
|
175 |
+
config = [2, 2, 2, 2, 2, 2, 2]
|
176 |
+
self.config = config
|
177 |
+
self.dim = dim
|
178 |
+
self.head_dim = 32
|
179 |
+
self.window_size = 8
|
180 |
+
|
181 |
+
# drop path rate for each layer
|
182 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
183 |
+
|
184 |
+
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
185 |
+
|
186 |
+
begin = 0
|
187 |
+
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
188 |
+
'W' if not i % 2 else 'SW', input_resolution)
|
189 |
+
for i in range(config[0])] + \
|
190 |
+
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
191 |
+
|
192 |
+
begin += config[0]
|
193 |
+
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
194 |
+
'W' if not i % 2 else 'SW', input_resolution // 2)
|
195 |
+
for i in range(config[1])] + \
|
196 |
+
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
197 |
+
|
198 |
+
begin += config[1]
|
199 |
+
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
200 |
+
'W' if not i % 2 else 'SW', input_resolution // 4)
|
201 |
+
for i in range(config[2])] + \
|
202 |
+
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
203 |
+
|
204 |
+
begin += config[2]
|
205 |
+
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
206 |
+
'W' if not i % 2 else 'SW', input_resolution // 8)
|
207 |
+
for i in range(config[3])]
|
208 |
+
|
209 |
+
begin += config[3]
|
210 |
+
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
|
211 |
+
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
212 |
+
'W' if not i % 2 else 'SW', input_resolution // 4)
|
213 |
+
for i in range(config[4])]
|
214 |
+
|
215 |
+
begin += config[4]
|
216 |
+
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
|
217 |
+
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
218 |
+
'W' if not i % 2 else 'SW', input_resolution // 2)
|
219 |
+
for i in range(config[5])]
|
220 |
+
|
221 |
+
begin += config[5]
|
222 |
+
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
|
223 |
+
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
224 |
+
'W' if not i % 2 else 'SW', input_resolution)
|
225 |
+
for i in range(config[6])]
|
226 |
+
|
227 |
+
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
228 |
+
|
229 |
+
self.m_head = nn.Sequential(*self.m_head)
|
230 |
+
self.m_down1 = nn.Sequential(*self.m_down1)
|
231 |
+
self.m_down2 = nn.Sequential(*self.m_down2)
|
232 |
+
self.m_down3 = nn.Sequential(*self.m_down3)
|
233 |
+
self.m_body = nn.Sequential(*self.m_body)
|
234 |
+
self.m_up3 = nn.Sequential(*self.m_up3)
|
235 |
+
self.m_up2 = nn.Sequential(*self.m_up2)
|
236 |
+
self.m_up1 = nn.Sequential(*self.m_up1)
|
237 |
+
self.m_tail = nn.Sequential(*self.m_tail)
|
238 |
+
# self.apply(self._init_weights)
|
239 |
+
|
240 |
+
def forward(self, x0):
|
241 |
+
|
242 |
+
h, w = x0.size()[-2:]
|
243 |
+
paddingBottom = int(np.ceil(h / 64) * 64 - h)
|
244 |
+
paddingRight = int(np.ceil(w / 64) * 64 - w)
|
245 |
+
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
246 |
+
|
247 |
+
x1 = self.m_head(x0)
|
248 |
+
x2 = self.m_down1(x1)
|
249 |
+
x3 = self.m_down2(x2)
|
250 |
+
x4 = self.m_down3(x3)
|
251 |
+
x = self.m_body(x4)
|
252 |
+
x = self.m_up3(x + x4)
|
253 |
+
x = self.m_up2(x + x3)
|
254 |
+
x = self.m_up1(x + x2)
|
255 |
+
x = self.m_tail(x + x1)
|
256 |
+
|
257 |
+
x = x[..., :h, :w]
|
258 |
+
|
259 |
+
return x
|
260 |
+
|
261 |
+
def _init_weights(self, m):
|
262 |
+
if isinstance(m, nn.Linear):
|
263 |
+
trunc_normal_(m.weight, std=.02)
|
264 |
+
if m.bias is not None:
|
265 |
+
nn.init.constant_(m.bias, 0)
|
266 |
+
elif isinstance(m, nn.LayerNorm):
|
267 |
+
nn.init.constant_(m.bias, 0)
|
268 |
+
nn.init.constant_(m.weight, 1.0)
|
extensions-builtin/SwinIR/__pycache__/preload.cpython-310.pyc
ADDED
Binary file (491 Bytes). View file
|
|
extensions-builtin/SwinIR/__pycache__/swinir_model_arch.cpython-310.pyc
ADDED
Binary file (27.8 kB). View file
|
|
extensions-builtin/SwinIR/__pycache__/swinir_model_arch_v2.cpython-310.pyc
ADDED
Binary file (31.3 kB). View file
|
|
extensions-builtin/SwinIR/preload.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
|
extensions-builtin/SwinIR/scripts/__pycache__/swinir_model.cpython-310.pyc
ADDED
Binary file (6.07 kB). View file
|
|
extensions-builtin/SwinIR/scripts/swinir_model.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import platform
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from modules import modelloader, devices, script_callbacks, shared
|
10 |
+
from modules.shared import opts, state
|
11 |
+
from swinir_model_arch import SwinIR
|
12 |
+
from swinir_model_arch_v2 import Swin2SR
|
13 |
+
from modules.upscaler import Upscaler, UpscalerData
|
14 |
+
|
15 |
+
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
16 |
+
|
17 |
+
device_swinir = devices.get_device_for('swinir')
|
18 |
+
|
19 |
+
|
20 |
+
class UpscalerSwinIR(Upscaler):
|
21 |
+
def __init__(self, dirname):
|
22 |
+
self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
|
23 |
+
self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
|
24 |
+
self.name = "SwinIR"
|
25 |
+
self.model_url = SWINIR_MODEL_URL
|
26 |
+
self.model_name = "SwinIR 4x"
|
27 |
+
self.user_path = dirname
|
28 |
+
super().__init__()
|
29 |
+
scalers = []
|
30 |
+
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
31 |
+
for model in model_files:
|
32 |
+
if model.startswith("http"):
|
33 |
+
name = self.model_name
|
34 |
+
else:
|
35 |
+
name = modelloader.friendly_name(model)
|
36 |
+
model_data = UpscalerData(name, model, self)
|
37 |
+
scalers.append(model_data)
|
38 |
+
self.scalers = scalers
|
39 |
+
|
40 |
+
def do_upscale(self, img, model_file):
|
41 |
+
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
|
42 |
+
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
|
43 |
+
current_config = (model_file, opts.SWIN_tile)
|
44 |
+
|
45 |
+
if use_compile and self._cached_model_config == current_config:
|
46 |
+
model = self._cached_model
|
47 |
+
else:
|
48 |
+
self._cached_model = None
|
49 |
+
try:
|
50 |
+
model = self.load_model(model_file)
|
51 |
+
except Exception as e:
|
52 |
+
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
53 |
+
return img
|
54 |
+
model = model.to(device_swinir, dtype=devices.dtype)
|
55 |
+
if use_compile:
|
56 |
+
model = torch.compile(model)
|
57 |
+
self._cached_model = model
|
58 |
+
self._cached_model_config = current_config
|
59 |
+
img = upscale(img, model)
|
60 |
+
devices.torch_gc()
|
61 |
+
return img
|
62 |
+
|
63 |
+
def load_model(self, path, scale=4):
|
64 |
+
if path.startswith("http"):
|
65 |
+
filename = modelloader.load_file_from_url(
|
66 |
+
url=path,
|
67 |
+
model_dir=self.model_download_path,
|
68 |
+
file_name=f"{self.model_name.replace(' ', '_')}.pth",
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
filename = path
|
72 |
+
if filename.endswith(".v2.pth"):
|
73 |
+
model = Swin2SR(
|
74 |
+
upscale=scale,
|
75 |
+
in_chans=3,
|
76 |
+
img_size=64,
|
77 |
+
window_size=8,
|
78 |
+
img_range=1.0,
|
79 |
+
depths=[6, 6, 6, 6, 6, 6],
|
80 |
+
embed_dim=180,
|
81 |
+
num_heads=[6, 6, 6, 6, 6, 6],
|
82 |
+
mlp_ratio=2,
|
83 |
+
upsampler="nearest+conv",
|
84 |
+
resi_connection="1conv",
|
85 |
+
)
|
86 |
+
params = None
|
87 |
+
else:
|
88 |
+
model = SwinIR(
|
89 |
+
upscale=scale,
|
90 |
+
in_chans=3,
|
91 |
+
img_size=64,
|
92 |
+
window_size=8,
|
93 |
+
img_range=1.0,
|
94 |
+
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
95 |
+
embed_dim=240,
|
96 |
+
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
97 |
+
mlp_ratio=2,
|
98 |
+
upsampler="nearest+conv",
|
99 |
+
resi_connection="3conv",
|
100 |
+
)
|
101 |
+
params = "params_ema"
|
102 |
+
|
103 |
+
pretrained_model = torch.load(filename)
|
104 |
+
if params is not None:
|
105 |
+
model.load_state_dict(pretrained_model[params], strict=True)
|
106 |
+
else:
|
107 |
+
model.load_state_dict(pretrained_model, strict=True)
|
108 |
+
return model
|
109 |
+
|
110 |
+
|
111 |
+
def upscale(
|
112 |
+
img,
|
113 |
+
model,
|
114 |
+
tile=None,
|
115 |
+
tile_overlap=None,
|
116 |
+
window_size=8,
|
117 |
+
scale=4,
|
118 |
+
):
|
119 |
+
tile = tile or opts.SWIN_tile
|
120 |
+
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
|
121 |
+
|
122 |
+
|
123 |
+
img = np.array(img)
|
124 |
+
img = img[:, :, ::-1]
|
125 |
+
img = np.moveaxis(img, 2, 0) / 255
|
126 |
+
img = torch.from_numpy(img).float()
|
127 |
+
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
|
128 |
+
with torch.no_grad(), devices.autocast():
|
129 |
+
_, _, h_old, w_old = img.size()
|
130 |
+
h_pad = (h_old // window_size + 1) * window_size - h_old
|
131 |
+
w_pad = (w_old // window_size + 1) * window_size - w_old
|
132 |
+
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
133 |
+
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
134 |
+
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
135 |
+
output = output[..., : h_old * scale, : w_old * scale]
|
136 |
+
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
137 |
+
if output.ndim == 3:
|
138 |
+
output = np.transpose(
|
139 |
+
output[[2, 1, 0], :, :], (1, 2, 0)
|
140 |
+
) # CHW-RGB to HCW-BGR
|
141 |
+
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
142 |
+
return Image.fromarray(output, "RGB")
|
143 |
+
|
144 |
+
|
145 |
+
def inference(img, model, tile, tile_overlap, window_size, scale):
|
146 |
+
# test the image tile by tile
|
147 |
+
b, c, h, w = img.size()
|
148 |
+
tile = min(tile, h, w)
|
149 |
+
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
150 |
+
sf = scale
|
151 |
+
|
152 |
+
stride = tile - tile_overlap
|
153 |
+
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
154 |
+
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
155 |
+
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
|
156 |
+
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
|
157 |
+
|
158 |
+
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
159 |
+
for h_idx in h_idx_list:
|
160 |
+
if state.interrupted or state.skipped:
|
161 |
+
break
|
162 |
+
|
163 |
+
for w_idx in w_idx_list:
|
164 |
+
if state.interrupted or state.skipped:
|
165 |
+
break
|
166 |
+
|
167 |
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
168 |
+
out_patch = model(in_patch)
|
169 |
+
out_patch_mask = torch.ones_like(out_patch)
|
170 |
+
|
171 |
+
E[
|
172 |
+
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
173 |
+
].add_(out_patch)
|
174 |
+
W[
|
175 |
+
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
176 |
+
].add_(out_patch_mask)
|
177 |
+
pbar.update(1)
|
178 |
+
output = E.div_(W)
|
179 |
+
|
180 |
+
return output
|
181 |
+
|
182 |
+
|
183 |
+
def on_ui_settings():
|
184 |
+
import gradio as gr
|
185 |
+
|
186 |
+
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
187 |
+
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
188 |
+
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
|
189 |
+
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
|
190 |
+
|
191 |
+
|
192 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
extensions-builtin/SwinIR/swinir_model_arch.py
ADDED
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint as checkpoint
|
11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
12 |
+
|
13 |
+
|
14 |
+
class Mlp(nn.Module):
|
15 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
16 |
+
super().__init__()
|
17 |
+
out_features = out_features or in_features
|
18 |
+
hidden_features = hidden_features or in_features
|
19 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
20 |
+
self.act = act_layer()
|
21 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
22 |
+
self.drop = nn.Dropout(drop)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.fc1(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.drop(x)
|
28 |
+
x = self.fc2(x)
|
29 |
+
x = self.drop(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
def window_partition(x, window_size):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
x: (B, H, W, C)
|
37 |
+
window_size (int): window size
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
windows: (num_windows*B, window_size, window_size, C)
|
41 |
+
"""
|
42 |
+
B, H, W, C = x.shape
|
43 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
44 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
45 |
+
return windows
|
46 |
+
|
47 |
+
|
48 |
+
def window_reverse(windows, window_size, H, W):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
window_size (int): Window size
|
53 |
+
H (int): Height of image
|
54 |
+
W (int): Width of image
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
x: (B, H, W, C)
|
58 |
+
"""
|
59 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
60 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
61 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class WindowAttention(nn.Module):
|
66 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
67 |
+
It supports both of shifted and non-shifted window.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
dim (int): Number of input channels.
|
71 |
+
window_size (tuple[int]): The height and width of the window.
|
72 |
+
num_heads (int): Number of attention heads.
|
73 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
74 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
75 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
76 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
80 |
+
|
81 |
+
super().__init__()
|
82 |
+
self.dim = dim
|
83 |
+
self.window_size = window_size # Wh, Ww
|
84 |
+
self.num_heads = num_heads
|
85 |
+
head_dim = dim // num_heads
|
86 |
+
self.scale = qk_scale or head_dim ** -0.5
|
87 |
+
|
88 |
+
# define a parameter table of relative position bias
|
89 |
+
self.relative_position_bias_table = nn.Parameter(
|
90 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
91 |
+
|
92 |
+
# get pair-wise relative position index for each token inside the window
|
93 |
+
coords_h = torch.arange(self.window_size[0])
|
94 |
+
coords_w = torch.arange(self.window_size[1])
|
95 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
96 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
97 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
98 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
99 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
100 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
101 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
102 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
103 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
104 |
+
|
105 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
107 |
+
self.proj = nn.Linear(dim, dim)
|
108 |
+
|
109 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
110 |
+
|
111 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
112 |
+
self.softmax = nn.Softmax(dim=-1)
|
113 |
+
|
114 |
+
def forward(self, x, mask=None):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
x: input features with shape of (num_windows*B, N, C)
|
118 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
119 |
+
"""
|
120 |
+
B_, N, C = x.shape
|
121 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
122 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
123 |
+
|
124 |
+
q = q * self.scale
|
125 |
+
attn = (q @ k.transpose(-2, -1))
|
126 |
+
|
127 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
128 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
129 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
130 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
131 |
+
|
132 |
+
if mask is not None:
|
133 |
+
nW = mask.shape[0]
|
134 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
135 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
136 |
+
attn = self.softmax(attn)
|
137 |
+
else:
|
138 |
+
attn = self.softmax(attn)
|
139 |
+
|
140 |
+
attn = self.attn_drop(attn)
|
141 |
+
|
142 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
143 |
+
x = self.proj(x)
|
144 |
+
x = self.proj_drop(x)
|
145 |
+
return x
|
146 |
+
|
147 |
+
def extra_repr(self) -> str:
|
148 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
149 |
+
|
150 |
+
def flops(self, N):
|
151 |
+
# calculate flops for 1 window with token length of N
|
152 |
+
flops = 0
|
153 |
+
# qkv = self.qkv(x)
|
154 |
+
flops += N * self.dim * 3 * self.dim
|
155 |
+
# attn = (q @ k.transpose(-2, -1))
|
156 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
157 |
+
# x = (attn @ v)
|
158 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
159 |
+
# x = self.proj(x)
|
160 |
+
flops += N * self.dim * self.dim
|
161 |
+
return flops
|
162 |
+
|
163 |
+
|
164 |
+
class SwinTransformerBlock(nn.Module):
|
165 |
+
r""" Swin Transformer Block.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
dim (int): Number of input channels.
|
169 |
+
input_resolution (tuple[int]): Input resolution.
|
170 |
+
num_heads (int): Number of attention heads.
|
171 |
+
window_size (int): Window size.
|
172 |
+
shift_size (int): Shift size for SW-MSA.
|
173 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
174 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
175 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
176 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
177 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
178 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
179 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
180 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
184 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
185 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
186 |
+
super().__init__()
|
187 |
+
self.dim = dim
|
188 |
+
self.input_resolution = input_resolution
|
189 |
+
self.num_heads = num_heads
|
190 |
+
self.window_size = window_size
|
191 |
+
self.shift_size = shift_size
|
192 |
+
self.mlp_ratio = mlp_ratio
|
193 |
+
if min(self.input_resolution) <= self.window_size:
|
194 |
+
# if window size is larger than input resolution, we don't partition windows
|
195 |
+
self.shift_size = 0
|
196 |
+
self.window_size = min(self.input_resolution)
|
197 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
198 |
+
|
199 |
+
self.norm1 = norm_layer(dim)
|
200 |
+
self.attn = WindowAttention(
|
201 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
202 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
203 |
+
|
204 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
205 |
+
self.norm2 = norm_layer(dim)
|
206 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
207 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
208 |
+
|
209 |
+
if self.shift_size > 0:
|
210 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
211 |
+
else:
|
212 |
+
attn_mask = None
|
213 |
+
|
214 |
+
self.register_buffer("attn_mask", attn_mask)
|
215 |
+
|
216 |
+
def calculate_mask(self, x_size):
|
217 |
+
# calculate attention mask for SW-MSA
|
218 |
+
H, W = x_size
|
219 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
220 |
+
h_slices = (slice(0, -self.window_size),
|
221 |
+
slice(-self.window_size, -self.shift_size),
|
222 |
+
slice(-self.shift_size, None))
|
223 |
+
w_slices = (slice(0, -self.window_size),
|
224 |
+
slice(-self.window_size, -self.shift_size),
|
225 |
+
slice(-self.shift_size, None))
|
226 |
+
cnt = 0
|
227 |
+
for h in h_slices:
|
228 |
+
for w in w_slices:
|
229 |
+
img_mask[:, h, w, :] = cnt
|
230 |
+
cnt += 1
|
231 |
+
|
232 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
233 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
234 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
235 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
236 |
+
|
237 |
+
return attn_mask
|
238 |
+
|
239 |
+
def forward(self, x, x_size):
|
240 |
+
H, W = x_size
|
241 |
+
B, L, C = x.shape
|
242 |
+
# assert L == H * W, "input feature has wrong size"
|
243 |
+
|
244 |
+
shortcut = x
|
245 |
+
x = self.norm1(x)
|
246 |
+
x = x.view(B, H, W, C)
|
247 |
+
|
248 |
+
# cyclic shift
|
249 |
+
if self.shift_size > 0:
|
250 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
251 |
+
else:
|
252 |
+
shifted_x = x
|
253 |
+
|
254 |
+
# partition windows
|
255 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
256 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
257 |
+
|
258 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
259 |
+
if self.input_resolution == x_size:
|
260 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
261 |
+
else:
|
262 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
263 |
+
|
264 |
+
# merge windows
|
265 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
266 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
267 |
+
|
268 |
+
# reverse cyclic shift
|
269 |
+
if self.shift_size > 0:
|
270 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
271 |
+
else:
|
272 |
+
x = shifted_x
|
273 |
+
x = x.view(B, H * W, C)
|
274 |
+
|
275 |
+
# FFN
|
276 |
+
x = shortcut + self.drop_path(x)
|
277 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
278 |
+
|
279 |
+
return x
|
280 |
+
|
281 |
+
def extra_repr(self) -> str:
|
282 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
283 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
284 |
+
|
285 |
+
def flops(self):
|
286 |
+
flops = 0
|
287 |
+
H, W = self.input_resolution
|
288 |
+
# norm1
|
289 |
+
flops += self.dim * H * W
|
290 |
+
# W-MSA/SW-MSA
|
291 |
+
nW = H * W / self.window_size / self.window_size
|
292 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
293 |
+
# mlp
|
294 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
295 |
+
# norm2
|
296 |
+
flops += self.dim * H * W
|
297 |
+
return flops
|
298 |
+
|
299 |
+
|
300 |
+
class PatchMerging(nn.Module):
|
301 |
+
r""" Patch Merging Layer.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
305 |
+
dim (int): Number of input channels.
|
306 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
310 |
+
super().__init__()
|
311 |
+
self.input_resolution = input_resolution
|
312 |
+
self.dim = dim
|
313 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
314 |
+
self.norm = norm_layer(4 * dim)
|
315 |
+
|
316 |
+
def forward(self, x):
|
317 |
+
"""
|
318 |
+
x: B, H*W, C
|
319 |
+
"""
|
320 |
+
H, W = self.input_resolution
|
321 |
+
B, L, C = x.shape
|
322 |
+
assert L == H * W, "input feature has wrong size"
|
323 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
324 |
+
|
325 |
+
x = x.view(B, H, W, C)
|
326 |
+
|
327 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
328 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
329 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
330 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
331 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
332 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
333 |
+
|
334 |
+
x = self.norm(x)
|
335 |
+
x = self.reduction(x)
|
336 |
+
|
337 |
+
return x
|
338 |
+
|
339 |
+
def extra_repr(self) -> str:
|
340 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
341 |
+
|
342 |
+
def flops(self):
|
343 |
+
H, W = self.input_resolution
|
344 |
+
flops = H * W * self.dim
|
345 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
346 |
+
return flops
|
347 |
+
|
348 |
+
|
349 |
+
class BasicLayer(nn.Module):
|
350 |
+
""" A basic Swin Transformer layer for one stage.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
dim (int): Number of input channels.
|
354 |
+
input_resolution (tuple[int]): Input resolution.
|
355 |
+
depth (int): Number of blocks.
|
356 |
+
num_heads (int): Number of attention heads.
|
357 |
+
window_size (int): Local window size.
|
358 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
359 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
360 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
361 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
362 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
363 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
364 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
365 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
366 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
370 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
371 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
372 |
+
|
373 |
+
super().__init__()
|
374 |
+
self.dim = dim
|
375 |
+
self.input_resolution = input_resolution
|
376 |
+
self.depth = depth
|
377 |
+
self.use_checkpoint = use_checkpoint
|
378 |
+
|
379 |
+
# build blocks
|
380 |
+
self.blocks = nn.ModuleList([
|
381 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
382 |
+
num_heads=num_heads, window_size=window_size,
|
383 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
384 |
+
mlp_ratio=mlp_ratio,
|
385 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
386 |
+
drop=drop, attn_drop=attn_drop,
|
387 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
388 |
+
norm_layer=norm_layer)
|
389 |
+
for i in range(depth)])
|
390 |
+
|
391 |
+
# patch merging layer
|
392 |
+
if downsample is not None:
|
393 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
394 |
+
else:
|
395 |
+
self.downsample = None
|
396 |
+
|
397 |
+
def forward(self, x, x_size):
|
398 |
+
for blk in self.blocks:
|
399 |
+
if self.use_checkpoint:
|
400 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
401 |
+
else:
|
402 |
+
x = blk(x, x_size)
|
403 |
+
if self.downsample is not None:
|
404 |
+
x = self.downsample(x)
|
405 |
+
return x
|
406 |
+
|
407 |
+
def extra_repr(self) -> str:
|
408 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
409 |
+
|
410 |
+
def flops(self):
|
411 |
+
flops = 0
|
412 |
+
for blk in self.blocks:
|
413 |
+
flops += blk.flops()
|
414 |
+
if self.downsample is not None:
|
415 |
+
flops += self.downsample.flops()
|
416 |
+
return flops
|
417 |
+
|
418 |
+
|
419 |
+
class RSTB(nn.Module):
|
420 |
+
"""Residual Swin Transformer Block (RSTB).
|
421 |
+
|
422 |
+
Args:
|
423 |
+
dim (int): Number of input channels.
|
424 |
+
input_resolution (tuple[int]): Input resolution.
|
425 |
+
depth (int): Number of blocks.
|
426 |
+
num_heads (int): Number of attention heads.
|
427 |
+
window_size (int): Local window size.
|
428 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
429 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
430 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
431 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
432 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
433 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
434 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
435 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
436 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
437 |
+
img_size: Input image size.
|
438 |
+
patch_size: Patch size.
|
439 |
+
resi_connection: The convolutional block before residual connection.
|
440 |
+
"""
|
441 |
+
|
442 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
443 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
444 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
445 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
446 |
+
super(RSTB, self).__init__()
|
447 |
+
|
448 |
+
self.dim = dim
|
449 |
+
self.input_resolution = input_resolution
|
450 |
+
|
451 |
+
self.residual_group = BasicLayer(dim=dim,
|
452 |
+
input_resolution=input_resolution,
|
453 |
+
depth=depth,
|
454 |
+
num_heads=num_heads,
|
455 |
+
window_size=window_size,
|
456 |
+
mlp_ratio=mlp_ratio,
|
457 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
458 |
+
drop=drop, attn_drop=attn_drop,
|
459 |
+
drop_path=drop_path,
|
460 |
+
norm_layer=norm_layer,
|
461 |
+
downsample=downsample,
|
462 |
+
use_checkpoint=use_checkpoint)
|
463 |
+
|
464 |
+
if resi_connection == '1conv':
|
465 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
466 |
+
elif resi_connection == '3conv':
|
467 |
+
# to save parameters and memory
|
468 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
469 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
470 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
471 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
472 |
+
|
473 |
+
self.patch_embed = PatchEmbed(
|
474 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
475 |
+
norm_layer=None)
|
476 |
+
|
477 |
+
self.patch_unembed = PatchUnEmbed(
|
478 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
479 |
+
norm_layer=None)
|
480 |
+
|
481 |
+
def forward(self, x, x_size):
|
482 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
483 |
+
|
484 |
+
def flops(self):
|
485 |
+
flops = 0
|
486 |
+
flops += self.residual_group.flops()
|
487 |
+
H, W = self.input_resolution
|
488 |
+
flops += H * W * self.dim * self.dim * 9
|
489 |
+
flops += self.patch_embed.flops()
|
490 |
+
flops += self.patch_unembed.flops()
|
491 |
+
|
492 |
+
return flops
|
493 |
+
|
494 |
+
|
495 |
+
class PatchEmbed(nn.Module):
|
496 |
+
r""" Image to Patch Embedding
|
497 |
+
|
498 |
+
Args:
|
499 |
+
img_size (int): Image size. Default: 224.
|
500 |
+
patch_size (int): Patch token size. Default: 4.
|
501 |
+
in_chans (int): Number of input image channels. Default: 3.
|
502 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
503 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
504 |
+
"""
|
505 |
+
|
506 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
507 |
+
super().__init__()
|
508 |
+
img_size = to_2tuple(img_size)
|
509 |
+
patch_size = to_2tuple(patch_size)
|
510 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
511 |
+
self.img_size = img_size
|
512 |
+
self.patch_size = patch_size
|
513 |
+
self.patches_resolution = patches_resolution
|
514 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
515 |
+
|
516 |
+
self.in_chans = in_chans
|
517 |
+
self.embed_dim = embed_dim
|
518 |
+
|
519 |
+
if norm_layer is not None:
|
520 |
+
self.norm = norm_layer(embed_dim)
|
521 |
+
else:
|
522 |
+
self.norm = None
|
523 |
+
|
524 |
+
def forward(self, x):
|
525 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
526 |
+
if self.norm is not None:
|
527 |
+
x = self.norm(x)
|
528 |
+
return x
|
529 |
+
|
530 |
+
def flops(self):
|
531 |
+
flops = 0
|
532 |
+
H, W = self.img_size
|
533 |
+
if self.norm is not None:
|
534 |
+
flops += H * W * self.embed_dim
|
535 |
+
return flops
|
536 |
+
|
537 |
+
|
538 |
+
class PatchUnEmbed(nn.Module):
|
539 |
+
r""" Image to Patch Unembedding
|
540 |
+
|
541 |
+
Args:
|
542 |
+
img_size (int): Image size. Default: 224.
|
543 |
+
patch_size (int): Patch token size. Default: 4.
|
544 |
+
in_chans (int): Number of input image channels. Default: 3.
|
545 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
546 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
547 |
+
"""
|
548 |
+
|
549 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
550 |
+
super().__init__()
|
551 |
+
img_size = to_2tuple(img_size)
|
552 |
+
patch_size = to_2tuple(patch_size)
|
553 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
554 |
+
self.img_size = img_size
|
555 |
+
self.patch_size = patch_size
|
556 |
+
self.patches_resolution = patches_resolution
|
557 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
558 |
+
|
559 |
+
self.in_chans = in_chans
|
560 |
+
self.embed_dim = embed_dim
|
561 |
+
|
562 |
+
def forward(self, x, x_size):
|
563 |
+
B, HW, C = x.shape
|
564 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
565 |
+
return x
|
566 |
+
|
567 |
+
def flops(self):
|
568 |
+
flops = 0
|
569 |
+
return flops
|
570 |
+
|
571 |
+
|
572 |
+
class Upsample(nn.Sequential):
|
573 |
+
"""Upsample module.
|
574 |
+
|
575 |
+
Args:
|
576 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
577 |
+
num_feat (int): Channel number of intermediate features.
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(self, scale, num_feat):
|
581 |
+
m = []
|
582 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
583 |
+
for _ in range(int(math.log(scale, 2))):
|
584 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
585 |
+
m.append(nn.PixelShuffle(2))
|
586 |
+
elif scale == 3:
|
587 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
588 |
+
m.append(nn.PixelShuffle(3))
|
589 |
+
else:
|
590 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
591 |
+
super(Upsample, self).__init__(*m)
|
592 |
+
|
593 |
+
|
594 |
+
class UpsampleOneStep(nn.Sequential):
|
595 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
596 |
+
Used in lightweight SR to save parameters.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
600 |
+
num_feat (int): Channel number of intermediate features.
|
601 |
+
|
602 |
+
"""
|
603 |
+
|
604 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
605 |
+
self.num_feat = num_feat
|
606 |
+
self.input_resolution = input_resolution
|
607 |
+
m = []
|
608 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
609 |
+
m.append(nn.PixelShuffle(scale))
|
610 |
+
super(UpsampleOneStep, self).__init__(*m)
|
611 |
+
|
612 |
+
def flops(self):
|
613 |
+
H, W = self.input_resolution
|
614 |
+
flops = H * W * self.num_feat * 3 * 9
|
615 |
+
return flops
|
616 |
+
|
617 |
+
|
618 |
+
class SwinIR(nn.Module):
|
619 |
+
r""" SwinIR
|
620 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
621 |
+
|
622 |
+
Args:
|
623 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
624 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
625 |
+
in_chans (int): Number of input image channels. Default: 3
|
626 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
627 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
628 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
629 |
+
window_size (int): Window size. Default: 7
|
630 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
631 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
632 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
633 |
+
drop_rate (float): Dropout rate. Default: 0
|
634 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
635 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
636 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
637 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
638 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
639 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
640 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
641 |
+
img_range: Image range. 1. or 255.
|
642 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
643 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
644 |
+
"""
|
645 |
+
|
646 |
+
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
647 |
+
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
648 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
649 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
650 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
651 |
+
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
652 |
+
**kwargs):
|
653 |
+
super(SwinIR, self).__init__()
|
654 |
+
num_in_ch = in_chans
|
655 |
+
num_out_ch = in_chans
|
656 |
+
num_feat = 64
|
657 |
+
self.img_range = img_range
|
658 |
+
if in_chans == 3:
|
659 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
660 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
661 |
+
else:
|
662 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
663 |
+
self.upscale = upscale
|
664 |
+
self.upsampler = upsampler
|
665 |
+
self.window_size = window_size
|
666 |
+
|
667 |
+
#####################################################################################################
|
668 |
+
################################### 1, shallow feature extraction ###################################
|
669 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
670 |
+
|
671 |
+
#####################################################################################################
|
672 |
+
################################### 2, deep feature extraction ######################################
|
673 |
+
self.num_layers = len(depths)
|
674 |
+
self.embed_dim = embed_dim
|
675 |
+
self.ape = ape
|
676 |
+
self.patch_norm = patch_norm
|
677 |
+
self.num_features = embed_dim
|
678 |
+
self.mlp_ratio = mlp_ratio
|
679 |
+
|
680 |
+
# split image into non-overlapping patches
|
681 |
+
self.patch_embed = PatchEmbed(
|
682 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
683 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
684 |
+
num_patches = self.patch_embed.num_patches
|
685 |
+
patches_resolution = self.patch_embed.patches_resolution
|
686 |
+
self.patches_resolution = patches_resolution
|
687 |
+
|
688 |
+
# merge non-overlapping patches into image
|
689 |
+
self.patch_unembed = PatchUnEmbed(
|
690 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
691 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
692 |
+
|
693 |
+
# absolute position embedding
|
694 |
+
if self.ape:
|
695 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
696 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
697 |
+
|
698 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
699 |
+
|
700 |
+
# stochastic depth
|
701 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
702 |
+
|
703 |
+
# build Residual Swin Transformer blocks (RSTB)
|
704 |
+
self.layers = nn.ModuleList()
|
705 |
+
for i_layer in range(self.num_layers):
|
706 |
+
layer = RSTB(dim=embed_dim,
|
707 |
+
input_resolution=(patches_resolution[0],
|
708 |
+
patches_resolution[1]),
|
709 |
+
depth=depths[i_layer],
|
710 |
+
num_heads=num_heads[i_layer],
|
711 |
+
window_size=window_size,
|
712 |
+
mlp_ratio=self.mlp_ratio,
|
713 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
714 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
715 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
716 |
+
norm_layer=norm_layer,
|
717 |
+
downsample=None,
|
718 |
+
use_checkpoint=use_checkpoint,
|
719 |
+
img_size=img_size,
|
720 |
+
patch_size=patch_size,
|
721 |
+
resi_connection=resi_connection
|
722 |
+
|
723 |
+
)
|
724 |
+
self.layers.append(layer)
|
725 |
+
self.norm = norm_layer(self.num_features)
|
726 |
+
|
727 |
+
# build the last conv layer in deep feature extraction
|
728 |
+
if resi_connection == '1conv':
|
729 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
730 |
+
elif resi_connection == '3conv':
|
731 |
+
# to save parameters and memory
|
732 |
+
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
733 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
734 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
735 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
736 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
737 |
+
|
738 |
+
#####################################################################################################
|
739 |
+
################################ 3, high quality image reconstruction ################################
|
740 |
+
if self.upsampler == 'pixelshuffle':
|
741 |
+
# for classical SR
|
742 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
743 |
+
nn.LeakyReLU(inplace=True))
|
744 |
+
self.upsample = Upsample(upscale, num_feat)
|
745 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
746 |
+
elif self.upsampler == 'pixelshuffledirect':
|
747 |
+
# for lightweight SR (to save parameters)
|
748 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
749 |
+
(patches_resolution[0], patches_resolution[1]))
|
750 |
+
elif self.upsampler == 'nearest+conv':
|
751 |
+
# for real-world SR (less artifacts)
|
752 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
753 |
+
nn.LeakyReLU(inplace=True))
|
754 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
755 |
+
if self.upscale == 4:
|
756 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
757 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
758 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
759 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
760 |
+
else:
|
761 |
+
# for image denoising and JPEG compression artifact reduction
|
762 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
763 |
+
|
764 |
+
self.apply(self._init_weights)
|
765 |
+
|
766 |
+
def _init_weights(self, m):
|
767 |
+
if isinstance(m, nn.Linear):
|
768 |
+
trunc_normal_(m.weight, std=.02)
|
769 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
770 |
+
nn.init.constant_(m.bias, 0)
|
771 |
+
elif isinstance(m, nn.LayerNorm):
|
772 |
+
nn.init.constant_(m.bias, 0)
|
773 |
+
nn.init.constant_(m.weight, 1.0)
|
774 |
+
|
775 |
+
@torch.jit.ignore
|
776 |
+
def no_weight_decay(self):
|
777 |
+
return {'absolute_pos_embed'}
|
778 |
+
|
779 |
+
@torch.jit.ignore
|
780 |
+
def no_weight_decay_keywords(self):
|
781 |
+
return {'relative_position_bias_table'}
|
782 |
+
|
783 |
+
def check_image_size(self, x):
|
784 |
+
_, _, h, w = x.size()
|
785 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
786 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
787 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
788 |
+
return x
|
789 |
+
|
790 |
+
def forward_features(self, x):
|
791 |
+
x_size = (x.shape[2], x.shape[3])
|
792 |
+
x = self.patch_embed(x)
|
793 |
+
if self.ape:
|
794 |
+
x = x + self.absolute_pos_embed
|
795 |
+
x = self.pos_drop(x)
|
796 |
+
|
797 |
+
for layer in self.layers:
|
798 |
+
x = layer(x, x_size)
|
799 |
+
|
800 |
+
x = self.norm(x) # B L C
|
801 |
+
x = self.patch_unembed(x, x_size)
|
802 |
+
|
803 |
+
return x
|
804 |
+
|
805 |
+
def forward(self, x):
|
806 |
+
H, W = x.shape[2:]
|
807 |
+
x = self.check_image_size(x)
|
808 |
+
|
809 |
+
self.mean = self.mean.type_as(x)
|
810 |
+
x = (x - self.mean) * self.img_range
|
811 |
+
|
812 |
+
if self.upsampler == 'pixelshuffle':
|
813 |
+
# for classical SR
|
814 |
+
x = self.conv_first(x)
|
815 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
816 |
+
x = self.conv_before_upsample(x)
|
817 |
+
x = self.conv_last(self.upsample(x))
|
818 |
+
elif self.upsampler == 'pixelshuffledirect':
|
819 |
+
# for lightweight SR
|
820 |
+
x = self.conv_first(x)
|
821 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
822 |
+
x = self.upsample(x)
|
823 |
+
elif self.upsampler == 'nearest+conv':
|
824 |
+
# for real-world SR
|
825 |
+
x = self.conv_first(x)
|
826 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
827 |
+
x = self.conv_before_upsample(x)
|
828 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
829 |
+
if self.upscale == 4:
|
830 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
831 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
832 |
+
else:
|
833 |
+
# for image denoising and JPEG compression artifact reduction
|
834 |
+
x_first = self.conv_first(x)
|
835 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
836 |
+
x = x + self.conv_last(res)
|
837 |
+
|
838 |
+
x = x / self.img_range + self.mean
|
839 |
+
|
840 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
841 |
+
|
842 |
+
def flops(self):
|
843 |
+
flops = 0
|
844 |
+
H, W = self.patches_resolution
|
845 |
+
flops += H * W * 3 * self.embed_dim * 9
|
846 |
+
flops += self.patch_embed.flops()
|
847 |
+
for layer in self.layers:
|
848 |
+
flops += layer.flops()
|
849 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
850 |
+
flops += self.upsample.flops()
|
851 |
+
return flops
|
852 |
+
|
853 |
+
|
854 |
+
if __name__ == '__main__':
|
855 |
+
upscale = 4
|
856 |
+
window_size = 8
|
857 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
858 |
+
width = (720 // upscale // window_size + 1) * window_size
|
859 |
+
model = SwinIR(upscale=2, img_size=(height, width),
|
860 |
+
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
861 |
+
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
862 |
+
print(model)
|
863 |
+
print(height, width, model.flops() / 1e9)
|
864 |
+
|
865 |
+
x = torch.randn((1, 3, height, width))
|
866 |
+
x = model(x)
|
867 |
+
print(x.shape)
|
extensions-builtin/SwinIR/swinir_model_arch_v2.py
ADDED
@@ -0,0 +1,1017 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
|
3 |
+
# Written by Conde and Choi et al.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint as checkpoint
|
12 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
13 |
+
|
14 |
+
|
15 |
+
class Mlp(nn.Module):
|
16 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
17 |
+
super().__init__()
|
18 |
+
out_features = out_features or in_features
|
19 |
+
hidden_features = hidden_features or in_features
|
20 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
21 |
+
self.act = act_layer()
|
22 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
23 |
+
self.drop = nn.Dropout(drop)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.fc1(x)
|
27 |
+
x = self.act(x)
|
28 |
+
x = self.drop(x)
|
29 |
+
x = self.fc2(x)
|
30 |
+
x = self.drop(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
def window_partition(x, window_size):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
x: (B, H, W, C)
|
38 |
+
window_size (int): window size
|
39 |
+
Returns:
|
40 |
+
windows: (num_windows*B, window_size, window_size, C)
|
41 |
+
"""
|
42 |
+
B, H, W, C = x.shape
|
43 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
44 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
45 |
+
return windows
|
46 |
+
|
47 |
+
|
48 |
+
def window_reverse(windows, window_size, H, W):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
window_size (int): Window size
|
53 |
+
H (int): Height of image
|
54 |
+
W (int): Width of image
|
55 |
+
Returns:
|
56 |
+
x: (B, H, W, C)
|
57 |
+
"""
|
58 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
59 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
60 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
61 |
+
return x
|
62 |
+
|
63 |
+
class WindowAttention(nn.Module):
|
64 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
65 |
+
It supports both of shifted and non-shifted window.
|
66 |
+
Args:
|
67 |
+
dim (int): Number of input channels.
|
68 |
+
window_size (tuple[int]): The height and width of the window.
|
69 |
+
num_heads (int): Number of attention heads.
|
70 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
71 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
72 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
73 |
+
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
77 |
+
pretrained_window_size=(0, 0)):
|
78 |
+
|
79 |
+
super().__init__()
|
80 |
+
self.dim = dim
|
81 |
+
self.window_size = window_size # Wh, Ww
|
82 |
+
self.pretrained_window_size = pretrained_window_size
|
83 |
+
self.num_heads = num_heads
|
84 |
+
|
85 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
86 |
+
|
87 |
+
# mlp to generate continuous relative position bias
|
88 |
+
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
|
89 |
+
nn.ReLU(inplace=True),
|
90 |
+
nn.Linear(512, num_heads, bias=False))
|
91 |
+
|
92 |
+
# get relative_coords_table
|
93 |
+
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
94 |
+
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
95 |
+
relative_coords_table = torch.stack(
|
96 |
+
torch.meshgrid([relative_coords_h,
|
97 |
+
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
98 |
+
if pretrained_window_size[0] > 0:
|
99 |
+
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
100 |
+
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
101 |
+
else:
|
102 |
+
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
103 |
+
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
104 |
+
relative_coords_table *= 8 # normalize to -8, 8
|
105 |
+
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
106 |
+
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
107 |
+
|
108 |
+
self.register_buffer("relative_coords_table", relative_coords_table)
|
109 |
+
|
110 |
+
# get pair-wise relative position index for each token inside the window
|
111 |
+
coords_h = torch.arange(self.window_size[0])
|
112 |
+
coords_w = torch.arange(self.window_size[1])
|
113 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
114 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
115 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
116 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
117 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
118 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
119 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
120 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
121 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
122 |
+
|
123 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
124 |
+
if qkv_bias:
|
125 |
+
self.q_bias = nn.Parameter(torch.zeros(dim))
|
126 |
+
self.v_bias = nn.Parameter(torch.zeros(dim))
|
127 |
+
else:
|
128 |
+
self.q_bias = None
|
129 |
+
self.v_bias = None
|
130 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
131 |
+
self.proj = nn.Linear(dim, dim)
|
132 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
133 |
+
self.softmax = nn.Softmax(dim=-1)
|
134 |
+
|
135 |
+
def forward(self, x, mask=None):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
x: input features with shape of (num_windows*B, N, C)
|
139 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
140 |
+
"""
|
141 |
+
B_, N, C = x.shape
|
142 |
+
qkv_bias = None
|
143 |
+
if self.q_bias is not None:
|
144 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
145 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
146 |
+
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
147 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
148 |
+
|
149 |
+
# cosine attention
|
150 |
+
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
151 |
+
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
|
152 |
+
attn = attn * logit_scale
|
153 |
+
|
154 |
+
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
|
155 |
+
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
156 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
157 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
158 |
+
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
159 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
160 |
+
|
161 |
+
if mask is not None:
|
162 |
+
nW = mask.shape[0]
|
163 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
164 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
165 |
+
attn = self.softmax(attn)
|
166 |
+
else:
|
167 |
+
attn = self.softmax(attn)
|
168 |
+
|
169 |
+
attn = self.attn_drop(attn)
|
170 |
+
|
171 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
172 |
+
x = self.proj(x)
|
173 |
+
x = self.proj_drop(x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
def extra_repr(self) -> str:
|
177 |
+
return f'dim={self.dim}, window_size={self.window_size}, ' \
|
178 |
+
f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
|
179 |
+
|
180 |
+
def flops(self, N):
|
181 |
+
# calculate flops for 1 window with token length of N
|
182 |
+
flops = 0
|
183 |
+
# qkv = self.qkv(x)
|
184 |
+
flops += N * self.dim * 3 * self.dim
|
185 |
+
# attn = (q @ k.transpose(-2, -1))
|
186 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
187 |
+
# x = (attn @ v)
|
188 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
189 |
+
# x = self.proj(x)
|
190 |
+
flops += N * self.dim * self.dim
|
191 |
+
return flops
|
192 |
+
|
193 |
+
class SwinTransformerBlock(nn.Module):
|
194 |
+
r""" Swin Transformer Block.
|
195 |
+
Args:
|
196 |
+
dim (int): Number of input channels.
|
197 |
+
input_resolution (tuple[int]): Input resulotion.
|
198 |
+
num_heads (int): Number of attention heads.
|
199 |
+
window_size (int): Window size.
|
200 |
+
shift_size (int): Shift size for SW-MSA.
|
201 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
202 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
203 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
204 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
205 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
206 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
207 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
208 |
+
pretrained_window_size (int): Window size in pre-training.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
212 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
213 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
|
214 |
+
super().__init__()
|
215 |
+
self.dim = dim
|
216 |
+
self.input_resolution = input_resolution
|
217 |
+
self.num_heads = num_heads
|
218 |
+
self.window_size = window_size
|
219 |
+
self.shift_size = shift_size
|
220 |
+
self.mlp_ratio = mlp_ratio
|
221 |
+
if min(self.input_resolution) <= self.window_size:
|
222 |
+
# if window size is larger than input resolution, we don't partition windows
|
223 |
+
self.shift_size = 0
|
224 |
+
self.window_size = min(self.input_resolution)
|
225 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
226 |
+
|
227 |
+
self.norm1 = norm_layer(dim)
|
228 |
+
self.attn = WindowAttention(
|
229 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
230 |
+
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
231 |
+
pretrained_window_size=to_2tuple(pretrained_window_size))
|
232 |
+
|
233 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
234 |
+
self.norm2 = norm_layer(dim)
|
235 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
236 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
237 |
+
|
238 |
+
if self.shift_size > 0:
|
239 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
240 |
+
else:
|
241 |
+
attn_mask = None
|
242 |
+
|
243 |
+
self.register_buffer("attn_mask", attn_mask)
|
244 |
+
|
245 |
+
def calculate_mask(self, x_size):
|
246 |
+
# calculate attention mask for SW-MSA
|
247 |
+
H, W = x_size
|
248 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
249 |
+
h_slices = (slice(0, -self.window_size),
|
250 |
+
slice(-self.window_size, -self.shift_size),
|
251 |
+
slice(-self.shift_size, None))
|
252 |
+
w_slices = (slice(0, -self.window_size),
|
253 |
+
slice(-self.window_size, -self.shift_size),
|
254 |
+
slice(-self.shift_size, None))
|
255 |
+
cnt = 0
|
256 |
+
for h in h_slices:
|
257 |
+
for w in w_slices:
|
258 |
+
img_mask[:, h, w, :] = cnt
|
259 |
+
cnt += 1
|
260 |
+
|
261 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
262 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
263 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
264 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
265 |
+
|
266 |
+
return attn_mask
|
267 |
+
|
268 |
+
def forward(self, x, x_size):
|
269 |
+
H, W = x_size
|
270 |
+
B, L, C = x.shape
|
271 |
+
#assert L == H * W, "input feature has wrong size"
|
272 |
+
|
273 |
+
shortcut = x
|
274 |
+
x = x.view(B, H, W, C)
|
275 |
+
|
276 |
+
# cyclic shift
|
277 |
+
if self.shift_size > 0:
|
278 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
279 |
+
else:
|
280 |
+
shifted_x = x
|
281 |
+
|
282 |
+
# partition windows
|
283 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
284 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
285 |
+
|
286 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
287 |
+
if self.input_resolution == x_size:
|
288 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
289 |
+
else:
|
290 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
291 |
+
|
292 |
+
# merge windows
|
293 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
294 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
295 |
+
|
296 |
+
# reverse cyclic shift
|
297 |
+
if self.shift_size > 0:
|
298 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
299 |
+
else:
|
300 |
+
x = shifted_x
|
301 |
+
x = x.view(B, H * W, C)
|
302 |
+
x = shortcut + self.drop_path(self.norm1(x))
|
303 |
+
|
304 |
+
# FFN
|
305 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
306 |
+
|
307 |
+
return x
|
308 |
+
|
309 |
+
def extra_repr(self) -> str:
|
310 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
311 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
312 |
+
|
313 |
+
def flops(self):
|
314 |
+
flops = 0
|
315 |
+
H, W = self.input_resolution
|
316 |
+
# norm1
|
317 |
+
flops += self.dim * H * W
|
318 |
+
# W-MSA/SW-MSA
|
319 |
+
nW = H * W / self.window_size / self.window_size
|
320 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
321 |
+
# mlp
|
322 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
323 |
+
# norm2
|
324 |
+
flops += self.dim * H * W
|
325 |
+
return flops
|
326 |
+
|
327 |
+
class PatchMerging(nn.Module):
|
328 |
+
r""" Patch Merging Layer.
|
329 |
+
Args:
|
330 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
331 |
+
dim (int): Number of input channels.
|
332 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
333 |
+
"""
|
334 |
+
|
335 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
336 |
+
super().__init__()
|
337 |
+
self.input_resolution = input_resolution
|
338 |
+
self.dim = dim
|
339 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
340 |
+
self.norm = norm_layer(2 * dim)
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
"""
|
344 |
+
x: B, H*W, C
|
345 |
+
"""
|
346 |
+
H, W = self.input_resolution
|
347 |
+
B, L, C = x.shape
|
348 |
+
assert L == H * W, "input feature has wrong size"
|
349 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
350 |
+
|
351 |
+
x = x.view(B, H, W, C)
|
352 |
+
|
353 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
354 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
355 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
356 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
357 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
358 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
359 |
+
|
360 |
+
x = self.reduction(x)
|
361 |
+
x = self.norm(x)
|
362 |
+
|
363 |
+
return x
|
364 |
+
|
365 |
+
def extra_repr(self) -> str:
|
366 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
367 |
+
|
368 |
+
def flops(self):
|
369 |
+
H, W = self.input_resolution
|
370 |
+
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
371 |
+
flops += H * W * self.dim // 2
|
372 |
+
return flops
|
373 |
+
|
374 |
+
class BasicLayer(nn.Module):
|
375 |
+
""" A basic Swin Transformer layer for one stage.
|
376 |
+
Args:
|
377 |
+
dim (int): Number of input channels.
|
378 |
+
input_resolution (tuple[int]): Input resolution.
|
379 |
+
depth (int): Number of blocks.
|
380 |
+
num_heads (int): Number of attention heads.
|
381 |
+
window_size (int): Local window size.
|
382 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
383 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
384 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
385 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
386 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
387 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
388 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
389 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
390 |
+
pretrained_window_size (int): Local window size in pre-training.
|
391 |
+
"""
|
392 |
+
|
393 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
394 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
395 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
396 |
+
pretrained_window_size=0):
|
397 |
+
|
398 |
+
super().__init__()
|
399 |
+
self.dim = dim
|
400 |
+
self.input_resolution = input_resolution
|
401 |
+
self.depth = depth
|
402 |
+
self.use_checkpoint = use_checkpoint
|
403 |
+
|
404 |
+
# build blocks
|
405 |
+
self.blocks = nn.ModuleList([
|
406 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
407 |
+
num_heads=num_heads, window_size=window_size,
|
408 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
409 |
+
mlp_ratio=mlp_ratio,
|
410 |
+
qkv_bias=qkv_bias,
|
411 |
+
drop=drop, attn_drop=attn_drop,
|
412 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
413 |
+
norm_layer=norm_layer,
|
414 |
+
pretrained_window_size=pretrained_window_size)
|
415 |
+
for i in range(depth)])
|
416 |
+
|
417 |
+
# patch merging layer
|
418 |
+
if downsample is not None:
|
419 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
420 |
+
else:
|
421 |
+
self.downsample = None
|
422 |
+
|
423 |
+
def forward(self, x, x_size):
|
424 |
+
for blk in self.blocks:
|
425 |
+
if self.use_checkpoint:
|
426 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
427 |
+
else:
|
428 |
+
x = blk(x, x_size)
|
429 |
+
if self.downsample is not None:
|
430 |
+
x = self.downsample(x)
|
431 |
+
return x
|
432 |
+
|
433 |
+
def extra_repr(self) -> str:
|
434 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
435 |
+
|
436 |
+
def flops(self):
|
437 |
+
flops = 0
|
438 |
+
for blk in self.blocks:
|
439 |
+
flops += blk.flops()
|
440 |
+
if self.downsample is not None:
|
441 |
+
flops += self.downsample.flops()
|
442 |
+
return flops
|
443 |
+
|
444 |
+
def _init_respostnorm(self):
|
445 |
+
for blk in self.blocks:
|
446 |
+
nn.init.constant_(blk.norm1.bias, 0)
|
447 |
+
nn.init.constant_(blk.norm1.weight, 0)
|
448 |
+
nn.init.constant_(blk.norm2.bias, 0)
|
449 |
+
nn.init.constant_(blk.norm2.weight, 0)
|
450 |
+
|
451 |
+
class PatchEmbed(nn.Module):
|
452 |
+
r""" Image to Patch Embedding
|
453 |
+
Args:
|
454 |
+
img_size (int): Image size. Default: 224.
|
455 |
+
patch_size (int): Patch token size. Default: 4.
|
456 |
+
in_chans (int): Number of input image channels. Default: 3.
|
457 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
458 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
459 |
+
"""
|
460 |
+
|
461 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
462 |
+
super().__init__()
|
463 |
+
img_size = to_2tuple(img_size)
|
464 |
+
patch_size = to_2tuple(patch_size)
|
465 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
466 |
+
self.img_size = img_size
|
467 |
+
self.patch_size = patch_size
|
468 |
+
self.patches_resolution = patches_resolution
|
469 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
470 |
+
|
471 |
+
self.in_chans = in_chans
|
472 |
+
self.embed_dim = embed_dim
|
473 |
+
|
474 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
475 |
+
if norm_layer is not None:
|
476 |
+
self.norm = norm_layer(embed_dim)
|
477 |
+
else:
|
478 |
+
self.norm = None
|
479 |
+
|
480 |
+
def forward(self, x):
|
481 |
+
B, C, H, W = x.shape
|
482 |
+
# FIXME look at relaxing size constraints
|
483 |
+
# assert H == self.img_size[0] and W == self.img_size[1],
|
484 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
485 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
486 |
+
if self.norm is not None:
|
487 |
+
x = self.norm(x)
|
488 |
+
return x
|
489 |
+
|
490 |
+
def flops(self):
|
491 |
+
Ho, Wo = self.patches_resolution
|
492 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
493 |
+
if self.norm is not None:
|
494 |
+
flops += Ho * Wo * self.embed_dim
|
495 |
+
return flops
|
496 |
+
|
497 |
+
class RSTB(nn.Module):
|
498 |
+
"""Residual Swin Transformer Block (RSTB).
|
499 |
+
|
500 |
+
Args:
|
501 |
+
dim (int): Number of input channels.
|
502 |
+
input_resolution (tuple[int]): Input resolution.
|
503 |
+
depth (int): Number of blocks.
|
504 |
+
num_heads (int): Number of attention heads.
|
505 |
+
window_size (int): Local window size.
|
506 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
507 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
508 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
509 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
510 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
511 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
512 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
513 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
514 |
+
img_size: Input image size.
|
515 |
+
patch_size: Patch size.
|
516 |
+
resi_connection: The convolutional block before residual connection.
|
517 |
+
"""
|
518 |
+
|
519 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
520 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
521 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
522 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
523 |
+
super(RSTB, self).__init__()
|
524 |
+
|
525 |
+
self.dim = dim
|
526 |
+
self.input_resolution = input_resolution
|
527 |
+
|
528 |
+
self.residual_group = BasicLayer(dim=dim,
|
529 |
+
input_resolution=input_resolution,
|
530 |
+
depth=depth,
|
531 |
+
num_heads=num_heads,
|
532 |
+
window_size=window_size,
|
533 |
+
mlp_ratio=mlp_ratio,
|
534 |
+
qkv_bias=qkv_bias,
|
535 |
+
drop=drop, attn_drop=attn_drop,
|
536 |
+
drop_path=drop_path,
|
537 |
+
norm_layer=norm_layer,
|
538 |
+
downsample=downsample,
|
539 |
+
use_checkpoint=use_checkpoint)
|
540 |
+
|
541 |
+
if resi_connection == '1conv':
|
542 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
543 |
+
elif resi_connection == '3conv':
|
544 |
+
# to save parameters and memory
|
545 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
546 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
547 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
548 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
549 |
+
|
550 |
+
self.patch_embed = PatchEmbed(
|
551 |
+
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
|
552 |
+
norm_layer=None)
|
553 |
+
|
554 |
+
self.patch_unembed = PatchUnEmbed(
|
555 |
+
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
|
556 |
+
norm_layer=None)
|
557 |
+
|
558 |
+
def forward(self, x, x_size):
|
559 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
560 |
+
|
561 |
+
def flops(self):
|
562 |
+
flops = 0
|
563 |
+
flops += self.residual_group.flops()
|
564 |
+
H, W = self.input_resolution
|
565 |
+
flops += H * W * self.dim * self.dim * 9
|
566 |
+
flops += self.patch_embed.flops()
|
567 |
+
flops += self.patch_unembed.flops()
|
568 |
+
|
569 |
+
return flops
|
570 |
+
|
571 |
+
class PatchUnEmbed(nn.Module):
|
572 |
+
r""" Image to Patch Unembedding
|
573 |
+
|
574 |
+
Args:
|
575 |
+
img_size (int): Image size. Default: 224.
|
576 |
+
patch_size (int): Patch token size. Default: 4.
|
577 |
+
in_chans (int): Number of input image channels. Default: 3.
|
578 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
579 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
580 |
+
"""
|
581 |
+
|
582 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
583 |
+
super().__init__()
|
584 |
+
img_size = to_2tuple(img_size)
|
585 |
+
patch_size = to_2tuple(patch_size)
|
586 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
587 |
+
self.img_size = img_size
|
588 |
+
self.patch_size = patch_size
|
589 |
+
self.patches_resolution = patches_resolution
|
590 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
591 |
+
|
592 |
+
self.in_chans = in_chans
|
593 |
+
self.embed_dim = embed_dim
|
594 |
+
|
595 |
+
def forward(self, x, x_size):
|
596 |
+
B, HW, C = x.shape
|
597 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
598 |
+
return x
|
599 |
+
|
600 |
+
def flops(self):
|
601 |
+
flops = 0
|
602 |
+
return flops
|
603 |
+
|
604 |
+
|
605 |
+
class Upsample(nn.Sequential):
|
606 |
+
"""Upsample module.
|
607 |
+
|
608 |
+
Args:
|
609 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
610 |
+
num_feat (int): Channel number of intermediate features.
|
611 |
+
"""
|
612 |
+
|
613 |
+
def __init__(self, scale, num_feat):
|
614 |
+
m = []
|
615 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
616 |
+
for _ in range(int(math.log(scale, 2))):
|
617 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
618 |
+
m.append(nn.PixelShuffle(2))
|
619 |
+
elif scale == 3:
|
620 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
621 |
+
m.append(nn.PixelShuffle(3))
|
622 |
+
else:
|
623 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
624 |
+
super(Upsample, self).__init__(*m)
|
625 |
+
|
626 |
+
class Upsample_hf(nn.Sequential):
|
627 |
+
"""Upsample module.
|
628 |
+
|
629 |
+
Args:
|
630 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
631 |
+
num_feat (int): Channel number of intermediate features.
|
632 |
+
"""
|
633 |
+
|
634 |
+
def __init__(self, scale, num_feat):
|
635 |
+
m = []
|
636 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
637 |
+
for _ in range(int(math.log(scale, 2))):
|
638 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
639 |
+
m.append(nn.PixelShuffle(2))
|
640 |
+
elif scale == 3:
|
641 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
642 |
+
m.append(nn.PixelShuffle(3))
|
643 |
+
else:
|
644 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
645 |
+
super(Upsample_hf, self).__init__(*m)
|
646 |
+
|
647 |
+
|
648 |
+
class UpsampleOneStep(nn.Sequential):
|
649 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
650 |
+
Used in lightweight SR to save parameters.
|
651 |
+
|
652 |
+
Args:
|
653 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
654 |
+
num_feat (int): Channel number of intermediate features.
|
655 |
+
|
656 |
+
"""
|
657 |
+
|
658 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
659 |
+
self.num_feat = num_feat
|
660 |
+
self.input_resolution = input_resolution
|
661 |
+
m = []
|
662 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
663 |
+
m.append(nn.PixelShuffle(scale))
|
664 |
+
super(UpsampleOneStep, self).__init__(*m)
|
665 |
+
|
666 |
+
def flops(self):
|
667 |
+
H, W = self.input_resolution
|
668 |
+
flops = H * W * self.num_feat * 3 * 9
|
669 |
+
return flops
|
670 |
+
|
671 |
+
|
672 |
+
|
673 |
+
class Swin2SR(nn.Module):
|
674 |
+
r""" Swin2SR
|
675 |
+
A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
|
676 |
+
|
677 |
+
Args:
|
678 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
679 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
680 |
+
in_chans (int): Number of input image channels. Default: 3
|
681 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
682 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
683 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
684 |
+
window_size (int): Window size. Default: 7
|
685 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
686 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
687 |
+
drop_rate (float): Dropout rate. Default: 0
|
688 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
689 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
690 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
691 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
692 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
693 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
694 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
695 |
+
img_range: Image range. 1. or 255.
|
696 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
697 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
698 |
+
"""
|
699 |
+
|
700 |
+
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
701 |
+
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
702 |
+
window_size=7, mlp_ratio=4., qkv_bias=True,
|
703 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
704 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
705 |
+
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
706 |
+
**kwargs):
|
707 |
+
super(Swin2SR, self).__init__()
|
708 |
+
num_in_ch = in_chans
|
709 |
+
num_out_ch = in_chans
|
710 |
+
num_feat = 64
|
711 |
+
self.img_range = img_range
|
712 |
+
if in_chans == 3:
|
713 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
714 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
715 |
+
else:
|
716 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
717 |
+
self.upscale = upscale
|
718 |
+
self.upsampler = upsampler
|
719 |
+
self.window_size = window_size
|
720 |
+
|
721 |
+
#####################################################################################################
|
722 |
+
################################### 1, shallow feature extraction ###################################
|
723 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
724 |
+
|
725 |
+
#####################################################################################################
|
726 |
+
################################### 2, deep feature extraction ######################################
|
727 |
+
self.num_layers = len(depths)
|
728 |
+
self.embed_dim = embed_dim
|
729 |
+
self.ape = ape
|
730 |
+
self.patch_norm = patch_norm
|
731 |
+
self.num_features = embed_dim
|
732 |
+
self.mlp_ratio = mlp_ratio
|
733 |
+
|
734 |
+
# split image into non-overlapping patches
|
735 |
+
self.patch_embed = PatchEmbed(
|
736 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
737 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
738 |
+
num_patches = self.patch_embed.num_patches
|
739 |
+
patches_resolution = self.patch_embed.patches_resolution
|
740 |
+
self.patches_resolution = patches_resolution
|
741 |
+
|
742 |
+
# merge non-overlapping patches into image
|
743 |
+
self.patch_unembed = PatchUnEmbed(
|
744 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
745 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
746 |
+
|
747 |
+
# absolute position embedding
|
748 |
+
if self.ape:
|
749 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
750 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
751 |
+
|
752 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
753 |
+
|
754 |
+
# stochastic depth
|
755 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
756 |
+
|
757 |
+
# build Residual Swin Transformer blocks (RSTB)
|
758 |
+
self.layers = nn.ModuleList()
|
759 |
+
for i_layer in range(self.num_layers):
|
760 |
+
layer = RSTB(dim=embed_dim,
|
761 |
+
input_resolution=(patches_resolution[0],
|
762 |
+
patches_resolution[1]),
|
763 |
+
depth=depths[i_layer],
|
764 |
+
num_heads=num_heads[i_layer],
|
765 |
+
window_size=window_size,
|
766 |
+
mlp_ratio=self.mlp_ratio,
|
767 |
+
qkv_bias=qkv_bias,
|
768 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
769 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
770 |
+
norm_layer=norm_layer,
|
771 |
+
downsample=None,
|
772 |
+
use_checkpoint=use_checkpoint,
|
773 |
+
img_size=img_size,
|
774 |
+
patch_size=patch_size,
|
775 |
+
resi_connection=resi_connection
|
776 |
+
|
777 |
+
)
|
778 |
+
self.layers.append(layer)
|
779 |
+
|
780 |
+
if self.upsampler == 'pixelshuffle_hf':
|
781 |
+
self.layers_hf = nn.ModuleList()
|
782 |
+
for i_layer in range(self.num_layers):
|
783 |
+
layer = RSTB(dim=embed_dim,
|
784 |
+
input_resolution=(patches_resolution[0],
|
785 |
+
patches_resolution[1]),
|
786 |
+
depth=depths[i_layer],
|
787 |
+
num_heads=num_heads[i_layer],
|
788 |
+
window_size=window_size,
|
789 |
+
mlp_ratio=self.mlp_ratio,
|
790 |
+
qkv_bias=qkv_bias,
|
791 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
792 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
793 |
+
norm_layer=norm_layer,
|
794 |
+
downsample=None,
|
795 |
+
use_checkpoint=use_checkpoint,
|
796 |
+
img_size=img_size,
|
797 |
+
patch_size=patch_size,
|
798 |
+
resi_connection=resi_connection
|
799 |
+
|
800 |
+
)
|
801 |
+
self.layers_hf.append(layer)
|
802 |
+
|
803 |
+
self.norm = norm_layer(self.num_features)
|
804 |
+
|
805 |
+
# build the last conv layer in deep feature extraction
|
806 |
+
if resi_connection == '1conv':
|
807 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
808 |
+
elif resi_connection == '3conv':
|
809 |
+
# to save parameters and memory
|
810 |
+
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
811 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
812 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
813 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
814 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
815 |
+
|
816 |
+
#####################################################################################################
|
817 |
+
################################ 3, high quality image reconstruction ################################
|
818 |
+
if self.upsampler == 'pixelshuffle':
|
819 |
+
# for classical SR
|
820 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
821 |
+
nn.LeakyReLU(inplace=True))
|
822 |
+
self.upsample = Upsample(upscale, num_feat)
|
823 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
824 |
+
elif self.upsampler == 'pixelshuffle_aux':
|
825 |
+
self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
826 |
+
self.conv_before_upsample = nn.Sequential(
|
827 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
828 |
+
nn.LeakyReLU(inplace=True))
|
829 |
+
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
830 |
+
self.conv_after_aux = nn.Sequential(
|
831 |
+
nn.Conv2d(3, num_feat, 3, 1, 1),
|
832 |
+
nn.LeakyReLU(inplace=True))
|
833 |
+
self.upsample = Upsample(upscale, num_feat)
|
834 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
835 |
+
|
836 |
+
elif self.upsampler == 'pixelshuffle_hf':
|
837 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
838 |
+
nn.LeakyReLU(inplace=True))
|
839 |
+
self.upsample = Upsample(upscale, num_feat)
|
840 |
+
self.upsample_hf = Upsample_hf(upscale, num_feat)
|
841 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
842 |
+
self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
|
843 |
+
nn.LeakyReLU(inplace=True))
|
844 |
+
self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
845 |
+
self.conv_before_upsample_hf = nn.Sequential(
|
846 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
847 |
+
nn.LeakyReLU(inplace=True))
|
848 |
+
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
849 |
+
|
850 |
+
elif self.upsampler == 'pixelshuffledirect':
|
851 |
+
# for lightweight SR (to save parameters)
|
852 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
853 |
+
(patches_resolution[0], patches_resolution[1]))
|
854 |
+
elif self.upsampler == 'nearest+conv':
|
855 |
+
# for real-world SR (less artifacts)
|
856 |
+
assert self.upscale == 4, 'only support x4 now.'
|
857 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
858 |
+
nn.LeakyReLU(inplace=True))
|
859 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
860 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
861 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
862 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
863 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
864 |
+
else:
|
865 |
+
# for image denoising and JPEG compression artifact reduction
|
866 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
867 |
+
|
868 |
+
self.apply(self._init_weights)
|
869 |
+
|
870 |
+
def _init_weights(self, m):
|
871 |
+
if isinstance(m, nn.Linear):
|
872 |
+
trunc_normal_(m.weight, std=.02)
|
873 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
874 |
+
nn.init.constant_(m.bias, 0)
|
875 |
+
elif isinstance(m, nn.LayerNorm):
|
876 |
+
nn.init.constant_(m.bias, 0)
|
877 |
+
nn.init.constant_(m.weight, 1.0)
|
878 |
+
|
879 |
+
@torch.jit.ignore
|
880 |
+
def no_weight_decay(self):
|
881 |
+
return {'absolute_pos_embed'}
|
882 |
+
|
883 |
+
@torch.jit.ignore
|
884 |
+
def no_weight_decay_keywords(self):
|
885 |
+
return {'relative_position_bias_table'}
|
886 |
+
|
887 |
+
def check_image_size(self, x):
|
888 |
+
_, _, h, w = x.size()
|
889 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
890 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
891 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
892 |
+
return x
|
893 |
+
|
894 |
+
def forward_features(self, x):
|
895 |
+
x_size = (x.shape[2], x.shape[3])
|
896 |
+
x = self.patch_embed(x)
|
897 |
+
if self.ape:
|
898 |
+
x = x + self.absolute_pos_embed
|
899 |
+
x = self.pos_drop(x)
|
900 |
+
|
901 |
+
for layer in self.layers:
|
902 |
+
x = layer(x, x_size)
|
903 |
+
|
904 |
+
x = self.norm(x) # B L C
|
905 |
+
x = self.patch_unembed(x, x_size)
|
906 |
+
|
907 |
+
return x
|
908 |
+
|
909 |
+
def forward_features_hf(self, x):
|
910 |
+
x_size = (x.shape[2], x.shape[3])
|
911 |
+
x = self.patch_embed(x)
|
912 |
+
if self.ape:
|
913 |
+
x = x + self.absolute_pos_embed
|
914 |
+
x = self.pos_drop(x)
|
915 |
+
|
916 |
+
for layer in self.layers_hf:
|
917 |
+
x = layer(x, x_size)
|
918 |
+
|
919 |
+
x = self.norm(x) # B L C
|
920 |
+
x = self.patch_unembed(x, x_size)
|
921 |
+
|
922 |
+
return x
|
923 |
+
|
924 |
+
def forward(self, x):
|
925 |
+
H, W = x.shape[2:]
|
926 |
+
x = self.check_image_size(x)
|
927 |
+
|
928 |
+
self.mean = self.mean.type_as(x)
|
929 |
+
x = (x - self.mean) * self.img_range
|
930 |
+
|
931 |
+
if self.upsampler == 'pixelshuffle':
|
932 |
+
# for classical SR
|
933 |
+
x = self.conv_first(x)
|
934 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
935 |
+
x = self.conv_before_upsample(x)
|
936 |
+
x = self.conv_last(self.upsample(x))
|
937 |
+
elif self.upsampler == 'pixelshuffle_aux':
|
938 |
+
bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
|
939 |
+
bicubic = self.conv_bicubic(bicubic)
|
940 |
+
x = self.conv_first(x)
|
941 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
942 |
+
x = self.conv_before_upsample(x)
|
943 |
+
aux = self.conv_aux(x) # b, 3, LR_H, LR_W
|
944 |
+
x = self.conv_after_aux(aux)
|
945 |
+
x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
|
946 |
+
x = self.conv_last(x)
|
947 |
+
aux = aux / self.img_range + self.mean
|
948 |
+
elif self.upsampler == 'pixelshuffle_hf':
|
949 |
+
# for classical SR with HF
|
950 |
+
x = self.conv_first(x)
|
951 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
952 |
+
x_before = self.conv_before_upsample(x)
|
953 |
+
x_out = self.conv_last(self.upsample(x_before))
|
954 |
+
|
955 |
+
x_hf = self.conv_first_hf(x_before)
|
956 |
+
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
957 |
+
x_hf = self.conv_before_upsample_hf(x_hf)
|
958 |
+
x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
|
959 |
+
x = x_out + x_hf
|
960 |
+
x_hf = x_hf / self.img_range + self.mean
|
961 |
+
|
962 |
+
elif self.upsampler == 'pixelshuffledirect':
|
963 |
+
# for lightweight SR
|
964 |
+
x = self.conv_first(x)
|
965 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
966 |
+
x = self.upsample(x)
|
967 |
+
elif self.upsampler == 'nearest+conv':
|
968 |
+
# for real-world SR
|
969 |
+
x = self.conv_first(x)
|
970 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
971 |
+
x = self.conv_before_upsample(x)
|
972 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
973 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
974 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
975 |
+
else:
|
976 |
+
# for image denoising and JPEG compression artifact reduction
|
977 |
+
x_first = self.conv_first(x)
|
978 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
979 |
+
x = x + self.conv_last(res)
|
980 |
+
|
981 |
+
x = x / self.img_range + self.mean
|
982 |
+
if self.upsampler == "pixelshuffle_aux":
|
983 |
+
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
984 |
+
|
985 |
+
elif self.upsampler == "pixelshuffle_hf":
|
986 |
+
x_out = x_out / self.img_range + self.mean
|
987 |
+
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
988 |
+
|
989 |
+
else:
|
990 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
991 |
+
|
992 |
+
def flops(self):
|
993 |
+
flops = 0
|
994 |
+
H, W = self.patches_resolution
|
995 |
+
flops += H * W * 3 * self.embed_dim * 9
|
996 |
+
flops += self.patch_embed.flops()
|
997 |
+
for layer in self.layers:
|
998 |
+
flops += layer.flops()
|
999 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
1000 |
+
flops += self.upsample.flops()
|
1001 |
+
return flops
|
1002 |
+
|
1003 |
+
|
1004 |
+
if __name__ == '__main__':
|
1005 |
+
upscale = 4
|
1006 |
+
window_size = 8
|
1007 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
1008 |
+
width = (720 // upscale // window_size + 1) * window_size
|
1009 |
+
model = Swin2SR(upscale=2, img_size=(height, width),
|
1010 |
+
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
1011 |
+
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
1012 |
+
print(model)
|
1013 |
+
print(height, width, model.flops() / 1e9)
|
1014 |
+
|
1015 |
+
x = torch.randn((1, 3, height, width))
|
1016 |
+
x = model(x)
|
1017 |
+
print(x.shape)
|
extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
ADDED
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
onUiLoaded(async() => {
|
2 |
+
const elementIDs = {
|
3 |
+
img2imgTabs: "#mode_img2img .tab-nav",
|
4 |
+
inpaint: "#img2maskimg",
|
5 |
+
inpaintSketch: "#inpaint_sketch",
|
6 |
+
rangeGroup: "#img2img_column_size",
|
7 |
+
sketch: "#img2img_sketch"
|
8 |
+
};
|
9 |
+
const tabNameToElementId = {
|
10 |
+
"Inpaint sketch": elementIDs.inpaintSketch,
|
11 |
+
"Inpaint": elementIDs.inpaint,
|
12 |
+
"Sketch": elementIDs.sketch
|
13 |
+
};
|
14 |
+
|
15 |
+
// Helper functions
|
16 |
+
// Get active tab
|
17 |
+
function getActiveTab(elements, all = false) {
|
18 |
+
const tabs = elements.img2imgTabs.querySelectorAll("button");
|
19 |
+
|
20 |
+
if (all) return tabs;
|
21 |
+
|
22 |
+
for (let tab of tabs) {
|
23 |
+
if (tab.classList.contains("selected")) {
|
24 |
+
return tab;
|
25 |
+
}
|
26 |
+
}
|
27 |
+
}
|
28 |
+
|
29 |
+
// Get tab ID
|
30 |
+
function getTabId(elements) {
|
31 |
+
const activeTab = getActiveTab(elements);
|
32 |
+
return tabNameToElementId[activeTab.innerText];
|
33 |
+
}
|
34 |
+
|
35 |
+
// Wait until opts loaded
|
36 |
+
async function waitForOpts() {
|
37 |
+
for (;;) {
|
38 |
+
if (window.opts && Object.keys(window.opts).length) {
|
39 |
+
return window.opts;
|
40 |
+
}
|
41 |
+
await new Promise(resolve => setTimeout(resolve, 100));
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
// Function for defining the "Ctrl", "Shift" and "Alt" keys
|
46 |
+
function isModifierKey(event, key) {
|
47 |
+
switch (key) {
|
48 |
+
case "Ctrl":
|
49 |
+
return event.ctrlKey;
|
50 |
+
case "Shift":
|
51 |
+
return event.shiftKey;
|
52 |
+
case "Alt":
|
53 |
+
return event.altKey;
|
54 |
+
default:
|
55 |
+
return false;
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
// Check if hotkey is valid
|
60 |
+
function isValidHotkey(value) {
|
61 |
+
const specialKeys = ["Ctrl", "Alt", "Shift", "Disable"];
|
62 |
+
return (
|
63 |
+
(typeof value === "string" &&
|
64 |
+
value.length === 1 &&
|
65 |
+
/[a-z]/i.test(value)) ||
|
66 |
+
specialKeys.includes(value)
|
67 |
+
);
|
68 |
+
}
|
69 |
+
|
70 |
+
// Normalize hotkey
|
71 |
+
function normalizeHotkey(hotkey) {
|
72 |
+
return hotkey.length === 1 ? "Key" + hotkey.toUpperCase() : hotkey;
|
73 |
+
}
|
74 |
+
|
75 |
+
// Format hotkey for display
|
76 |
+
function formatHotkeyForDisplay(hotkey) {
|
77 |
+
return hotkey.startsWith("Key") ? hotkey.slice(3) : hotkey;
|
78 |
+
}
|
79 |
+
|
80 |
+
// Create hotkey configuration with the provided options
|
81 |
+
function createHotkeyConfig(defaultHotkeysConfig, hotkeysConfigOpts) {
|
82 |
+
const result = {}; // Resulting hotkey configuration
|
83 |
+
const usedKeys = new Set(); // Set of used hotkeys
|
84 |
+
|
85 |
+
// Iterate through defaultHotkeysConfig keys
|
86 |
+
for (const key in defaultHotkeysConfig) {
|
87 |
+
const userValue = hotkeysConfigOpts[key]; // User-provided hotkey value
|
88 |
+
const defaultValue = defaultHotkeysConfig[key]; // Default hotkey value
|
89 |
+
|
90 |
+
// Apply appropriate value for undefined, boolean, or object userValue
|
91 |
+
if (
|
92 |
+
userValue === undefined ||
|
93 |
+
typeof userValue === "boolean" ||
|
94 |
+
typeof userValue === "object" ||
|
95 |
+
userValue === "disable"
|
96 |
+
) {
|
97 |
+
result[key] =
|
98 |
+
userValue === undefined ? defaultValue : userValue;
|
99 |
+
} else if (isValidHotkey(userValue)) {
|
100 |
+
const normalizedUserValue = normalizeHotkey(userValue);
|
101 |
+
|
102 |
+
// Check for conflicting hotkeys
|
103 |
+
if (!usedKeys.has(normalizedUserValue)) {
|
104 |
+
usedKeys.add(normalizedUserValue);
|
105 |
+
result[key] = normalizedUserValue;
|
106 |
+
} else {
|
107 |
+
console.error(
|
108 |
+
`Hotkey: ${formatHotkeyForDisplay(
|
109 |
+
userValue
|
110 |
+
)} for ${key} is repeated and conflicts with another hotkey. The default hotkey is used: ${formatHotkeyForDisplay(
|
111 |
+
defaultValue
|
112 |
+
)}`
|
113 |
+
);
|
114 |
+
result[key] = defaultValue;
|
115 |
+
}
|
116 |
+
} else {
|
117 |
+
console.error(
|
118 |
+
`Hotkey: ${formatHotkeyForDisplay(
|
119 |
+
userValue
|
120 |
+
)} for ${key} is not valid. The default hotkey is used: ${formatHotkeyForDisplay(
|
121 |
+
defaultValue
|
122 |
+
)}`
|
123 |
+
);
|
124 |
+
result[key] = defaultValue;
|
125 |
+
}
|
126 |
+
}
|
127 |
+
|
128 |
+
return result;
|
129 |
+
}
|
130 |
+
|
131 |
+
// Disables functions in the config object based on the provided list of function names
|
132 |
+
function disableFunctions(config, disabledFunctions) {
|
133 |
+
// Bind the hasOwnProperty method to the functionMap object to avoid errors
|
134 |
+
const hasOwnProperty =
|
135 |
+
Object.prototype.hasOwnProperty.bind(functionMap);
|
136 |
+
|
137 |
+
// Loop through the disabledFunctions array and disable the corresponding functions in the config object
|
138 |
+
disabledFunctions.forEach(funcName => {
|
139 |
+
if (hasOwnProperty(funcName)) {
|
140 |
+
const key = functionMap[funcName];
|
141 |
+
config[key] = "disable";
|
142 |
+
}
|
143 |
+
});
|
144 |
+
|
145 |
+
// Return the updated config object
|
146 |
+
return config;
|
147 |
+
}
|
148 |
+
|
149 |
+
/**
|
150 |
+
* The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio.
|
151 |
+
* If the image display property is set to 'none', the mask breaks. To fix this, the function
|
152 |
+
* temporarily sets the display property to 'block' and then hides the mask again after 300 milliseconds
|
153 |
+
* to avoid breaking the canvas. Additionally, the function adjusts the mask to work correctly on
|
154 |
+
* very long images.
|
155 |
+
*/
|
156 |
+
function restoreImgRedMask(elements) {
|
157 |
+
const mainTabId = getTabId(elements);
|
158 |
+
|
159 |
+
if (!mainTabId) return;
|
160 |
+
|
161 |
+
const mainTab = gradioApp().querySelector(mainTabId);
|
162 |
+
const img = mainTab.querySelector("img");
|
163 |
+
const imageARPreview = gradioApp().querySelector("#imageARPreview");
|
164 |
+
|
165 |
+
if (!img || !imageARPreview) return;
|
166 |
+
|
167 |
+
imageARPreview.style.transform = "";
|
168 |
+
if (parseFloat(mainTab.style.width) > 865) {
|
169 |
+
const transformString = mainTab.style.transform;
|
170 |
+
const scaleMatch = transformString.match(
|
171 |
+
/scale\(([-+]?[0-9]*\.?[0-9]+)\)/
|
172 |
+
);
|
173 |
+
let zoom = 1; // default zoom
|
174 |
+
|
175 |
+
if (scaleMatch && scaleMatch[1]) {
|
176 |
+
zoom = Number(scaleMatch[1]);
|
177 |
+
}
|
178 |
+
|
179 |
+
imageARPreview.style.transformOrigin = "0 0";
|
180 |
+
imageARPreview.style.transform = `scale(${zoom})`;
|
181 |
+
}
|
182 |
+
|
183 |
+
if (img.style.display !== "none") return;
|
184 |
+
|
185 |
+
img.style.display = "block";
|
186 |
+
|
187 |
+
setTimeout(() => {
|
188 |
+
img.style.display = "none";
|
189 |
+
}, 400);
|
190 |
+
}
|
191 |
+
|
192 |
+
const hotkeysConfigOpts = await waitForOpts();
|
193 |
+
|
194 |
+
// Default config
|
195 |
+
const defaultHotkeysConfig = {
|
196 |
+
canvas_hotkey_zoom: "Alt",
|
197 |
+
canvas_hotkey_adjust: "Ctrl",
|
198 |
+
canvas_hotkey_reset: "KeyR",
|
199 |
+
canvas_hotkey_fullscreen: "KeyS",
|
200 |
+
canvas_hotkey_move: "KeyF",
|
201 |
+
canvas_hotkey_overlap: "KeyO",
|
202 |
+
canvas_disabled_functions: [],
|
203 |
+
canvas_show_tooltip: true,
|
204 |
+
canvas_blur_prompt: false
|
205 |
+
};
|
206 |
+
|
207 |
+
const functionMap = {
|
208 |
+
"Zoom": "canvas_hotkey_zoom",
|
209 |
+
"Adjust brush size": "canvas_hotkey_adjust",
|
210 |
+
"Moving canvas": "canvas_hotkey_move",
|
211 |
+
"Fullscreen": "canvas_hotkey_fullscreen",
|
212 |
+
"Reset Zoom": "canvas_hotkey_reset",
|
213 |
+
"Overlap": "canvas_hotkey_overlap"
|
214 |
+
};
|
215 |
+
|
216 |
+
// Loading the configuration from opts
|
217 |
+
const preHotkeysConfig = createHotkeyConfig(
|
218 |
+
defaultHotkeysConfig,
|
219 |
+
hotkeysConfigOpts
|
220 |
+
);
|
221 |
+
|
222 |
+
// Disable functions that are not needed by the user
|
223 |
+
const hotkeysConfig = disableFunctions(
|
224 |
+
preHotkeysConfig,
|
225 |
+
preHotkeysConfig.canvas_disabled_functions
|
226 |
+
);
|
227 |
+
|
228 |
+
let isMoving = false;
|
229 |
+
let mouseX, mouseY;
|
230 |
+
let activeElement;
|
231 |
+
|
232 |
+
const elements = Object.fromEntries(
|
233 |
+
Object.keys(elementIDs).map(id => [
|
234 |
+
id,
|
235 |
+
gradioApp().querySelector(elementIDs[id])
|
236 |
+
])
|
237 |
+
);
|
238 |
+
const elemData = {};
|
239 |
+
|
240 |
+
// Apply functionality to the range inputs. Restore redmask and correct for long images.
|
241 |
+
const rangeInputs = elements.rangeGroup ?
|
242 |
+
Array.from(elements.rangeGroup.querySelectorAll("input")) :
|
243 |
+
[
|
244 |
+
gradioApp().querySelector("#img2img_width input[type='range']"),
|
245 |
+
gradioApp().querySelector("#img2img_height input[type='range']")
|
246 |
+
];
|
247 |
+
|
248 |
+
for (const input of rangeInputs) {
|
249 |
+
input?.addEventListener("input", () => restoreImgRedMask(elements));
|
250 |
+
}
|
251 |
+
|
252 |
+
function applyZoomAndPan(elemId) {
|
253 |
+
const targetElement = gradioApp().querySelector(elemId);
|
254 |
+
|
255 |
+
if (!targetElement) {
|
256 |
+
console.log("Element not found");
|
257 |
+
return;
|
258 |
+
}
|
259 |
+
|
260 |
+
targetElement.style.transformOrigin = "0 0";
|
261 |
+
|
262 |
+
elemData[elemId] = {
|
263 |
+
zoom: 1,
|
264 |
+
panX: 0,
|
265 |
+
panY: 0
|
266 |
+
};
|
267 |
+
let fullScreenMode = false;
|
268 |
+
|
269 |
+
// Create tooltip
|
270 |
+
function createTooltip() {
|
271 |
+
const toolTipElemnt =
|
272 |
+
targetElement.querySelector(".image-container");
|
273 |
+
const tooltip = document.createElement("div");
|
274 |
+
tooltip.className = "canvas-tooltip";
|
275 |
+
|
276 |
+
// Creating an item of information
|
277 |
+
const info = document.createElement("i");
|
278 |
+
info.className = "canvas-tooltip-info";
|
279 |
+
info.textContent = "";
|
280 |
+
|
281 |
+
// Create a container for the contents of the tooltip
|
282 |
+
const tooltipContent = document.createElement("div");
|
283 |
+
tooltipContent.className = "canvas-tooltip-content";
|
284 |
+
|
285 |
+
// Define an array with hotkey information and their actions
|
286 |
+
const hotkeysInfo = [
|
287 |
+
{
|
288 |
+
configKey: "canvas_hotkey_zoom",
|
289 |
+
action: "Zoom canvas",
|
290 |
+
keySuffix: " + wheel"
|
291 |
+
},
|
292 |
+
{
|
293 |
+
configKey: "canvas_hotkey_adjust",
|
294 |
+
action: "Adjust brush size",
|
295 |
+
keySuffix: " + wheel"
|
296 |
+
},
|
297 |
+
{configKey: "canvas_hotkey_reset", action: "Reset zoom"},
|
298 |
+
{
|
299 |
+
configKey: "canvas_hotkey_fullscreen",
|
300 |
+
action: "Fullscreen mode"
|
301 |
+
},
|
302 |
+
{configKey: "canvas_hotkey_move", action: "Move canvas"},
|
303 |
+
{configKey: "canvas_hotkey_overlap", action: "Overlap"}
|
304 |
+
];
|
305 |
+
|
306 |
+
// Create hotkeys array with disabled property based on the config values
|
307 |
+
const hotkeys = hotkeysInfo.map(info => {
|
308 |
+
const configValue = hotkeysConfig[info.configKey];
|
309 |
+
const key = info.keySuffix ?
|
310 |
+
`${configValue}${info.keySuffix}` :
|
311 |
+
configValue.charAt(configValue.length - 1);
|
312 |
+
return {
|
313 |
+
key,
|
314 |
+
action: info.action,
|
315 |
+
disabled: configValue === "disable"
|
316 |
+
};
|
317 |
+
});
|
318 |
+
|
319 |
+
for (const hotkey of hotkeys) {
|
320 |
+
if (hotkey.disabled) {
|
321 |
+
continue;
|
322 |
+
}
|
323 |
+
|
324 |
+
const p = document.createElement("p");
|
325 |
+
p.innerHTML = `<b>${hotkey.key}</b> - ${hotkey.action}`;
|
326 |
+
tooltipContent.appendChild(p);
|
327 |
+
}
|
328 |
+
|
329 |
+
// Add information and content elements to the tooltip element
|
330 |
+
tooltip.appendChild(info);
|
331 |
+
tooltip.appendChild(tooltipContent);
|
332 |
+
|
333 |
+
// Add a hint element to the target element
|
334 |
+
toolTipElemnt.appendChild(tooltip);
|
335 |
+
}
|
336 |
+
|
337 |
+
//Show tool tip if setting enable
|
338 |
+
if (hotkeysConfig.canvas_show_tooltip) {
|
339 |
+
createTooltip();
|
340 |
+
}
|
341 |
+
|
342 |
+
// In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui.
|
343 |
+
function fixCanvas() {
|
344 |
+
const activeTab = getActiveTab(elements).textContent.trim();
|
345 |
+
|
346 |
+
if (activeTab !== "img2img") {
|
347 |
+
const img = targetElement.querySelector(`${elemId} img`);
|
348 |
+
|
349 |
+
if (img && img.style.display !== "none") {
|
350 |
+
img.style.display = "none";
|
351 |
+
img.style.visibility = "hidden";
|
352 |
+
}
|
353 |
+
}
|
354 |
+
}
|
355 |
+
|
356 |
+
// Reset the zoom level and pan position of the target element to their initial values
|
357 |
+
function resetZoom() {
|
358 |
+
elemData[elemId] = {
|
359 |
+
zoomLevel: 1,
|
360 |
+
panX: 0,
|
361 |
+
panY: 0
|
362 |
+
};
|
363 |
+
|
364 |
+
fixCanvas();
|
365 |
+
targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
|
366 |
+
|
367 |
+
const canvas = gradioApp().querySelector(
|
368 |
+
`${elemId} canvas[key="interface"]`
|
369 |
+
);
|
370 |
+
|
371 |
+
toggleOverlap("off");
|
372 |
+
fullScreenMode = false;
|
373 |
+
|
374 |
+
if (
|
375 |
+
canvas &&
|
376 |
+
parseFloat(canvas.style.width) > 865 &&
|
377 |
+
parseFloat(targetElement.style.width) > 865
|
378 |
+
) {
|
379 |
+
fitToElement();
|
380 |
+
return;
|
381 |
+
}
|
382 |
+
|
383 |
+
targetElement.style.width = "";
|
384 |
+
if (canvas) {
|
385 |
+
targetElement.style.height = canvas.style.height;
|
386 |
+
}
|
387 |
+
}
|
388 |
+
|
389 |
+
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
|
390 |
+
function toggleOverlap(forced = "") {
|
391 |
+
const zIndex1 = "0";
|
392 |
+
const zIndex2 = "998";
|
393 |
+
|
394 |
+
targetElement.style.zIndex =
|
395 |
+
targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1;
|
396 |
+
|
397 |
+
if (forced === "off") {
|
398 |
+
targetElement.style.zIndex = zIndex1;
|
399 |
+
} else if (forced === "on") {
|
400 |
+
targetElement.style.zIndex = zIndex2;
|
401 |
+
}
|
402 |
+
}
|
403 |
+
|
404 |
+
// Adjust the brush size based on the deltaY value from a mouse wheel event
|
405 |
+
function adjustBrushSize(
|
406 |
+
elemId,
|
407 |
+
deltaY,
|
408 |
+
withoutValue = false,
|
409 |
+
percentage = 5
|
410 |
+
) {
|
411 |
+
const input =
|
412 |
+
gradioApp().querySelector(
|
413 |
+
`${elemId} input[aria-label='Brush radius']`
|
414 |
+
) ||
|
415 |
+
gradioApp().querySelector(
|
416 |
+
`${elemId} button[aria-label="Use brush"]`
|
417 |
+
);
|
418 |
+
|
419 |
+
if (input) {
|
420 |
+
input.click();
|
421 |
+
if (!withoutValue) {
|
422 |
+
const maxValue =
|
423 |
+
parseFloat(input.getAttribute("max")) || 100;
|
424 |
+
const changeAmount = maxValue * (percentage / 100);
|
425 |
+
const newValue =
|
426 |
+
parseFloat(input.value) +
|
427 |
+
(deltaY > 0 ? -changeAmount : changeAmount);
|
428 |
+
input.value = Math.min(Math.max(newValue, 0), maxValue);
|
429 |
+
input.dispatchEvent(new Event("change"));
|
430 |
+
}
|
431 |
+
}
|
432 |
+
}
|
433 |
+
|
434 |
+
// Reset zoom when uploading a new image
|
435 |
+
const fileInput = gradioApp().querySelector(
|
436 |
+
`${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`
|
437 |
+
);
|
438 |
+
fileInput.addEventListener("click", resetZoom);
|
439 |
+
|
440 |
+
// Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
|
441 |
+
function updateZoom(newZoomLevel, mouseX, mouseY) {
|
442 |
+
newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15));
|
443 |
+
|
444 |
+
elemData[elemId].panX +=
|
445 |
+
mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel;
|
446 |
+
elemData[elemId].panY +=
|
447 |
+
mouseY - (mouseY * newZoomLevel) / elemData[elemId].zoomLevel;
|
448 |
+
|
449 |
+
targetElement.style.transformOrigin = "0 0";
|
450 |
+
targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
|
451 |
+
|
452 |
+
toggleOverlap("on");
|
453 |
+
return newZoomLevel;
|
454 |
+
}
|
455 |
+
|
456 |
+
// Change the zoom level based on user interaction
|
457 |
+
function changeZoomLevel(operation, e) {
|
458 |
+
if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) {
|
459 |
+
e.preventDefault();
|
460 |
+
|
461 |
+
let zoomPosX, zoomPosY;
|
462 |
+
let delta = 0.2;
|
463 |
+
if (elemData[elemId].zoomLevel > 7) {
|
464 |
+
delta = 0.9;
|
465 |
+
} else if (elemData[elemId].zoomLevel > 2) {
|
466 |
+
delta = 0.6;
|
467 |
+
}
|
468 |
+
|
469 |
+
zoomPosX = e.clientX;
|
470 |
+
zoomPosY = e.clientY;
|
471 |
+
|
472 |
+
fullScreenMode = false;
|
473 |
+
elemData[elemId].zoomLevel = updateZoom(
|
474 |
+
elemData[elemId].zoomLevel +
|
475 |
+
(operation === "+" ? delta : -delta),
|
476 |
+
zoomPosX - targetElement.getBoundingClientRect().left,
|
477 |
+
zoomPosY - targetElement.getBoundingClientRect().top
|
478 |
+
);
|
479 |
+
}
|
480 |
+
}
|
481 |
+
|
482 |
+
/**
|
483 |
+
* This function fits the target element to the screen by calculating
|
484 |
+
* the required scale and offsets. It also updates the global variables
|
485 |
+
* zoomLevel, panX, and panY to reflect the new state.
|
486 |
+
*/
|
487 |
+
|
488 |
+
function fitToElement() {
|
489 |
+
//Reset Zoom
|
490 |
+
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
491 |
+
|
492 |
+
// Get element and screen dimensions
|
493 |
+
const elementWidth = targetElement.offsetWidth;
|
494 |
+
const elementHeight = targetElement.offsetHeight;
|
495 |
+
const parentElement = targetElement.parentElement;
|
496 |
+
const screenWidth = parentElement.clientWidth;
|
497 |
+
const screenHeight = parentElement.clientHeight;
|
498 |
+
|
499 |
+
// Get element's coordinates relative to the parent element
|
500 |
+
const elementRect = targetElement.getBoundingClientRect();
|
501 |
+
const parentRect = parentElement.getBoundingClientRect();
|
502 |
+
const elementX = elementRect.x - parentRect.x;
|
503 |
+
|
504 |
+
// Calculate scale and offsets
|
505 |
+
const scaleX = screenWidth / elementWidth;
|
506 |
+
const scaleY = screenHeight / elementHeight;
|
507 |
+
const scale = Math.min(scaleX, scaleY);
|
508 |
+
|
509 |
+
const transformOrigin =
|
510 |
+
window.getComputedStyle(targetElement).transformOrigin;
|
511 |
+
const [originX, originY] = transformOrigin.split(" ");
|
512 |
+
const originXValue = parseFloat(originX);
|
513 |
+
const originYValue = parseFloat(originY);
|
514 |
+
|
515 |
+
const offsetX =
|
516 |
+
(screenWidth - elementWidth * scale) / 2 -
|
517 |
+
originXValue * (1 - scale);
|
518 |
+
const offsetY =
|
519 |
+
(screenHeight - elementHeight * scale) / 2.5 -
|
520 |
+
originYValue * (1 - scale);
|
521 |
+
|
522 |
+
// Apply scale and offsets to the element
|
523 |
+
targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
|
524 |
+
|
525 |
+
// Update global variables
|
526 |
+
elemData[elemId].zoomLevel = scale;
|
527 |
+
elemData[elemId].panX = offsetX;
|
528 |
+
elemData[elemId].panY = offsetY;
|
529 |
+
|
530 |
+
fullScreenMode = false;
|
531 |
+
toggleOverlap("off");
|
532 |
+
}
|
533 |
+
|
534 |
+
/**
|
535 |
+
* This function fits the target element to the screen by calculating
|
536 |
+
* the required scale and offsets. It also updates the global variables
|
537 |
+
* zoomLevel, panX, and panY to reflect the new state.
|
538 |
+
*/
|
539 |
+
|
540 |
+
// Fullscreen mode
|
541 |
+
function fitToScreen() {
|
542 |
+
const canvas = gradioApp().querySelector(
|
543 |
+
`${elemId} canvas[key="interface"]`
|
544 |
+
);
|
545 |
+
|
546 |
+
if (!canvas) return;
|
547 |
+
|
548 |
+
if (canvas.offsetWidth > 862) {
|
549 |
+
targetElement.style.width = canvas.offsetWidth + "px";
|
550 |
+
}
|
551 |
+
|
552 |
+
if (fullScreenMode) {
|
553 |
+
resetZoom();
|
554 |
+
fullScreenMode = false;
|
555 |
+
return;
|
556 |
+
}
|
557 |
+
|
558 |
+
//Reset Zoom
|
559 |
+
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
560 |
+
|
561 |
+
// Get scrollbar width to right-align the image
|
562 |
+
const scrollbarWidth =
|
563 |
+
window.innerWidth - document.documentElement.clientWidth;
|
564 |
+
|
565 |
+
// Get element and screen dimensions
|
566 |
+
const elementWidth = targetElement.offsetWidth;
|
567 |
+
const elementHeight = targetElement.offsetHeight;
|
568 |
+
const screenWidth = window.innerWidth - scrollbarWidth;
|
569 |
+
const screenHeight = window.innerHeight;
|
570 |
+
|
571 |
+
// Get element's coordinates relative to the page
|
572 |
+
const elementRect = targetElement.getBoundingClientRect();
|
573 |
+
const elementY = elementRect.y;
|
574 |
+
const elementX = elementRect.x;
|
575 |
+
|
576 |
+
// Calculate scale and offsets
|
577 |
+
const scaleX = screenWidth / elementWidth;
|
578 |
+
const scaleY = screenHeight / elementHeight;
|
579 |
+
const scale = Math.min(scaleX, scaleY);
|
580 |
+
|
581 |
+
// Get the current transformOrigin
|
582 |
+
const computedStyle = window.getComputedStyle(targetElement);
|
583 |
+
const transformOrigin = computedStyle.transformOrigin;
|
584 |
+
const [originX, originY] = transformOrigin.split(" ");
|
585 |
+
const originXValue = parseFloat(originX);
|
586 |
+
const originYValue = parseFloat(originY);
|
587 |
+
|
588 |
+
// Calculate offsets with respect to the transformOrigin
|
589 |
+
const offsetX =
|
590 |
+
(screenWidth - elementWidth * scale) / 2 -
|
591 |
+
elementX -
|
592 |
+
originXValue * (1 - scale);
|
593 |
+
const offsetY =
|
594 |
+
(screenHeight - elementHeight * scale) / 2 -
|
595 |
+
elementY -
|
596 |
+
originYValue * (1 - scale);
|
597 |
+
|
598 |
+
// Apply scale and offsets to the element
|
599 |
+
targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
|
600 |
+
|
601 |
+
// Update global variables
|
602 |
+
elemData[elemId].zoomLevel = scale;
|
603 |
+
elemData[elemId].panX = offsetX;
|
604 |
+
elemData[elemId].panY = offsetY;
|
605 |
+
|
606 |
+
fullScreenMode = true;
|
607 |
+
toggleOverlap("on");
|
608 |
+
}
|
609 |
+
|
610 |
+
// Handle keydown events
|
611 |
+
function handleKeyDown(event) {
|
612 |
+
// Disable key locks to make pasting from the buffer work correctly
|
613 |
+
if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === "F5") {
|
614 |
+
return;
|
615 |
+
}
|
616 |
+
|
617 |
+
// before activating shortcut, ensure user is not actively typing in an input field
|
618 |
+
if (!hotkeysConfig.canvas_blur_prompt) {
|
619 |
+
if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') {
|
620 |
+
return;
|
621 |
+
}
|
622 |
+
}
|
623 |
+
|
624 |
+
|
625 |
+
const hotkeyActions = {
|
626 |
+
[hotkeysConfig.canvas_hotkey_reset]: resetZoom,
|
627 |
+
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
|
628 |
+
[hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen
|
629 |
+
};
|
630 |
+
|
631 |
+
const action = hotkeyActions[event.code];
|
632 |
+
if (action) {
|
633 |
+
event.preventDefault();
|
634 |
+
action(event);
|
635 |
+
}
|
636 |
+
|
637 |
+
if (
|
638 |
+
isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) ||
|
639 |
+
isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust)
|
640 |
+
) {
|
641 |
+
event.preventDefault();
|
642 |
+
}
|
643 |
+
}
|
644 |
+
|
645 |
+
// Get Mouse position
|
646 |
+
function getMousePosition(e) {
|
647 |
+
mouseX = e.offsetX;
|
648 |
+
mouseY = e.offsetY;
|
649 |
+
}
|
650 |
+
|
651 |
+
targetElement.addEventListener("mousemove", getMousePosition);
|
652 |
+
|
653 |
+
// Handle events only inside the targetElement
|
654 |
+
let isKeyDownHandlerAttached = false;
|
655 |
+
|
656 |
+
function handleMouseMove() {
|
657 |
+
if (!isKeyDownHandlerAttached) {
|
658 |
+
document.addEventListener("keydown", handleKeyDown);
|
659 |
+
isKeyDownHandlerAttached = true;
|
660 |
+
|
661 |
+
activeElement = elemId;
|
662 |
+
}
|
663 |
+
}
|
664 |
+
|
665 |
+
function handleMouseLeave() {
|
666 |
+
if (isKeyDownHandlerAttached) {
|
667 |
+
document.removeEventListener("keydown", handleKeyDown);
|
668 |
+
isKeyDownHandlerAttached = false;
|
669 |
+
|
670 |
+
activeElement = null;
|
671 |
+
}
|
672 |
+
}
|
673 |
+
|
674 |
+
// Add mouse event handlers
|
675 |
+
targetElement.addEventListener("mousemove", handleMouseMove);
|
676 |
+
targetElement.addEventListener("mouseleave", handleMouseLeave);
|
677 |
+
|
678 |
+
// Reset zoom when click on another tab
|
679 |
+
elements.img2imgTabs.addEventListener("click", resetZoom);
|
680 |
+
elements.img2imgTabs.addEventListener("click", () => {
|
681 |
+
// targetElement.style.width = "";
|
682 |
+
if (parseInt(targetElement.style.width) > 865) {
|
683 |
+
setTimeout(fitToElement, 0);
|
684 |
+
}
|
685 |
+
});
|
686 |
+
|
687 |
+
targetElement.addEventListener("wheel", e => {
|
688 |
+
// change zoom level
|
689 |
+
const operation = e.deltaY > 0 ? "-" : "+";
|
690 |
+
changeZoomLevel(operation, e);
|
691 |
+
|
692 |
+
// Handle brush size adjustment with ctrl key pressed
|
693 |
+
if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) {
|
694 |
+
e.preventDefault();
|
695 |
+
|
696 |
+
// Increase or decrease brush size based on scroll direction
|
697 |
+
adjustBrushSize(elemId, e.deltaY);
|
698 |
+
}
|
699 |
+
});
|
700 |
+
|
701 |
+
// Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
|
702 |
+
function handleMoveKeyDown(e) {
|
703 |
+
|
704 |
+
// Disable key locks to make pasting from the buffer work correctly
|
705 |
+
if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") {
|
706 |
+
return;
|
707 |
+
}
|
708 |
+
|
709 |
+
// before activating shortcut, ensure user is not actively typing in an input field
|
710 |
+
if (!hotkeysConfig.canvas_blur_prompt) {
|
711 |
+
if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') {
|
712 |
+
return;
|
713 |
+
}
|
714 |
+
}
|
715 |
+
|
716 |
+
|
717 |
+
if (e.code === hotkeysConfig.canvas_hotkey_move) {
|
718 |
+
if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {
|
719 |
+
e.preventDefault();
|
720 |
+
document.activeElement.blur();
|
721 |
+
isMoving = true;
|
722 |
+
}
|
723 |
+
}
|
724 |
+
}
|
725 |
+
|
726 |
+
function handleMoveKeyUp(e) {
|
727 |
+
if (e.code === hotkeysConfig.canvas_hotkey_move) {
|
728 |
+
isMoving = false;
|
729 |
+
}
|
730 |
+
}
|
731 |
+
|
732 |
+
document.addEventListener("keydown", handleMoveKeyDown);
|
733 |
+
document.addEventListener("keyup", handleMoveKeyUp);
|
734 |
+
|
735 |
+
// Detect zoom level and update the pan speed.
|
736 |
+
function updatePanPosition(movementX, movementY) {
|
737 |
+
let panSpeed = 2;
|
738 |
+
|
739 |
+
if (elemData[elemId].zoomLevel > 8) {
|
740 |
+
panSpeed = 3.5;
|
741 |
+
}
|
742 |
+
|
743 |
+
elemData[elemId].panX += movementX * panSpeed;
|
744 |
+
elemData[elemId].panY += movementY * panSpeed;
|
745 |
+
|
746 |
+
// Delayed redraw of an element
|
747 |
+
requestAnimationFrame(() => {
|
748 |
+
targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${elemData[elemId].zoomLevel})`;
|
749 |
+
toggleOverlap("on");
|
750 |
+
});
|
751 |
+
}
|
752 |
+
|
753 |
+
function handleMoveByKey(e) {
|
754 |
+
if (isMoving && elemId === activeElement) {
|
755 |
+
updatePanPosition(e.movementX, e.movementY);
|
756 |
+
targetElement.style.pointerEvents = "none";
|
757 |
+
} else {
|
758 |
+
targetElement.style.pointerEvents = "auto";
|
759 |
+
}
|
760 |
+
}
|
761 |
+
|
762 |
+
// Prevents sticking to the mouse
|
763 |
+
window.onblur = function() {
|
764 |
+
isMoving = false;
|
765 |
+
};
|
766 |
+
|
767 |
+
gradioApp().addEventListener("mousemove", handleMoveByKey);
|
768 |
+
}
|
769 |
+
|
770 |
+
applyZoomAndPan(elementIDs.sketch);
|
771 |
+
applyZoomAndPan(elementIDs.inpaint);
|
772 |
+
applyZoomAndPan(elementIDs.inpaintSketch);
|
773 |
+
|
774 |
+
// Make the function global so that other extensions can take advantage of this solution
|
775 |
+
window.applyZoomAndPan = applyZoomAndPan;
|
776 |
+
});
|
extensions-builtin/canvas-zoom-and-pan/scripts/__pycache__/hotkey_config.cpython-310.pyc
ADDED
Binary file (1.46 kB). View file
|
|
extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from modules import shared
|
3 |
+
|
4 |
+
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
|
5 |
+
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
|
6 |
+
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
|
7 |
+
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
|
8 |
+
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
|
9 |
+
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
10 |
+
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
11 |
+
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
12 |
+
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
13 |
+
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
14 |
+
}))
|
extensions-builtin/canvas-zoom-and-pan/style.css
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.canvas-tooltip-info {
|
2 |
+
position: absolute;
|
3 |
+
top: 10px;
|
4 |
+
left: 10px;
|
5 |
+
cursor: help;
|
6 |
+
background-color: rgba(0, 0, 0, 0.3);
|
7 |
+
width: 20px;
|
8 |
+
height: 20px;
|
9 |
+
border-radius: 50%;
|
10 |
+
display: flex;
|
11 |
+
align-items: center;
|
12 |
+
justify-content: center;
|
13 |
+
flex-direction: column;
|
14 |
+
|
15 |
+
z-index: 100;
|
16 |
+
}
|
17 |
+
|
18 |
+
.canvas-tooltip-info::after {
|
19 |
+
content: '';
|
20 |
+
display: block;
|
21 |
+
width: 2px;
|
22 |
+
height: 7px;
|
23 |
+
background-color: white;
|
24 |
+
margin-top: 2px;
|
25 |
+
}
|
26 |
+
|
27 |
+
.canvas-tooltip-info::before {
|
28 |
+
content: '';
|
29 |
+
display: block;
|
30 |
+
width: 2px;
|
31 |
+
height: 2px;
|
32 |
+
background-color: white;
|
33 |
+
}
|
34 |
+
|
35 |
+
.canvas-tooltip-content {
|
36 |
+
display: none;
|
37 |
+
background-color: #f9f9f9;
|
38 |
+
color: #333;
|
39 |
+
border: 1px solid #ddd;
|
40 |
+
padding: 15px;
|
41 |
+
position: absolute;
|
42 |
+
top: 40px;
|
43 |
+
left: 10px;
|
44 |
+
width: 250px;
|
45 |
+
font-size: 16px;
|
46 |
+
opacity: 0;
|
47 |
+
border-radius: 8px;
|
48 |
+
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
|
49 |
+
|
50 |
+
z-index: 100;
|
51 |
+
}
|
52 |
+
|
53 |
+
.canvas-tooltip:hover .canvas-tooltip-content {
|
54 |
+
display: block;
|
55 |
+
animation: fadeIn 0.5s;
|
56 |
+
opacity: 1;
|
57 |
+
}
|
58 |
+
|
59 |
+
@keyframes fadeIn {
|
60 |
+
from {opacity: 0;}
|
61 |
+
to {opacity: 1;}
|
62 |
+
}
|
63 |
+
|
extensions-builtin/extra-options-section/scripts/__pycache__/extra_options_section.cpython-310.pyc
ADDED
Binary file (2.84 kB). View file
|
|
extensions-builtin/extra-options-section/scripts/extra_options_section.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from modules import scripts, shared, ui_components, ui_settings
|
3 |
+
from modules.ui_components import FormColumn
|
4 |
+
|
5 |
+
|
6 |
+
class ExtraOptionsSection(scripts.Script):
|
7 |
+
section = "extra_options"
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
self.comps = None
|
11 |
+
self.setting_names = None
|
12 |
+
|
13 |
+
def title(self):
|
14 |
+
return "Extra options"
|
15 |
+
|
16 |
+
def show(self, is_img2img):
|
17 |
+
return scripts.AlwaysVisible
|
18 |
+
|
19 |
+
def ui(self, is_img2img):
|
20 |
+
self.comps = []
|
21 |
+
self.setting_names = []
|
22 |
+
|
23 |
+
with gr.Blocks() as interface:
|
24 |
+
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
|
25 |
+
for setting_name in shared.opts.extra_options:
|
26 |
+
with FormColumn():
|
27 |
+
comp = ui_settings.create_setting_component(setting_name)
|
28 |
+
|
29 |
+
self.comps.append(comp)
|
30 |
+
self.setting_names.append(setting_name)
|
31 |
+
|
32 |
+
def get_settings_values():
|
33 |
+
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
34 |
+
|
35 |
+
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
|
36 |
+
|
37 |
+
return self.comps
|
38 |
+
|
39 |
+
def before_process(self, p, *args):
|
40 |
+
for name, value in zip(self.setting_names, args):
|
41 |
+
if name not in p.override_settings:
|
42 |
+
p.override_settings[name] = value
|
43 |
+
|
44 |
+
|
45 |
+
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
46 |
+
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
|
47 |
+
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
|
48 |
+
}))
|
extensions-builtin/mobile/javascript/mobile.js
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
var isSetupForMobile = false;
|
2 |
+
|
3 |
+
function isMobile() {
|
4 |
+
for (var tab of ["txt2img", "img2img"]) {
|
5 |
+
var imageTab = gradioApp().getElementById(tab + '_results');
|
6 |
+
if (imageTab && imageTab.offsetParent && imageTab.offsetLeft == 0) {
|
7 |
+
return true;
|
8 |
+
}
|
9 |
+
}
|
10 |
+
|
11 |
+
return false;
|
12 |
+
}
|
13 |
+
|
14 |
+
function reportWindowSize() {
|
15 |
+
var currentlyMobile = isMobile();
|
16 |
+
if (currentlyMobile == isSetupForMobile) return;
|
17 |
+
isSetupForMobile = currentlyMobile;
|
18 |
+
|
19 |
+
for (var tab of ["txt2img", "img2img"]) {
|
20 |
+
var button = gradioApp().getElementById(tab + '_generate_box');
|
21 |
+
var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
|
22 |
+
target.insertBefore(button, target.firstElementChild);
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
window.addEventListener("resize", reportWindowSize);
|
extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Stable Diffusion WebUI - Bracket checker
|
2 |
+
// By Hingashi no Florin/Bwin4L & @akx
|
3 |
+
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
4 |
+
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
5 |
+
|
6 |
+
function checkBrackets(textArea, counterElt) {
|
7 |
+
var counts = {};
|
8 |
+
(textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
|
9 |
+
counts[bracket] = (counts[bracket] || 0) + 1;
|
10 |
+
});
|
11 |
+
var errors = [];
|
12 |
+
|
13 |
+
function checkPair(open, close, kind) {
|
14 |
+
if (counts[open] !== counts[close]) {
|
15 |
+
errors.push(
|
16 |
+
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
|
17 |
+
);
|
18 |
+
}
|
19 |
+
}
|
20 |
+
|
21 |
+
checkPair('(', ')', 'round brackets');
|
22 |
+
checkPair('[', ']', 'square brackets');
|
23 |
+
checkPair('{', '}', 'curly brackets');
|
24 |
+
counterElt.title = errors.join('\n');
|
25 |
+
counterElt.classList.toggle('error', errors.length !== 0);
|
26 |
+
}
|
27 |
+
|
28 |
+
function setupBracketChecking(id_prompt, id_counter) {
|
29 |
+
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
30 |
+
var counter = gradioApp().getElementById(id_counter);
|
31 |
+
|
32 |
+
if (textarea && counter) {
|
33 |
+
textarea.addEventListener("input", () => checkBrackets(textarea, counter));
|
34 |
+
}
|
35 |
+
}
|
36 |
+
|
37 |
+
onUiLoaded(function() {
|
38 |
+
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
|
39 |
+
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
|
40 |
+
setupBracketChecking('img2img_prompt', 'img2img_token_counter');
|
41 |
+
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
|
42 |
+
});
|
extensions-builtin/sd_theme_editor/install.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
import launch
|
extensions-builtin/sd_theme_editor/javascript/ui_theme.js
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
function hexToRgb(color) {
|
2 |
+
let hex = color[0] === "#" ? color.slice(1) : color;
|
3 |
+
let c;
|
4 |
+
|
5 |
+
// expand the short hex by doubling each character, fc0 -> ffcc00
|
6 |
+
if (hex.length !== 6) {
|
7 |
+
hex = (() => {
|
8 |
+
const result = [];
|
9 |
+
for (c of Array.from(hex)) {
|
10 |
+
result.push(`${c}${c}`);
|
11 |
+
}
|
12 |
+
return result;
|
13 |
+
})().join("");
|
14 |
+
}
|
15 |
+
const colorStr = hex.match(/#?(.{2})(.{2})(.{2})/).slice(1);
|
16 |
+
const rgb = colorStr.map((col) => parseInt(col, 16));
|
17 |
+
rgb.push(1);
|
18 |
+
return rgb;
|
19 |
+
}
|
20 |
+
|
21 |
+
function rgbToHsl(rgb) {
|
22 |
+
const r = rgb[0] / 255;
|
23 |
+
const g = rgb[1] / 255;
|
24 |
+
const b = rgb[2] / 255;
|
25 |
+
|
26 |
+
const max = Math.max(r, g, b);
|
27 |
+
const min = Math.min(r, g, b);
|
28 |
+
const diff = max - min;
|
29 |
+
const add = max + min;
|
30 |
+
|
31 |
+
const hue =
|
32 |
+
min === max
|
33 |
+
? 0
|
34 |
+
: r === max
|
35 |
+
? ((60 * (g - b)) / diff + 360) % 360
|
36 |
+
: g === max
|
37 |
+
? (60 * (b - r)) / diff + 120
|
38 |
+
: (60 * (r - g)) / diff + 240;
|
39 |
+
|
40 |
+
const lum = 0.5 * add;
|
41 |
+
|
42 |
+
const sat =
|
43 |
+
lum === 0 ? 0 : lum === 1 ? 1 : lum <= 0.5 ? diff / add : diff / (2 - add);
|
44 |
+
|
45 |
+
const h = Math.round(hue);
|
46 |
+
const s = Math.round(sat * 100);
|
47 |
+
const l = Math.round(lum * 100);
|
48 |
+
const a = rgb[3] || 1;
|
49 |
+
|
50 |
+
return [h, s, l, a];
|
51 |
+
}
|
52 |
+
|
53 |
+
function hexToHsl(color) {
|
54 |
+
const rgb = hexToRgb(color);
|
55 |
+
const hsl = rgbToHsl(rgb);
|
56 |
+
return "hsl(" + hsl[0] + "deg " + hsl[1] + "% " + hsl[2] + "%)";
|
57 |
+
}
|
58 |
+
|
59 |
+
function hslToHex(h, s, l) {
|
60 |
+
l /= 100;
|
61 |
+
const a = (s * Math.min(l, 1 - l)) / 100;
|
62 |
+
const f = (n) => {
|
63 |
+
const k = (n + h / 30) % 12;
|
64 |
+
const color = l - a * Math.max(Math.min(k - 3, 9 - k, 1), -1);
|
65 |
+
return Math.round(255 * Math.max(0, Math.min(color, 1)))
|
66 |
+
.toString(16)
|
67 |
+
.padStart(2, "0"); // convert to Hex and prefix "0" if needed
|
68 |
+
};
|
69 |
+
return `#${f(0)}${f(8)}${f(4)}`;
|
70 |
+
}
|
71 |
+
|
72 |
+
function hsl2rgb(h, s, l) {
|
73 |
+
let a = s * Math.min(l, 1 - l);
|
74 |
+
let f = (n, k = (n + h / 30) % 12) =>
|
75 |
+
l - a * Math.max(Math.min(k - 3, 9 - k, 1), -1);
|
76 |
+
return [f(0), f(8), f(4)];
|
77 |
+
}
|
78 |
+
|
79 |
+
function invertColor(hex) {
|
80 |
+
if (hex.indexOf("#") === 0) {
|
81 |
+
hex = hex.slice(1);
|
82 |
+
}
|
83 |
+
// convert 3-digit hex to 6-digits.
|
84 |
+
if (hex.length === 3) {
|
85 |
+
hex = hex[0] + hex[0] + hex[1] + hex[1] + hex[2] + hex[2];
|
86 |
+
}
|
87 |
+
if (hex.length !== 6) {
|
88 |
+
throw new Error("Invalid HEX color.");
|
89 |
+
}
|
90 |
+
// invert color components
|
91 |
+
var r = (255 - parseInt(hex.slice(0, 2), 16)).toString(16),
|
92 |
+
g = (255 - parseInt(hex.slice(2, 4), 16)).toString(16),
|
93 |
+
b = (255 - parseInt(hex.slice(4, 6), 16)).toString(16);
|
94 |
+
// pad each with zeros and return
|
95 |
+
return "#" + padZero(r) + padZero(g) + padZero(b);
|
96 |
+
}
|
97 |
+
|
98 |
+
function padZero(str, len) {
|
99 |
+
len = len || 2;
|
100 |
+
var zeros = new Array(len).join("0");
|
101 |
+
return (zeros + str).slice(-len);
|
102 |
+
}
|
103 |
+
|
104 |
+
function getValsWrappedIn(str, c1, c2) {
|
105 |
+
var rg = new RegExp("(?<=\\" + c1 + ")(.*?)(?=\\" + c2 + ")", "g");
|
106 |
+
return str.match(rg);
|
107 |
+
}
|
108 |
+
|
109 |
+
let styleobj = {};
|
110 |
+
let hslobj = {};
|
111 |
+
let isColorsInv;
|
112 |
+
|
113 |
+
const toHSLArray = (hslStr) => hslStr.match(/\d+/g).map(Number);
|
114 |
+
|
115 |
+
function offsetColorsHSV(ohsl) {
|
116 |
+
let inner_styles = "";
|
117 |
+
|
118 |
+
for (const key in styleobj) {
|
119 |
+
let keyVal = styleobj[key];
|
120 |
+
|
121 |
+
if (keyVal.indexOf("#") != -1 || keyVal.indexOf("hsl") != -1) {
|
122 |
+
let colcomp = gradioApp().querySelector("#" + key + " input");
|
123 |
+
if (colcomp) {
|
124 |
+
let hsl;
|
125 |
+
|
126 |
+
if (keyVal.indexOf("#") != -1) {
|
127 |
+
keyVal = keyVal.replace(/\s+/g, "");
|
128 |
+
//inv ? keyVal = invertColor(keyVal) : 0;
|
129 |
+
if (isColorsInv) {
|
130 |
+
keyVal = invertColor(keyVal);
|
131 |
+
styleobj[key] = keyVal;
|
132 |
+
}
|
133 |
+
hsl = rgbToHsl(hexToRgb(keyVal));
|
134 |
+
} else {
|
135 |
+
if (isColorsInv) {
|
136 |
+
let c = toHSLArray(keyVal);
|
137 |
+
let hex = hslToHex(c[0], c[1], c[2]);
|
138 |
+
keyVal = invertColor(hex);
|
139 |
+
styleobj[key] = keyVal;
|
140 |
+
hsl = rgbToHsl(hexToRgb(keyVal));
|
141 |
+
} else {
|
142 |
+
hsl = toHSLArray(keyVal);
|
143 |
+
}
|
144 |
+
}
|
145 |
+
|
146 |
+
let h = (parseInt(hsl[0]) + parseInt(ohsl[0])) % 360;
|
147 |
+
let s = parseInt(hsl[1]) + parseInt(ohsl[1]);
|
148 |
+
let l = parseInt(hsl[2]) + parseInt(ohsl[2]);
|
149 |
+
|
150 |
+
let hex = hslToHex(
|
151 |
+
h,
|
152 |
+
Math.min(Math.max(s, 0), 100),
|
153 |
+
Math.min(Math.max(l, 0), 100)
|
154 |
+
);
|
155 |
+
|
156 |
+
colcomp.value = hex;
|
157 |
+
|
158 |
+
hslobj[key] = "hsl(" + h + "deg " + s + "% " + l + "%)";
|
159 |
+
inner_styles += key + ":" + hslobj[key] + ";";
|
160 |
+
}
|
161 |
+
} else {
|
162 |
+
inner_styles += key + ":" + styleobj[key] + ";";
|
163 |
+
}
|
164 |
+
}
|
165 |
+
|
166 |
+
isColorsInv = false;
|
167 |
+
|
168 |
+
const preview_styles = gradioApp().querySelector("#preview-styles");
|
169 |
+
preview_styles.innerHTML = ":root {" + inner_styles + "}";
|
170 |
+
preview_styles.innerHTML +=
|
171 |
+
"@media only screen and (max-width: 860px) {:root{--ae-outside-gap-size: var(--ae-mobile-outside-gap-size);--ae-inside-padding-size: var(--ae-mobile-inside-padding-size);}}";
|
172 |
+
|
173 |
+
const vars_textarea = gradioApp().querySelector("#theme_vars textarea");
|
174 |
+
vars_textarea.value = inner_styles;
|
175 |
+
|
176 |
+
const inputEvent = new Event("input");
|
177 |
+
Object.defineProperty(inputEvent, "target", { value: vars_textarea });
|
178 |
+
vars_textarea.dispatchEvent(inputEvent);
|
179 |
+
}
|
180 |
+
|
181 |
+
function updateTheme(vars) {
|
182 |
+
let inner_styles = "";
|
183 |
+
|
184 |
+
for (let i = 0; i < vars.length - 1; i++) {
|
185 |
+
let key = vars[i].split(":");
|
186 |
+
let id = key[0].replace(/\s+/g, "");
|
187 |
+
let val = key[1].trim();
|
188 |
+
|
189 |
+
styleobj[id] = val;
|
190 |
+
inner_styles += id + ":" + val + ";";
|
191 |
+
|
192 |
+
gradioApp()
|
193 |
+
.querySelectorAll("#" + id + " input")
|
194 |
+
.forEach((elem) => {
|
195 |
+
if (val.indexOf("hsl") != -1) {
|
196 |
+
let hsl = toHSLArray(val);
|
197 |
+
let hex = hslToHex(hsl[0], hsl[1], hsl[2]);
|
198 |
+
elem.value = hex;
|
199 |
+
} else {
|
200 |
+
elem.value = val.split("px")[0];
|
201 |
+
}
|
202 |
+
});
|
203 |
+
}
|
204 |
+
|
205 |
+
const preview_styles = gradioApp().querySelector("#preview-styles");
|
206 |
+
|
207 |
+
if (preview_styles) {
|
208 |
+
preview_styles.innerHTML = ":root {" + inner_styles + "}";
|
209 |
+
preview_styles.innerHTML +=
|
210 |
+
"@media only screen and (max-width: 860px) {:root{--ae-outside-gap-size: var(--ae-mobile-outside-gap-size);--ae-inside-padding-size: var(--ae-mobile-inside-padding-size);}}";
|
211 |
+
} else {
|
212 |
+
const r = gradioApp();
|
213 |
+
const style = document.createElement("style");
|
214 |
+
style.id = "preview-styles";
|
215 |
+
style.innerHTML = ":root {" + inner_styles + "}";
|
216 |
+
style.innerHTML +=
|
217 |
+
"@media only screen and (max-width: 860px) {:root{--ae-outside-gap-size: var(--ae-mobile-outside-gap-size);--ae-inside-padding-size: var(--ae-mobile-inside-padding-size);}}";
|
218 |
+
r.appendChild(style);
|
219 |
+
}
|
220 |
+
|
221 |
+
const vars_textarea = gradioApp().querySelector("#theme_vars textarea");
|
222 |
+
const css_textarea = gradioApp().querySelector("#theme_css textarea");
|
223 |
+
|
224 |
+
vars_textarea.value = inner_styles;
|
225 |
+
css_textarea.value = css_textarea.value;
|
226 |
+
|
227 |
+
//console.log(Object);
|
228 |
+
|
229 |
+
const vEvent = new Event("input");
|
230 |
+
const cEvent = new Event("input");
|
231 |
+
Object.defineProperty(vEvent, "target", { value: vars_textarea });
|
232 |
+
Object.defineProperty(cEvent, "target", { value: css_textarea });
|
233 |
+
vars_textarea.dispatchEvent(vEvent);
|
234 |
+
css_textarea.dispatchEvent(cEvent);
|
235 |
+
}
|
236 |
+
|
237 |
+
function applyTheme() {
|
238 |
+
console.log("apply");
|
239 |
+
}
|
240 |
+
|
241 |
+
function initTheme() {
|
242 |
+
const current_style = gradioApp().querySelector(".gradio-container > style");
|
243 |
+
//console.log(current_style);
|
244 |
+
//const head = document.head;
|
245 |
+
//head.appendChild(current_style);
|
246 |
+
|
247 |
+
const css_styles = current_style.innerHTML.split(
|
248 |
+
"/*BREAKPOINT_CSS_CONTENT*/"
|
249 |
+
);
|
250 |
+
let init_css_vars = css_styles[0].split("}")[0].split("{")[1];
|
251 |
+
init_css_vars = init_css_vars.replace(/\n|\r/g, "");
|
252 |
+
|
253 |
+
let init_vars = init_css_vars.split(";");
|
254 |
+
let vars = init_vars;
|
255 |
+
|
256 |
+
//console.log(vars);
|
257 |
+
|
258 |
+
const vars_textarea = gradioApp().querySelector("#theme_vars textarea");
|
259 |
+
const css_textarea = gradioApp().querySelector("#theme_css textarea");
|
260 |
+
//const result_textarea = gradioApp().querySelector('#theme_result textarea');
|
261 |
+
vars_textarea.value = init_css_vars;
|
262 |
+
css_textarea.value =
|
263 |
+
"/*BREAKPOINT_CSS_CONTENT*/" + css_styles[1] + "/*BREAKPOINT_CSS_CONTENT*/";
|
264 |
+
|
265 |
+
updateTheme(vars);
|
266 |
+
|
267 |
+
//vars_textarea.addEventListener("change", function(e) {
|
268 |
+
//e.preventDefault();
|
269 |
+
//e.stopPropagation();
|
270 |
+
//vars = vars_textarea.value.split(";");
|
271 |
+
//console.log(e);
|
272 |
+
//updateTheme(vars);
|
273 |
+
|
274 |
+
//})
|
275 |
+
|
276 |
+
const preview_styles = gradioApp().querySelector("#preview-styles");
|
277 |
+
let intervalChange;
|
278 |
+
|
279 |
+
gradioApp()
|
280 |
+
.querySelectorAll("#ui_theme_settings input")
|
281 |
+
.forEach((elem) => {
|
282 |
+
elem.addEventListener("input", function (e) {
|
283 |
+
let celem = e.currentTarget;
|
284 |
+
let val = e.currentTarget.value;
|
285 |
+
let curr_val;
|
286 |
+
|
287 |
+
switch (e.currentTarget.type) {
|
288 |
+
case "range":
|
289 |
+
celem = celem.parentElement;
|
290 |
+
val = e.currentTarget.value + "px";
|
291 |
+
break;
|
292 |
+
case "color":
|
293 |
+
celem = celem.parentElement.parentElement;
|
294 |
+
val = e.currentTarget.value;
|
295 |
+
break;
|
296 |
+
case "number":
|
297 |
+
celem = celem.parentElement.parentElement.parentElement;
|
298 |
+
val = e.currentTarget.value + "px";
|
299 |
+
break;
|
300 |
+
}
|
301 |
+
|
302 |
+
styleobj[celem.id] = val;
|
303 |
+
|
304 |
+
//console.log(styleobj);
|
305 |
+
|
306 |
+
if (intervalChange != null) clearInterval(intervalChange);
|
307 |
+
intervalChange = setTimeout(() => {
|
308 |
+
let inner_styles = "";
|
309 |
+
|
310 |
+
for (const key in styleobj) {
|
311 |
+
inner_styles += key + ":" + styleobj[key] + ";";
|
312 |
+
}
|
313 |
+
|
314 |
+
vars = inner_styles.split(";");
|
315 |
+
preview_styles.innerHTML = ":root {" + inner_styles + "}";
|
316 |
+
preview_styles.innerHTML +=
|
317 |
+
"@media only screen and (max-width: 860px) {:root{--ae-outside-gap-size: var(--ae-mobile-outside-gap-size);--ae-inside-padding-size: var(--ae-mobile-inside-padding-size);}}";
|
318 |
+
|
319 |
+
vars_textarea.value = inner_styles;
|
320 |
+
const vEvent = new Event("input");
|
321 |
+
Object.defineProperty(vEvent, "target", { value: vars_textarea });
|
322 |
+
vars_textarea.dispatchEvent(vEvent);
|
323 |
+
|
324 |
+
offsetColorsHSV(hsloffset);
|
325 |
+
}, 1000);
|
326 |
+
});
|
327 |
+
});
|
328 |
+
|
329 |
+
const reset_btn = gradioApp().getElementById("theme_reset_btn");
|
330 |
+
reset_btn.addEventListener("click", function (e) {
|
331 |
+
e.preventDefault();
|
332 |
+
e.stopPropagation();
|
333 |
+
gradioApp()
|
334 |
+
.querySelectorAll("#ui_theme_hsv input")
|
335 |
+
.forEach((elem) => {
|
336 |
+
elem.value = 0;
|
337 |
+
});
|
338 |
+
hsloffset = [0, 0, 0];
|
339 |
+
updateTheme(init_vars);
|
340 |
+
});
|
341 |
+
|
342 |
+
/*
|
343 |
+
const apply_btn = gradioApp().getElementById('theme_apply_btn');
|
344 |
+
apply_btn.addEventListener("click", function(e) {
|
345 |
+
e.preventDefault();
|
346 |
+
e.stopPropagation();
|
347 |
+
init_css_vars = vars_textarea.value.replace(/\n|\r/g, "");
|
348 |
+
vars_textarea.value = init_css_vars;
|
349 |
+
|
350 |
+
init_vars = init_css_vars.split(";");
|
351 |
+
vars = init_vars;
|
352 |
+
updateTheme(vars);
|
353 |
+
})
|
354 |
+
*/
|
355 |
+
|
356 |
+
let intervalCheck;
|
357 |
+
function dropDownOnChange() {
|
358 |
+
if (init_css_vars != vars_textarea.value) {
|
359 |
+
clearInterval(intervalCheck);
|
360 |
+
init_css_vars = vars_textarea.value.replace(/\n|\r/g, "");
|
361 |
+
vars_textarea.value = init_css_vars;
|
362 |
+
init_vars = init_css_vars.split(";");
|
363 |
+
vars = init_vars;
|
364 |
+
updateTheme(vars);
|
365 |
+
}
|
366 |
+
}
|
367 |
+
|
368 |
+
const drop_down = gradioApp().querySelector("#themes_drop_down");
|
369 |
+
drop_down.addEventListener("click", function (e) {
|
370 |
+
if (intervalCheck != null) clearInterval(intervalCheck);
|
371 |
+
intervalCheck = setInterval(dropDownOnChange, 100);
|
372 |
+
//console.log("ok");
|
373 |
+
});
|
374 |
+
|
375 |
+
let hsloffset = [0, 0, 0];
|
376 |
+
|
377 |
+
const hue = gradioApp()
|
378 |
+
.querySelectorAll("#theme_hue input")
|
379 |
+
.forEach((elem) => {
|
380 |
+
elem.addEventListener("change", function (e) {
|
381 |
+
e.preventDefault();
|
382 |
+
e.stopPropagation();
|
383 |
+
hsloffset[0] = e.currentTarget.value;
|
384 |
+
offsetColorsHSV(hsloffset);
|
385 |
+
});
|
386 |
+
});
|
387 |
+
|
388 |
+
const sat = gradioApp()
|
389 |
+
.querySelectorAll("#theme_sat input")
|
390 |
+
.forEach((elem) => {
|
391 |
+
elem.addEventListener("change", function (e) {
|
392 |
+
e.preventDefault();
|
393 |
+
e.stopPropagation();
|
394 |
+
hsloffset[1] = e.currentTarget.value;
|
395 |
+
offsetColorsHSV(hsloffset);
|
396 |
+
});
|
397 |
+
});
|
398 |
+
|
399 |
+
const brt = gradioApp()
|
400 |
+
.querySelectorAll("#theme_brt input")
|
401 |
+
.forEach((elem) => {
|
402 |
+
elem.addEventListener("change", function (e) {
|
403 |
+
e.preventDefault();
|
404 |
+
e.stopPropagation();
|
405 |
+
hsloffset[2] = e.currentTarget.value;
|
406 |
+
offsetColorsHSV(hsloffset);
|
407 |
+
});
|
408 |
+
});
|
409 |
+
|
410 |
+
const inv_btn = gradioApp().getElementById("theme_invert_btn");
|
411 |
+
inv_btn.addEventListener("click", function (e) {
|
412 |
+
e.preventDefault();
|
413 |
+
e.stopPropagation();
|
414 |
+
isColorsInv = !isColorsInv;
|
415 |
+
offsetColorsHSV(hsloffset);
|
416 |
+
});
|
417 |
+
}
|
418 |
+
|
419 |
+
function observeGradioApp() {
|
420 |
+
const observer = new MutationObserver(() => {
|
421 |
+
const block = gradioApp().getElementById("tab_ui_theme");
|
422 |
+
if (block) {
|
423 |
+
observer.disconnect();
|
424 |
+
|
425 |
+
setTimeout(() => {
|
426 |
+
initTheme();
|
427 |
+
}, "500");
|
428 |
+
}
|
429 |
+
});
|
430 |
+
observer.observe(gradioApp(), { childList: true, subtree: true });
|
431 |
+
}
|
432 |
+
|
433 |
+
document.addEventListener("DOMContentLoaded", () => {
|
434 |
+
observeGradioApp();
|
435 |
+
});
|
extensions-builtin/sd_theme_editor/scripts/__pycache__/ui_theme.cpython-310.pyc
ADDED
Binary file (6.32 kB). View file
|
|
extensions-builtin/sd_theme_editor/scripts/ui_theme.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from pathlib import Path
|
4 |
+
import gradio as gr
|
5 |
+
import modules.scripts as scripts
|
6 |
+
from modules import script_callbacks, shared
|
7 |
+
|
8 |
+
basedir = scripts.basedir()
|
9 |
+
webui_dir = Path(basedir).parents[1]
|
10 |
+
|
11 |
+
themes_folder = os.path.join(basedir, "themes")
|
12 |
+
javascript_folder = os.path.join(basedir, "javascript")
|
13 |
+
webui_style_path = os.path.join(webui_dir, "style.css")
|
14 |
+
|
15 |
+
def get_files(folder, file_filter=[], file_list=[], split=False):
|
16 |
+
file_list = [file_name if not split else os.path.splitext(file_name)[0] for file_name in os.listdir(folder) if os.path.isfile(os.path.join(folder, file_name)) and file_name not in file_filter]
|
17 |
+
return file_list
|
18 |
+
|
19 |
+
|
20 |
+
def on_ui_tabs():
|
21 |
+
|
22 |
+
with gr.Blocks(analytics_enabled=False) as ui_theme:
|
23 |
+
with gr.Row():
|
24 |
+
with gr.Column():
|
25 |
+
with gr.Row():
|
26 |
+
themes_dropdown = gr.Dropdown(label="Themes", elem_id="themes_drop_down", interactive=True, choices=get_files(themes_folder,[".css, .txt"]), type="value")
|
27 |
+
save_as_filename = gr.Text(label="Save / Save as")
|
28 |
+
with gr.Row():
|
29 |
+
reset_button = gr.Button(elem_id="theme_reset_btn", value="Reset", variant="primary")
|
30 |
+
#apply_button = gr.Button(elem_id="theme_apply_btn", value="Apply", variant="primary")
|
31 |
+
save_button = gr.Button(value="Save", variant="primary")
|
32 |
+
#delete_button = gr.Button(value="Delete", variant="primary")
|
33 |
+
|
34 |
+
#with gr.Accordion(label="Debug View", open=True):
|
35 |
+
with gr.Row(elem_id="theme_hidden"):
|
36 |
+
vars_text = gr.Textbox(label="Vars", elem_id="theme_vars", show_label=True, lines=7, interactive=False, visible=True)
|
37 |
+
css_text = gr.Textbox(label="Css", elem_id="theme_css", show_label=True, lines=7, interactive=False, visible=True)
|
38 |
+
#result_text = gr.Text(elem_id="theme_result", interactive=False, visible=False)
|
39 |
+
with gr.Column(elem_id="theme_overflow_container"):
|
40 |
+
with gr.Accordion(label="Theme Color adjustments", open=True):
|
41 |
+
with gr.Row():
|
42 |
+
with gr.Column(scale=6, elem_id="ui_theme_hsv"):
|
43 |
+
gr.Slider(elem_id="theme_hue", label='Hue', minimum=0, maximum=360, step=1)
|
44 |
+
gr.Slider(elem_id="theme_sat", label='Saturation', minimum=-100, maximum=100, step=1, value=0, interactive=True)
|
45 |
+
gr.Slider(elem_id="theme_brt", label='Lightness', minimum=-50, maximum=50, step=1, value=0, interactive=True)
|
46 |
+
|
47 |
+
gr.Button(elem_id="theme_invert_btn", value="Invert", variant="primary")
|
48 |
+
|
49 |
+
|
50 |
+
with gr.Row(elem_id="ui_theme_settings"):
|
51 |
+
with gr.Column():
|
52 |
+
with gr.Column():
|
53 |
+
with gr.Accordion(label="Main", open=True):
|
54 |
+
gr.ColorPicker(elem_id="--ae-main-bg-color", interactive=True, label="Background color")
|
55 |
+
gr.ColorPicker(elem_id="--ae-primary-color", label="Primary color")
|
56 |
+
|
57 |
+
with gr.Accordion(label="Focus", open=True):
|
58 |
+
gr.ColorPicker(elem_id="--ae-textarea-focus-color", label="Textarea color")
|
59 |
+
gr.ColorPicker(elem_id="--ae-input-focus-color", label="Input color")
|
60 |
+
|
61 |
+
with gr.Accordion(label="Spacing", open=True):
|
62 |
+
gr.Slider(elem_id="--ae-outside-gap-size", label='Gap size', minimum=1, maximum=16, step=1, interactive=True)
|
63 |
+
gr.Slider(elem_id="--ae-inside-padding-size", label='Padding size', minimum=1, maximum=16, step=1, interactive=True)
|
64 |
+
|
65 |
+
with gr.Accordion(label="Spacing (Mobile)", open=True):
|
66 |
+
gr.Slider(elem_id="--ae-mobile-outside-gap-size", label='Mobile Gap size', minimum=1, maximum=16, step=1, interactive=True)
|
67 |
+
gr.Slider(elem_id="--ae-mobile-inside-padding-size", label='Mobile Padding size', minimum=1, maximum=16, step=1, interactive=True)
|
68 |
+
|
69 |
+
with gr.Accordion(label="Panel", open=True):
|
70 |
+
gr.ColorPicker(elem_id="--ae-label-color", label="Label color")
|
71 |
+
gr.ColorPicker(elem_id="--ae-frame-bg-color", label="Frame Background color")
|
72 |
+
gr.ColorPicker(elem_id="--ae-panel-bg-color", label="Background color")
|
73 |
+
gr.ColorPicker(elem_id="--ae-panel-border-color", label="Border color")
|
74 |
+
gr.Slider(elem_id="--ae-panel-border-radius", label='Border radius', minimum=0, maximum=16, step=1)
|
75 |
+
|
76 |
+
gr.ColorPicker(elem_id="--ae-input-color", label="Input text color")
|
77 |
+
gr.ColorPicker(elem_id="--ae-input-bg-color", label="Input background color")
|
78 |
+
gr.ColorPicker(elem_id="--ae-input-border-color", label="Input border color")
|
79 |
+
with gr.Column():
|
80 |
+
with gr.Row(elem_id="theme_sub-panel"):
|
81 |
+
|
82 |
+
with gr.Accordion(label="SubPanel", open=True):
|
83 |
+
gr.ColorPicker(elem_id="--ae-subgroup-bg-color", label="Subgoup background color")
|
84 |
+
#gr.ColorPicker(elem_id="--ae-subgroup-label-color", label="Label color", value="#000000")
|
85 |
+
gr.ColorPicker(elem_id="--ae-subpanel-bg-color", label="Background color")
|
86 |
+
gr.ColorPicker(elem_id="--ae-subpanel-border-color", label="Border color")
|
87 |
+
gr.Slider(elem_id="--ae-subpanel-border-radius", label='Border radius', minimum=0, maximum=16, step=1)
|
88 |
+
|
89 |
+
gr.ColorPicker(elem_id="--ae-subgroup-input-color", label="Input text color")
|
90 |
+
gr.ColorPicker(elem_id="--ae-subgroup-input-bg-color", label="Input background color")
|
91 |
+
gr.ColorPicker(elem_id="--ae-subgroup-input-border-color", label="Input border color")
|
92 |
+
|
93 |
+
with gr.Row():
|
94 |
+
with gr.Column():
|
95 |
+
with gr.Accordion(label="Navigation menu", open=True):
|
96 |
+
gr.ColorPicker(elem_id="--ae-nav-bg-color", label="Background color")
|
97 |
+
gr.ColorPicker(elem_id="--ae-nav-color", label="Text color")
|
98 |
+
gr.ColorPicker(elem_id="--ae-nav-hover-color", label="Hover color")
|
99 |
+
|
100 |
+
with gr.Accordion(label="Icon", open=True):
|
101 |
+
gr.ColorPicker(elem_id="--ae-icon-color", label="Color")
|
102 |
+
gr.ColorPicker(elem_id="--ae-icon-hover-color", label="Hover color")
|
103 |
+
|
104 |
+
with gr.Accordion(label="Other", open=True):
|
105 |
+
gr.ColorPicker(elem_id="--ae-text-color", label="Text color")
|
106 |
+
gr.ColorPicker(elem_id="--ae-placeholder-color", label="Placeholder color")
|
107 |
+
gr.ColorPicker(elem_id="--ae-cancel-color", label="Cancel/Interrupt color")
|
108 |
+
|
109 |
+
with gr.Accordion(label="Modal", open=True):
|
110 |
+
gr.ColorPicker(elem_id="--ae-modal-bg-color", label="Background color")
|
111 |
+
gr.ColorPicker(elem_id="--ae-modal-icon-color", label="Icon color")
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
def save_theme( vars_text, css_text, filename):
|
116 |
+
style_data= ":root{" + vars_text + "}" + css_text
|
117 |
+
with open(os.path.join(themes_folder, f"{filename}.css"), 'w', encoding="utf-8") as file:
|
118 |
+
file.write(vars_text)
|
119 |
+
file.close()
|
120 |
+
with open(webui_style_path, 'w', encoding="utf-8") as file:
|
121 |
+
file.write(style_data)
|
122 |
+
file.close()
|
123 |
+
themes_dropdown.choices=get_files(themes_folder,[".css, .txt"])
|
124 |
+
return gr.update(choices=themes_dropdown.choices, value=f"{filename}.css")
|
125 |
+
|
126 |
+
def open_theme(filename, css_text):
|
127 |
+
with open(os.path.join(themes_folder, f"{filename}"), 'r') as file:
|
128 |
+
vars_text=file.read()
|
129 |
+
no_ext=filename.rsplit('.', 1)[0]
|
130 |
+
#save_theme( vars_text, css_text, no_ext)
|
131 |
+
# shared.state.interrupt()
|
132 |
+
# shared.state.need_restart = True
|
133 |
+
return [vars_text, no_ext]
|
134 |
+
|
135 |
+
# def delete_theme(filename):
|
136 |
+
# try:
|
137 |
+
# os.remove(os.path.join(themes_folder, filename))
|
138 |
+
# except FileNotFoundError:
|
139 |
+
# pass
|
140 |
+
|
141 |
+
# delete_button.click(
|
142 |
+
# fn = lambda: delete_theme()
|
143 |
+
# )
|
144 |
+
|
145 |
+
save_button.click(
|
146 |
+
fn=save_theme,
|
147 |
+
inputs=[vars_text, css_text, save_as_filename],
|
148 |
+
outputs=themes_dropdown
|
149 |
+
)
|
150 |
+
|
151 |
+
themes_dropdown.change(
|
152 |
+
fn=open_theme,
|
153 |
+
#_js = "applyTheme",
|
154 |
+
inputs=[themes_dropdown, css_text],
|
155 |
+
outputs=[vars_text, save_as_filename]
|
156 |
+
)
|
157 |
+
|
158 |
+
# apply_button.click(
|
159 |
+
# fn=None,
|
160 |
+
# _js = "applyTheme"
|
161 |
+
# )
|
162 |
+
|
163 |
+
# vars_text.change(
|
164 |
+
# fn=None,
|
165 |
+
# _js = "applyTheme",
|
166 |
+
# inputs=[],
|
167 |
+
# outputs=[vars_text, css_text]
|
168 |
+
# )
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
return (ui_theme, 'Theme', 'ui_theme'),
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
script_callbacks.on_ui_tabs(on_ui_tabs)
|
extensions-builtin/sd_theme_editor/style.css
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#theme_menu {
|
2 |
+
z-index: 9999;
|
3 |
+
background-color: var(--ae-input-bg-color);
|
4 |
+
position: relative;
|
5 |
+
width: 38px;
|
6 |
+
height: 38px;
|
7 |
+
border-radius: 100%;
|
8 |
+
cursor: pointer;
|
9 |
+
min-width: unset;
|
10 |
+
max-width: 38px;
|
11 |
+
align-self: center;
|
12 |
+
}
|
13 |
+
|
14 |
+
#theme_menu::before {
|
15 |
+
content: " ";
|
16 |
+
display: inline-block;
|
17 |
+
-webkit-mask-size: cover;
|
18 |
+
mask-size: cover;
|
19 |
+
background-color: var(--ae-icon-color);
|
20 |
+
width: var(--ae-icon-size);
|
21 |
+
height: var(--ae-icon-size);
|
22 |
+
-webkit-mask: url(./file=html/svg/contrast-drop-2-line.svg) no-repeat 50% 50%;
|
23 |
+
mask: url(./file=html/svg/contrast-drop-2-line.svg) no-repeat 50% 50%;
|
24 |
+
cursor: pointer;
|
25 |
+
position: relative;
|
26 |
+
left: 50%;
|
27 |
+
top: 50%;
|
28 |
+
transform: translate(-50%, -50%) scale(1);
|
29 |
+
}
|
30 |
+
|
31 |
+
#theme_menu.fixed,
|
32 |
+
#theme_menu:hover {
|
33 |
+
background-color: var(--ae-icon-color);
|
34 |
+
}
|
35 |
+
|
36 |
+
#theme_menu.fixed::before,
|
37 |
+
#theme_menu:hover::before {
|
38 |
+
background-color: var(--ae-icon-hover-color);
|
39 |
+
}
|
40 |
+
|
41 |
+
#theme_overflow_container {
|
42 |
+
overflow-y: auto;
|
43 |
+
height: calc(
|
44 |
+
100vh - var(--ae-top-header-height) - (var(--ae-outside-gap-size) * 2) -
|
45 |
+
(var(--ae-inside-padding-size) * 4) - 96px
|
46 |
+
);
|
47 |
+
overflow-x: hidden;
|
48 |
+
}
|
49 |
+
|
50 |
+
#tab_ui_theme.open {
|
51 |
+
transform: translateX(0);
|
52 |
+
box-shadow: rgba(0, 0, 0, 0.4) -30px 0 30px -30px;
|
53 |
+
}
|
54 |
+
|
55 |
+
#tab_ui_theme.aside {
|
56 |
+
display: block !important;
|
57 |
+
}
|
58 |
+
|
59 |
+
#tab_ui_theme.aside {
|
60 |
+
position: fixed;
|
61 |
+
top: var(--ae-top-header-height);
|
62 |
+
width: 90%;
|
63 |
+
right: 0;
|
64 |
+
height: calc(100% - var(--ae-top-header-height));
|
65 |
+
max-width: 480px;
|
66 |
+
z-index: 9999;
|
67 |
+
transform: translateX(100%);
|
68 |
+
transition: all 0.25s ease 0s;
|
69 |
+
box-shadow: rgba(0, 0, 0, 0) -30px 0 30px -30px;
|
70 |
+
padding: calc(1rem - var(--ae-outside-gap-size));
|
71 |
+
background-color: var(--ae-main-bg-color) !important;
|
72 |
+
}
|
73 |
+
#tab_ui_theme.aside.open {
|
74 |
+
transform: translateX(0);
|
75 |
+
box-shadow: rgba(0, 0, 0, 0.4) -30px 0 30px -30px;
|
76 |
+
}
|
77 |
+
|
78 |
+
#theme_hidden,
|
79 |
+
#setting_ui_header_tabs .theme,
|
80 |
+
#setting_ui_hidden_tabs .theme {
|
81 |
+
display: none !important;
|
82 |
+
}
|
83 |
+
|
84 |
+
#tab_ui_theme [id*="color"] label {
|
85 |
+
display: flex;
|
86 |
+
align-items: center;
|
87 |
+
pointer-events: none;
|
88 |
+
}
|
89 |
+
#tab_ui_theme [id*="color"] label span {
|
90 |
+
min-width: 50% !important;
|
91 |
+
}
|
92 |
+
#tab_ui_theme [id*="color"] label input {
|
93 |
+
flex-grow: 1;
|
94 |
+
pointer-events: all;
|
95 |
+
cursor: pointer;
|
96 |
+
}
|
97 |
+
|
98 |
+
#settings_ui_theme > div > div {
|
99 |
+
flex-direction: row;
|
100 |
+
flex-wrap: wrap;
|
101 |
+
}
|
102 |
+
#settings_ui_theme > div > div > div {
|
103 |
+
max-width: 30%;
|
104 |
+
}
|
105 |
+
|
106 |
+
#tab_ui_theme > div {
|
107 |
+
padding: 16px !important;
|
108 |
+
padding-top: 0 !important;
|
109 |
+
}
|
110 |
+
|
111 |
+
#ui_theme_hsv + button {
|
112 |
+
min-width: unset;
|
113 |
+
}
|
extensions-builtin/sd_theme_editor/themes/Golde.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(99deg 11% 8%);--ae-primary-color:hsl(44deg 63% 55%);--ae-input-bg-color:hsl(106deg 8% 12%);--ae-input-border-color:hsl(104deg 9% 32%);--ae-panel-bg-color:hsl(104deg 9% 20%);--ae-panel-border-color:hsl(104deg 9% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(99deg 11% 8%);--ae-subgroup-input-bg-color:hsl(99deg 11% 8%);--ae-subgroup-input-border-color:hsl(104deg 9% 32%);--ae-subpanel-bg-color:hsl(106deg 8% 12%);--ae-subpanel-border-color:hsl(104deg 9% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(56deg 30% 36%);--ae-input-focus-color:hsl(44deg 63% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(104deg 9% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(105deg 9% 77%);--ae-icon-hover-color:hsl(99deg 11% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(98deg 9% 4%);--ae-nav-color:hsl(105deg 9% 77%);--ae-nav-hover-color:hsl(98deg 9% 4%);--ae-input-color:hsl(44deg 63% 55%);--ae-label-color:hsl(105deg 9% 77%);--ae-subgroup-input-color:hsl(44deg 63% 55%);--ae-placeholder-color:hsl(104deg 9% 32%);--ae-text-color:hsl(105deg 9% 77%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(108deg 8% 12%);--ae-modal-bg-color:hsl(96deg 12% 8%);--ae-modal-icon-color:hsl(44deg 63% 55%);
|
extensions-builtin/sd_theme_editor/themes/backup.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(0deg 0% 10%);--ae-primary-color:hsl(168deg 96% 42%);--ae-input-bg-color:hsl(225deg 6% 13%);--ae-input-border-color:hsl(214deg 5% 30%);--ae-panel-bg-color:hsl(225deg 5% 17%);--ae-panel-border-color:hsl(214deg 5% 30%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-bg-color:hsl(225deg 6% 13%);--ae-subgroup-input-border-color:hsl(214deg 5% 30%);--ae-subpanel-bg-color:hsl(220deg 4% 14%);--ae-subpanel-border-color:hsl(214deg 5% 30%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(168deg 97% 41%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(0deg 84% 60%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(168deg 96% 42%);--ae-icon-hover-color:hsl(0deg 0% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(0deg 0% 4%);--ae-nav-color:hsl(210deg 4% 80%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(210deg 4% 80%);--ae-label-color:hsl(210deg 4% 80%);--ae-subgroup-input-color:hsl(0deg 100% 100%);--ae-placeholder-color:hsl(214deg 5% 30%);--ae-text-color:hsl(210deg 4% 80%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(225deg 6% 13%);--ae-modal-bg-color:hsl(0deg 0% 10%);--ae-modal-icon-color:hsl(168deg 97% 41%);
|
extensions-builtin/sd_theme_editor/themes/d-230-52-94.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(230deg 52% 4%);--ae-primary-color:hsl(38deg 148% 36%);--ae-input-bg-color:hsl(95deg 58% 7%);--ae-input-border-color:hsl(84deg 57% 24%);--ae-panel-bg-color:hsl(95deg 57% 11%);--ae-panel-border-color:hsl(84deg 57% 24%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(230deg 52% 4%);--ae-subgroup-input-bg-color:hsl(95deg 58% 7%);--ae-subgroup-input-border-color:hsl(84deg 57% 24%);--ae-subpanel-bg-color:hsl(90deg 56% 8%);--ae-subpanel-border-color:hsl(84deg 57% 24%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(80deg 55% 30%);--ae-input-focus-color:hsl(38deg 149% 35%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(230deg 136% 54%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(80deg 56% 74%);--ae-icon-hover-color:hsl(230deg 52% 4%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(230deg 52% 98%);--ae-nav-color:hsl(80deg 56% 74%);--ae-nav-hover-color:hsl(230deg 52% 98%);--ae-input-color:hsl(80deg 56% 74%);--ae-label-color:hsl(80deg 56% 74%);--ae-subgroup-input-color:hsl(230deg 152% 94%);--ae-placeholder-color:hsl(84deg 57% 24%);--ae-text-color:hsl(80deg 56% 74%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(94deg 60% 7%);--ae-modal-bg-color:hsl(229deg 52% 4%);--ae-modal-icon-color:hsl(38deg 100% 36%);
|
extensions-builtin/sd_theme_editor/themes/default.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(0deg 0% 10%);--ae-primary-color:hsl(168deg 97% 41%);--ae-input-bg-color:hsl(225deg 6% 13%);--ae-input-border-color:hsl(214deg 5% 30%);--ae-panel-bg-color:hsl(225deg 5% 17%);--ae-panel-border-color:hsl(214deg 5% 30%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-bg-color:hsl(225deg 6% 13%);--ae-subgroup-input-border-color:hsl(214deg 5% 30%);--ae-subpanel-bg-color:hsl(220deg 4% 14%);--ae-subpanel-border-color:hsl(214deg 5% 30%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(168deg 97% 41%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(0deg 84% 60%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(168deg 97% 41%);--ae-icon-hover-color:hsl(0deg 0% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(0deg 0% 4%);--ae-nav-color:hsl(210deg 4% 80%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(210deg 4% 80%);--ae-label-color:hsl(210deg 4% 80%);--ae-subgroup-input-color:hsl(210deg 4% 80%);--ae-placeholder-color:hsl(214deg 5% 30%);--ae-text-color:hsl(210deg 4% 80%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(225deg 6% 13%);--ae-modal-bg-color:hsl(0deg 0% 10%);--ae-modal-icon-color:hsl(168deg 97% 41%);
|
extensions-builtin/sd_theme_editor/themes/default_cyan.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(0deg 0% 10%);--ae-primary-color:hsl(199deg 60% 60%);--ae-input-bg-color:hsl(225deg 6% 13%);--ae-input-border-color:hsl(214deg 5% 30%);--ae-panel-bg-color:hsl(225deg 5% 17%);--ae-panel-border-color:hsl(214deg 5% 30%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-bg-color:hsl(225deg 6% 13%);--ae-subgroup-input-border-color:hsl(214deg 5% 30%);--ae-subpanel-bg-color:hsl(220deg 4% 14%);--ae-subpanel-border-color:hsl(214deg 5% 30%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(199deg 60% 60%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(357deg 50% 57%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(210deg 4% 80%);--ae-icon-hover-color:hsl(0deg 0% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(0deg 0% 4%);--ae-nav-color:hsl(210deg 4% 80%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(210deg 4% 80%);--ae-label-color:hsl(210deg 4% 80%);--ae-subgroup-input-color:hsl(210deg 4% 80%);--ae-placeholder-color:hsl(214deg 5% 30%);--ae-text-color:hsl(210deg 4% 80%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(225deg 6% 13%);--ae-modal-bg-color:hsl(0deg 0% 10%);--ae-modal-icon-color:hsl(199deg 60% 60%);
|
extensions-builtin/sd_theme_editor/themes/default_orange.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(0deg 0% 10%);--ae-primary-color:hsl(16deg 77% 60%);--ae-input-bg-color:hsl(225deg 6% 13%);--ae-input-border-color:hsl(214deg 5% 30%);--ae-panel-bg-color:hsl(225deg 5% 17%);--ae-panel-border-color:hsl(214deg 5% 30%);--ae-panel-border-radius:8px;--ae-subgroup-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-bg-color:hsl(225deg 6% 13%);--ae-subgroup-input-border-color:hsl(214deg 5% 30%);--ae-subpanel-bg-color:hsl(220deg 4% 14%);--ae-subpanel-border-color:hsl(214deg 5% 30%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(16deg 77% 60%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(193deg 54% 55%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(210deg 4% 80%);--ae-icon-hover-color:hsl(0deg 0% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(0deg 0% 4%);--ae-nav-color:hsl(210deg 4% 80%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(210deg 4% 80%);--ae-label-color:hsl(210deg 4% 80%);--ae-subgroup-input-color:hsl(210deg 4% 80%);--ae-placeholder-color:hsl(214deg 5% 30%);--ae-text-color:hsl(210deg 4% 80%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(225deg 6% 13%);--ae-modal-bg-color:hsl(0deg 0% 10%);--ae-modal-icon-color:hsl(16deg 77% 60%);
|
extensions-builtin/sd_theme_editor/themes/fun.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(253deg 22% 8%);--ae-primary-color:hsl(76deg 96% 55%);--ae-input-bg-color:hsl(260deg 25% 12%);--ae-input-border-color:hsl(258deg 24% 32%);--ae-panel-bg-color:hsl(258deg 24% 20%);--ae-panel-border-color:hsl(258deg 24% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(253deg 22% 8%);--ae-subgroup-input-bg-color:hsl(258deg 24% 8%);--ae-subgroup-input-border-color:hsl(258deg 24% 32%);--ae-subpanel-bg-color:hsl(260deg 25% 12%);--ae-subpanel-border-color:hsl(258deg 24% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(296deg 96% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(258deg 24% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(259deg 24% 77%);--ae-icon-hover-color:hsl(253deg 22% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(252deg 24% 4%);--ae-nav-color:hsl(259deg 24% 77%);--ae-nav-hover-color:hsl(252deg 24% 4%);--ae-input-color:hsl(305deg 96% 55%);--ae-label-color:hsl(259deg 24% 77%);--ae-subgroup-input-color:hsl(76deg 96% 55%);--ae-placeholder-color:hsl(258deg 24% 32%);--ae-text-color:hsl(259deg 24% 77%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(260deg 25% 12%);--ae-modal-bg-color:hsl(253deg 22% 8%);--ae-modal-icon-color:hsl(76deg 96% 55%);
|
extensions-builtin/sd_theme_editor/themes/minimal.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(0deg 0% 8%);--ae-primary-color:hsl(168deg 96% 42%);--ae-input-bg-color:hsl(0deg 0% 10%);--ae-input-border-color:hsl(0deg 0% 10%);--ae-panel-bg-color:hsl(0deg 0% 17%);--ae-panel-border-color:hsl(0deg 0% 17%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-border-color:hsl(0deg 0% 10%);--ae-subpanel-bg-color:hsl(0deg 0% 14%);--ae-subpanel-border-color:hsl(0deg 0% 15%);--ae-subpanel-border-radius:4px;--ae-textarea-focus-color:hsl(0deg 0% 36%);--ae-input-focus-color:hsl(168deg 97% 41%);--ae-outside-gap-size:1px;--ae-inside-padding-size:5px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(0deg 84% 60%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(168deg 96% 42%);--ae-icon-hover-color:hsl(0deg 0% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(0deg 0% 4%);--ae-nav-color:hsl(0deg 0% 80%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(210deg 4% 80%);--ae-label-color:hsl(0deg 0% 65%);--ae-subgroup-input-color:hsl(0deg 100% 100%);--ae-placeholder-color:hsl(0deg 0% 30%);--ae-text-color:hsl(0deg 0% 80%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(0deg 0% 14%);--ae-modal-bg-color:hsl(0deg 0% 10%);--ae-modal-icon-color:hsl(168deg 97% 41%);
|
extensions-builtin/sd_theme_editor/themes/minimal_orange.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(210deg 28% 8%);--ae-primary-color:hsl(18deg 124% 42%);--ae-input-bg-color:hsl(210deg 28% 10%);--ae-input-border-color:hsl(210deg 28% 10%);--ae-panel-bg-color:hsl(210deg 28% 17%);--ae-panel-border-color:hsl(210deg 28% 17%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(210deg 28% 10%);--ae-subgroup-input-bg-color:hsl(210deg 28% 10%);--ae-subgroup-input-border-color:hsl(210deg 28% 10%);--ae-subpanel-bg-color:hsl(210deg 28% 14%);--ae-subpanel-border-color:hsl(210deg 28% 15%);--ae-subpanel-border-radius:4px;--ae-textarea-focus-color:hsl(210deg 28% 36%);--ae-input-focus-color:hsl(18deg 125% 41%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(210deg 112% 60%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(18deg 124% 42%);--ae-icon-hover-color:hsl(210deg 28% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(210deg 28% 4%);--ae-nav-color:hsl(210deg 28% 80%);--ae-nav-hover-color:hsl(210deg 28% 4%);--ae-input-color:hsl(60deg 32% 80%);--ae-label-color:hsl(210deg 28% 65%);--ae-subgroup-input-color:hsl(210deg 128% 100%);--ae-placeholder-color:hsl(210deg 28% 30%);--ae-text-color:hsl(210deg 28% 80%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(210deg 28% 14%);--ae-modal-bg-color:hsl(210deg 28% 10%);--ae-modal-icon-color:hsl(18deg 125% 41%);
|
extensions-builtin/sd_theme_editor/themes/moonlight.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(240deg 16% 6%);--ae-primary-color:hsl(222deg 75% 62%);--ae-input-bg-color:hsl(240deg 17% 8%);--ae-input-border-color:hsl(240deg 20% 16%);--ae-panel-bg-color:hsl(240deg 18% 12%);--ae-panel-border-color:hsl(240deg 16% 16%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(240deg 17% 8%);--ae-subgroup-input-bg-color:hsl(240deg 17% 8%);--ae-subgroup-input-border-color:hsl(240deg 20% 16%);--ae-subpanel-bg-color:hsl(240deg 18% 10%);--ae-subpanel-border-color:hsl(240deg 20% 16%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(222deg 75% 62%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(222deg 75% 62%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(222deg 75% 62%);--ae-icon-hover-color:hsl(240deg 18% 12%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(240deg 16% 6%);--ae-nav-color:hsl(185deg 66% 85%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(185deg 66% 85%);--ae-label-color:hsl(185deg 66% 85%);--ae-subgroup-input-color:hsl(185deg 66% 85%);--ae-placeholder-color:hsl(240deg 20% 24%);--ae-text-color:hsl(185deg 66% 85%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(240deg 18% 10%);--ae-modal-bg-color:hsl(240deg 16% 6%);--ae-modal-icon-color:hsl(222deg 75% 62%);
|
extensions-builtin/sd_theme_editor/themes/ogxBGreen.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(195deg 22% 8%);--ae-primary-color:hsl(159deg 96% 55%);--ae-input-bg-color:hsl(202deg 25% 12%);--ae-input-border-color:hsl(200deg 24% 32%);--ae-panel-bg-color:hsl(200deg 24% 20%);--ae-panel-border-color:hsl(200deg 24% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(195deg 22% 8%);--ae-subgroup-input-bg-color:hsl(200deg 24% 8%);--ae-subgroup-input-border-color:hsl(200deg 24% 32%);--ae-subpanel-bg-color:hsl(202deg 25% 12%);--ae-subpanel-border-color:hsl(200deg 24% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(152deg 3% 36%);--ae-input-focus-color:hsl(159deg 96% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(200deg 24% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(159deg 96% 55%);--ae-icon-hover-color:hsl(195deg 22% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(194deg 24% 4%);--ae-nav-color:hsl(201deg 24% 77%);--ae-nav-hover-color:hsl(194deg 24% 4%);--ae-input-color:hsl(159deg 96% 55%);--ae-label-color:hsl(201deg 24% 77%);--ae-subgroup-input-color:hsl(159deg 96% 55%);--ae-placeholder-color:hsl(200deg 24% 32%);--ae-text-color:hsl(201deg 24% 77%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(200deg 25% 12%);--ae-modal-bg-color:hsl(193deg 22% 8%);--ae-modal-icon-color:hsl(159deg 96% 55%);
|
extensions-builtin/sd_theme_editor/themes/ogxCyan.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(253deg 22% 8%);--ae-primary-color:hsl(198deg 96% 55%);--ae-input-bg-color:hsl(260deg 25% 12%);--ae-input-border-color:hsl(258deg 24% 32%);--ae-panel-bg-color:hsl(258deg 24% 20%);--ae-panel-border-color:hsl(258deg 24% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(253deg 22% 8%);--ae-subgroup-input-bg-color:hsl(258deg 24% 8%);--ae-subgroup-input-border-color:hsl(258deg 24% 32%);--ae-subpanel-bg-color:hsl(260deg 25% 12%);--ae-subpanel-border-color:hsl(258deg 24% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(198deg 96% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(258deg 24% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(259deg 24% 77%);--ae-icon-hover-color:hsl(253deg 22% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(252deg 24% 4%);--ae-nav-color:hsl(259deg 24% 77%);--ae-nav-hover-color:hsl(252deg 24% 4%);--ae-input-color:hsl(198deg 96% 55%);--ae-label-color:hsl(259deg 24% 77%);--ae-subgroup-input-color:hsl(198deg 96% 55%);--ae-placeholder-color:hsl(258deg 24% 32%);--ae-text-color:hsl(259deg 24% 77%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(260deg 25% 12%);--ae-modal-bg-color:hsl(253deg 22% 8%);--ae-modal-icon-color:hsl(198deg 96% 55%);
|
extensions-builtin/sd_theme_editor/themes/ogxCyanInvert.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(73deg 22% 92%);--ae-primary-color:hsl(18deg 96% 45%);--ae-input-bg-color:hsl(80deg 25% 88%);--ae-input-border-color:hsl(78deg 24% 68%);--ae-panel-bg-color:hsl(78deg 24% 80%);--ae-panel-border-color:hsl(78deg 24% 68%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(73deg 22% 92%);--ae-subgroup-input-bg-color:hsl(73deg 22% 92%);--ae-subgroup-input-border-color:hsl(78deg 24% 68%);--ae-subpanel-bg-color:hsl(80deg 25% 88%);--ae-subpanel-border-color:hsl(78deg 24% 68%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(30deg 3% 64%);--ae-input-focus-color:hsl(18deg 96% 45%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(78deg 24% 68%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(79deg 24% 23%);--ae-icon-hover-color:hsl(73deg 22% 92%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(72deg 24% 96%);--ae-nav-color:hsl(79deg 24% 23%);--ae-nav-hover-color:hsl(72deg 24% 96%);--ae-input-color:hsl(18deg 96% 45%);--ae-label-color:hsl(79deg 24% 23%);--ae-subgroup-input-color:hsl(18deg 96% 45%);--ae-placeholder-color:hsl(78deg 24% 68%);--ae-text-color:hsl(79deg 24% 23%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(80deg 25% 88%);--ae-modal-bg-color:hsl(73deg 22% 92%);--ae-modal-icon-color:hsl(18deg 96% 45%);
|
extensions-builtin/sd_theme_editor/themes/ogxGreen.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(253deg 22% 8%);--ae-primary-color:hsl(149deg 96% 55%);--ae-input-bg-color:hsl(260deg 25% 12%);--ae-input-border-color:hsl(258deg 24% 32%);--ae-panel-bg-color:hsl(258deg 24% 20%);--ae-panel-border-color:hsl(258deg 24% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(253deg 22% 8%);--ae-subgroup-input-bg-color:hsl(258deg 24% 8%);--ae-subgroup-input-border-color:hsl(258deg 24% 32%);--ae-subpanel-bg-color:hsl(260deg 25% 12%);--ae-subpanel-border-color:hsl(258deg 24% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(149deg 96% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(258deg 24% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(259deg 24% 77%);--ae-icon-hover-color:hsl(253deg 22% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(252deg 24% 4%);--ae-nav-color:hsl(259deg 24% 77%);--ae-nav-hover-color:hsl(252deg 24% 4%);--ae-input-color:hsl(149deg 96% 55%);--ae-label-color:hsl(259deg 24% 77%);--ae-subgroup-input-color:hsl(149deg 96% 55%);--ae-placeholder-color:hsl(258deg 24% 32%);--ae-text-color:hsl(259deg 24% 77%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(260deg 25% 12%);--ae-modal-bg-color:hsl(253deg 22% 8%);--ae-modal-icon-color:hsl(253deg 22% 8%);
|
extensions-builtin/sd_theme_editor/themes/ogxRed.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(253deg 22% 8%);--ae-primary-color:hsl(347deg 96% 55%);--ae-input-bg-color:hsl(260deg 25% 12%);--ae-input-border-color:hsl(258deg 24% 32%);--ae-panel-bg-color:hsl(258deg 24% 20%);--ae-panel-border-color:hsl(258deg 24% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(253deg 22% 8%);--ae-subgroup-input-bg-color:hsl(258deg 24% 8%);--ae-subgroup-input-border-color:hsl(258deg 24% 32%);--ae-subpanel-bg-color:hsl(260deg 25% 12%);--ae-subpanel-border-color:hsl(258deg 24% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(347deg 96% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(258deg 24% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(259deg 24% 77%);--ae-icon-hover-color:hsl(253deg 22% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(252deg 24% 4%);--ae-nav-color:hsl(259deg 24% 77%);--ae-nav-hover-color:hsl(252deg 24% 4%);--ae-input-color:hsl(347deg 96% 55%);--ae-label-color:hsl(259deg 24% 77%);--ae-subgroup-input-color:hsl(347deg 96% 55%);--ae-placeholder-color:hsl(258deg 24% 32%);--ae-text-color:hsl(259deg 24% 77%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(260deg 25% 12%);--ae-modal-bg-color:hsl(253deg 22% 8%);--ae-modal-icon-color:hsl(347deg 96% 55%);
|
extensions-builtin/sd_theme_editor/themes/retrog.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(197deg 97% 14%);--ae-primary-color:hsl(27deg 99% 50%);--ae-input-bg-color:hsl(197deg 98% 16%);--ae-input-border-color:hsl(166deg 62% 33%);--ae-panel-bg-color:hsl(196deg 98% 18%);--ae-panel-border-color:hsl(166deg 62% 33%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(197deg 97% 14%);--ae-subgroup-input-bg-color:hsl(197deg 97% 14%);--ae-subgroup-input-border-color:hsl(166deg 62% 33%);--ae-subpanel-bg-color:hsl(197deg 98% 16%);--ae-subpanel-border-color:hsl(166deg 62% 33%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(222deg 75% 62%);--ae-outside-gap-size:6px;--ae-inside-padding-size:6px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(70deg 69% 54%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(27deg 99% 50%);--ae-icon-hover-color:hsl(196deg 98% 18%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(197deg 97% 14%);--ae-nav-color:hsl(70deg 69% 54%);--ae-nav-hover-color:hsl(197deg 97% 14%);--ae-input-color:hsl(70deg 69% 54%);--ae-label-color:hsl(70deg 69% 54%);--ae-subgroup-input-color:hsl(27deg 99% 50%);--ae-placeholder-color:hsl(166deg 62% 33%);--ae-text-color:hsl(185deg 66% 85%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(197deg 98% 16%);--ae-modal-bg-color:hsl(197deg 97% 14%);--ae-modal-icon-color:hsl(27deg 99% 50%);
|
extensions-builtin/sd_theme_editor/themes/tron.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(185deg 75% 3%);--ae-primary-color:hsl(182deg 95% 51%);--ae-input-bg-color:hsl(185deg 73% 3%);--ae-input-border-color:hsl(185deg 72% 25%);--ae-panel-bg-color:hsl(180deg 76% 5%);--ae-panel-border-color:hsl(185deg 72% 25%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(185deg 73% 3%);--ae-subgroup-input-bg-color:hsl(185deg 73% 3%);--ae-subgroup-input-border-color:hsl(185deg 72% 25%);--ae-subpanel-bg-color:hsl(185deg 73% 3%);--ae-subpanel-border-color:hsl(185deg 72% 25%);--ae-subpanel-border-radius:0px;--ae-textarea-focus-color:hsl(182deg 95% 51%);--ae-input-focus-color:hsl(182deg 95% 51%);--ae-outside-gap-size:2px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(182deg 95% 51%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(182deg 95% 51%);--ae-icon-hover-color:hsl(185deg 73% 3%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(185deg 73% 3%);--ae-nav-color:hsl(182deg 95% 75%);--ae-nav-hover-color:hsl(0deg 100% 50%);--ae-input-color:hsl(182deg 95% 75%);--ae-label-color:hsl(182deg 95% 75%);--ae-subgroup-input-color:hsl(182deg 95% 51%);--ae-placeholder-color:hsl(185deg 72% 25%);--ae-text-color:hsl(182deg 95% 75%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:8px;--ae-frame-bg-color:hsl(180deg 76% 5%);--ae-modal-bg-color:hsl(185deg 73% 3%);--ae-modal-icon-color:hsl(182deg 95% 51%);
|
extensions-builtin/sd_theme_editor/themes/tron2.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
--ae-main-bg-color:hsl(185deg 75% 3%);--ae-primary-color:hsl(182deg 95% 51%);--ae-input-bg-color:hsl(184deg 89% 7%);--ae-input-border-color:hsl(185deg 72% 25%);--ae-panel-bg-color:hsl(185deg 73% 3%);--ae-panel-border-color:hsl(185deg 72% 25%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(185deg 73% 3%);--ae-subgroup-input-bg-color:hsl(185deg 73% 3%);--ae-subgroup-input-border-color:hsl(185deg 72% 25%);--ae-subpanel-bg-color:hsl(185deg 73% 3%);--ae-subpanel-border-color:hsl(185deg 72% 25%);--ae-subpanel-border-radius:0px;--ae-textarea-focus-color:hsl(182deg 95% 51%);--ae-input-focus-color:hsl(182deg 95% 51%);--ae-outside-gap-size:2px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(182deg 95% 51%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(182deg 95% 51%);--ae-icon-hover-color:hsl(185deg 73% 3%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(185deg 73% 3%);--ae-nav-color:hsl(182deg 95% 75%);--ae-nav-hover-color:hsl(0deg 100% 50%);--ae-input-color:hsl(182deg 95% 75%);--ae-label-color:hsl(182deg 95% 75%);--ae-subgroup-input-color:hsl(182deg 95% 51%);--ae-placeholder-color:hsl(185deg 72% 25%);--ae-text-color:hsl(182deg 95% 75%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:8px;--ae-frame-bg-color:hsl(185deg 73% 3%);
|
html/200w.webp
ADDED
html/card-no-preview.png
ADDED
html/extra-networks-card.html
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div class='card-container' >
|
2 |
+
<div class='card' onclick={card_clicked} data-name="{name}" {sort_keys}>
|
3 |
+
<img src={preview_image} loading="lazy">
|
4 |
+
{metadata_button}
|
5 |
+
{edit_button}
|
6 |
+
<div class='actions'>
|
7 |
+
<div class='additional'>
|
8 |
+
<ul>
|
9 |
+
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
10 |
+
</ul>
|
11 |
+
<span style="display:none" class='search_term{search_only}'>{search_term}</span>
|
12 |
+
</div>
|
13 |
+
<span class='name'>{name}</span>
|
14 |
+
</div>
|
15 |
+
<span class='description'>{description}</span>
|
16 |
+
</div>
|
17 |
+
<i class="image-icon"></i>
|
18 |
+
</div>
|
html/extra-networks-no-cards.html
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div class='nocards'>
|
2 |
+
<h1>Nothing here. Add some content to the following directories:</h1>
|
3 |
+
|
4 |
+
<ul>
|
5 |
+
{dirs}
|
6 |
+
</ul>
|
7 |
+
</div>
|
8 |
+
|
html/favicon.ico
ADDED
html/footer.html
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div class="footer-wrapper">
|
2 |
+
<ul class="footer-links">
|
3 |
+
<li>
|
4 |
+
<a href="/docs" data-tooltip="Use via API" data-position="top" class="top">
|
5 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path fill="none" d="M0 0h24v24H0z"/><path d="M6.2,9.6c0.5-0.7,1-1.4,1.6-2c3.1-3.1,7.5-4,11.1-2.5c1.4,3.6,0.6,8-2.5,11.1c-0.6,0.6-1.3,1.1-2,1.6l0.1,2.3c0,0.2-0.1,0.4-0.3,0.4l-4,1c-0.2,0.1-0.5-0.1-0.5-0.3c0,0,0-0.1,0-0.1v-2.7c0-0.2-0.1-0.4-0.3-0.6l-3.2-3.2c-0.2-0.2-0.4-0.3-0.6-0.3H2.9c-0.2,0-0.4-0.2-0.4-0.4c0,0,0-0.1,0-0.1l1-4c0.1-0.2,0.2-0.3,0.4-0.3C3.9,9.5,6.2,9.6,6.2,9.6zM12.1,11.9c0.7,0.7,1.8,0.7,2.4,0c0.7-0.7,0.7-1.8,0-2.4c-0.7-0.7-1.8-0.7-2.4,0C11.4,10.2,11.4,11.3,12.1,11.9z"/></svg>
|
6 |
+
</a>
|
7 |
+
</li>
|
8 |
+
<li>
|
9 |
+
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui" data-tooltip="Github" data-position="top" class="top">
|
10 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path fill="none" d="M0 0h24v24H0z"/><path d="M12 2C6.475 2 2 6.475 2 12a9.994 9.994 0 0 0 6.838 9.488c.5.087.687-.213.687-.476 0-.237-.013-1.024-.013-1.862-2.512.463-3.162-.612-3.362-1.175-.113-.288-.6-1.175-1.025-1.413-.35-.187-.85-.65-.013-.662.788-.013 1.35.725 1.538 1.025.9 1.512 2.338 1.087 2.912.825.088-.65.35-1.087.638-1.337-2.225-.25-4.55-1.113-4.55-4.938 0-1.088.387-1.987 1.025-2.688-.1-.25-.45-1.275.1-2.65 0 0 .837-.262 2.75 1.026a9.28 9.28 0 0 1 2.5-.338c.85 0 1.7.112 2.5.337 1.912-1.3 2.75-1.024 2.75-1.024.55 1.375.2 2.4.1 2.65.637.7 1.025 1.587 1.025 2.687 0 3.838-2.337 4.688-4.562 4.938.362.312.675.912.675 1.85 0 1.337-.013 2.412-.013 2.75 0 .262.188.574.688.474A10.016 10.016 0 0 0 22 12c0-5.525-4.475-10-10-10z"/></svg></a>
|
11 |
+
</li>
|
12 |
+
<li>
|
13 |
+
<a href="https://gradio.app" data-tooltip="Gradio" data-position="top" class="top">
|
14 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path d="M0,0H24V24H0V0Z" style="fill: none;"/><path d="M21.14,10.92h0s0-2.91,0-2.91L12.04,3.02v2.91l6.45,3.54-1.95,1.07-4.5-2.47v2.91l1.85,1.01-1.88,1.03-1.85-1.02,1.89-1.03v-2.93l-4.56,2.49-1.94-1.06,6.5-3.55V3L2.86,8.01h0s0,0,0,0v2.93l1.94,1.06-1.94,1.06h0s0,0,0,0v2.93l9.14,5.01,9.14-5.01v-.02h0s0-2.91,0-2.91l-1.94-1.06,1.93-1.06v-.02Zm-2.65,3.6l-6.49,3.56-6.46-3.54,1.94-1.06,4.53,2.48h0s4.55-2.5,4.55-2.5l1.93,1.06Z"/></svg>
|
15 |
+
</a>
|
16 |
+
</li>
|
17 |
+
<li>
|
18 |
+
<a href="https://github.com/anapnoe/stable-diffusion-webui-ux" data-tooltip="UI/UX design by anapnoe" data-position="top" class="top">
|
19 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24">
|
20 |
+
<polygon points="19.5,13.7 21.6,9.9 11.8,9.9 16.7,18.5 17.7,16.7 17.1,15.7 16.7,16.5 13.5,10.9 19.9,10.9 18.9,12.7 "/>
|
21 |
+
<polygon points="17.3,11.7 15.8,11.7 20.2,19.3 3.8,19.3 12,5.2 14.2,9 15.8,9 12,2.5 1.5,20.7 22.5,20.7 "/>
|
22 |
+
</svg>
|
23 |
+
</li>
|
24 |
+
<li>
|
25 |
+
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false" data-tooltip="Reload UI" data-position="top" class="top">
|
26 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path fill="none" d="M0 0h24v24H0z"/><path d="M12 22C6.477 22 2 17.523 2 12S6.477 2 12 2s10 4.477 10 10-4.477 10-10 10zm4.82-4.924a7 7 0 1 0-1.852 1.266l-.975-1.755A5 5 0 1 1 17 12h-3l2.82 5.076z"/></svg>
|
27 |
+
</a>
|
28 |
+
</li>
|
29 |
+
<li>
|
30 |
+
<div class="tooltip-html">
|
31 |
+
<i class="icon-info">
|
32 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path fill="none" d="M0 0h24v24H0z"/><path d="M13 18v2h4v2H7v-2h4v-2H2.992A.998.998 0 0 1 2 16.993V4.007C2 3.451 2.455 3 2.992 3h18.016c.548 0 .992.449.992 1.007v12.986c0 .556-.455 1.007-.992 1.007H13z"/></svg>
|
33 |
+
</i>
|
34 |
+
<div class="top center">
|
35 |
+
{versions}
|
36 |
+
<i></i>
|
37 |
+
</div>
|
38 |
+
</div>
|
39 |
+
</li>
|
40 |
+
<li class="coffee-circle">
|
41 |
+
<div class="tooltip-html">
|
42 |
+
<i class="coffee">
|
43 |
+
<a href="https://buymeacoffee.com/dayanbayah">
|
44 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path fill="red" d="M12.001 4.529c2.349-2.109 5.979-2.039 8.242.228 2.262 2.268 2.34 5.88.236 8.236l-8.48 8.492-8.478-8.492c-2.104-2.356-2.025-5.974.236-8.236 2.265-2.264 5.888-2.34 8.244-.228z"/></svg>
|
45 |
+
</a>
|
46 |
+
</i>
|
47 |
+
<div class="top">
|
48 |
+
<img src="./file=html/200w.webp" width="160" height="90" alt="thanks for your support">
|
49 |
+
<i></i>
|
50 |
+
<p style="color:var(--ae-primary-color);">Your donation can help fund the continued development of Stable Diffusion web UI/UX project. Enjoy!</p>
|
51 |
+
</div>
|
52 |
+
</div>
|
53 |
+
</li>
|
54 |
+
</ul>
|
55 |
+
</div>
|