toto10 commited on
Commit
6018960
1 Parent(s): 55ce6e1

a76b4ccb8bb91c9e097ac910cbd005a5983176e8658571782cad9f042c9de31e

Browse files
Files changed (50) hide show
  1. extensions-builtin/ScuNET/__pycache__/scunet_model_arch.cpython-310.pyc +0 -0
  2. extensions-builtin/ScuNET/preload.py +6 -0
  3. extensions-builtin/ScuNET/scripts/__pycache__/scunet_model.cpython-310.pyc +0 -0
  4. extensions-builtin/ScuNET/scripts/scunet_model.py +144 -0
  5. extensions-builtin/ScuNET/scunet_model_arch.py +268 -0
  6. extensions-builtin/SwinIR/__pycache__/preload.cpython-310.pyc +0 -0
  7. extensions-builtin/SwinIR/__pycache__/swinir_model_arch.cpython-310.pyc +0 -0
  8. extensions-builtin/SwinIR/__pycache__/swinir_model_arch_v2.cpython-310.pyc +0 -0
  9. extensions-builtin/SwinIR/preload.py +6 -0
  10. extensions-builtin/SwinIR/scripts/__pycache__/swinir_model.cpython-310.pyc +0 -0
  11. extensions-builtin/SwinIR/scripts/swinir_model.py +192 -0
  12. extensions-builtin/SwinIR/swinir_model_arch.py +867 -0
  13. extensions-builtin/SwinIR/swinir_model_arch_v2.py +1017 -0
  14. extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +776 -0
  15. extensions-builtin/canvas-zoom-and-pan/scripts/__pycache__/hotkey_config.cpython-310.pyc +0 -0
  16. extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +14 -0
  17. extensions-builtin/canvas-zoom-and-pan/style.css +63 -0
  18. extensions-builtin/extra-options-section/scripts/__pycache__/extra_options_section.cpython-310.pyc +0 -0
  19. extensions-builtin/extra-options-section/scripts/extra_options_section.py +48 -0
  20. extensions-builtin/mobile/javascript/mobile.js +26 -0
  21. extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +42 -0
  22. extensions-builtin/sd_theme_editor/install.py +1 -0
  23. extensions-builtin/sd_theme_editor/javascript/ui_theme.js +435 -0
  24. extensions-builtin/sd_theme_editor/scripts/__pycache__/ui_theme.cpython-310.pyc +0 -0
  25. extensions-builtin/sd_theme_editor/scripts/ui_theme.py +177 -0
  26. extensions-builtin/sd_theme_editor/style.css +113 -0
  27. extensions-builtin/sd_theme_editor/themes/Golde.css +1 -0
  28. extensions-builtin/sd_theme_editor/themes/backup.css +1 -0
  29. extensions-builtin/sd_theme_editor/themes/d-230-52-94.css +1 -0
  30. extensions-builtin/sd_theme_editor/themes/default.css +1 -0
  31. extensions-builtin/sd_theme_editor/themes/default_cyan.css +1 -0
  32. extensions-builtin/sd_theme_editor/themes/default_orange.css +1 -0
  33. extensions-builtin/sd_theme_editor/themes/fun.css +1 -0
  34. extensions-builtin/sd_theme_editor/themes/minimal.css +1 -0
  35. extensions-builtin/sd_theme_editor/themes/minimal_orange.css +1 -0
  36. extensions-builtin/sd_theme_editor/themes/moonlight.css +1 -0
  37. extensions-builtin/sd_theme_editor/themes/ogxBGreen.css +1 -0
  38. extensions-builtin/sd_theme_editor/themes/ogxCyan.css +1 -0
  39. extensions-builtin/sd_theme_editor/themes/ogxCyanInvert.css +1 -0
  40. extensions-builtin/sd_theme_editor/themes/ogxGreen.css +1 -0
  41. extensions-builtin/sd_theme_editor/themes/ogxRed.css +1 -0
  42. extensions-builtin/sd_theme_editor/themes/retrog.css +1 -0
  43. extensions-builtin/sd_theme_editor/themes/tron.css +1 -0
  44. extensions-builtin/sd_theme_editor/themes/tron2.css +1 -0
  45. html/200w.webp +0 -0
  46. html/card-no-preview.png +0 -0
  47. html/extra-networks-card.html +18 -0
  48. html/extra-networks-no-cards.html +8 -0
  49. html/favicon.ico +0 -0
  50. html/footer.html +55 -0
extensions-builtin/ScuNET/__pycache__/scunet_model_arch.cpython-310.pyc ADDED
Binary file (9.27 kB). View file
 
extensions-builtin/ScuNET/preload.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+ from modules import paths
3
+
4
+
5
+ def preload(parser):
6
+ parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
extensions-builtin/ScuNET/scripts/__pycache__/scunet_model.cpython-310.pyc ADDED
Binary file (5.02 kB). View file
 
extensions-builtin/ScuNET/scripts/scunet_model.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import PIL.Image
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ import modules.upscaler
9
+ from modules import devices, modelloader, script_callbacks, errors
10
+ from scunet_model_arch import SCUNet
11
+
12
+ from modules.modelloader import load_file_from_url
13
+ from modules.shared import opts
14
+
15
+
16
+ class UpscalerScuNET(modules.upscaler.Upscaler):
17
+ def __init__(self, dirname):
18
+ self.name = "ScuNET"
19
+ self.model_name = "ScuNET GAN"
20
+ self.model_name2 = "ScuNET PSNR"
21
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
22
+ self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
23
+ self.user_path = dirname
24
+ super().__init__()
25
+ model_paths = self.find_models(ext_filter=[".pth"])
26
+ scalers = []
27
+ add_model2 = True
28
+ for file in model_paths:
29
+ if file.startswith("http"):
30
+ name = self.model_name
31
+ else:
32
+ name = modelloader.friendly_name(file)
33
+ if name == self.model_name2 or file == self.model_url2:
34
+ add_model2 = False
35
+ try:
36
+ scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
37
+ scalers.append(scaler_data)
38
+ except Exception:
39
+ errors.report(f"Error loading ScuNET model: {file}", exc_info=True)
40
+ if add_model2:
41
+ scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
42
+ scalers.append(scaler_data2)
43
+ self.scalers = scalers
44
+
45
+ @staticmethod
46
+ @torch.no_grad()
47
+ def tiled_inference(img, model):
48
+ # test the image tile by tile
49
+ h, w = img.shape[2:]
50
+ tile = opts.SCUNET_tile
51
+ tile_overlap = opts.SCUNET_tile_overlap
52
+ if tile == 0:
53
+ return model(img)
54
+
55
+ device = devices.get_device_for('scunet')
56
+ assert tile % 8 == 0, "tile size should be a multiple of window_size"
57
+ sf = 1
58
+
59
+ stride = tile - tile_overlap
60
+ h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
61
+ w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
62
+ E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
63
+ W = torch.zeros_like(E, dtype=devices.dtype, device=device)
64
+
65
+ with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
66
+ for h_idx in h_idx_list:
67
+
68
+ for w_idx in w_idx_list:
69
+
70
+ in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
71
+
72
+ out_patch = model(in_patch)
73
+ out_patch_mask = torch.ones_like(out_patch)
74
+
75
+ E[
76
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
77
+ ].add_(out_patch)
78
+ W[
79
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
80
+ ].add_(out_patch_mask)
81
+ pbar.update(1)
82
+ output = E.div_(W)
83
+
84
+ return output
85
+
86
+ def do_upscale(self, img: PIL.Image.Image, selected_file):
87
+
88
+ devices.torch_gc()
89
+
90
+ try:
91
+ model = self.load_model(selected_file)
92
+ except Exception as e:
93
+ print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
94
+ return img
95
+
96
+ device = devices.get_device_for('scunet')
97
+ tile = opts.SCUNET_tile
98
+ h, w = img.height, img.width
99
+ np_img = np.array(img)
100
+ np_img = np_img[:, :, ::-1] # RGB to BGR
101
+ np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
102
+ torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
103
+
104
+ if tile > h or tile > w:
105
+ _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
106
+ _img[:, :, :h, :w] = torch_img # pad image
107
+ torch_img = _img
108
+
109
+ torch_output = self.tiled_inference(torch_img, model).squeeze(0)
110
+ torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
111
+ np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
112
+ del torch_img, torch_output
113
+ devices.torch_gc()
114
+
115
+ output = np_output.transpose((1, 2, 0)) # CHW to HWC
116
+ output = output[:, :, ::-1] # BGR to RGB
117
+ return PIL.Image.fromarray((output * 255).astype(np.uint8))
118
+
119
+ def load_model(self, path: str):
120
+ device = devices.get_device_for('scunet')
121
+ if path.startswith("http"):
122
+ # TODO: this doesn't use `path` at all?
123
+ filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
124
+ else:
125
+ filename = path
126
+ model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
127
+ model.load_state_dict(torch.load(filename), strict=True)
128
+ model.eval()
129
+ for _, v in model.named_parameters():
130
+ v.requires_grad = False
131
+ model = model.to(device)
132
+
133
+ return model
134
+
135
+
136
+ def on_ui_settings():
137
+ import gradio as gr
138
+ from modules import shared
139
+
140
+ shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
141
+ shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
142
+
143
+
144
+ script_callbacks.on_ui_settings(on_ui_settings)
extensions-builtin/ScuNET/scunet_model_arch.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ from timm.models.layers import trunc_normal_, DropPath
8
+
9
+
10
+ class WMSA(nn.Module):
11
+ """ Self-attention module in Swin Transformer
12
+ """
13
+
14
+ def __init__(self, input_dim, output_dim, head_dim, window_size, type):
15
+ super(WMSA, self).__init__()
16
+ self.input_dim = input_dim
17
+ self.output_dim = output_dim
18
+ self.head_dim = head_dim
19
+ self.scale = self.head_dim ** -0.5
20
+ self.n_heads = input_dim // head_dim
21
+ self.window_size = window_size
22
+ self.type = type
23
+ self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
24
+
25
+ self.relative_position_params = nn.Parameter(
26
+ torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
27
+
28
+ self.linear = nn.Linear(self.input_dim, self.output_dim)
29
+
30
+ trunc_normal_(self.relative_position_params, std=.02)
31
+ self.relative_position_params = torch.nn.Parameter(
32
+ self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
33
+ 2).transpose(
34
+ 0, 1))
35
+
36
+ def generate_mask(self, h, w, p, shift):
37
+ """ generating the mask of SW-MSA
38
+ Args:
39
+ shift: shift parameters in CyclicShift.
40
+ Returns:
41
+ attn_mask: should be (1 1 w p p),
42
+ """
43
+ # supporting square.
44
+ attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
45
+ if self.type == 'W':
46
+ return attn_mask
47
+
48
+ s = p - shift
49
+ attn_mask[-1, :, :s, :, s:, :] = True
50
+ attn_mask[-1, :, s:, :, :s, :] = True
51
+ attn_mask[:, -1, :, :s, :, s:] = True
52
+ attn_mask[:, -1, :, s:, :, :s] = True
53
+ attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
54
+ return attn_mask
55
+
56
+ def forward(self, x):
57
+ """ Forward pass of Window Multi-head Self-attention module.
58
+ Args:
59
+ x: input tensor with shape of [b h w c];
60
+ attn_mask: attention mask, fill -inf where the value is True;
61
+ Returns:
62
+ output: tensor shape [b h w c]
63
+ """
64
+ if self.type != 'W':
65
+ x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
66
+
67
+ x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
68
+ h_windows = x.size(1)
69
+ w_windows = x.size(2)
70
+ # square validation
71
+ # assert h_windows == w_windows
72
+
73
+ x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
74
+ qkv = self.embedding_layer(x)
75
+ q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
76
+ sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
77
+ # Adding learnable relative embedding
78
+ sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
79
+ # Using Attn Mask to distinguish different subwindows.
80
+ if self.type != 'W':
81
+ attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
82
+ sim = sim.masked_fill_(attn_mask, float("-inf"))
83
+
84
+ probs = nn.functional.softmax(sim, dim=-1)
85
+ output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
86
+ output = rearrange(output, 'h b w p c -> b w p (h c)')
87
+ output = self.linear(output)
88
+ output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
89
+
90
+ if self.type != 'W':
91
+ output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
92
+
93
+ return output
94
+
95
+ def relative_embedding(self):
96
+ cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
97
+ relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
98
+ # negative is allowed
99
+ return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
100
+
101
+
102
+ class Block(nn.Module):
103
+ def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
104
+ """ SwinTransformer Block
105
+ """
106
+ super(Block, self).__init__()
107
+ self.input_dim = input_dim
108
+ self.output_dim = output_dim
109
+ assert type in ['W', 'SW']
110
+ self.type = type
111
+ if input_resolution <= window_size:
112
+ self.type = 'W'
113
+
114
+ self.ln1 = nn.LayerNorm(input_dim)
115
+ self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
116
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
117
+ self.ln2 = nn.LayerNorm(input_dim)
118
+ self.mlp = nn.Sequential(
119
+ nn.Linear(input_dim, 4 * input_dim),
120
+ nn.GELU(),
121
+ nn.Linear(4 * input_dim, output_dim),
122
+ )
123
+
124
+ def forward(self, x):
125
+ x = x + self.drop_path(self.msa(self.ln1(x)))
126
+ x = x + self.drop_path(self.mlp(self.ln2(x)))
127
+ return x
128
+
129
+
130
+ class ConvTransBlock(nn.Module):
131
+ def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
132
+ """ SwinTransformer and Conv Block
133
+ """
134
+ super(ConvTransBlock, self).__init__()
135
+ self.conv_dim = conv_dim
136
+ self.trans_dim = trans_dim
137
+ self.head_dim = head_dim
138
+ self.window_size = window_size
139
+ self.drop_path = drop_path
140
+ self.type = type
141
+ self.input_resolution = input_resolution
142
+
143
+ assert self.type in ['W', 'SW']
144
+ if self.input_resolution <= self.window_size:
145
+ self.type = 'W'
146
+
147
+ self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
148
+ self.type, self.input_resolution)
149
+ self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
150
+ self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
151
+
152
+ self.conv_block = nn.Sequential(
153
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
154
+ nn.ReLU(True),
155
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
156
+ )
157
+
158
+ def forward(self, x):
159
+ conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
160
+ conv_x = self.conv_block(conv_x) + conv_x
161
+ trans_x = Rearrange('b c h w -> b h w c')(trans_x)
162
+ trans_x = self.trans_block(trans_x)
163
+ trans_x = Rearrange('b h w c -> b c h w')(trans_x)
164
+ res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
165
+ x = x + res
166
+
167
+ return x
168
+
169
+
170
+ class SCUNet(nn.Module):
171
+ # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
172
+ def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
173
+ super(SCUNet, self).__init__()
174
+ if config is None:
175
+ config = [2, 2, 2, 2, 2, 2, 2]
176
+ self.config = config
177
+ self.dim = dim
178
+ self.head_dim = 32
179
+ self.window_size = 8
180
+
181
+ # drop path rate for each layer
182
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
183
+
184
+ self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
185
+
186
+ begin = 0
187
+ self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
188
+ 'W' if not i % 2 else 'SW', input_resolution)
189
+ for i in range(config[0])] + \
190
+ [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
191
+
192
+ begin += config[0]
193
+ self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
194
+ 'W' if not i % 2 else 'SW', input_resolution // 2)
195
+ for i in range(config[1])] + \
196
+ [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
197
+
198
+ begin += config[1]
199
+ self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
200
+ 'W' if not i % 2 else 'SW', input_resolution // 4)
201
+ for i in range(config[2])] + \
202
+ [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
203
+
204
+ begin += config[2]
205
+ self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
206
+ 'W' if not i % 2 else 'SW', input_resolution // 8)
207
+ for i in range(config[3])]
208
+
209
+ begin += config[3]
210
+ self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
211
+ [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
212
+ 'W' if not i % 2 else 'SW', input_resolution // 4)
213
+ for i in range(config[4])]
214
+
215
+ begin += config[4]
216
+ self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
217
+ [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
218
+ 'W' if not i % 2 else 'SW', input_resolution // 2)
219
+ for i in range(config[5])]
220
+
221
+ begin += config[5]
222
+ self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
223
+ [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
224
+ 'W' if not i % 2 else 'SW', input_resolution)
225
+ for i in range(config[6])]
226
+
227
+ self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
228
+
229
+ self.m_head = nn.Sequential(*self.m_head)
230
+ self.m_down1 = nn.Sequential(*self.m_down1)
231
+ self.m_down2 = nn.Sequential(*self.m_down2)
232
+ self.m_down3 = nn.Sequential(*self.m_down3)
233
+ self.m_body = nn.Sequential(*self.m_body)
234
+ self.m_up3 = nn.Sequential(*self.m_up3)
235
+ self.m_up2 = nn.Sequential(*self.m_up2)
236
+ self.m_up1 = nn.Sequential(*self.m_up1)
237
+ self.m_tail = nn.Sequential(*self.m_tail)
238
+ # self.apply(self._init_weights)
239
+
240
+ def forward(self, x0):
241
+
242
+ h, w = x0.size()[-2:]
243
+ paddingBottom = int(np.ceil(h / 64) * 64 - h)
244
+ paddingRight = int(np.ceil(w / 64) * 64 - w)
245
+ x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
246
+
247
+ x1 = self.m_head(x0)
248
+ x2 = self.m_down1(x1)
249
+ x3 = self.m_down2(x2)
250
+ x4 = self.m_down3(x3)
251
+ x = self.m_body(x4)
252
+ x = self.m_up3(x + x4)
253
+ x = self.m_up2(x + x3)
254
+ x = self.m_up1(x + x2)
255
+ x = self.m_tail(x + x1)
256
+
257
+ x = x[..., :h, :w]
258
+
259
+ return x
260
+
261
+ def _init_weights(self, m):
262
+ if isinstance(m, nn.Linear):
263
+ trunc_normal_(m.weight, std=.02)
264
+ if m.bias is not None:
265
+ nn.init.constant_(m.bias, 0)
266
+ elif isinstance(m, nn.LayerNorm):
267
+ nn.init.constant_(m.bias, 0)
268
+ nn.init.constant_(m.weight, 1.0)
extensions-builtin/SwinIR/__pycache__/preload.cpython-310.pyc ADDED
Binary file (491 Bytes). View file
 
extensions-builtin/SwinIR/__pycache__/swinir_model_arch.cpython-310.pyc ADDED
Binary file (27.8 kB). View file
 
extensions-builtin/SwinIR/__pycache__/swinir_model_arch_v2.cpython-310.pyc ADDED
Binary file (31.3 kB). View file
 
extensions-builtin/SwinIR/preload.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+ from modules import paths
3
+
4
+
5
+ def preload(parser):
6
+ parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
extensions-builtin/SwinIR/scripts/__pycache__/swinir_model.cpython-310.pyc ADDED
Binary file (6.07 kB). View file
 
extensions-builtin/SwinIR/scripts/swinir_model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import platform
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+
9
+ from modules import modelloader, devices, script_callbacks, shared
10
+ from modules.shared import opts, state
11
+ from swinir_model_arch import SwinIR
12
+ from swinir_model_arch_v2 import Swin2SR
13
+ from modules.upscaler import Upscaler, UpscalerData
14
+
15
+ SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
16
+
17
+ device_swinir = devices.get_device_for('swinir')
18
+
19
+
20
+ class UpscalerSwinIR(Upscaler):
21
+ def __init__(self, dirname):
22
+ self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
23
+ self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
24
+ self.name = "SwinIR"
25
+ self.model_url = SWINIR_MODEL_URL
26
+ self.model_name = "SwinIR 4x"
27
+ self.user_path = dirname
28
+ super().__init__()
29
+ scalers = []
30
+ model_files = self.find_models(ext_filter=[".pt", ".pth"])
31
+ for model in model_files:
32
+ if model.startswith("http"):
33
+ name = self.model_name
34
+ else:
35
+ name = modelloader.friendly_name(model)
36
+ model_data = UpscalerData(name, model, self)
37
+ scalers.append(model_data)
38
+ self.scalers = scalers
39
+
40
+ def do_upscale(self, img, model_file):
41
+ use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
42
+ and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
43
+ current_config = (model_file, opts.SWIN_tile)
44
+
45
+ if use_compile and self._cached_model_config == current_config:
46
+ model = self._cached_model
47
+ else:
48
+ self._cached_model = None
49
+ try:
50
+ model = self.load_model(model_file)
51
+ except Exception as e:
52
+ print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
53
+ return img
54
+ model = model.to(device_swinir, dtype=devices.dtype)
55
+ if use_compile:
56
+ model = torch.compile(model)
57
+ self._cached_model = model
58
+ self._cached_model_config = current_config
59
+ img = upscale(img, model)
60
+ devices.torch_gc()
61
+ return img
62
+
63
+ def load_model(self, path, scale=4):
64
+ if path.startswith("http"):
65
+ filename = modelloader.load_file_from_url(
66
+ url=path,
67
+ model_dir=self.model_download_path,
68
+ file_name=f"{self.model_name.replace(' ', '_')}.pth",
69
+ )
70
+ else:
71
+ filename = path
72
+ if filename.endswith(".v2.pth"):
73
+ model = Swin2SR(
74
+ upscale=scale,
75
+ in_chans=3,
76
+ img_size=64,
77
+ window_size=8,
78
+ img_range=1.0,
79
+ depths=[6, 6, 6, 6, 6, 6],
80
+ embed_dim=180,
81
+ num_heads=[6, 6, 6, 6, 6, 6],
82
+ mlp_ratio=2,
83
+ upsampler="nearest+conv",
84
+ resi_connection="1conv",
85
+ )
86
+ params = None
87
+ else:
88
+ model = SwinIR(
89
+ upscale=scale,
90
+ in_chans=3,
91
+ img_size=64,
92
+ window_size=8,
93
+ img_range=1.0,
94
+ depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
95
+ embed_dim=240,
96
+ num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
97
+ mlp_ratio=2,
98
+ upsampler="nearest+conv",
99
+ resi_connection="3conv",
100
+ )
101
+ params = "params_ema"
102
+
103
+ pretrained_model = torch.load(filename)
104
+ if params is not None:
105
+ model.load_state_dict(pretrained_model[params], strict=True)
106
+ else:
107
+ model.load_state_dict(pretrained_model, strict=True)
108
+ return model
109
+
110
+
111
+ def upscale(
112
+ img,
113
+ model,
114
+ tile=None,
115
+ tile_overlap=None,
116
+ window_size=8,
117
+ scale=4,
118
+ ):
119
+ tile = tile or opts.SWIN_tile
120
+ tile_overlap = tile_overlap or opts.SWIN_tile_overlap
121
+
122
+
123
+ img = np.array(img)
124
+ img = img[:, :, ::-1]
125
+ img = np.moveaxis(img, 2, 0) / 255
126
+ img = torch.from_numpy(img).float()
127
+ img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
128
+ with torch.no_grad(), devices.autocast():
129
+ _, _, h_old, w_old = img.size()
130
+ h_pad = (h_old // window_size + 1) * window_size - h_old
131
+ w_pad = (w_old // window_size + 1) * window_size - w_old
132
+ img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
133
+ img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
134
+ output = inference(img, model, tile, tile_overlap, window_size, scale)
135
+ output = output[..., : h_old * scale, : w_old * scale]
136
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
137
+ if output.ndim == 3:
138
+ output = np.transpose(
139
+ output[[2, 1, 0], :, :], (1, 2, 0)
140
+ ) # CHW-RGB to HCW-BGR
141
+ output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
142
+ return Image.fromarray(output, "RGB")
143
+
144
+
145
+ def inference(img, model, tile, tile_overlap, window_size, scale):
146
+ # test the image tile by tile
147
+ b, c, h, w = img.size()
148
+ tile = min(tile, h, w)
149
+ assert tile % window_size == 0, "tile size should be a multiple of window_size"
150
+ sf = scale
151
+
152
+ stride = tile - tile_overlap
153
+ h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
154
+ w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
155
+ E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
156
+ W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
157
+
158
+ with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
159
+ for h_idx in h_idx_list:
160
+ if state.interrupted or state.skipped:
161
+ break
162
+
163
+ for w_idx in w_idx_list:
164
+ if state.interrupted or state.skipped:
165
+ break
166
+
167
+ in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
168
+ out_patch = model(in_patch)
169
+ out_patch_mask = torch.ones_like(out_patch)
170
+
171
+ E[
172
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
173
+ ].add_(out_patch)
174
+ W[
175
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
176
+ ].add_(out_patch_mask)
177
+ pbar.update(1)
178
+ output = E.div_(W)
179
+
180
+ return output
181
+
182
+
183
+ def on_ui_settings():
184
+ import gradio as gr
185
+
186
+ shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
187
+ shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
188
+ if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
189
+ shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
190
+
191
+
192
+ script_callbacks.on_ui_settings(on_ui_settings)
extensions-builtin/SwinIR/swinir_model_arch.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+ # -----------------------------------------------------------------------------------
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+
13
+
14
+ class Mlp(nn.Module):
15
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features)
22
+ self.drop = nn.Dropout(drop)
23
+
24
+ def forward(self, x):
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
31
+
32
+
33
+ def window_partition(x, window_size):
34
+ """
35
+ Args:
36
+ x: (B, H, W, C)
37
+ window_size (int): window size
38
+
39
+ Returns:
40
+ windows: (num_windows*B, window_size, window_size, C)
41
+ """
42
+ B, H, W, C = x.shape
43
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45
+ return windows
46
+
47
+
48
+ def window_reverse(windows, window_size, H, W):
49
+ """
50
+ Args:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ window_size (int): Window size
53
+ H (int): Height of image
54
+ W (int): Width of image
55
+
56
+ Returns:
57
+ x: (B, H, W, C)
58
+ """
59
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
60
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
61
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
62
+ return x
63
+
64
+
65
+ class WindowAttention(nn.Module):
66
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
67
+ It supports both of shifted and non-shifted window.
68
+
69
+ Args:
70
+ dim (int): Number of input channels.
71
+ window_size (tuple[int]): The height and width of the window.
72
+ num_heads (int): Number of attention heads.
73
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
74
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
75
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
76
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
77
+ """
78
+
79
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
80
+
81
+ super().__init__()
82
+ self.dim = dim
83
+ self.window_size = window_size # Wh, Ww
84
+ self.num_heads = num_heads
85
+ head_dim = dim // num_heads
86
+ self.scale = qk_scale or head_dim ** -0.5
87
+
88
+ # define a parameter table of relative position bias
89
+ self.relative_position_bias_table = nn.Parameter(
90
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
91
+
92
+ # get pair-wise relative position index for each token inside the window
93
+ coords_h = torch.arange(self.window_size[0])
94
+ coords_w = torch.arange(self.window_size[1])
95
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
96
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
97
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
98
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
99
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
100
+ relative_coords[:, :, 1] += self.window_size[1] - 1
101
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
102
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
103
+ self.register_buffer("relative_position_index", relative_position_index)
104
+
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(dim, dim)
108
+
109
+ self.proj_drop = nn.Dropout(proj_drop)
110
+
111
+ trunc_normal_(self.relative_position_bias_table, std=.02)
112
+ self.softmax = nn.Softmax(dim=-1)
113
+
114
+ def forward(self, x, mask=None):
115
+ """
116
+ Args:
117
+ x: input features with shape of (num_windows*B, N, C)
118
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
119
+ """
120
+ B_, N, C = x.shape
121
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
122
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
123
+
124
+ q = q * self.scale
125
+ attn = (q @ k.transpose(-2, -1))
126
+
127
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
128
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
129
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
130
+ attn = attn + relative_position_bias.unsqueeze(0)
131
+
132
+ if mask is not None:
133
+ nW = mask.shape[0]
134
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
135
+ attn = attn.view(-1, self.num_heads, N, N)
136
+ attn = self.softmax(attn)
137
+ else:
138
+ attn = self.softmax(attn)
139
+
140
+ attn = self.attn_drop(attn)
141
+
142
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
143
+ x = self.proj(x)
144
+ x = self.proj_drop(x)
145
+ return x
146
+
147
+ def extra_repr(self) -> str:
148
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
149
+
150
+ def flops(self, N):
151
+ # calculate flops for 1 window with token length of N
152
+ flops = 0
153
+ # qkv = self.qkv(x)
154
+ flops += N * self.dim * 3 * self.dim
155
+ # attn = (q @ k.transpose(-2, -1))
156
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
157
+ # x = (attn @ v)
158
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
159
+ # x = self.proj(x)
160
+ flops += N * self.dim * self.dim
161
+ return flops
162
+
163
+
164
+ class SwinTransformerBlock(nn.Module):
165
+ r""" Swin Transformer Block.
166
+
167
+ Args:
168
+ dim (int): Number of input channels.
169
+ input_resolution (tuple[int]): Input resolution.
170
+ num_heads (int): Number of attention heads.
171
+ window_size (int): Window size.
172
+ shift_size (int): Shift size for SW-MSA.
173
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
174
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
175
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
176
+ drop (float, optional): Dropout rate. Default: 0.0
177
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
178
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
179
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
180
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
181
+ """
182
+
183
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
184
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
185
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
186
+ super().__init__()
187
+ self.dim = dim
188
+ self.input_resolution = input_resolution
189
+ self.num_heads = num_heads
190
+ self.window_size = window_size
191
+ self.shift_size = shift_size
192
+ self.mlp_ratio = mlp_ratio
193
+ if min(self.input_resolution) <= self.window_size:
194
+ # if window size is larger than input resolution, we don't partition windows
195
+ self.shift_size = 0
196
+ self.window_size = min(self.input_resolution)
197
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
198
+
199
+ self.norm1 = norm_layer(dim)
200
+ self.attn = WindowAttention(
201
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
202
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
203
+
204
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
205
+ self.norm2 = norm_layer(dim)
206
+ mlp_hidden_dim = int(dim * mlp_ratio)
207
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
208
+
209
+ if self.shift_size > 0:
210
+ attn_mask = self.calculate_mask(self.input_resolution)
211
+ else:
212
+ attn_mask = None
213
+
214
+ self.register_buffer("attn_mask", attn_mask)
215
+
216
+ def calculate_mask(self, x_size):
217
+ # calculate attention mask for SW-MSA
218
+ H, W = x_size
219
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
220
+ h_slices = (slice(0, -self.window_size),
221
+ slice(-self.window_size, -self.shift_size),
222
+ slice(-self.shift_size, None))
223
+ w_slices = (slice(0, -self.window_size),
224
+ slice(-self.window_size, -self.shift_size),
225
+ slice(-self.shift_size, None))
226
+ cnt = 0
227
+ for h in h_slices:
228
+ for w in w_slices:
229
+ img_mask[:, h, w, :] = cnt
230
+ cnt += 1
231
+
232
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
233
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
234
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
235
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
236
+
237
+ return attn_mask
238
+
239
+ def forward(self, x, x_size):
240
+ H, W = x_size
241
+ B, L, C = x.shape
242
+ # assert L == H * W, "input feature has wrong size"
243
+
244
+ shortcut = x
245
+ x = self.norm1(x)
246
+ x = x.view(B, H, W, C)
247
+
248
+ # cyclic shift
249
+ if self.shift_size > 0:
250
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
251
+ else:
252
+ shifted_x = x
253
+
254
+ # partition windows
255
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
256
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
257
+
258
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
259
+ if self.input_resolution == x_size:
260
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
261
+ else:
262
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
263
+
264
+ # merge windows
265
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
266
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
267
+
268
+ # reverse cyclic shift
269
+ if self.shift_size > 0:
270
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
271
+ else:
272
+ x = shifted_x
273
+ x = x.view(B, H * W, C)
274
+
275
+ # FFN
276
+ x = shortcut + self.drop_path(x)
277
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
278
+
279
+ return x
280
+
281
+ def extra_repr(self) -> str:
282
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
283
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
284
+
285
+ def flops(self):
286
+ flops = 0
287
+ H, W = self.input_resolution
288
+ # norm1
289
+ flops += self.dim * H * W
290
+ # W-MSA/SW-MSA
291
+ nW = H * W / self.window_size / self.window_size
292
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
293
+ # mlp
294
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
295
+ # norm2
296
+ flops += self.dim * H * W
297
+ return flops
298
+
299
+
300
+ class PatchMerging(nn.Module):
301
+ r""" Patch Merging Layer.
302
+
303
+ Args:
304
+ input_resolution (tuple[int]): Resolution of input feature.
305
+ dim (int): Number of input channels.
306
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
307
+ """
308
+
309
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
310
+ super().__init__()
311
+ self.input_resolution = input_resolution
312
+ self.dim = dim
313
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
314
+ self.norm = norm_layer(4 * dim)
315
+
316
+ def forward(self, x):
317
+ """
318
+ x: B, H*W, C
319
+ """
320
+ H, W = self.input_resolution
321
+ B, L, C = x.shape
322
+ assert L == H * W, "input feature has wrong size"
323
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
324
+
325
+ x = x.view(B, H, W, C)
326
+
327
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
328
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
329
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
330
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
331
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
332
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
333
+
334
+ x = self.norm(x)
335
+ x = self.reduction(x)
336
+
337
+ return x
338
+
339
+ def extra_repr(self) -> str:
340
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
341
+
342
+ def flops(self):
343
+ H, W = self.input_resolution
344
+ flops = H * W * self.dim
345
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
346
+ return flops
347
+
348
+
349
+ class BasicLayer(nn.Module):
350
+ """ A basic Swin Transformer layer for one stage.
351
+
352
+ Args:
353
+ dim (int): Number of input channels.
354
+ input_resolution (tuple[int]): Input resolution.
355
+ depth (int): Number of blocks.
356
+ num_heads (int): Number of attention heads.
357
+ window_size (int): Local window size.
358
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
359
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
361
+ drop (float, optional): Dropout rate. Default: 0.0
362
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
363
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
364
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
365
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
366
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
367
+ """
368
+
369
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
370
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
371
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
372
+
373
+ super().__init__()
374
+ self.dim = dim
375
+ self.input_resolution = input_resolution
376
+ self.depth = depth
377
+ self.use_checkpoint = use_checkpoint
378
+
379
+ # build blocks
380
+ self.blocks = nn.ModuleList([
381
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
382
+ num_heads=num_heads, window_size=window_size,
383
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
384
+ mlp_ratio=mlp_ratio,
385
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
386
+ drop=drop, attn_drop=attn_drop,
387
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
388
+ norm_layer=norm_layer)
389
+ for i in range(depth)])
390
+
391
+ # patch merging layer
392
+ if downsample is not None:
393
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
394
+ else:
395
+ self.downsample = None
396
+
397
+ def forward(self, x, x_size):
398
+ for blk in self.blocks:
399
+ if self.use_checkpoint:
400
+ x = checkpoint.checkpoint(blk, x, x_size)
401
+ else:
402
+ x = blk(x, x_size)
403
+ if self.downsample is not None:
404
+ x = self.downsample(x)
405
+ return x
406
+
407
+ def extra_repr(self) -> str:
408
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
409
+
410
+ def flops(self):
411
+ flops = 0
412
+ for blk in self.blocks:
413
+ flops += blk.flops()
414
+ if self.downsample is not None:
415
+ flops += self.downsample.flops()
416
+ return flops
417
+
418
+
419
+ class RSTB(nn.Module):
420
+ """Residual Swin Transformer Block (RSTB).
421
+
422
+ Args:
423
+ dim (int): Number of input channels.
424
+ input_resolution (tuple[int]): Input resolution.
425
+ depth (int): Number of blocks.
426
+ num_heads (int): Number of attention heads.
427
+ window_size (int): Local window size.
428
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
429
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
430
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
431
+ drop (float, optional): Dropout rate. Default: 0.0
432
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
433
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
434
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
435
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
436
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
437
+ img_size: Input image size.
438
+ patch_size: Patch size.
439
+ resi_connection: The convolutional block before residual connection.
440
+ """
441
+
442
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
443
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
444
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
445
+ img_size=224, patch_size=4, resi_connection='1conv'):
446
+ super(RSTB, self).__init__()
447
+
448
+ self.dim = dim
449
+ self.input_resolution = input_resolution
450
+
451
+ self.residual_group = BasicLayer(dim=dim,
452
+ input_resolution=input_resolution,
453
+ depth=depth,
454
+ num_heads=num_heads,
455
+ window_size=window_size,
456
+ mlp_ratio=mlp_ratio,
457
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
458
+ drop=drop, attn_drop=attn_drop,
459
+ drop_path=drop_path,
460
+ norm_layer=norm_layer,
461
+ downsample=downsample,
462
+ use_checkpoint=use_checkpoint)
463
+
464
+ if resi_connection == '1conv':
465
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
466
+ elif resi_connection == '3conv':
467
+ # to save parameters and memory
468
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
469
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
470
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
471
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
472
+
473
+ self.patch_embed = PatchEmbed(
474
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
475
+ norm_layer=None)
476
+
477
+ self.patch_unembed = PatchUnEmbed(
478
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
479
+ norm_layer=None)
480
+
481
+ def forward(self, x, x_size):
482
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
483
+
484
+ def flops(self):
485
+ flops = 0
486
+ flops += self.residual_group.flops()
487
+ H, W = self.input_resolution
488
+ flops += H * W * self.dim * self.dim * 9
489
+ flops += self.patch_embed.flops()
490
+ flops += self.patch_unembed.flops()
491
+
492
+ return flops
493
+
494
+
495
+ class PatchEmbed(nn.Module):
496
+ r""" Image to Patch Embedding
497
+
498
+ Args:
499
+ img_size (int): Image size. Default: 224.
500
+ patch_size (int): Patch token size. Default: 4.
501
+ in_chans (int): Number of input image channels. Default: 3.
502
+ embed_dim (int): Number of linear projection output channels. Default: 96.
503
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
504
+ """
505
+
506
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
507
+ super().__init__()
508
+ img_size = to_2tuple(img_size)
509
+ patch_size = to_2tuple(patch_size)
510
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
511
+ self.img_size = img_size
512
+ self.patch_size = patch_size
513
+ self.patches_resolution = patches_resolution
514
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
515
+
516
+ self.in_chans = in_chans
517
+ self.embed_dim = embed_dim
518
+
519
+ if norm_layer is not None:
520
+ self.norm = norm_layer(embed_dim)
521
+ else:
522
+ self.norm = None
523
+
524
+ def forward(self, x):
525
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
526
+ if self.norm is not None:
527
+ x = self.norm(x)
528
+ return x
529
+
530
+ def flops(self):
531
+ flops = 0
532
+ H, W = self.img_size
533
+ if self.norm is not None:
534
+ flops += H * W * self.embed_dim
535
+ return flops
536
+
537
+
538
+ class PatchUnEmbed(nn.Module):
539
+ r""" Image to Patch Unembedding
540
+
541
+ Args:
542
+ img_size (int): Image size. Default: 224.
543
+ patch_size (int): Patch token size. Default: 4.
544
+ in_chans (int): Number of input image channels. Default: 3.
545
+ embed_dim (int): Number of linear projection output channels. Default: 96.
546
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
547
+ """
548
+
549
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
550
+ super().__init__()
551
+ img_size = to_2tuple(img_size)
552
+ patch_size = to_2tuple(patch_size)
553
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
554
+ self.img_size = img_size
555
+ self.patch_size = patch_size
556
+ self.patches_resolution = patches_resolution
557
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
558
+
559
+ self.in_chans = in_chans
560
+ self.embed_dim = embed_dim
561
+
562
+ def forward(self, x, x_size):
563
+ B, HW, C = x.shape
564
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
565
+ return x
566
+
567
+ def flops(self):
568
+ flops = 0
569
+ return flops
570
+
571
+
572
+ class Upsample(nn.Sequential):
573
+ """Upsample module.
574
+
575
+ Args:
576
+ scale (int): Scale factor. Supported scales: 2^n and 3.
577
+ num_feat (int): Channel number of intermediate features.
578
+ """
579
+
580
+ def __init__(self, scale, num_feat):
581
+ m = []
582
+ if (scale & (scale - 1)) == 0: # scale = 2^n
583
+ for _ in range(int(math.log(scale, 2))):
584
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
585
+ m.append(nn.PixelShuffle(2))
586
+ elif scale == 3:
587
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
588
+ m.append(nn.PixelShuffle(3))
589
+ else:
590
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
591
+ super(Upsample, self).__init__(*m)
592
+
593
+
594
+ class UpsampleOneStep(nn.Sequential):
595
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
596
+ Used in lightweight SR to save parameters.
597
+
598
+ Args:
599
+ scale (int): Scale factor. Supported scales: 2^n and 3.
600
+ num_feat (int): Channel number of intermediate features.
601
+
602
+ """
603
+
604
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
605
+ self.num_feat = num_feat
606
+ self.input_resolution = input_resolution
607
+ m = []
608
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
609
+ m.append(nn.PixelShuffle(scale))
610
+ super(UpsampleOneStep, self).__init__(*m)
611
+
612
+ def flops(self):
613
+ H, W = self.input_resolution
614
+ flops = H * W * self.num_feat * 3 * 9
615
+ return flops
616
+
617
+
618
+ class SwinIR(nn.Module):
619
+ r""" SwinIR
620
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
621
+
622
+ Args:
623
+ img_size (int | tuple(int)): Input image size. Default 64
624
+ patch_size (int | tuple(int)): Patch size. Default: 1
625
+ in_chans (int): Number of input image channels. Default: 3
626
+ embed_dim (int): Patch embedding dimension. Default: 96
627
+ depths (tuple(int)): Depth of each Swin Transformer layer.
628
+ num_heads (tuple(int)): Number of attention heads in different layers.
629
+ window_size (int): Window size. Default: 7
630
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
631
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
632
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
633
+ drop_rate (float): Dropout rate. Default: 0
634
+ attn_drop_rate (float): Attention dropout rate. Default: 0
635
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
636
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
637
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
638
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
639
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
640
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
641
+ img_range: Image range. 1. or 255.
642
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
643
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
644
+ """
645
+
646
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
647
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
648
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
649
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
650
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
651
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
652
+ **kwargs):
653
+ super(SwinIR, self).__init__()
654
+ num_in_ch = in_chans
655
+ num_out_ch = in_chans
656
+ num_feat = 64
657
+ self.img_range = img_range
658
+ if in_chans == 3:
659
+ rgb_mean = (0.4488, 0.4371, 0.4040)
660
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
661
+ else:
662
+ self.mean = torch.zeros(1, 1, 1, 1)
663
+ self.upscale = upscale
664
+ self.upsampler = upsampler
665
+ self.window_size = window_size
666
+
667
+ #####################################################################################################
668
+ ################################### 1, shallow feature extraction ###################################
669
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
670
+
671
+ #####################################################################################################
672
+ ################################### 2, deep feature extraction ######################################
673
+ self.num_layers = len(depths)
674
+ self.embed_dim = embed_dim
675
+ self.ape = ape
676
+ self.patch_norm = patch_norm
677
+ self.num_features = embed_dim
678
+ self.mlp_ratio = mlp_ratio
679
+
680
+ # split image into non-overlapping patches
681
+ self.patch_embed = PatchEmbed(
682
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
683
+ norm_layer=norm_layer if self.patch_norm else None)
684
+ num_patches = self.patch_embed.num_patches
685
+ patches_resolution = self.patch_embed.patches_resolution
686
+ self.patches_resolution = patches_resolution
687
+
688
+ # merge non-overlapping patches into image
689
+ self.patch_unembed = PatchUnEmbed(
690
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
691
+ norm_layer=norm_layer if self.patch_norm else None)
692
+
693
+ # absolute position embedding
694
+ if self.ape:
695
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
696
+ trunc_normal_(self.absolute_pos_embed, std=.02)
697
+
698
+ self.pos_drop = nn.Dropout(p=drop_rate)
699
+
700
+ # stochastic depth
701
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
702
+
703
+ # build Residual Swin Transformer blocks (RSTB)
704
+ self.layers = nn.ModuleList()
705
+ for i_layer in range(self.num_layers):
706
+ layer = RSTB(dim=embed_dim,
707
+ input_resolution=(patches_resolution[0],
708
+ patches_resolution[1]),
709
+ depth=depths[i_layer],
710
+ num_heads=num_heads[i_layer],
711
+ window_size=window_size,
712
+ mlp_ratio=self.mlp_ratio,
713
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
714
+ drop=drop_rate, attn_drop=attn_drop_rate,
715
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
716
+ norm_layer=norm_layer,
717
+ downsample=None,
718
+ use_checkpoint=use_checkpoint,
719
+ img_size=img_size,
720
+ patch_size=patch_size,
721
+ resi_connection=resi_connection
722
+
723
+ )
724
+ self.layers.append(layer)
725
+ self.norm = norm_layer(self.num_features)
726
+
727
+ # build the last conv layer in deep feature extraction
728
+ if resi_connection == '1conv':
729
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
730
+ elif resi_connection == '3conv':
731
+ # to save parameters and memory
732
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
733
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
734
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
735
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
736
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
737
+
738
+ #####################################################################################################
739
+ ################################ 3, high quality image reconstruction ################################
740
+ if self.upsampler == 'pixelshuffle':
741
+ # for classical SR
742
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
743
+ nn.LeakyReLU(inplace=True))
744
+ self.upsample = Upsample(upscale, num_feat)
745
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
746
+ elif self.upsampler == 'pixelshuffledirect':
747
+ # for lightweight SR (to save parameters)
748
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
749
+ (patches_resolution[0], patches_resolution[1]))
750
+ elif self.upsampler == 'nearest+conv':
751
+ # for real-world SR (less artifacts)
752
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
753
+ nn.LeakyReLU(inplace=True))
754
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
755
+ if self.upscale == 4:
756
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
757
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
758
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
759
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
760
+ else:
761
+ # for image denoising and JPEG compression artifact reduction
762
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
763
+
764
+ self.apply(self._init_weights)
765
+
766
+ def _init_weights(self, m):
767
+ if isinstance(m, nn.Linear):
768
+ trunc_normal_(m.weight, std=.02)
769
+ if isinstance(m, nn.Linear) and m.bias is not None:
770
+ nn.init.constant_(m.bias, 0)
771
+ elif isinstance(m, nn.LayerNorm):
772
+ nn.init.constant_(m.bias, 0)
773
+ nn.init.constant_(m.weight, 1.0)
774
+
775
+ @torch.jit.ignore
776
+ def no_weight_decay(self):
777
+ return {'absolute_pos_embed'}
778
+
779
+ @torch.jit.ignore
780
+ def no_weight_decay_keywords(self):
781
+ return {'relative_position_bias_table'}
782
+
783
+ def check_image_size(self, x):
784
+ _, _, h, w = x.size()
785
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
786
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
787
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
788
+ return x
789
+
790
+ def forward_features(self, x):
791
+ x_size = (x.shape[2], x.shape[3])
792
+ x = self.patch_embed(x)
793
+ if self.ape:
794
+ x = x + self.absolute_pos_embed
795
+ x = self.pos_drop(x)
796
+
797
+ for layer in self.layers:
798
+ x = layer(x, x_size)
799
+
800
+ x = self.norm(x) # B L C
801
+ x = self.patch_unembed(x, x_size)
802
+
803
+ return x
804
+
805
+ def forward(self, x):
806
+ H, W = x.shape[2:]
807
+ x = self.check_image_size(x)
808
+
809
+ self.mean = self.mean.type_as(x)
810
+ x = (x - self.mean) * self.img_range
811
+
812
+ if self.upsampler == 'pixelshuffle':
813
+ # for classical SR
814
+ x = self.conv_first(x)
815
+ x = self.conv_after_body(self.forward_features(x)) + x
816
+ x = self.conv_before_upsample(x)
817
+ x = self.conv_last(self.upsample(x))
818
+ elif self.upsampler == 'pixelshuffledirect':
819
+ # for lightweight SR
820
+ x = self.conv_first(x)
821
+ x = self.conv_after_body(self.forward_features(x)) + x
822
+ x = self.upsample(x)
823
+ elif self.upsampler == 'nearest+conv':
824
+ # for real-world SR
825
+ x = self.conv_first(x)
826
+ x = self.conv_after_body(self.forward_features(x)) + x
827
+ x = self.conv_before_upsample(x)
828
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
829
+ if self.upscale == 4:
830
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
831
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
832
+ else:
833
+ # for image denoising and JPEG compression artifact reduction
834
+ x_first = self.conv_first(x)
835
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
836
+ x = x + self.conv_last(res)
837
+
838
+ x = x / self.img_range + self.mean
839
+
840
+ return x[:, :, :H*self.upscale, :W*self.upscale]
841
+
842
+ def flops(self):
843
+ flops = 0
844
+ H, W = self.patches_resolution
845
+ flops += H * W * 3 * self.embed_dim * 9
846
+ flops += self.patch_embed.flops()
847
+ for layer in self.layers:
848
+ flops += layer.flops()
849
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
850
+ flops += self.upsample.flops()
851
+ return flops
852
+
853
+
854
+ if __name__ == '__main__':
855
+ upscale = 4
856
+ window_size = 8
857
+ height = (1024 // upscale // window_size + 1) * window_size
858
+ width = (720 // upscale // window_size + 1) * window_size
859
+ model = SwinIR(upscale=2, img_size=(height, width),
860
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
861
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
862
+ print(model)
863
+ print(height, width, model.flops() / 1e9)
864
+
865
+ x = torch.randn((1, 3, height, width))
866
+ x = model(x)
867
+ print(x.shape)
extensions-builtin/SwinIR/swinir_model_arch_v2.py ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
3
+ # Written by Conde and Choi et al.
4
+ # -----------------------------------------------------------------------------------
5
+
6
+ import math
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as checkpoint
12
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
+
14
+
15
+ class Mlp(nn.Module):
16
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
17
+ super().__init__()
18
+ out_features = out_features or in_features
19
+ hidden_features = hidden_features or in_features
20
+ self.fc1 = nn.Linear(in_features, hidden_features)
21
+ self.act = act_layer()
22
+ self.fc2 = nn.Linear(hidden_features, out_features)
23
+ self.drop = nn.Dropout(drop)
24
+
25
+ def forward(self, x):
26
+ x = self.fc1(x)
27
+ x = self.act(x)
28
+ x = self.drop(x)
29
+ x = self.fc2(x)
30
+ x = self.drop(x)
31
+ return x
32
+
33
+
34
+ def window_partition(x, window_size):
35
+ """
36
+ Args:
37
+ x: (B, H, W, C)
38
+ window_size (int): window size
39
+ Returns:
40
+ windows: (num_windows*B, window_size, window_size, C)
41
+ """
42
+ B, H, W, C = x.shape
43
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45
+ return windows
46
+
47
+
48
+ def window_reverse(windows, window_size, H, W):
49
+ """
50
+ Args:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ window_size (int): Window size
53
+ H (int): Height of image
54
+ W (int): Width of image
55
+ Returns:
56
+ x: (B, H, W, C)
57
+ """
58
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
59
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
60
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
61
+ return x
62
+
63
+ class WindowAttention(nn.Module):
64
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
65
+ It supports both of shifted and non-shifted window.
66
+ Args:
67
+ dim (int): Number of input channels.
68
+ window_size (tuple[int]): The height and width of the window.
69
+ num_heads (int): Number of attention heads.
70
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
71
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
72
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
73
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
74
+ """
75
+
76
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
77
+ pretrained_window_size=(0, 0)):
78
+
79
+ super().__init__()
80
+ self.dim = dim
81
+ self.window_size = window_size # Wh, Ww
82
+ self.pretrained_window_size = pretrained_window_size
83
+ self.num_heads = num_heads
84
+
85
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
86
+
87
+ # mlp to generate continuous relative position bias
88
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
89
+ nn.ReLU(inplace=True),
90
+ nn.Linear(512, num_heads, bias=False))
91
+
92
+ # get relative_coords_table
93
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
94
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
95
+ relative_coords_table = torch.stack(
96
+ torch.meshgrid([relative_coords_h,
97
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
98
+ if pretrained_window_size[0] > 0:
99
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
100
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
101
+ else:
102
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
103
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
104
+ relative_coords_table *= 8 # normalize to -8, 8
105
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
106
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
107
+
108
+ self.register_buffer("relative_coords_table", relative_coords_table)
109
+
110
+ # get pair-wise relative position index for each token inside the window
111
+ coords_h = torch.arange(self.window_size[0])
112
+ coords_w = torch.arange(self.window_size[1])
113
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
114
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
115
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
116
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
117
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
118
+ relative_coords[:, :, 1] += self.window_size[1] - 1
119
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
120
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
121
+ self.register_buffer("relative_position_index", relative_position_index)
122
+
123
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
124
+ if qkv_bias:
125
+ self.q_bias = nn.Parameter(torch.zeros(dim))
126
+ self.v_bias = nn.Parameter(torch.zeros(dim))
127
+ else:
128
+ self.q_bias = None
129
+ self.v_bias = None
130
+ self.attn_drop = nn.Dropout(attn_drop)
131
+ self.proj = nn.Linear(dim, dim)
132
+ self.proj_drop = nn.Dropout(proj_drop)
133
+ self.softmax = nn.Softmax(dim=-1)
134
+
135
+ def forward(self, x, mask=None):
136
+ """
137
+ Args:
138
+ x: input features with shape of (num_windows*B, N, C)
139
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
140
+ """
141
+ B_, N, C = x.shape
142
+ qkv_bias = None
143
+ if self.q_bias is not None:
144
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
145
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
146
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
147
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
148
+
149
+ # cosine attention
150
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
151
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
152
+ attn = attn * logit_scale
153
+
154
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
155
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
156
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
157
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
158
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
159
+ attn = attn + relative_position_bias.unsqueeze(0)
160
+
161
+ if mask is not None:
162
+ nW = mask.shape[0]
163
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
164
+ attn = attn.view(-1, self.num_heads, N, N)
165
+ attn = self.softmax(attn)
166
+ else:
167
+ attn = self.softmax(attn)
168
+
169
+ attn = self.attn_drop(attn)
170
+
171
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+ def extra_repr(self) -> str:
177
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
178
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
179
+
180
+ def flops(self, N):
181
+ # calculate flops for 1 window with token length of N
182
+ flops = 0
183
+ # qkv = self.qkv(x)
184
+ flops += N * self.dim * 3 * self.dim
185
+ # attn = (q @ k.transpose(-2, -1))
186
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
187
+ # x = (attn @ v)
188
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
189
+ # x = self.proj(x)
190
+ flops += N * self.dim * self.dim
191
+ return flops
192
+
193
+ class SwinTransformerBlock(nn.Module):
194
+ r""" Swin Transformer Block.
195
+ Args:
196
+ dim (int): Number of input channels.
197
+ input_resolution (tuple[int]): Input resulotion.
198
+ num_heads (int): Number of attention heads.
199
+ window_size (int): Window size.
200
+ shift_size (int): Shift size for SW-MSA.
201
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
202
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
203
+ drop (float, optional): Dropout rate. Default: 0.0
204
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
205
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
206
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
207
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
208
+ pretrained_window_size (int): Window size in pre-training.
209
+ """
210
+
211
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
212
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
213
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
214
+ super().__init__()
215
+ self.dim = dim
216
+ self.input_resolution = input_resolution
217
+ self.num_heads = num_heads
218
+ self.window_size = window_size
219
+ self.shift_size = shift_size
220
+ self.mlp_ratio = mlp_ratio
221
+ if min(self.input_resolution) <= self.window_size:
222
+ # if window size is larger than input resolution, we don't partition windows
223
+ self.shift_size = 0
224
+ self.window_size = min(self.input_resolution)
225
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
226
+
227
+ self.norm1 = norm_layer(dim)
228
+ self.attn = WindowAttention(
229
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
230
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
231
+ pretrained_window_size=to_2tuple(pretrained_window_size))
232
+
233
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
234
+ self.norm2 = norm_layer(dim)
235
+ mlp_hidden_dim = int(dim * mlp_ratio)
236
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
237
+
238
+ if self.shift_size > 0:
239
+ attn_mask = self.calculate_mask(self.input_resolution)
240
+ else:
241
+ attn_mask = None
242
+
243
+ self.register_buffer("attn_mask", attn_mask)
244
+
245
+ def calculate_mask(self, x_size):
246
+ # calculate attention mask for SW-MSA
247
+ H, W = x_size
248
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
249
+ h_slices = (slice(0, -self.window_size),
250
+ slice(-self.window_size, -self.shift_size),
251
+ slice(-self.shift_size, None))
252
+ w_slices = (slice(0, -self.window_size),
253
+ slice(-self.window_size, -self.shift_size),
254
+ slice(-self.shift_size, None))
255
+ cnt = 0
256
+ for h in h_slices:
257
+ for w in w_slices:
258
+ img_mask[:, h, w, :] = cnt
259
+ cnt += 1
260
+
261
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
262
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
263
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
264
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
265
+
266
+ return attn_mask
267
+
268
+ def forward(self, x, x_size):
269
+ H, W = x_size
270
+ B, L, C = x.shape
271
+ #assert L == H * W, "input feature has wrong size"
272
+
273
+ shortcut = x
274
+ x = x.view(B, H, W, C)
275
+
276
+ # cyclic shift
277
+ if self.shift_size > 0:
278
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
279
+ else:
280
+ shifted_x = x
281
+
282
+ # partition windows
283
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
284
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
285
+
286
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
287
+ if self.input_resolution == x_size:
288
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
289
+ else:
290
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
291
+
292
+ # merge windows
293
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
294
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
295
+
296
+ # reverse cyclic shift
297
+ if self.shift_size > 0:
298
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
299
+ else:
300
+ x = shifted_x
301
+ x = x.view(B, H * W, C)
302
+ x = shortcut + self.drop_path(self.norm1(x))
303
+
304
+ # FFN
305
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
306
+
307
+ return x
308
+
309
+ def extra_repr(self) -> str:
310
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
311
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
312
+
313
+ def flops(self):
314
+ flops = 0
315
+ H, W = self.input_resolution
316
+ # norm1
317
+ flops += self.dim * H * W
318
+ # W-MSA/SW-MSA
319
+ nW = H * W / self.window_size / self.window_size
320
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
321
+ # mlp
322
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
323
+ # norm2
324
+ flops += self.dim * H * W
325
+ return flops
326
+
327
+ class PatchMerging(nn.Module):
328
+ r""" Patch Merging Layer.
329
+ Args:
330
+ input_resolution (tuple[int]): Resolution of input feature.
331
+ dim (int): Number of input channels.
332
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
333
+ """
334
+
335
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
336
+ super().__init__()
337
+ self.input_resolution = input_resolution
338
+ self.dim = dim
339
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
340
+ self.norm = norm_layer(2 * dim)
341
+
342
+ def forward(self, x):
343
+ """
344
+ x: B, H*W, C
345
+ """
346
+ H, W = self.input_resolution
347
+ B, L, C = x.shape
348
+ assert L == H * W, "input feature has wrong size"
349
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
350
+
351
+ x = x.view(B, H, W, C)
352
+
353
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
354
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
355
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
356
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
357
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
358
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
359
+
360
+ x = self.reduction(x)
361
+ x = self.norm(x)
362
+
363
+ return x
364
+
365
+ def extra_repr(self) -> str:
366
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
367
+
368
+ def flops(self):
369
+ H, W = self.input_resolution
370
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
371
+ flops += H * W * self.dim // 2
372
+ return flops
373
+
374
+ class BasicLayer(nn.Module):
375
+ """ A basic Swin Transformer layer for one stage.
376
+ Args:
377
+ dim (int): Number of input channels.
378
+ input_resolution (tuple[int]): Input resolution.
379
+ depth (int): Number of blocks.
380
+ num_heads (int): Number of attention heads.
381
+ window_size (int): Local window size.
382
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
383
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
384
+ drop (float, optional): Dropout rate. Default: 0.0
385
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
386
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
387
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
388
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
389
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
390
+ pretrained_window_size (int): Local window size in pre-training.
391
+ """
392
+
393
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
394
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
395
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
396
+ pretrained_window_size=0):
397
+
398
+ super().__init__()
399
+ self.dim = dim
400
+ self.input_resolution = input_resolution
401
+ self.depth = depth
402
+ self.use_checkpoint = use_checkpoint
403
+
404
+ # build blocks
405
+ self.blocks = nn.ModuleList([
406
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
407
+ num_heads=num_heads, window_size=window_size,
408
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
409
+ mlp_ratio=mlp_ratio,
410
+ qkv_bias=qkv_bias,
411
+ drop=drop, attn_drop=attn_drop,
412
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
413
+ norm_layer=norm_layer,
414
+ pretrained_window_size=pretrained_window_size)
415
+ for i in range(depth)])
416
+
417
+ # patch merging layer
418
+ if downsample is not None:
419
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
420
+ else:
421
+ self.downsample = None
422
+
423
+ def forward(self, x, x_size):
424
+ for blk in self.blocks:
425
+ if self.use_checkpoint:
426
+ x = checkpoint.checkpoint(blk, x, x_size)
427
+ else:
428
+ x = blk(x, x_size)
429
+ if self.downsample is not None:
430
+ x = self.downsample(x)
431
+ return x
432
+
433
+ def extra_repr(self) -> str:
434
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
435
+
436
+ def flops(self):
437
+ flops = 0
438
+ for blk in self.blocks:
439
+ flops += blk.flops()
440
+ if self.downsample is not None:
441
+ flops += self.downsample.flops()
442
+ return flops
443
+
444
+ def _init_respostnorm(self):
445
+ for blk in self.blocks:
446
+ nn.init.constant_(blk.norm1.bias, 0)
447
+ nn.init.constant_(blk.norm1.weight, 0)
448
+ nn.init.constant_(blk.norm2.bias, 0)
449
+ nn.init.constant_(blk.norm2.weight, 0)
450
+
451
+ class PatchEmbed(nn.Module):
452
+ r""" Image to Patch Embedding
453
+ Args:
454
+ img_size (int): Image size. Default: 224.
455
+ patch_size (int): Patch token size. Default: 4.
456
+ in_chans (int): Number of input image channels. Default: 3.
457
+ embed_dim (int): Number of linear projection output channels. Default: 96.
458
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
459
+ """
460
+
461
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
462
+ super().__init__()
463
+ img_size = to_2tuple(img_size)
464
+ patch_size = to_2tuple(patch_size)
465
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
466
+ self.img_size = img_size
467
+ self.patch_size = patch_size
468
+ self.patches_resolution = patches_resolution
469
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
470
+
471
+ self.in_chans = in_chans
472
+ self.embed_dim = embed_dim
473
+
474
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
475
+ if norm_layer is not None:
476
+ self.norm = norm_layer(embed_dim)
477
+ else:
478
+ self.norm = None
479
+
480
+ def forward(self, x):
481
+ B, C, H, W = x.shape
482
+ # FIXME look at relaxing size constraints
483
+ # assert H == self.img_size[0] and W == self.img_size[1],
484
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
485
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
486
+ if self.norm is not None:
487
+ x = self.norm(x)
488
+ return x
489
+
490
+ def flops(self):
491
+ Ho, Wo = self.patches_resolution
492
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
493
+ if self.norm is not None:
494
+ flops += Ho * Wo * self.embed_dim
495
+ return flops
496
+
497
+ class RSTB(nn.Module):
498
+ """Residual Swin Transformer Block (RSTB).
499
+
500
+ Args:
501
+ dim (int): Number of input channels.
502
+ input_resolution (tuple[int]): Input resolution.
503
+ depth (int): Number of blocks.
504
+ num_heads (int): Number of attention heads.
505
+ window_size (int): Local window size.
506
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
507
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
508
+ drop (float, optional): Dropout rate. Default: 0.0
509
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
510
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
511
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
512
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
513
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
514
+ img_size: Input image size.
515
+ patch_size: Patch size.
516
+ resi_connection: The convolutional block before residual connection.
517
+ """
518
+
519
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
520
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
521
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
522
+ img_size=224, patch_size=4, resi_connection='1conv'):
523
+ super(RSTB, self).__init__()
524
+
525
+ self.dim = dim
526
+ self.input_resolution = input_resolution
527
+
528
+ self.residual_group = BasicLayer(dim=dim,
529
+ input_resolution=input_resolution,
530
+ depth=depth,
531
+ num_heads=num_heads,
532
+ window_size=window_size,
533
+ mlp_ratio=mlp_ratio,
534
+ qkv_bias=qkv_bias,
535
+ drop=drop, attn_drop=attn_drop,
536
+ drop_path=drop_path,
537
+ norm_layer=norm_layer,
538
+ downsample=downsample,
539
+ use_checkpoint=use_checkpoint)
540
+
541
+ if resi_connection == '1conv':
542
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
543
+ elif resi_connection == '3conv':
544
+ # to save parameters and memory
545
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
546
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
547
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
548
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
549
+
550
+ self.patch_embed = PatchEmbed(
551
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
552
+ norm_layer=None)
553
+
554
+ self.patch_unembed = PatchUnEmbed(
555
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
556
+ norm_layer=None)
557
+
558
+ def forward(self, x, x_size):
559
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
560
+
561
+ def flops(self):
562
+ flops = 0
563
+ flops += self.residual_group.flops()
564
+ H, W = self.input_resolution
565
+ flops += H * W * self.dim * self.dim * 9
566
+ flops += self.patch_embed.flops()
567
+ flops += self.patch_unembed.flops()
568
+
569
+ return flops
570
+
571
+ class PatchUnEmbed(nn.Module):
572
+ r""" Image to Patch Unembedding
573
+
574
+ Args:
575
+ img_size (int): Image size. Default: 224.
576
+ patch_size (int): Patch token size. Default: 4.
577
+ in_chans (int): Number of input image channels. Default: 3.
578
+ embed_dim (int): Number of linear projection output channels. Default: 96.
579
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
580
+ """
581
+
582
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
583
+ super().__init__()
584
+ img_size = to_2tuple(img_size)
585
+ patch_size = to_2tuple(patch_size)
586
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
587
+ self.img_size = img_size
588
+ self.patch_size = patch_size
589
+ self.patches_resolution = patches_resolution
590
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
591
+
592
+ self.in_chans = in_chans
593
+ self.embed_dim = embed_dim
594
+
595
+ def forward(self, x, x_size):
596
+ B, HW, C = x.shape
597
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
598
+ return x
599
+
600
+ def flops(self):
601
+ flops = 0
602
+ return flops
603
+
604
+
605
+ class Upsample(nn.Sequential):
606
+ """Upsample module.
607
+
608
+ Args:
609
+ scale (int): Scale factor. Supported scales: 2^n and 3.
610
+ num_feat (int): Channel number of intermediate features.
611
+ """
612
+
613
+ def __init__(self, scale, num_feat):
614
+ m = []
615
+ if (scale & (scale - 1)) == 0: # scale = 2^n
616
+ for _ in range(int(math.log(scale, 2))):
617
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
618
+ m.append(nn.PixelShuffle(2))
619
+ elif scale == 3:
620
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
621
+ m.append(nn.PixelShuffle(3))
622
+ else:
623
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
624
+ super(Upsample, self).__init__(*m)
625
+
626
+ class Upsample_hf(nn.Sequential):
627
+ """Upsample module.
628
+
629
+ Args:
630
+ scale (int): Scale factor. Supported scales: 2^n and 3.
631
+ num_feat (int): Channel number of intermediate features.
632
+ """
633
+
634
+ def __init__(self, scale, num_feat):
635
+ m = []
636
+ if (scale & (scale - 1)) == 0: # scale = 2^n
637
+ for _ in range(int(math.log(scale, 2))):
638
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
639
+ m.append(nn.PixelShuffle(2))
640
+ elif scale == 3:
641
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
642
+ m.append(nn.PixelShuffle(3))
643
+ else:
644
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
645
+ super(Upsample_hf, self).__init__(*m)
646
+
647
+
648
+ class UpsampleOneStep(nn.Sequential):
649
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
650
+ Used in lightweight SR to save parameters.
651
+
652
+ Args:
653
+ scale (int): Scale factor. Supported scales: 2^n and 3.
654
+ num_feat (int): Channel number of intermediate features.
655
+
656
+ """
657
+
658
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
659
+ self.num_feat = num_feat
660
+ self.input_resolution = input_resolution
661
+ m = []
662
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
663
+ m.append(nn.PixelShuffle(scale))
664
+ super(UpsampleOneStep, self).__init__(*m)
665
+
666
+ def flops(self):
667
+ H, W = self.input_resolution
668
+ flops = H * W * self.num_feat * 3 * 9
669
+ return flops
670
+
671
+
672
+
673
+ class Swin2SR(nn.Module):
674
+ r""" Swin2SR
675
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
676
+
677
+ Args:
678
+ img_size (int | tuple(int)): Input image size. Default 64
679
+ patch_size (int | tuple(int)): Patch size. Default: 1
680
+ in_chans (int): Number of input image channels. Default: 3
681
+ embed_dim (int): Patch embedding dimension. Default: 96
682
+ depths (tuple(int)): Depth of each Swin Transformer layer.
683
+ num_heads (tuple(int)): Number of attention heads in different layers.
684
+ window_size (int): Window size. Default: 7
685
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
686
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
687
+ drop_rate (float): Dropout rate. Default: 0
688
+ attn_drop_rate (float): Attention dropout rate. Default: 0
689
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
690
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
691
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
692
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
693
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
694
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
695
+ img_range: Image range. 1. or 255.
696
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
697
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
698
+ """
699
+
700
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
701
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
702
+ window_size=7, mlp_ratio=4., qkv_bias=True,
703
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
704
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
705
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
706
+ **kwargs):
707
+ super(Swin2SR, self).__init__()
708
+ num_in_ch = in_chans
709
+ num_out_ch = in_chans
710
+ num_feat = 64
711
+ self.img_range = img_range
712
+ if in_chans == 3:
713
+ rgb_mean = (0.4488, 0.4371, 0.4040)
714
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
715
+ else:
716
+ self.mean = torch.zeros(1, 1, 1, 1)
717
+ self.upscale = upscale
718
+ self.upsampler = upsampler
719
+ self.window_size = window_size
720
+
721
+ #####################################################################################################
722
+ ################################### 1, shallow feature extraction ###################################
723
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
724
+
725
+ #####################################################################################################
726
+ ################################### 2, deep feature extraction ######################################
727
+ self.num_layers = len(depths)
728
+ self.embed_dim = embed_dim
729
+ self.ape = ape
730
+ self.patch_norm = patch_norm
731
+ self.num_features = embed_dim
732
+ self.mlp_ratio = mlp_ratio
733
+
734
+ # split image into non-overlapping patches
735
+ self.patch_embed = PatchEmbed(
736
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
737
+ norm_layer=norm_layer if self.patch_norm else None)
738
+ num_patches = self.patch_embed.num_patches
739
+ patches_resolution = self.patch_embed.patches_resolution
740
+ self.patches_resolution = patches_resolution
741
+
742
+ # merge non-overlapping patches into image
743
+ self.patch_unembed = PatchUnEmbed(
744
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
745
+ norm_layer=norm_layer if self.patch_norm else None)
746
+
747
+ # absolute position embedding
748
+ if self.ape:
749
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
750
+ trunc_normal_(self.absolute_pos_embed, std=.02)
751
+
752
+ self.pos_drop = nn.Dropout(p=drop_rate)
753
+
754
+ # stochastic depth
755
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
756
+
757
+ # build Residual Swin Transformer blocks (RSTB)
758
+ self.layers = nn.ModuleList()
759
+ for i_layer in range(self.num_layers):
760
+ layer = RSTB(dim=embed_dim,
761
+ input_resolution=(patches_resolution[0],
762
+ patches_resolution[1]),
763
+ depth=depths[i_layer],
764
+ num_heads=num_heads[i_layer],
765
+ window_size=window_size,
766
+ mlp_ratio=self.mlp_ratio,
767
+ qkv_bias=qkv_bias,
768
+ drop=drop_rate, attn_drop=attn_drop_rate,
769
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
770
+ norm_layer=norm_layer,
771
+ downsample=None,
772
+ use_checkpoint=use_checkpoint,
773
+ img_size=img_size,
774
+ patch_size=patch_size,
775
+ resi_connection=resi_connection
776
+
777
+ )
778
+ self.layers.append(layer)
779
+
780
+ if self.upsampler == 'pixelshuffle_hf':
781
+ self.layers_hf = nn.ModuleList()
782
+ for i_layer in range(self.num_layers):
783
+ layer = RSTB(dim=embed_dim,
784
+ input_resolution=(patches_resolution[0],
785
+ patches_resolution[1]),
786
+ depth=depths[i_layer],
787
+ num_heads=num_heads[i_layer],
788
+ window_size=window_size,
789
+ mlp_ratio=self.mlp_ratio,
790
+ qkv_bias=qkv_bias,
791
+ drop=drop_rate, attn_drop=attn_drop_rate,
792
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
793
+ norm_layer=norm_layer,
794
+ downsample=None,
795
+ use_checkpoint=use_checkpoint,
796
+ img_size=img_size,
797
+ patch_size=patch_size,
798
+ resi_connection=resi_connection
799
+
800
+ )
801
+ self.layers_hf.append(layer)
802
+
803
+ self.norm = norm_layer(self.num_features)
804
+
805
+ # build the last conv layer in deep feature extraction
806
+ if resi_connection == '1conv':
807
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
808
+ elif resi_connection == '3conv':
809
+ # to save parameters and memory
810
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
811
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
812
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
813
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
814
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
815
+
816
+ #####################################################################################################
817
+ ################################ 3, high quality image reconstruction ################################
818
+ if self.upsampler == 'pixelshuffle':
819
+ # for classical SR
820
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
821
+ nn.LeakyReLU(inplace=True))
822
+ self.upsample = Upsample(upscale, num_feat)
823
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
824
+ elif self.upsampler == 'pixelshuffle_aux':
825
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
826
+ self.conv_before_upsample = nn.Sequential(
827
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
828
+ nn.LeakyReLU(inplace=True))
829
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
830
+ self.conv_after_aux = nn.Sequential(
831
+ nn.Conv2d(3, num_feat, 3, 1, 1),
832
+ nn.LeakyReLU(inplace=True))
833
+ self.upsample = Upsample(upscale, num_feat)
834
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
835
+
836
+ elif self.upsampler == 'pixelshuffle_hf':
837
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
838
+ nn.LeakyReLU(inplace=True))
839
+ self.upsample = Upsample(upscale, num_feat)
840
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
841
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
842
+ self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
843
+ nn.LeakyReLU(inplace=True))
844
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
845
+ self.conv_before_upsample_hf = nn.Sequential(
846
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
847
+ nn.LeakyReLU(inplace=True))
848
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
849
+
850
+ elif self.upsampler == 'pixelshuffledirect':
851
+ # for lightweight SR (to save parameters)
852
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
853
+ (patches_resolution[0], patches_resolution[1]))
854
+ elif self.upsampler == 'nearest+conv':
855
+ # for real-world SR (less artifacts)
856
+ assert self.upscale == 4, 'only support x4 now.'
857
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
858
+ nn.LeakyReLU(inplace=True))
859
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
860
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
861
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
862
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
863
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
864
+ else:
865
+ # for image denoising and JPEG compression artifact reduction
866
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
867
+
868
+ self.apply(self._init_weights)
869
+
870
+ def _init_weights(self, m):
871
+ if isinstance(m, nn.Linear):
872
+ trunc_normal_(m.weight, std=.02)
873
+ if isinstance(m, nn.Linear) and m.bias is not None:
874
+ nn.init.constant_(m.bias, 0)
875
+ elif isinstance(m, nn.LayerNorm):
876
+ nn.init.constant_(m.bias, 0)
877
+ nn.init.constant_(m.weight, 1.0)
878
+
879
+ @torch.jit.ignore
880
+ def no_weight_decay(self):
881
+ return {'absolute_pos_embed'}
882
+
883
+ @torch.jit.ignore
884
+ def no_weight_decay_keywords(self):
885
+ return {'relative_position_bias_table'}
886
+
887
+ def check_image_size(self, x):
888
+ _, _, h, w = x.size()
889
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
890
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
891
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
892
+ return x
893
+
894
+ def forward_features(self, x):
895
+ x_size = (x.shape[2], x.shape[3])
896
+ x = self.patch_embed(x)
897
+ if self.ape:
898
+ x = x + self.absolute_pos_embed
899
+ x = self.pos_drop(x)
900
+
901
+ for layer in self.layers:
902
+ x = layer(x, x_size)
903
+
904
+ x = self.norm(x) # B L C
905
+ x = self.patch_unembed(x, x_size)
906
+
907
+ return x
908
+
909
+ def forward_features_hf(self, x):
910
+ x_size = (x.shape[2], x.shape[3])
911
+ x = self.patch_embed(x)
912
+ if self.ape:
913
+ x = x + self.absolute_pos_embed
914
+ x = self.pos_drop(x)
915
+
916
+ for layer in self.layers_hf:
917
+ x = layer(x, x_size)
918
+
919
+ x = self.norm(x) # B L C
920
+ x = self.patch_unembed(x, x_size)
921
+
922
+ return x
923
+
924
+ def forward(self, x):
925
+ H, W = x.shape[2:]
926
+ x = self.check_image_size(x)
927
+
928
+ self.mean = self.mean.type_as(x)
929
+ x = (x - self.mean) * self.img_range
930
+
931
+ if self.upsampler == 'pixelshuffle':
932
+ # for classical SR
933
+ x = self.conv_first(x)
934
+ x = self.conv_after_body(self.forward_features(x)) + x
935
+ x = self.conv_before_upsample(x)
936
+ x = self.conv_last(self.upsample(x))
937
+ elif self.upsampler == 'pixelshuffle_aux':
938
+ bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
939
+ bicubic = self.conv_bicubic(bicubic)
940
+ x = self.conv_first(x)
941
+ x = self.conv_after_body(self.forward_features(x)) + x
942
+ x = self.conv_before_upsample(x)
943
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
944
+ x = self.conv_after_aux(aux)
945
+ x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
946
+ x = self.conv_last(x)
947
+ aux = aux / self.img_range + self.mean
948
+ elif self.upsampler == 'pixelshuffle_hf':
949
+ # for classical SR with HF
950
+ x = self.conv_first(x)
951
+ x = self.conv_after_body(self.forward_features(x)) + x
952
+ x_before = self.conv_before_upsample(x)
953
+ x_out = self.conv_last(self.upsample(x_before))
954
+
955
+ x_hf = self.conv_first_hf(x_before)
956
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
957
+ x_hf = self.conv_before_upsample_hf(x_hf)
958
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
959
+ x = x_out + x_hf
960
+ x_hf = x_hf / self.img_range + self.mean
961
+
962
+ elif self.upsampler == 'pixelshuffledirect':
963
+ # for lightweight SR
964
+ x = self.conv_first(x)
965
+ x = self.conv_after_body(self.forward_features(x)) + x
966
+ x = self.upsample(x)
967
+ elif self.upsampler == 'nearest+conv':
968
+ # for real-world SR
969
+ x = self.conv_first(x)
970
+ x = self.conv_after_body(self.forward_features(x)) + x
971
+ x = self.conv_before_upsample(x)
972
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
973
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
974
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
975
+ else:
976
+ # for image denoising and JPEG compression artifact reduction
977
+ x_first = self.conv_first(x)
978
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
979
+ x = x + self.conv_last(res)
980
+
981
+ x = x / self.img_range + self.mean
982
+ if self.upsampler == "pixelshuffle_aux":
983
+ return x[:, :, :H*self.upscale, :W*self.upscale], aux
984
+
985
+ elif self.upsampler == "pixelshuffle_hf":
986
+ x_out = x_out / self.img_range + self.mean
987
+ return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
988
+
989
+ else:
990
+ return x[:, :, :H*self.upscale, :W*self.upscale]
991
+
992
+ def flops(self):
993
+ flops = 0
994
+ H, W = self.patches_resolution
995
+ flops += H * W * 3 * self.embed_dim * 9
996
+ flops += self.patch_embed.flops()
997
+ for layer in self.layers:
998
+ flops += layer.flops()
999
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
1000
+ flops += self.upsample.flops()
1001
+ return flops
1002
+
1003
+
1004
+ if __name__ == '__main__':
1005
+ upscale = 4
1006
+ window_size = 8
1007
+ height = (1024 // upscale // window_size + 1) * window_size
1008
+ width = (720 // upscale // window_size + 1) * window_size
1009
+ model = Swin2SR(upscale=2, img_size=(height, width),
1010
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
1011
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
1012
+ print(model)
1013
+ print(height, width, model.flops() / 1e9)
1014
+
1015
+ x = torch.randn((1, 3, height, width))
1016
+ x = model(x)
1017
+ print(x.shape)
extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ onUiLoaded(async() => {
2
+ const elementIDs = {
3
+ img2imgTabs: "#mode_img2img .tab-nav",
4
+ inpaint: "#img2maskimg",
5
+ inpaintSketch: "#inpaint_sketch",
6
+ rangeGroup: "#img2img_column_size",
7
+ sketch: "#img2img_sketch"
8
+ };
9
+ const tabNameToElementId = {
10
+ "Inpaint sketch": elementIDs.inpaintSketch,
11
+ "Inpaint": elementIDs.inpaint,
12
+ "Sketch": elementIDs.sketch
13
+ };
14
+
15
+ // Helper functions
16
+ // Get active tab
17
+ function getActiveTab(elements, all = false) {
18
+ const tabs = elements.img2imgTabs.querySelectorAll("button");
19
+
20
+ if (all) return tabs;
21
+
22
+ for (let tab of tabs) {
23
+ if (tab.classList.contains("selected")) {
24
+ return tab;
25
+ }
26
+ }
27
+ }
28
+
29
+ // Get tab ID
30
+ function getTabId(elements) {
31
+ const activeTab = getActiveTab(elements);
32
+ return tabNameToElementId[activeTab.innerText];
33
+ }
34
+
35
+ // Wait until opts loaded
36
+ async function waitForOpts() {
37
+ for (;;) {
38
+ if (window.opts && Object.keys(window.opts).length) {
39
+ return window.opts;
40
+ }
41
+ await new Promise(resolve => setTimeout(resolve, 100));
42
+ }
43
+ }
44
+
45
+ // Function for defining the "Ctrl", "Shift" and "Alt" keys
46
+ function isModifierKey(event, key) {
47
+ switch (key) {
48
+ case "Ctrl":
49
+ return event.ctrlKey;
50
+ case "Shift":
51
+ return event.shiftKey;
52
+ case "Alt":
53
+ return event.altKey;
54
+ default:
55
+ return false;
56
+ }
57
+ }
58
+
59
+ // Check if hotkey is valid
60
+ function isValidHotkey(value) {
61
+ const specialKeys = ["Ctrl", "Alt", "Shift", "Disable"];
62
+ return (
63
+ (typeof value === "string" &&
64
+ value.length === 1 &&
65
+ /[a-z]/i.test(value)) ||
66
+ specialKeys.includes(value)
67
+ );
68
+ }
69
+
70
+ // Normalize hotkey
71
+ function normalizeHotkey(hotkey) {
72
+ return hotkey.length === 1 ? "Key" + hotkey.toUpperCase() : hotkey;
73
+ }
74
+
75
+ // Format hotkey for display
76
+ function formatHotkeyForDisplay(hotkey) {
77
+ return hotkey.startsWith("Key") ? hotkey.slice(3) : hotkey;
78
+ }
79
+
80
+ // Create hotkey configuration with the provided options
81
+ function createHotkeyConfig(defaultHotkeysConfig, hotkeysConfigOpts) {
82
+ const result = {}; // Resulting hotkey configuration
83
+ const usedKeys = new Set(); // Set of used hotkeys
84
+
85
+ // Iterate through defaultHotkeysConfig keys
86
+ for (const key in defaultHotkeysConfig) {
87
+ const userValue = hotkeysConfigOpts[key]; // User-provided hotkey value
88
+ const defaultValue = defaultHotkeysConfig[key]; // Default hotkey value
89
+
90
+ // Apply appropriate value for undefined, boolean, or object userValue
91
+ if (
92
+ userValue === undefined ||
93
+ typeof userValue === "boolean" ||
94
+ typeof userValue === "object" ||
95
+ userValue === "disable"
96
+ ) {
97
+ result[key] =
98
+ userValue === undefined ? defaultValue : userValue;
99
+ } else if (isValidHotkey(userValue)) {
100
+ const normalizedUserValue = normalizeHotkey(userValue);
101
+
102
+ // Check for conflicting hotkeys
103
+ if (!usedKeys.has(normalizedUserValue)) {
104
+ usedKeys.add(normalizedUserValue);
105
+ result[key] = normalizedUserValue;
106
+ } else {
107
+ console.error(
108
+ `Hotkey: ${formatHotkeyForDisplay(
109
+ userValue
110
+ )} for ${key} is repeated and conflicts with another hotkey. The default hotkey is used: ${formatHotkeyForDisplay(
111
+ defaultValue
112
+ )}`
113
+ );
114
+ result[key] = defaultValue;
115
+ }
116
+ } else {
117
+ console.error(
118
+ `Hotkey: ${formatHotkeyForDisplay(
119
+ userValue
120
+ )} for ${key} is not valid. The default hotkey is used: ${formatHotkeyForDisplay(
121
+ defaultValue
122
+ )}`
123
+ );
124
+ result[key] = defaultValue;
125
+ }
126
+ }
127
+
128
+ return result;
129
+ }
130
+
131
+ // Disables functions in the config object based on the provided list of function names
132
+ function disableFunctions(config, disabledFunctions) {
133
+ // Bind the hasOwnProperty method to the functionMap object to avoid errors
134
+ const hasOwnProperty =
135
+ Object.prototype.hasOwnProperty.bind(functionMap);
136
+
137
+ // Loop through the disabledFunctions array and disable the corresponding functions in the config object
138
+ disabledFunctions.forEach(funcName => {
139
+ if (hasOwnProperty(funcName)) {
140
+ const key = functionMap[funcName];
141
+ config[key] = "disable";
142
+ }
143
+ });
144
+
145
+ // Return the updated config object
146
+ return config;
147
+ }
148
+
149
+ /**
150
+ * The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio.
151
+ * If the image display property is set to 'none', the mask breaks. To fix this, the function
152
+ * temporarily sets the display property to 'block' and then hides the mask again after 300 milliseconds
153
+ * to avoid breaking the canvas. Additionally, the function adjusts the mask to work correctly on
154
+ * very long images.
155
+ */
156
+ function restoreImgRedMask(elements) {
157
+ const mainTabId = getTabId(elements);
158
+
159
+ if (!mainTabId) return;
160
+
161
+ const mainTab = gradioApp().querySelector(mainTabId);
162
+ const img = mainTab.querySelector("img");
163
+ const imageARPreview = gradioApp().querySelector("#imageARPreview");
164
+
165
+ if (!img || !imageARPreview) return;
166
+
167
+ imageARPreview.style.transform = "";
168
+ if (parseFloat(mainTab.style.width) > 865) {
169
+ const transformString = mainTab.style.transform;
170
+ const scaleMatch = transformString.match(
171
+ /scale\(([-+]?[0-9]*\.?[0-9]+)\)/
172
+ );
173
+ let zoom = 1; // default zoom
174
+
175
+ if (scaleMatch && scaleMatch[1]) {
176
+ zoom = Number(scaleMatch[1]);
177
+ }
178
+
179
+ imageARPreview.style.transformOrigin = "0 0";
180
+ imageARPreview.style.transform = `scale(${zoom})`;
181
+ }
182
+
183
+ if (img.style.display !== "none") return;
184
+
185
+ img.style.display = "block";
186
+
187
+ setTimeout(() => {
188
+ img.style.display = "none";
189
+ }, 400);
190
+ }
191
+
192
+ const hotkeysConfigOpts = await waitForOpts();
193
+
194
+ // Default config
195
+ const defaultHotkeysConfig = {
196
+ canvas_hotkey_zoom: "Alt",
197
+ canvas_hotkey_adjust: "Ctrl",
198
+ canvas_hotkey_reset: "KeyR",
199
+ canvas_hotkey_fullscreen: "KeyS",
200
+ canvas_hotkey_move: "KeyF",
201
+ canvas_hotkey_overlap: "KeyO",
202
+ canvas_disabled_functions: [],
203
+ canvas_show_tooltip: true,
204
+ canvas_blur_prompt: false
205
+ };
206
+
207
+ const functionMap = {
208
+ "Zoom": "canvas_hotkey_zoom",
209
+ "Adjust brush size": "canvas_hotkey_adjust",
210
+ "Moving canvas": "canvas_hotkey_move",
211
+ "Fullscreen": "canvas_hotkey_fullscreen",
212
+ "Reset Zoom": "canvas_hotkey_reset",
213
+ "Overlap": "canvas_hotkey_overlap"
214
+ };
215
+
216
+ // Loading the configuration from opts
217
+ const preHotkeysConfig = createHotkeyConfig(
218
+ defaultHotkeysConfig,
219
+ hotkeysConfigOpts
220
+ );
221
+
222
+ // Disable functions that are not needed by the user
223
+ const hotkeysConfig = disableFunctions(
224
+ preHotkeysConfig,
225
+ preHotkeysConfig.canvas_disabled_functions
226
+ );
227
+
228
+ let isMoving = false;
229
+ let mouseX, mouseY;
230
+ let activeElement;
231
+
232
+ const elements = Object.fromEntries(
233
+ Object.keys(elementIDs).map(id => [
234
+ id,
235
+ gradioApp().querySelector(elementIDs[id])
236
+ ])
237
+ );
238
+ const elemData = {};
239
+
240
+ // Apply functionality to the range inputs. Restore redmask and correct for long images.
241
+ const rangeInputs = elements.rangeGroup ?
242
+ Array.from(elements.rangeGroup.querySelectorAll("input")) :
243
+ [
244
+ gradioApp().querySelector("#img2img_width input[type='range']"),
245
+ gradioApp().querySelector("#img2img_height input[type='range']")
246
+ ];
247
+
248
+ for (const input of rangeInputs) {
249
+ input?.addEventListener("input", () => restoreImgRedMask(elements));
250
+ }
251
+
252
+ function applyZoomAndPan(elemId) {
253
+ const targetElement = gradioApp().querySelector(elemId);
254
+
255
+ if (!targetElement) {
256
+ console.log("Element not found");
257
+ return;
258
+ }
259
+
260
+ targetElement.style.transformOrigin = "0 0";
261
+
262
+ elemData[elemId] = {
263
+ zoom: 1,
264
+ panX: 0,
265
+ panY: 0
266
+ };
267
+ let fullScreenMode = false;
268
+
269
+ // Create tooltip
270
+ function createTooltip() {
271
+ const toolTipElemnt =
272
+ targetElement.querySelector(".image-container");
273
+ const tooltip = document.createElement("div");
274
+ tooltip.className = "canvas-tooltip";
275
+
276
+ // Creating an item of information
277
+ const info = document.createElement("i");
278
+ info.className = "canvas-tooltip-info";
279
+ info.textContent = "";
280
+
281
+ // Create a container for the contents of the tooltip
282
+ const tooltipContent = document.createElement("div");
283
+ tooltipContent.className = "canvas-tooltip-content";
284
+
285
+ // Define an array with hotkey information and their actions
286
+ const hotkeysInfo = [
287
+ {
288
+ configKey: "canvas_hotkey_zoom",
289
+ action: "Zoom canvas",
290
+ keySuffix: " + wheel"
291
+ },
292
+ {
293
+ configKey: "canvas_hotkey_adjust",
294
+ action: "Adjust brush size",
295
+ keySuffix: " + wheel"
296
+ },
297
+ {configKey: "canvas_hotkey_reset", action: "Reset zoom"},
298
+ {
299
+ configKey: "canvas_hotkey_fullscreen",
300
+ action: "Fullscreen mode"
301
+ },
302
+ {configKey: "canvas_hotkey_move", action: "Move canvas"},
303
+ {configKey: "canvas_hotkey_overlap", action: "Overlap"}
304
+ ];
305
+
306
+ // Create hotkeys array with disabled property based on the config values
307
+ const hotkeys = hotkeysInfo.map(info => {
308
+ const configValue = hotkeysConfig[info.configKey];
309
+ const key = info.keySuffix ?
310
+ `${configValue}${info.keySuffix}` :
311
+ configValue.charAt(configValue.length - 1);
312
+ return {
313
+ key,
314
+ action: info.action,
315
+ disabled: configValue === "disable"
316
+ };
317
+ });
318
+
319
+ for (const hotkey of hotkeys) {
320
+ if (hotkey.disabled) {
321
+ continue;
322
+ }
323
+
324
+ const p = document.createElement("p");
325
+ p.innerHTML = `<b>${hotkey.key}</b> - ${hotkey.action}`;
326
+ tooltipContent.appendChild(p);
327
+ }
328
+
329
+ // Add information and content elements to the tooltip element
330
+ tooltip.appendChild(info);
331
+ tooltip.appendChild(tooltipContent);
332
+
333
+ // Add a hint element to the target element
334
+ toolTipElemnt.appendChild(tooltip);
335
+ }
336
+
337
+ //Show tool tip if setting enable
338
+ if (hotkeysConfig.canvas_show_tooltip) {
339
+ createTooltip();
340
+ }
341
+
342
+ // In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui.
343
+ function fixCanvas() {
344
+ const activeTab = getActiveTab(elements).textContent.trim();
345
+
346
+ if (activeTab !== "img2img") {
347
+ const img = targetElement.querySelector(`${elemId} img`);
348
+
349
+ if (img && img.style.display !== "none") {
350
+ img.style.display = "none";
351
+ img.style.visibility = "hidden";
352
+ }
353
+ }
354
+ }
355
+
356
+ // Reset the zoom level and pan position of the target element to their initial values
357
+ function resetZoom() {
358
+ elemData[elemId] = {
359
+ zoomLevel: 1,
360
+ panX: 0,
361
+ panY: 0
362
+ };
363
+
364
+ fixCanvas();
365
+ targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
366
+
367
+ const canvas = gradioApp().querySelector(
368
+ `${elemId} canvas[key="interface"]`
369
+ );
370
+
371
+ toggleOverlap("off");
372
+ fullScreenMode = false;
373
+
374
+ if (
375
+ canvas &&
376
+ parseFloat(canvas.style.width) > 865 &&
377
+ parseFloat(targetElement.style.width) > 865
378
+ ) {
379
+ fitToElement();
380
+ return;
381
+ }
382
+
383
+ targetElement.style.width = "";
384
+ if (canvas) {
385
+ targetElement.style.height = canvas.style.height;
386
+ }
387
+ }
388
+
389
+ // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
390
+ function toggleOverlap(forced = "") {
391
+ const zIndex1 = "0";
392
+ const zIndex2 = "998";
393
+
394
+ targetElement.style.zIndex =
395
+ targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1;
396
+
397
+ if (forced === "off") {
398
+ targetElement.style.zIndex = zIndex1;
399
+ } else if (forced === "on") {
400
+ targetElement.style.zIndex = zIndex2;
401
+ }
402
+ }
403
+
404
+ // Adjust the brush size based on the deltaY value from a mouse wheel event
405
+ function adjustBrushSize(
406
+ elemId,
407
+ deltaY,
408
+ withoutValue = false,
409
+ percentage = 5
410
+ ) {
411
+ const input =
412
+ gradioApp().querySelector(
413
+ `${elemId} input[aria-label='Brush radius']`
414
+ ) ||
415
+ gradioApp().querySelector(
416
+ `${elemId} button[aria-label="Use brush"]`
417
+ );
418
+
419
+ if (input) {
420
+ input.click();
421
+ if (!withoutValue) {
422
+ const maxValue =
423
+ parseFloat(input.getAttribute("max")) || 100;
424
+ const changeAmount = maxValue * (percentage / 100);
425
+ const newValue =
426
+ parseFloat(input.value) +
427
+ (deltaY > 0 ? -changeAmount : changeAmount);
428
+ input.value = Math.min(Math.max(newValue, 0), maxValue);
429
+ input.dispatchEvent(new Event("change"));
430
+ }
431
+ }
432
+ }
433
+
434
+ // Reset zoom when uploading a new image
435
+ const fileInput = gradioApp().querySelector(
436
+ `${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`
437
+ );
438
+ fileInput.addEventListener("click", resetZoom);
439
+
440
+ // Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
441
+ function updateZoom(newZoomLevel, mouseX, mouseY) {
442
+ newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15));
443
+
444
+ elemData[elemId].panX +=
445
+ mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel;
446
+ elemData[elemId].panY +=
447
+ mouseY - (mouseY * newZoomLevel) / elemData[elemId].zoomLevel;
448
+
449
+ targetElement.style.transformOrigin = "0 0";
450
+ targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
451
+
452
+ toggleOverlap("on");
453
+ return newZoomLevel;
454
+ }
455
+
456
+ // Change the zoom level based on user interaction
457
+ function changeZoomLevel(operation, e) {
458
+ if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) {
459
+ e.preventDefault();
460
+
461
+ let zoomPosX, zoomPosY;
462
+ let delta = 0.2;
463
+ if (elemData[elemId].zoomLevel > 7) {
464
+ delta = 0.9;
465
+ } else if (elemData[elemId].zoomLevel > 2) {
466
+ delta = 0.6;
467
+ }
468
+
469
+ zoomPosX = e.clientX;
470
+ zoomPosY = e.clientY;
471
+
472
+ fullScreenMode = false;
473
+ elemData[elemId].zoomLevel = updateZoom(
474
+ elemData[elemId].zoomLevel +
475
+ (operation === "+" ? delta : -delta),
476
+ zoomPosX - targetElement.getBoundingClientRect().left,
477
+ zoomPosY - targetElement.getBoundingClientRect().top
478
+ );
479
+ }
480
+ }
481
+
482
+ /**
483
+ * This function fits the target element to the screen by calculating
484
+ * the required scale and offsets. It also updates the global variables
485
+ * zoomLevel, panX, and panY to reflect the new state.
486
+ */
487
+
488
+ function fitToElement() {
489
+ //Reset Zoom
490
+ targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
491
+
492
+ // Get element and screen dimensions
493
+ const elementWidth = targetElement.offsetWidth;
494
+ const elementHeight = targetElement.offsetHeight;
495
+ const parentElement = targetElement.parentElement;
496
+ const screenWidth = parentElement.clientWidth;
497
+ const screenHeight = parentElement.clientHeight;
498
+
499
+ // Get element's coordinates relative to the parent element
500
+ const elementRect = targetElement.getBoundingClientRect();
501
+ const parentRect = parentElement.getBoundingClientRect();
502
+ const elementX = elementRect.x - parentRect.x;
503
+
504
+ // Calculate scale and offsets
505
+ const scaleX = screenWidth / elementWidth;
506
+ const scaleY = screenHeight / elementHeight;
507
+ const scale = Math.min(scaleX, scaleY);
508
+
509
+ const transformOrigin =
510
+ window.getComputedStyle(targetElement).transformOrigin;
511
+ const [originX, originY] = transformOrigin.split(" ");
512
+ const originXValue = parseFloat(originX);
513
+ const originYValue = parseFloat(originY);
514
+
515
+ const offsetX =
516
+ (screenWidth - elementWidth * scale) / 2 -
517
+ originXValue * (1 - scale);
518
+ const offsetY =
519
+ (screenHeight - elementHeight * scale) / 2.5 -
520
+ originYValue * (1 - scale);
521
+
522
+ // Apply scale and offsets to the element
523
+ targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
524
+
525
+ // Update global variables
526
+ elemData[elemId].zoomLevel = scale;
527
+ elemData[elemId].panX = offsetX;
528
+ elemData[elemId].panY = offsetY;
529
+
530
+ fullScreenMode = false;
531
+ toggleOverlap("off");
532
+ }
533
+
534
+ /**
535
+ * This function fits the target element to the screen by calculating
536
+ * the required scale and offsets. It also updates the global variables
537
+ * zoomLevel, panX, and panY to reflect the new state.
538
+ */
539
+
540
+ // Fullscreen mode
541
+ function fitToScreen() {
542
+ const canvas = gradioApp().querySelector(
543
+ `${elemId} canvas[key="interface"]`
544
+ );
545
+
546
+ if (!canvas) return;
547
+
548
+ if (canvas.offsetWidth > 862) {
549
+ targetElement.style.width = canvas.offsetWidth + "px";
550
+ }
551
+
552
+ if (fullScreenMode) {
553
+ resetZoom();
554
+ fullScreenMode = false;
555
+ return;
556
+ }
557
+
558
+ //Reset Zoom
559
+ targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
560
+
561
+ // Get scrollbar width to right-align the image
562
+ const scrollbarWidth =
563
+ window.innerWidth - document.documentElement.clientWidth;
564
+
565
+ // Get element and screen dimensions
566
+ const elementWidth = targetElement.offsetWidth;
567
+ const elementHeight = targetElement.offsetHeight;
568
+ const screenWidth = window.innerWidth - scrollbarWidth;
569
+ const screenHeight = window.innerHeight;
570
+
571
+ // Get element's coordinates relative to the page
572
+ const elementRect = targetElement.getBoundingClientRect();
573
+ const elementY = elementRect.y;
574
+ const elementX = elementRect.x;
575
+
576
+ // Calculate scale and offsets
577
+ const scaleX = screenWidth / elementWidth;
578
+ const scaleY = screenHeight / elementHeight;
579
+ const scale = Math.min(scaleX, scaleY);
580
+
581
+ // Get the current transformOrigin
582
+ const computedStyle = window.getComputedStyle(targetElement);
583
+ const transformOrigin = computedStyle.transformOrigin;
584
+ const [originX, originY] = transformOrigin.split(" ");
585
+ const originXValue = parseFloat(originX);
586
+ const originYValue = parseFloat(originY);
587
+
588
+ // Calculate offsets with respect to the transformOrigin
589
+ const offsetX =
590
+ (screenWidth - elementWidth * scale) / 2 -
591
+ elementX -
592
+ originXValue * (1 - scale);
593
+ const offsetY =
594
+ (screenHeight - elementHeight * scale) / 2 -
595
+ elementY -
596
+ originYValue * (1 - scale);
597
+
598
+ // Apply scale and offsets to the element
599
+ targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
600
+
601
+ // Update global variables
602
+ elemData[elemId].zoomLevel = scale;
603
+ elemData[elemId].panX = offsetX;
604
+ elemData[elemId].panY = offsetY;
605
+
606
+ fullScreenMode = true;
607
+ toggleOverlap("on");
608
+ }
609
+
610
+ // Handle keydown events
611
+ function handleKeyDown(event) {
612
+ // Disable key locks to make pasting from the buffer work correctly
613
+ if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === "F5") {
614
+ return;
615
+ }
616
+
617
+ // before activating shortcut, ensure user is not actively typing in an input field
618
+ if (!hotkeysConfig.canvas_blur_prompt) {
619
+ if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') {
620
+ return;
621
+ }
622
+ }
623
+
624
+
625
+ const hotkeyActions = {
626
+ [hotkeysConfig.canvas_hotkey_reset]: resetZoom,
627
+ [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
628
+ [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen
629
+ };
630
+
631
+ const action = hotkeyActions[event.code];
632
+ if (action) {
633
+ event.preventDefault();
634
+ action(event);
635
+ }
636
+
637
+ if (
638
+ isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) ||
639
+ isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust)
640
+ ) {
641
+ event.preventDefault();
642
+ }
643
+ }
644
+
645
+ // Get Mouse position
646
+ function getMousePosition(e) {
647
+ mouseX = e.offsetX;
648
+ mouseY = e.offsetY;
649
+ }
650
+
651
+ targetElement.addEventListener("mousemove", getMousePosition);
652
+
653
+ // Handle events only inside the targetElement
654
+ let isKeyDownHandlerAttached = false;
655
+
656
+ function handleMouseMove() {
657
+ if (!isKeyDownHandlerAttached) {
658
+ document.addEventListener("keydown", handleKeyDown);
659
+ isKeyDownHandlerAttached = true;
660
+
661
+ activeElement = elemId;
662
+ }
663
+ }
664
+
665
+ function handleMouseLeave() {
666
+ if (isKeyDownHandlerAttached) {
667
+ document.removeEventListener("keydown", handleKeyDown);
668
+ isKeyDownHandlerAttached = false;
669
+
670
+ activeElement = null;
671
+ }
672
+ }
673
+
674
+ // Add mouse event handlers
675
+ targetElement.addEventListener("mousemove", handleMouseMove);
676
+ targetElement.addEventListener("mouseleave", handleMouseLeave);
677
+
678
+ // Reset zoom when click on another tab
679
+ elements.img2imgTabs.addEventListener("click", resetZoom);
680
+ elements.img2imgTabs.addEventListener("click", () => {
681
+ // targetElement.style.width = "";
682
+ if (parseInt(targetElement.style.width) > 865) {
683
+ setTimeout(fitToElement, 0);
684
+ }
685
+ });
686
+
687
+ targetElement.addEventListener("wheel", e => {
688
+ // change zoom level
689
+ const operation = e.deltaY > 0 ? "-" : "+";
690
+ changeZoomLevel(operation, e);
691
+
692
+ // Handle brush size adjustment with ctrl key pressed
693
+ if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) {
694
+ e.preventDefault();
695
+
696
+ // Increase or decrease brush size based on scroll direction
697
+ adjustBrushSize(elemId, e.deltaY);
698
+ }
699
+ });
700
+
701
+ // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
702
+ function handleMoveKeyDown(e) {
703
+
704
+ // Disable key locks to make pasting from the buffer work correctly
705
+ if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") {
706
+ return;
707
+ }
708
+
709
+ // before activating shortcut, ensure user is not actively typing in an input field
710
+ if (!hotkeysConfig.canvas_blur_prompt) {
711
+ if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') {
712
+ return;
713
+ }
714
+ }
715
+
716
+
717
+ if (e.code === hotkeysConfig.canvas_hotkey_move) {
718
+ if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {
719
+ e.preventDefault();
720
+ document.activeElement.blur();
721
+ isMoving = true;
722
+ }
723
+ }
724
+ }
725
+
726
+ function handleMoveKeyUp(e) {
727
+ if (e.code === hotkeysConfig.canvas_hotkey_move) {
728
+ isMoving = false;
729
+ }
730
+ }
731
+
732
+ document.addEventListener("keydown", handleMoveKeyDown);
733
+ document.addEventListener("keyup", handleMoveKeyUp);
734
+
735
+ // Detect zoom level and update the pan speed.
736
+ function updatePanPosition(movementX, movementY) {
737
+ let panSpeed = 2;
738
+
739
+ if (elemData[elemId].zoomLevel > 8) {
740
+ panSpeed = 3.5;
741
+ }
742
+
743
+ elemData[elemId].panX += movementX * panSpeed;
744
+ elemData[elemId].panY += movementY * panSpeed;
745
+
746
+ // Delayed redraw of an element
747
+ requestAnimationFrame(() => {
748
+ targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${elemData[elemId].zoomLevel})`;
749
+ toggleOverlap("on");
750
+ });
751
+ }
752
+
753
+ function handleMoveByKey(e) {
754
+ if (isMoving && elemId === activeElement) {
755
+ updatePanPosition(e.movementX, e.movementY);
756
+ targetElement.style.pointerEvents = "none";
757
+ } else {
758
+ targetElement.style.pointerEvents = "auto";
759
+ }
760
+ }
761
+
762
+ // Prevents sticking to the mouse
763
+ window.onblur = function() {
764
+ isMoving = false;
765
+ };
766
+
767
+ gradioApp().addEventListener("mousemove", handleMoveByKey);
768
+ }
769
+
770
+ applyZoomAndPan(elementIDs.sketch);
771
+ applyZoomAndPan(elementIDs.inpaint);
772
+ applyZoomAndPan(elementIDs.inpaintSketch);
773
+
774
+ // Make the function global so that other extensions can take advantage of this solution
775
+ window.applyZoomAndPan = applyZoomAndPan;
776
+ });
extensions-builtin/canvas-zoom-and-pan/scripts/__pycache__/hotkey_config.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from modules import shared
3
+
4
+ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
5
+ "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
6
+ "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
7
+ "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
8
+ "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
9
+ "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
10
+ "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
11
+ "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
12
+ "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
13
+ "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
14
+ }))
extensions-builtin/canvas-zoom-and-pan/style.css ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .canvas-tooltip-info {
2
+ position: absolute;
3
+ top: 10px;
4
+ left: 10px;
5
+ cursor: help;
6
+ background-color: rgba(0, 0, 0, 0.3);
7
+ width: 20px;
8
+ height: 20px;
9
+ border-radius: 50%;
10
+ display: flex;
11
+ align-items: center;
12
+ justify-content: center;
13
+ flex-direction: column;
14
+
15
+ z-index: 100;
16
+ }
17
+
18
+ .canvas-tooltip-info::after {
19
+ content: '';
20
+ display: block;
21
+ width: 2px;
22
+ height: 7px;
23
+ background-color: white;
24
+ margin-top: 2px;
25
+ }
26
+
27
+ .canvas-tooltip-info::before {
28
+ content: '';
29
+ display: block;
30
+ width: 2px;
31
+ height: 2px;
32
+ background-color: white;
33
+ }
34
+
35
+ .canvas-tooltip-content {
36
+ display: none;
37
+ background-color: #f9f9f9;
38
+ color: #333;
39
+ border: 1px solid #ddd;
40
+ padding: 15px;
41
+ position: absolute;
42
+ top: 40px;
43
+ left: 10px;
44
+ width: 250px;
45
+ font-size: 16px;
46
+ opacity: 0;
47
+ border-radius: 8px;
48
+ box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
49
+
50
+ z-index: 100;
51
+ }
52
+
53
+ .canvas-tooltip:hover .canvas-tooltip-content {
54
+ display: block;
55
+ animation: fadeIn 0.5s;
56
+ opacity: 1;
57
+ }
58
+
59
+ @keyframes fadeIn {
60
+ from {opacity: 0;}
61
+ to {opacity: 1;}
62
+ }
63
+
extensions-builtin/extra-options-section/scripts/__pycache__/extra_options_section.cpython-310.pyc ADDED
Binary file (2.84 kB). View file
 
extensions-builtin/extra-options-section/scripts/extra_options_section.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from modules import scripts, shared, ui_components, ui_settings
3
+ from modules.ui_components import FormColumn
4
+
5
+
6
+ class ExtraOptionsSection(scripts.Script):
7
+ section = "extra_options"
8
+
9
+ def __init__(self):
10
+ self.comps = None
11
+ self.setting_names = None
12
+
13
+ def title(self):
14
+ return "Extra options"
15
+
16
+ def show(self, is_img2img):
17
+ return scripts.AlwaysVisible
18
+
19
+ def ui(self, is_img2img):
20
+ self.comps = []
21
+ self.setting_names = []
22
+
23
+ with gr.Blocks() as interface:
24
+ with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
25
+ for setting_name in shared.opts.extra_options:
26
+ with FormColumn():
27
+ comp = ui_settings.create_setting_component(setting_name)
28
+
29
+ self.comps.append(comp)
30
+ self.setting_names.append(setting_name)
31
+
32
+ def get_settings_values():
33
+ return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
34
+
35
+ interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
36
+
37
+ return self.comps
38
+
39
+ def before_process(self, p, *args):
40
+ for name, value in zip(self.setting_names, args):
41
+ if name not in p.override_settings:
42
+ p.override_settings[name] = value
43
+
44
+
45
+ shared.options_templates.update(shared.options_section(('ui', "User interface"), {
46
+ "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
47
+ "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
48
+ }))
extensions-builtin/mobile/javascript/mobile.js ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var isSetupForMobile = false;
2
+
3
+ function isMobile() {
4
+ for (var tab of ["txt2img", "img2img"]) {
5
+ var imageTab = gradioApp().getElementById(tab + '_results');
6
+ if (imageTab && imageTab.offsetParent && imageTab.offsetLeft == 0) {
7
+ return true;
8
+ }
9
+ }
10
+
11
+ return false;
12
+ }
13
+
14
+ function reportWindowSize() {
15
+ var currentlyMobile = isMobile();
16
+ if (currentlyMobile == isSetupForMobile) return;
17
+ isSetupForMobile = currentlyMobile;
18
+
19
+ for (var tab of ["txt2img", "img2img"]) {
20
+ var button = gradioApp().getElementById(tab + '_generate_box');
21
+ var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
22
+ target.insertBefore(button, target.firstElementChild);
23
+ }
24
+ }
25
+
26
+ window.addEventListener("resize", reportWindowSize);
extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Stable Diffusion WebUI - Bracket checker
2
+ // By Hingashi no Florin/Bwin4L & @akx
3
+ // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
4
+ // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
5
+
6
+ function checkBrackets(textArea, counterElt) {
7
+ var counts = {};
8
+ (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
9
+ counts[bracket] = (counts[bracket] || 0) + 1;
10
+ });
11
+ var errors = [];
12
+
13
+ function checkPair(open, close, kind) {
14
+ if (counts[open] !== counts[close]) {
15
+ errors.push(
16
+ `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
17
+ );
18
+ }
19
+ }
20
+
21
+ checkPair('(', ')', 'round brackets');
22
+ checkPair('[', ']', 'square brackets');
23
+ checkPair('{', '}', 'curly brackets');
24
+ counterElt.title = errors.join('\n');
25
+ counterElt.classList.toggle('error', errors.length !== 0);
26
+ }
27
+
28
+ function setupBracketChecking(id_prompt, id_counter) {
29
+ var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
30
+ var counter = gradioApp().getElementById(id_counter);
31
+
32
+ if (textarea && counter) {
33
+ textarea.addEventListener("input", () => checkBrackets(textarea, counter));
34
+ }
35
+ }
36
+
37
+ onUiLoaded(function() {
38
+ setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
39
+ setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
40
+ setupBracketChecking('img2img_prompt', 'img2img_token_counter');
41
+ setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
42
+ });
extensions-builtin/sd_theme_editor/install.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import launch
extensions-builtin/sd_theme_editor/javascript/ui_theme.js ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function hexToRgb(color) {
2
+ let hex = color[0] === "#" ? color.slice(1) : color;
3
+ let c;
4
+
5
+ // expand the short hex by doubling each character, fc0 -> ffcc00
6
+ if (hex.length !== 6) {
7
+ hex = (() => {
8
+ const result = [];
9
+ for (c of Array.from(hex)) {
10
+ result.push(`${c}${c}`);
11
+ }
12
+ return result;
13
+ })().join("");
14
+ }
15
+ const colorStr = hex.match(/#?(.{2})(.{2})(.{2})/).slice(1);
16
+ const rgb = colorStr.map((col) => parseInt(col, 16));
17
+ rgb.push(1);
18
+ return rgb;
19
+ }
20
+
21
+ function rgbToHsl(rgb) {
22
+ const r = rgb[0] / 255;
23
+ const g = rgb[1] / 255;
24
+ const b = rgb[2] / 255;
25
+
26
+ const max = Math.max(r, g, b);
27
+ const min = Math.min(r, g, b);
28
+ const diff = max - min;
29
+ const add = max + min;
30
+
31
+ const hue =
32
+ min === max
33
+ ? 0
34
+ : r === max
35
+ ? ((60 * (g - b)) / diff + 360) % 360
36
+ : g === max
37
+ ? (60 * (b - r)) / diff + 120
38
+ : (60 * (r - g)) / diff + 240;
39
+
40
+ const lum = 0.5 * add;
41
+
42
+ const sat =
43
+ lum === 0 ? 0 : lum === 1 ? 1 : lum <= 0.5 ? diff / add : diff / (2 - add);
44
+
45
+ const h = Math.round(hue);
46
+ const s = Math.round(sat * 100);
47
+ const l = Math.round(lum * 100);
48
+ const a = rgb[3] || 1;
49
+
50
+ return [h, s, l, a];
51
+ }
52
+
53
+ function hexToHsl(color) {
54
+ const rgb = hexToRgb(color);
55
+ const hsl = rgbToHsl(rgb);
56
+ return "hsl(" + hsl[0] + "deg " + hsl[1] + "% " + hsl[2] + "%)";
57
+ }
58
+
59
+ function hslToHex(h, s, l) {
60
+ l /= 100;
61
+ const a = (s * Math.min(l, 1 - l)) / 100;
62
+ const f = (n) => {
63
+ const k = (n + h / 30) % 12;
64
+ const color = l - a * Math.max(Math.min(k - 3, 9 - k, 1), -1);
65
+ return Math.round(255 * Math.max(0, Math.min(color, 1)))
66
+ .toString(16)
67
+ .padStart(2, "0"); // convert to Hex and prefix "0" if needed
68
+ };
69
+ return `#${f(0)}${f(8)}${f(4)}`;
70
+ }
71
+
72
+ function hsl2rgb(h, s, l) {
73
+ let a = s * Math.min(l, 1 - l);
74
+ let f = (n, k = (n + h / 30) % 12) =>
75
+ l - a * Math.max(Math.min(k - 3, 9 - k, 1), -1);
76
+ return [f(0), f(8), f(4)];
77
+ }
78
+
79
+ function invertColor(hex) {
80
+ if (hex.indexOf("#") === 0) {
81
+ hex = hex.slice(1);
82
+ }
83
+ // convert 3-digit hex to 6-digits.
84
+ if (hex.length === 3) {
85
+ hex = hex[0] + hex[0] + hex[1] + hex[1] + hex[2] + hex[2];
86
+ }
87
+ if (hex.length !== 6) {
88
+ throw new Error("Invalid HEX color.");
89
+ }
90
+ // invert color components
91
+ var r = (255 - parseInt(hex.slice(0, 2), 16)).toString(16),
92
+ g = (255 - parseInt(hex.slice(2, 4), 16)).toString(16),
93
+ b = (255 - parseInt(hex.slice(4, 6), 16)).toString(16);
94
+ // pad each with zeros and return
95
+ return "#" + padZero(r) + padZero(g) + padZero(b);
96
+ }
97
+
98
+ function padZero(str, len) {
99
+ len = len || 2;
100
+ var zeros = new Array(len).join("0");
101
+ return (zeros + str).slice(-len);
102
+ }
103
+
104
+ function getValsWrappedIn(str, c1, c2) {
105
+ var rg = new RegExp("(?<=\\" + c1 + ")(.*?)(?=\\" + c2 + ")", "g");
106
+ return str.match(rg);
107
+ }
108
+
109
+ let styleobj = {};
110
+ let hslobj = {};
111
+ let isColorsInv;
112
+
113
+ const toHSLArray = (hslStr) => hslStr.match(/\d+/g).map(Number);
114
+
115
+ function offsetColorsHSV(ohsl) {
116
+ let inner_styles = "";
117
+
118
+ for (const key in styleobj) {
119
+ let keyVal = styleobj[key];
120
+
121
+ if (keyVal.indexOf("#") != -1 || keyVal.indexOf("hsl") != -1) {
122
+ let colcomp = gradioApp().querySelector("#" + key + " input");
123
+ if (colcomp) {
124
+ let hsl;
125
+
126
+ if (keyVal.indexOf("#") != -1) {
127
+ keyVal = keyVal.replace(/\s+/g, "");
128
+ //inv ? keyVal = invertColor(keyVal) : 0;
129
+ if (isColorsInv) {
130
+ keyVal = invertColor(keyVal);
131
+ styleobj[key] = keyVal;
132
+ }
133
+ hsl = rgbToHsl(hexToRgb(keyVal));
134
+ } else {
135
+ if (isColorsInv) {
136
+ let c = toHSLArray(keyVal);
137
+ let hex = hslToHex(c[0], c[1], c[2]);
138
+ keyVal = invertColor(hex);
139
+ styleobj[key] = keyVal;
140
+ hsl = rgbToHsl(hexToRgb(keyVal));
141
+ } else {
142
+ hsl = toHSLArray(keyVal);
143
+ }
144
+ }
145
+
146
+ let h = (parseInt(hsl[0]) + parseInt(ohsl[0])) % 360;
147
+ let s = parseInt(hsl[1]) + parseInt(ohsl[1]);
148
+ let l = parseInt(hsl[2]) + parseInt(ohsl[2]);
149
+
150
+ let hex = hslToHex(
151
+ h,
152
+ Math.min(Math.max(s, 0), 100),
153
+ Math.min(Math.max(l, 0), 100)
154
+ );
155
+
156
+ colcomp.value = hex;
157
+
158
+ hslobj[key] = "hsl(" + h + "deg " + s + "% " + l + "%)";
159
+ inner_styles += key + ":" + hslobj[key] + ";";
160
+ }
161
+ } else {
162
+ inner_styles += key + ":" + styleobj[key] + ";";
163
+ }
164
+ }
165
+
166
+ isColorsInv = false;
167
+
168
+ const preview_styles = gradioApp().querySelector("#preview-styles");
169
+ preview_styles.innerHTML = ":root {" + inner_styles + "}";
170
+ preview_styles.innerHTML +=
171
+ "@media only screen and (max-width: 860px) {:root{--ae-outside-gap-size: var(--ae-mobile-outside-gap-size);--ae-inside-padding-size: var(--ae-mobile-inside-padding-size);}}";
172
+
173
+ const vars_textarea = gradioApp().querySelector("#theme_vars textarea");
174
+ vars_textarea.value = inner_styles;
175
+
176
+ const inputEvent = new Event("input");
177
+ Object.defineProperty(inputEvent, "target", { value: vars_textarea });
178
+ vars_textarea.dispatchEvent(inputEvent);
179
+ }
180
+
181
+ function updateTheme(vars) {
182
+ let inner_styles = "";
183
+
184
+ for (let i = 0; i < vars.length - 1; i++) {
185
+ let key = vars[i].split(":");
186
+ let id = key[0].replace(/\s+/g, "");
187
+ let val = key[1].trim();
188
+
189
+ styleobj[id] = val;
190
+ inner_styles += id + ":" + val + ";";
191
+
192
+ gradioApp()
193
+ .querySelectorAll("#" + id + " input")
194
+ .forEach((elem) => {
195
+ if (val.indexOf("hsl") != -1) {
196
+ let hsl = toHSLArray(val);
197
+ let hex = hslToHex(hsl[0], hsl[1], hsl[2]);
198
+ elem.value = hex;
199
+ } else {
200
+ elem.value = val.split("px")[0];
201
+ }
202
+ });
203
+ }
204
+
205
+ const preview_styles = gradioApp().querySelector("#preview-styles");
206
+
207
+ if (preview_styles) {
208
+ preview_styles.innerHTML = ":root {" + inner_styles + "}";
209
+ preview_styles.innerHTML +=
210
+ "@media only screen and (max-width: 860px) {:root{--ae-outside-gap-size: var(--ae-mobile-outside-gap-size);--ae-inside-padding-size: var(--ae-mobile-inside-padding-size);}}";
211
+ } else {
212
+ const r = gradioApp();
213
+ const style = document.createElement("style");
214
+ style.id = "preview-styles";
215
+ style.innerHTML = ":root {" + inner_styles + "}";
216
+ style.innerHTML +=
217
+ "@media only screen and (max-width: 860px) {:root{--ae-outside-gap-size: var(--ae-mobile-outside-gap-size);--ae-inside-padding-size: var(--ae-mobile-inside-padding-size);}}";
218
+ r.appendChild(style);
219
+ }
220
+
221
+ const vars_textarea = gradioApp().querySelector("#theme_vars textarea");
222
+ const css_textarea = gradioApp().querySelector("#theme_css textarea");
223
+
224
+ vars_textarea.value = inner_styles;
225
+ css_textarea.value = css_textarea.value;
226
+
227
+ //console.log(Object);
228
+
229
+ const vEvent = new Event("input");
230
+ const cEvent = new Event("input");
231
+ Object.defineProperty(vEvent, "target", { value: vars_textarea });
232
+ Object.defineProperty(cEvent, "target", { value: css_textarea });
233
+ vars_textarea.dispatchEvent(vEvent);
234
+ css_textarea.dispatchEvent(cEvent);
235
+ }
236
+
237
+ function applyTheme() {
238
+ console.log("apply");
239
+ }
240
+
241
+ function initTheme() {
242
+ const current_style = gradioApp().querySelector(".gradio-container > style");
243
+ //console.log(current_style);
244
+ //const head = document.head;
245
+ //head.appendChild(current_style);
246
+
247
+ const css_styles = current_style.innerHTML.split(
248
+ "/*BREAKPOINT_CSS_CONTENT*/"
249
+ );
250
+ let init_css_vars = css_styles[0].split("}")[0].split("{")[1];
251
+ init_css_vars = init_css_vars.replace(/\n|\r/g, "");
252
+
253
+ let init_vars = init_css_vars.split(";");
254
+ let vars = init_vars;
255
+
256
+ //console.log(vars);
257
+
258
+ const vars_textarea = gradioApp().querySelector("#theme_vars textarea");
259
+ const css_textarea = gradioApp().querySelector("#theme_css textarea");
260
+ //const result_textarea = gradioApp().querySelector('#theme_result textarea');
261
+ vars_textarea.value = init_css_vars;
262
+ css_textarea.value =
263
+ "/*BREAKPOINT_CSS_CONTENT*/" + css_styles[1] + "/*BREAKPOINT_CSS_CONTENT*/";
264
+
265
+ updateTheme(vars);
266
+
267
+ //vars_textarea.addEventListener("change", function(e) {
268
+ //e.preventDefault();
269
+ //e.stopPropagation();
270
+ //vars = vars_textarea.value.split(";");
271
+ //console.log(e);
272
+ //updateTheme(vars);
273
+
274
+ //})
275
+
276
+ const preview_styles = gradioApp().querySelector("#preview-styles");
277
+ let intervalChange;
278
+
279
+ gradioApp()
280
+ .querySelectorAll("#ui_theme_settings input")
281
+ .forEach((elem) => {
282
+ elem.addEventListener("input", function (e) {
283
+ let celem = e.currentTarget;
284
+ let val = e.currentTarget.value;
285
+ let curr_val;
286
+
287
+ switch (e.currentTarget.type) {
288
+ case "range":
289
+ celem = celem.parentElement;
290
+ val = e.currentTarget.value + "px";
291
+ break;
292
+ case "color":
293
+ celem = celem.parentElement.parentElement;
294
+ val = e.currentTarget.value;
295
+ break;
296
+ case "number":
297
+ celem = celem.parentElement.parentElement.parentElement;
298
+ val = e.currentTarget.value + "px";
299
+ break;
300
+ }
301
+
302
+ styleobj[celem.id] = val;
303
+
304
+ //console.log(styleobj);
305
+
306
+ if (intervalChange != null) clearInterval(intervalChange);
307
+ intervalChange = setTimeout(() => {
308
+ let inner_styles = "";
309
+
310
+ for (const key in styleobj) {
311
+ inner_styles += key + ":" + styleobj[key] + ";";
312
+ }
313
+
314
+ vars = inner_styles.split(";");
315
+ preview_styles.innerHTML = ":root {" + inner_styles + "}";
316
+ preview_styles.innerHTML +=
317
+ "@media only screen and (max-width: 860px) {:root{--ae-outside-gap-size: var(--ae-mobile-outside-gap-size);--ae-inside-padding-size: var(--ae-mobile-inside-padding-size);}}";
318
+
319
+ vars_textarea.value = inner_styles;
320
+ const vEvent = new Event("input");
321
+ Object.defineProperty(vEvent, "target", { value: vars_textarea });
322
+ vars_textarea.dispatchEvent(vEvent);
323
+
324
+ offsetColorsHSV(hsloffset);
325
+ }, 1000);
326
+ });
327
+ });
328
+
329
+ const reset_btn = gradioApp().getElementById("theme_reset_btn");
330
+ reset_btn.addEventListener("click", function (e) {
331
+ e.preventDefault();
332
+ e.stopPropagation();
333
+ gradioApp()
334
+ .querySelectorAll("#ui_theme_hsv input")
335
+ .forEach((elem) => {
336
+ elem.value = 0;
337
+ });
338
+ hsloffset = [0, 0, 0];
339
+ updateTheme(init_vars);
340
+ });
341
+
342
+ /*
343
+ const apply_btn = gradioApp().getElementById('theme_apply_btn');
344
+ apply_btn.addEventListener("click", function(e) {
345
+ e.preventDefault();
346
+ e.stopPropagation();
347
+ init_css_vars = vars_textarea.value.replace(/\n|\r/g, "");
348
+ vars_textarea.value = init_css_vars;
349
+
350
+ init_vars = init_css_vars.split(";");
351
+ vars = init_vars;
352
+ updateTheme(vars);
353
+ })
354
+ */
355
+
356
+ let intervalCheck;
357
+ function dropDownOnChange() {
358
+ if (init_css_vars != vars_textarea.value) {
359
+ clearInterval(intervalCheck);
360
+ init_css_vars = vars_textarea.value.replace(/\n|\r/g, "");
361
+ vars_textarea.value = init_css_vars;
362
+ init_vars = init_css_vars.split(";");
363
+ vars = init_vars;
364
+ updateTheme(vars);
365
+ }
366
+ }
367
+
368
+ const drop_down = gradioApp().querySelector("#themes_drop_down");
369
+ drop_down.addEventListener("click", function (e) {
370
+ if (intervalCheck != null) clearInterval(intervalCheck);
371
+ intervalCheck = setInterval(dropDownOnChange, 100);
372
+ //console.log("ok");
373
+ });
374
+
375
+ let hsloffset = [0, 0, 0];
376
+
377
+ const hue = gradioApp()
378
+ .querySelectorAll("#theme_hue input")
379
+ .forEach((elem) => {
380
+ elem.addEventListener("change", function (e) {
381
+ e.preventDefault();
382
+ e.stopPropagation();
383
+ hsloffset[0] = e.currentTarget.value;
384
+ offsetColorsHSV(hsloffset);
385
+ });
386
+ });
387
+
388
+ const sat = gradioApp()
389
+ .querySelectorAll("#theme_sat input")
390
+ .forEach((elem) => {
391
+ elem.addEventListener("change", function (e) {
392
+ e.preventDefault();
393
+ e.stopPropagation();
394
+ hsloffset[1] = e.currentTarget.value;
395
+ offsetColorsHSV(hsloffset);
396
+ });
397
+ });
398
+
399
+ const brt = gradioApp()
400
+ .querySelectorAll("#theme_brt input")
401
+ .forEach((elem) => {
402
+ elem.addEventListener("change", function (e) {
403
+ e.preventDefault();
404
+ e.stopPropagation();
405
+ hsloffset[2] = e.currentTarget.value;
406
+ offsetColorsHSV(hsloffset);
407
+ });
408
+ });
409
+
410
+ const inv_btn = gradioApp().getElementById("theme_invert_btn");
411
+ inv_btn.addEventListener("click", function (e) {
412
+ e.preventDefault();
413
+ e.stopPropagation();
414
+ isColorsInv = !isColorsInv;
415
+ offsetColorsHSV(hsloffset);
416
+ });
417
+ }
418
+
419
+ function observeGradioApp() {
420
+ const observer = new MutationObserver(() => {
421
+ const block = gradioApp().getElementById("tab_ui_theme");
422
+ if (block) {
423
+ observer.disconnect();
424
+
425
+ setTimeout(() => {
426
+ initTheme();
427
+ }, "500");
428
+ }
429
+ });
430
+ observer.observe(gradioApp(), { childList: true, subtree: true });
431
+ }
432
+
433
+ document.addEventListener("DOMContentLoaded", () => {
434
+ observeGradioApp();
435
+ });
extensions-builtin/sd_theme_editor/scripts/__pycache__/ui_theme.cpython-310.pyc ADDED
Binary file (6.32 kB). View file
 
extensions-builtin/sd_theme_editor/scripts/ui_theme.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ import gradio as gr
5
+ import modules.scripts as scripts
6
+ from modules import script_callbacks, shared
7
+
8
+ basedir = scripts.basedir()
9
+ webui_dir = Path(basedir).parents[1]
10
+
11
+ themes_folder = os.path.join(basedir, "themes")
12
+ javascript_folder = os.path.join(basedir, "javascript")
13
+ webui_style_path = os.path.join(webui_dir, "style.css")
14
+
15
+ def get_files(folder, file_filter=[], file_list=[], split=False):
16
+ file_list = [file_name if not split else os.path.splitext(file_name)[0] for file_name in os.listdir(folder) if os.path.isfile(os.path.join(folder, file_name)) and file_name not in file_filter]
17
+ return file_list
18
+
19
+
20
+ def on_ui_tabs():
21
+
22
+ with gr.Blocks(analytics_enabled=False) as ui_theme:
23
+ with gr.Row():
24
+ with gr.Column():
25
+ with gr.Row():
26
+ themes_dropdown = gr.Dropdown(label="Themes", elem_id="themes_drop_down", interactive=True, choices=get_files(themes_folder,[".css, .txt"]), type="value")
27
+ save_as_filename = gr.Text(label="Save / Save as")
28
+ with gr.Row():
29
+ reset_button = gr.Button(elem_id="theme_reset_btn", value="Reset", variant="primary")
30
+ #apply_button = gr.Button(elem_id="theme_apply_btn", value="Apply", variant="primary")
31
+ save_button = gr.Button(value="Save", variant="primary")
32
+ #delete_button = gr.Button(value="Delete", variant="primary")
33
+
34
+ #with gr.Accordion(label="Debug View", open=True):
35
+ with gr.Row(elem_id="theme_hidden"):
36
+ vars_text = gr.Textbox(label="Vars", elem_id="theme_vars", show_label=True, lines=7, interactive=False, visible=True)
37
+ css_text = gr.Textbox(label="Css", elem_id="theme_css", show_label=True, lines=7, interactive=False, visible=True)
38
+ #result_text = gr.Text(elem_id="theme_result", interactive=False, visible=False)
39
+ with gr.Column(elem_id="theme_overflow_container"):
40
+ with gr.Accordion(label="Theme Color adjustments", open=True):
41
+ with gr.Row():
42
+ with gr.Column(scale=6, elem_id="ui_theme_hsv"):
43
+ gr.Slider(elem_id="theme_hue", label='Hue', minimum=0, maximum=360, step=1)
44
+ gr.Slider(elem_id="theme_sat", label='Saturation', minimum=-100, maximum=100, step=1, value=0, interactive=True)
45
+ gr.Slider(elem_id="theme_brt", label='Lightness', minimum=-50, maximum=50, step=1, value=0, interactive=True)
46
+
47
+ gr.Button(elem_id="theme_invert_btn", value="Invert", variant="primary")
48
+
49
+
50
+ with gr.Row(elem_id="ui_theme_settings"):
51
+ with gr.Column():
52
+ with gr.Column():
53
+ with gr.Accordion(label="Main", open=True):
54
+ gr.ColorPicker(elem_id="--ae-main-bg-color", interactive=True, label="Background color")
55
+ gr.ColorPicker(elem_id="--ae-primary-color", label="Primary color")
56
+
57
+ with gr.Accordion(label="Focus", open=True):
58
+ gr.ColorPicker(elem_id="--ae-textarea-focus-color", label="Textarea color")
59
+ gr.ColorPicker(elem_id="--ae-input-focus-color", label="Input color")
60
+
61
+ with gr.Accordion(label="Spacing", open=True):
62
+ gr.Slider(elem_id="--ae-outside-gap-size", label='Gap size', minimum=1, maximum=16, step=1, interactive=True)
63
+ gr.Slider(elem_id="--ae-inside-padding-size", label='Padding size', minimum=1, maximum=16, step=1, interactive=True)
64
+
65
+ with gr.Accordion(label="Spacing (Mobile)", open=True):
66
+ gr.Slider(elem_id="--ae-mobile-outside-gap-size", label='Mobile Gap size', minimum=1, maximum=16, step=1, interactive=True)
67
+ gr.Slider(elem_id="--ae-mobile-inside-padding-size", label='Mobile Padding size', minimum=1, maximum=16, step=1, interactive=True)
68
+
69
+ with gr.Accordion(label="Panel", open=True):
70
+ gr.ColorPicker(elem_id="--ae-label-color", label="Label color")
71
+ gr.ColorPicker(elem_id="--ae-frame-bg-color", label="Frame Background color")
72
+ gr.ColorPicker(elem_id="--ae-panel-bg-color", label="Background color")
73
+ gr.ColorPicker(elem_id="--ae-panel-border-color", label="Border color")
74
+ gr.Slider(elem_id="--ae-panel-border-radius", label='Border radius', minimum=0, maximum=16, step=1)
75
+
76
+ gr.ColorPicker(elem_id="--ae-input-color", label="Input text color")
77
+ gr.ColorPicker(elem_id="--ae-input-bg-color", label="Input background color")
78
+ gr.ColorPicker(elem_id="--ae-input-border-color", label="Input border color")
79
+ with gr.Column():
80
+ with gr.Row(elem_id="theme_sub-panel"):
81
+
82
+ with gr.Accordion(label="SubPanel", open=True):
83
+ gr.ColorPicker(elem_id="--ae-subgroup-bg-color", label="Subgoup background color")
84
+ #gr.ColorPicker(elem_id="--ae-subgroup-label-color", label="Label color", value="#000000")
85
+ gr.ColorPicker(elem_id="--ae-subpanel-bg-color", label="Background color")
86
+ gr.ColorPicker(elem_id="--ae-subpanel-border-color", label="Border color")
87
+ gr.Slider(elem_id="--ae-subpanel-border-radius", label='Border radius', minimum=0, maximum=16, step=1)
88
+
89
+ gr.ColorPicker(elem_id="--ae-subgroup-input-color", label="Input text color")
90
+ gr.ColorPicker(elem_id="--ae-subgroup-input-bg-color", label="Input background color")
91
+ gr.ColorPicker(elem_id="--ae-subgroup-input-border-color", label="Input border color")
92
+
93
+ with gr.Row():
94
+ with gr.Column():
95
+ with gr.Accordion(label="Navigation menu", open=True):
96
+ gr.ColorPicker(elem_id="--ae-nav-bg-color", label="Background color")
97
+ gr.ColorPicker(elem_id="--ae-nav-color", label="Text color")
98
+ gr.ColorPicker(elem_id="--ae-nav-hover-color", label="Hover color")
99
+
100
+ with gr.Accordion(label="Icon", open=True):
101
+ gr.ColorPicker(elem_id="--ae-icon-color", label="Color")
102
+ gr.ColorPicker(elem_id="--ae-icon-hover-color", label="Hover color")
103
+
104
+ with gr.Accordion(label="Other", open=True):
105
+ gr.ColorPicker(elem_id="--ae-text-color", label="Text color")
106
+ gr.ColorPicker(elem_id="--ae-placeholder-color", label="Placeholder color")
107
+ gr.ColorPicker(elem_id="--ae-cancel-color", label="Cancel/Interrupt color")
108
+
109
+ with gr.Accordion(label="Modal", open=True):
110
+ gr.ColorPicker(elem_id="--ae-modal-bg-color", label="Background color")
111
+ gr.ColorPicker(elem_id="--ae-modal-icon-color", label="Icon color")
112
+
113
+
114
+
115
+ def save_theme( vars_text, css_text, filename):
116
+ style_data= ":root{" + vars_text + "}" + css_text
117
+ with open(os.path.join(themes_folder, f"{filename}.css"), 'w', encoding="utf-8") as file:
118
+ file.write(vars_text)
119
+ file.close()
120
+ with open(webui_style_path, 'w', encoding="utf-8") as file:
121
+ file.write(style_data)
122
+ file.close()
123
+ themes_dropdown.choices=get_files(themes_folder,[".css, .txt"])
124
+ return gr.update(choices=themes_dropdown.choices, value=f"{filename}.css")
125
+
126
+ def open_theme(filename, css_text):
127
+ with open(os.path.join(themes_folder, f"{filename}"), 'r') as file:
128
+ vars_text=file.read()
129
+ no_ext=filename.rsplit('.', 1)[0]
130
+ #save_theme( vars_text, css_text, no_ext)
131
+ # shared.state.interrupt()
132
+ # shared.state.need_restart = True
133
+ return [vars_text, no_ext]
134
+
135
+ # def delete_theme(filename):
136
+ # try:
137
+ # os.remove(os.path.join(themes_folder, filename))
138
+ # except FileNotFoundError:
139
+ # pass
140
+
141
+ # delete_button.click(
142
+ # fn = lambda: delete_theme()
143
+ # )
144
+
145
+ save_button.click(
146
+ fn=save_theme,
147
+ inputs=[vars_text, css_text, save_as_filename],
148
+ outputs=themes_dropdown
149
+ )
150
+
151
+ themes_dropdown.change(
152
+ fn=open_theme,
153
+ #_js = "applyTheme",
154
+ inputs=[themes_dropdown, css_text],
155
+ outputs=[vars_text, save_as_filename]
156
+ )
157
+
158
+ # apply_button.click(
159
+ # fn=None,
160
+ # _js = "applyTheme"
161
+ # )
162
+
163
+ # vars_text.change(
164
+ # fn=None,
165
+ # _js = "applyTheme",
166
+ # inputs=[],
167
+ # outputs=[vars_text, css_text]
168
+ # )
169
+
170
+
171
+
172
+
173
+ return (ui_theme, 'Theme', 'ui_theme'),
174
+
175
+
176
+
177
+ script_callbacks.on_ui_tabs(on_ui_tabs)
extensions-builtin/sd_theme_editor/style.css ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #theme_menu {
2
+ z-index: 9999;
3
+ background-color: var(--ae-input-bg-color);
4
+ position: relative;
5
+ width: 38px;
6
+ height: 38px;
7
+ border-radius: 100%;
8
+ cursor: pointer;
9
+ min-width: unset;
10
+ max-width: 38px;
11
+ align-self: center;
12
+ }
13
+
14
+ #theme_menu::before {
15
+ content: " ";
16
+ display: inline-block;
17
+ -webkit-mask-size: cover;
18
+ mask-size: cover;
19
+ background-color: var(--ae-icon-color);
20
+ width: var(--ae-icon-size);
21
+ height: var(--ae-icon-size);
22
+ -webkit-mask: url(./file=html/svg/contrast-drop-2-line.svg) no-repeat 50% 50%;
23
+ mask: url(./file=html/svg/contrast-drop-2-line.svg) no-repeat 50% 50%;
24
+ cursor: pointer;
25
+ position: relative;
26
+ left: 50%;
27
+ top: 50%;
28
+ transform: translate(-50%, -50%) scale(1);
29
+ }
30
+
31
+ #theme_menu.fixed,
32
+ #theme_menu:hover {
33
+ background-color: var(--ae-icon-color);
34
+ }
35
+
36
+ #theme_menu.fixed::before,
37
+ #theme_menu:hover::before {
38
+ background-color: var(--ae-icon-hover-color);
39
+ }
40
+
41
+ #theme_overflow_container {
42
+ overflow-y: auto;
43
+ height: calc(
44
+ 100vh - var(--ae-top-header-height) - (var(--ae-outside-gap-size) * 2) -
45
+ (var(--ae-inside-padding-size) * 4) - 96px
46
+ );
47
+ overflow-x: hidden;
48
+ }
49
+
50
+ #tab_ui_theme.open {
51
+ transform: translateX(0);
52
+ box-shadow: rgba(0, 0, 0, 0.4) -30px 0 30px -30px;
53
+ }
54
+
55
+ #tab_ui_theme.aside {
56
+ display: block !important;
57
+ }
58
+
59
+ #tab_ui_theme.aside {
60
+ position: fixed;
61
+ top: var(--ae-top-header-height);
62
+ width: 90%;
63
+ right: 0;
64
+ height: calc(100% - var(--ae-top-header-height));
65
+ max-width: 480px;
66
+ z-index: 9999;
67
+ transform: translateX(100%);
68
+ transition: all 0.25s ease 0s;
69
+ box-shadow: rgba(0, 0, 0, 0) -30px 0 30px -30px;
70
+ padding: calc(1rem - var(--ae-outside-gap-size));
71
+ background-color: var(--ae-main-bg-color) !important;
72
+ }
73
+ #tab_ui_theme.aside.open {
74
+ transform: translateX(0);
75
+ box-shadow: rgba(0, 0, 0, 0.4) -30px 0 30px -30px;
76
+ }
77
+
78
+ #theme_hidden,
79
+ #setting_ui_header_tabs .theme,
80
+ #setting_ui_hidden_tabs .theme {
81
+ display: none !important;
82
+ }
83
+
84
+ #tab_ui_theme [id*="color"] label {
85
+ display: flex;
86
+ align-items: center;
87
+ pointer-events: none;
88
+ }
89
+ #tab_ui_theme [id*="color"] label span {
90
+ min-width: 50% !important;
91
+ }
92
+ #tab_ui_theme [id*="color"] label input {
93
+ flex-grow: 1;
94
+ pointer-events: all;
95
+ cursor: pointer;
96
+ }
97
+
98
+ #settings_ui_theme > div > div {
99
+ flex-direction: row;
100
+ flex-wrap: wrap;
101
+ }
102
+ #settings_ui_theme > div > div > div {
103
+ max-width: 30%;
104
+ }
105
+
106
+ #tab_ui_theme > div {
107
+ padding: 16px !important;
108
+ padding-top: 0 !important;
109
+ }
110
+
111
+ #ui_theme_hsv + button {
112
+ min-width: unset;
113
+ }
extensions-builtin/sd_theme_editor/themes/Golde.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(99deg 11% 8%);--ae-primary-color:hsl(44deg 63% 55%);--ae-input-bg-color:hsl(106deg 8% 12%);--ae-input-border-color:hsl(104deg 9% 32%);--ae-panel-bg-color:hsl(104deg 9% 20%);--ae-panel-border-color:hsl(104deg 9% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(99deg 11% 8%);--ae-subgroup-input-bg-color:hsl(99deg 11% 8%);--ae-subgroup-input-border-color:hsl(104deg 9% 32%);--ae-subpanel-bg-color:hsl(106deg 8% 12%);--ae-subpanel-border-color:hsl(104deg 9% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(56deg 30% 36%);--ae-input-focus-color:hsl(44deg 63% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(104deg 9% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(105deg 9% 77%);--ae-icon-hover-color:hsl(99deg 11% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(98deg 9% 4%);--ae-nav-color:hsl(105deg 9% 77%);--ae-nav-hover-color:hsl(98deg 9% 4%);--ae-input-color:hsl(44deg 63% 55%);--ae-label-color:hsl(105deg 9% 77%);--ae-subgroup-input-color:hsl(44deg 63% 55%);--ae-placeholder-color:hsl(104deg 9% 32%);--ae-text-color:hsl(105deg 9% 77%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(108deg 8% 12%);--ae-modal-bg-color:hsl(96deg 12% 8%);--ae-modal-icon-color:hsl(44deg 63% 55%);
extensions-builtin/sd_theme_editor/themes/backup.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(0deg 0% 10%);--ae-primary-color:hsl(168deg 96% 42%);--ae-input-bg-color:hsl(225deg 6% 13%);--ae-input-border-color:hsl(214deg 5% 30%);--ae-panel-bg-color:hsl(225deg 5% 17%);--ae-panel-border-color:hsl(214deg 5% 30%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-bg-color:hsl(225deg 6% 13%);--ae-subgroup-input-border-color:hsl(214deg 5% 30%);--ae-subpanel-bg-color:hsl(220deg 4% 14%);--ae-subpanel-border-color:hsl(214deg 5% 30%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(168deg 97% 41%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(0deg 84% 60%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(168deg 96% 42%);--ae-icon-hover-color:hsl(0deg 0% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(0deg 0% 4%);--ae-nav-color:hsl(210deg 4% 80%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(210deg 4% 80%);--ae-label-color:hsl(210deg 4% 80%);--ae-subgroup-input-color:hsl(0deg 100% 100%);--ae-placeholder-color:hsl(214deg 5% 30%);--ae-text-color:hsl(210deg 4% 80%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(225deg 6% 13%);--ae-modal-bg-color:hsl(0deg 0% 10%);--ae-modal-icon-color:hsl(168deg 97% 41%);
extensions-builtin/sd_theme_editor/themes/d-230-52-94.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(230deg 52% 4%);--ae-primary-color:hsl(38deg 148% 36%);--ae-input-bg-color:hsl(95deg 58% 7%);--ae-input-border-color:hsl(84deg 57% 24%);--ae-panel-bg-color:hsl(95deg 57% 11%);--ae-panel-border-color:hsl(84deg 57% 24%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(230deg 52% 4%);--ae-subgroup-input-bg-color:hsl(95deg 58% 7%);--ae-subgroup-input-border-color:hsl(84deg 57% 24%);--ae-subpanel-bg-color:hsl(90deg 56% 8%);--ae-subpanel-border-color:hsl(84deg 57% 24%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(80deg 55% 30%);--ae-input-focus-color:hsl(38deg 149% 35%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(230deg 136% 54%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(80deg 56% 74%);--ae-icon-hover-color:hsl(230deg 52% 4%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(230deg 52% 98%);--ae-nav-color:hsl(80deg 56% 74%);--ae-nav-hover-color:hsl(230deg 52% 98%);--ae-input-color:hsl(80deg 56% 74%);--ae-label-color:hsl(80deg 56% 74%);--ae-subgroup-input-color:hsl(230deg 152% 94%);--ae-placeholder-color:hsl(84deg 57% 24%);--ae-text-color:hsl(80deg 56% 74%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(94deg 60% 7%);--ae-modal-bg-color:hsl(229deg 52% 4%);--ae-modal-icon-color:hsl(38deg 100% 36%);
extensions-builtin/sd_theme_editor/themes/default.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(0deg 0% 10%);--ae-primary-color:hsl(168deg 97% 41%);--ae-input-bg-color:hsl(225deg 6% 13%);--ae-input-border-color:hsl(214deg 5% 30%);--ae-panel-bg-color:hsl(225deg 5% 17%);--ae-panel-border-color:hsl(214deg 5% 30%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-bg-color:hsl(225deg 6% 13%);--ae-subgroup-input-border-color:hsl(214deg 5% 30%);--ae-subpanel-bg-color:hsl(220deg 4% 14%);--ae-subpanel-border-color:hsl(214deg 5% 30%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(168deg 97% 41%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(0deg 84% 60%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(168deg 97% 41%);--ae-icon-hover-color:hsl(0deg 0% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(0deg 0% 4%);--ae-nav-color:hsl(210deg 4% 80%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(210deg 4% 80%);--ae-label-color:hsl(210deg 4% 80%);--ae-subgroup-input-color:hsl(210deg 4% 80%);--ae-placeholder-color:hsl(214deg 5% 30%);--ae-text-color:hsl(210deg 4% 80%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(225deg 6% 13%);--ae-modal-bg-color:hsl(0deg 0% 10%);--ae-modal-icon-color:hsl(168deg 97% 41%);
extensions-builtin/sd_theme_editor/themes/default_cyan.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(0deg 0% 10%);--ae-primary-color:hsl(199deg 60% 60%);--ae-input-bg-color:hsl(225deg 6% 13%);--ae-input-border-color:hsl(214deg 5% 30%);--ae-panel-bg-color:hsl(225deg 5% 17%);--ae-panel-border-color:hsl(214deg 5% 30%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-bg-color:hsl(225deg 6% 13%);--ae-subgroup-input-border-color:hsl(214deg 5% 30%);--ae-subpanel-bg-color:hsl(220deg 4% 14%);--ae-subpanel-border-color:hsl(214deg 5% 30%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(199deg 60% 60%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(357deg 50% 57%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(210deg 4% 80%);--ae-icon-hover-color:hsl(0deg 0% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(0deg 0% 4%);--ae-nav-color:hsl(210deg 4% 80%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(210deg 4% 80%);--ae-label-color:hsl(210deg 4% 80%);--ae-subgroup-input-color:hsl(210deg 4% 80%);--ae-placeholder-color:hsl(214deg 5% 30%);--ae-text-color:hsl(210deg 4% 80%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(225deg 6% 13%);--ae-modal-bg-color:hsl(0deg 0% 10%);--ae-modal-icon-color:hsl(199deg 60% 60%);
extensions-builtin/sd_theme_editor/themes/default_orange.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(0deg 0% 10%);--ae-primary-color:hsl(16deg 77% 60%);--ae-input-bg-color:hsl(225deg 6% 13%);--ae-input-border-color:hsl(214deg 5% 30%);--ae-panel-bg-color:hsl(225deg 5% 17%);--ae-panel-border-color:hsl(214deg 5% 30%);--ae-panel-border-radius:8px;--ae-subgroup-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-bg-color:hsl(225deg 6% 13%);--ae-subgroup-input-border-color:hsl(214deg 5% 30%);--ae-subpanel-bg-color:hsl(220deg 4% 14%);--ae-subpanel-border-color:hsl(214deg 5% 30%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(16deg 77% 60%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(193deg 54% 55%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(210deg 4% 80%);--ae-icon-hover-color:hsl(0deg 0% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(0deg 0% 4%);--ae-nav-color:hsl(210deg 4% 80%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(210deg 4% 80%);--ae-label-color:hsl(210deg 4% 80%);--ae-subgroup-input-color:hsl(210deg 4% 80%);--ae-placeholder-color:hsl(214deg 5% 30%);--ae-text-color:hsl(210deg 4% 80%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(225deg 6% 13%);--ae-modal-bg-color:hsl(0deg 0% 10%);--ae-modal-icon-color:hsl(16deg 77% 60%);
extensions-builtin/sd_theme_editor/themes/fun.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(253deg 22% 8%);--ae-primary-color:hsl(76deg 96% 55%);--ae-input-bg-color:hsl(260deg 25% 12%);--ae-input-border-color:hsl(258deg 24% 32%);--ae-panel-bg-color:hsl(258deg 24% 20%);--ae-panel-border-color:hsl(258deg 24% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(253deg 22% 8%);--ae-subgroup-input-bg-color:hsl(258deg 24% 8%);--ae-subgroup-input-border-color:hsl(258deg 24% 32%);--ae-subpanel-bg-color:hsl(260deg 25% 12%);--ae-subpanel-border-color:hsl(258deg 24% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(296deg 96% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(258deg 24% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(259deg 24% 77%);--ae-icon-hover-color:hsl(253deg 22% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(252deg 24% 4%);--ae-nav-color:hsl(259deg 24% 77%);--ae-nav-hover-color:hsl(252deg 24% 4%);--ae-input-color:hsl(305deg 96% 55%);--ae-label-color:hsl(259deg 24% 77%);--ae-subgroup-input-color:hsl(76deg 96% 55%);--ae-placeholder-color:hsl(258deg 24% 32%);--ae-text-color:hsl(259deg 24% 77%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(260deg 25% 12%);--ae-modal-bg-color:hsl(253deg 22% 8%);--ae-modal-icon-color:hsl(76deg 96% 55%);
extensions-builtin/sd_theme_editor/themes/minimal.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(0deg 0% 8%);--ae-primary-color:hsl(168deg 96% 42%);--ae-input-bg-color:hsl(0deg 0% 10%);--ae-input-border-color:hsl(0deg 0% 10%);--ae-panel-bg-color:hsl(0deg 0% 17%);--ae-panel-border-color:hsl(0deg 0% 17%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-bg-color:hsl(0deg 0% 10%);--ae-subgroup-input-border-color:hsl(0deg 0% 10%);--ae-subpanel-bg-color:hsl(0deg 0% 14%);--ae-subpanel-border-color:hsl(0deg 0% 15%);--ae-subpanel-border-radius:4px;--ae-textarea-focus-color:hsl(0deg 0% 36%);--ae-input-focus-color:hsl(168deg 97% 41%);--ae-outside-gap-size:1px;--ae-inside-padding-size:5px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(0deg 84% 60%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(168deg 96% 42%);--ae-icon-hover-color:hsl(0deg 0% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(0deg 0% 4%);--ae-nav-color:hsl(0deg 0% 80%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(210deg 4% 80%);--ae-label-color:hsl(0deg 0% 65%);--ae-subgroup-input-color:hsl(0deg 100% 100%);--ae-placeholder-color:hsl(0deg 0% 30%);--ae-text-color:hsl(0deg 0% 80%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(0deg 0% 14%);--ae-modal-bg-color:hsl(0deg 0% 10%);--ae-modal-icon-color:hsl(168deg 97% 41%);
extensions-builtin/sd_theme_editor/themes/minimal_orange.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(210deg 28% 8%);--ae-primary-color:hsl(18deg 124% 42%);--ae-input-bg-color:hsl(210deg 28% 10%);--ae-input-border-color:hsl(210deg 28% 10%);--ae-panel-bg-color:hsl(210deg 28% 17%);--ae-panel-border-color:hsl(210deg 28% 17%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(210deg 28% 10%);--ae-subgroup-input-bg-color:hsl(210deg 28% 10%);--ae-subgroup-input-border-color:hsl(210deg 28% 10%);--ae-subpanel-bg-color:hsl(210deg 28% 14%);--ae-subpanel-border-color:hsl(210deg 28% 15%);--ae-subpanel-border-radius:4px;--ae-textarea-focus-color:hsl(210deg 28% 36%);--ae-input-focus-color:hsl(18deg 125% 41%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(210deg 112% 60%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(18deg 124% 42%);--ae-icon-hover-color:hsl(210deg 28% 10%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(210deg 28% 4%);--ae-nav-color:hsl(210deg 28% 80%);--ae-nav-hover-color:hsl(210deg 28% 4%);--ae-input-color:hsl(60deg 32% 80%);--ae-label-color:hsl(210deg 28% 65%);--ae-subgroup-input-color:hsl(210deg 128% 100%);--ae-placeholder-color:hsl(210deg 28% 30%);--ae-text-color:hsl(210deg 28% 80%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(210deg 28% 14%);--ae-modal-bg-color:hsl(210deg 28% 10%);--ae-modal-icon-color:hsl(18deg 125% 41%);
extensions-builtin/sd_theme_editor/themes/moonlight.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(240deg 16% 6%);--ae-primary-color:hsl(222deg 75% 62%);--ae-input-bg-color:hsl(240deg 17% 8%);--ae-input-border-color:hsl(240deg 20% 16%);--ae-panel-bg-color:hsl(240deg 18% 12%);--ae-panel-border-color:hsl(240deg 16% 16%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(240deg 17% 8%);--ae-subgroup-input-bg-color:hsl(240deg 17% 8%);--ae-subgroup-input-border-color:hsl(240deg 20% 16%);--ae-subpanel-bg-color:hsl(240deg 18% 10%);--ae-subpanel-border-color:hsl(240deg 20% 16%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(222deg 75% 62%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(222deg 75% 62%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(222deg 75% 62%);--ae-icon-hover-color:hsl(240deg 18% 12%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(240deg 16% 6%);--ae-nav-color:hsl(185deg 66% 85%);--ae-nav-hover-color:hsl(0deg 0% 4%);--ae-input-color:hsl(185deg 66% 85%);--ae-label-color:hsl(185deg 66% 85%);--ae-subgroup-input-color:hsl(185deg 66% 85%);--ae-placeholder-color:hsl(240deg 20% 24%);--ae-text-color:hsl(185deg 66% 85%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(240deg 18% 10%);--ae-modal-bg-color:hsl(240deg 16% 6%);--ae-modal-icon-color:hsl(222deg 75% 62%);
extensions-builtin/sd_theme_editor/themes/ogxBGreen.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(195deg 22% 8%);--ae-primary-color:hsl(159deg 96% 55%);--ae-input-bg-color:hsl(202deg 25% 12%);--ae-input-border-color:hsl(200deg 24% 32%);--ae-panel-bg-color:hsl(200deg 24% 20%);--ae-panel-border-color:hsl(200deg 24% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(195deg 22% 8%);--ae-subgroup-input-bg-color:hsl(200deg 24% 8%);--ae-subgroup-input-border-color:hsl(200deg 24% 32%);--ae-subpanel-bg-color:hsl(202deg 25% 12%);--ae-subpanel-border-color:hsl(200deg 24% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(152deg 3% 36%);--ae-input-focus-color:hsl(159deg 96% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(200deg 24% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(159deg 96% 55%);--ae-icon-hover-color:hsl(195deg 22% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(194deg 24% 4%);--ae-nav-color:hsl(201deg 24% 77%);--ae-nav-hover-color:hsl(194deg 24% 4%);--ae-input-color:hsl(159deg 96% 55%);--ae-label-color:hsl(201deg 24% 77%);--ae-subgroup-input-color:hsl(159deg 96% 55%);--ae-placeholder-color:hsl(200deg 24% 32%);--ae-text-color:hsl(201deg 24% 77%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(200deg 25% 12%);--ae-modal-bg-color:hsl(193deg 22% 8%);--ae-modal-icon-color:hsl(159deg 96% 55%);
extensions-builtin/sd_theme_editor/themes/ogxCyan.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(253deg 22% 8%);--ae-primary-color:hsl(198deg 96% 55%);--ae-input-bg-color:hsl(260deg 25% 12%);--ae-input-border-color:hsl(258deg 24% 32%);--ae-panel-bg-color:hsl(258deg 24% 20%);--ae-panel-border-color:hsl(258deg 24% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(253deg 22% 8%);--ae-subgroup-input-bg-color:hsl(258deg 24% 8%);--ae-subgroup-input-border-color:hsl(258deg 24% 32%);--ae-subpanel-bg-color:hsl(260deg 25% 12%);--ae-subpanel-border-color:hsl(258deg 24% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(198deg 96% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(258deg 24% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(259deg 24% 77%);--ae-icon-hover-color:hsl(253deg 22% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(252deg 24% 4%);--ae-nav-color:hsl(259deg 24% 77%);--ae-nav-hover-color:hsl(252deg 24% 4%);--ae-input-color:hsl(198deg 96% 55%);--ae-label-color:hsl(259deg 24% 77%);--ae-subgroup-input-color:hsl(198deg 96% 55%);--ae-placeholder-color:hsl(258deg 24% 32%);--ae-text-color:hsl(259deg 24% 77%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(260deg 25% 12%);--ae-modal-bg-color:hsl(253deg 22% 8%);--ae-modal-icon-color:hsl(198deg 96% 55%);
extensions-builtin/sd_theme_editor/themes/ogxCyanInvert.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(73deg 22% 92%);--ae-primary-color:hsl(18deg 96% 45%);--ae-input-bg-color:hsl(80deg 25% 88%);--ae-input-border-color:hsl(78deg 24% 68%);--ae-panel-bg-color:hsl(78deg 24% 80%);--ae-panel-border-color:hsl(78deg 24% 68%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(73deg 22% 92%);--ae-subgroup-input-bg-color:hsl(73deg 22% 92%);--ae-subgroup-input-border-color:hsl(78deg 24% 68%);--ae-subpanel-bg-color:hsl(80deg 25% 88%);--ae-subpanel-border-color:hsl(78deg 24% 68%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(30deg 3% 64%);--ae-input-focus-color:hsl(18deg 96% 45%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(78deg 24% 68%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(79deg 24% 23%);--ae-icon-hover-color:hsl(73deg 22% 92%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(72deg 24% 96%);--ae-nav-color:hsl(79deg 24% 23%);--ae-nav-hover-color:hsl(72deg 24% 96%);--ae-input-color:hsl(18deg 96% 45%);--ae-label-color:hsl(79deg 24% 23%);--ae-subgroup-input-color:hsl(18deg 96% 45%);--ae-placeholder-color:hsl(78deg 24% 68%);--ae-text-color:hsl(79deg 24% 23%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:2px;--ae-frame-bg-color:hsl(80deg 25% 88%);--ae-modal-bg-color:hsl(73deg 22% 92%);--ae-modal-icon-color:hsl(18deg 96% 45%);
extensions-builtin/sd_theme_editor/themes/ogxGreen.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(253deg 22% 8%);--ae-primary-color:hsl(149deg 96% 55%);--ae-input-bg-color:hsl(260deg 25% 12%);--ae-input-border-color:hsl(258deg 24% 32%);--ae-panel-bg-color:hsl(258deg 24% 20%);--ae-panel-border-color:hsl(258deg 24% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(253deg 22% 8%);--ae-subgroup-input-bg-color:hsl(258deg 24% 8%);--ae-subgroup-input-border-color:hsl(258deg 24% 32%);--ae-subpanel-bg-color:hsl(260deg 25% 12%);--ae-subpanel-border-color:hsl(258deg 24% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(149deg 96% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(258deg 24% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(259deg 24% 77%);--ae-icon-hover-color:hsl(253deg 22% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(252deg 24% 4%);--ae-nav-color:hsl(259deg 24% 77%);--ae-nav-hover-color:hsl(252deg 24% 4%);--ae-input-color:hsl(149deg 96% 55%);--ae-label-color:hsl(259deg 24% 77%);--ae-subgroup-input-color:hsl(149deg 96% 55%);--ae-placeholder-color:hsl(258deg 24% 32%);--ae-text-color:hsl(259deg 24% 77%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(260deg 25% 12%);--ae-modal-bg-color:hsl(253deg 22% 8%);--ae-modal-icon-color:hsl(253deg 22% 8%);
extensions-builtin/sd_theme_editor/themes/ogxRed.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(253deg 22% 8%);--ae-primary-color:hsl(347deg 96% 55%);--ae-input-bg-color:hsl(260deg 25% 12%);--ae-input-border-color:hsl(258deg 24% 32%);--ae-panel-bg-color:hsl(258deg 24% 20%);--ae-panel-border-color:hsl(258deg 24% 32%);--ae-panel-border-radius:4px;--ae-subgroup-bg-color:hsl(253deg 22% 8%);--ae-subgroup-input-bg-color:hsl(258deg 24% 8%);--ae-subgroup-input-border-color:hsl(258deg 24% 32%);--ae-subpanel-bg-color:hsl(260deg 25% 12%);--ae-subpanel-border-color:hsl(258deg 24% 32%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(347deg 96% 55%);--ae-outside-gap-size:8px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(258deg 24% 32%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(259deg 24% 77%);--ae-icon-hover-color:hsl(253deg 22% 8%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(252deg 24% 4%);--ae-nav-color:hsl(259deg 24% 77%);--ae-nav-hover-color:hsl(252deg 24% 4%);--ae-input-color:hsl(347deg 96% 55%);--ae-label-color:hsl(259deg 24% 77%);--ae-subgroup-input-color:hsl(347deg 96% 55%);--ae-placeholder-color:hsl(258deg 24% 32%);--ae-text-color:hsl(259deg 24% 77%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(260deg 25% 12%);--ae-modal-bg-color:hsl(253deg 22% 8%);--ae-modal-icon-color:hsl(347deg 96% 55%);
extensions-builtin/sd_theme_editor/themes/retrog.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(197deg 97% 14%);--ae-primary-color:hsl(27deg 99% 50%);--ae-input-bg-color:hsl(197deg 98% 16%);--ae-input-border-color:hsl(166deg 62% 33%);--ae-panel-bg-color:hsl(196deg 98% 18%);--ae-panel-border-color:hsl(166deg 62% 33%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(197deg 97% 14%);--ae-subgroup-input-bg-color:hsl(197deg 97% 14%);--ae-subgroup-input-border-color:hsl(166deg 62% 33%);--ae-subpanel-bg-color:hsl(197deg 98% 16%);--ae-subpanel-border-color:hsl(166deg 62% 33%);--ae-subpanel-border-radius:8px;--ae-textarea-focus-color:hsl(210deg 3% 36%);--ae-input-focus-color:hsl(222deg 75% 62%);--ae-outside-gap-size:6px;--ae-inside-padding-size:6px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(70deg 69% 54%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(27deg 99% 50%);--ae-icon-hover-color:hsl(196deg 98% 18%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(197deg 97% 14%);--ae-nav-color:hsl(70deg 69% 54%);--ae-nav-hover-color:hsl(197deg 97% 14%);--ae-input-color:hsl(70deg 69% 54%);--ae-label-color:hsl(70deg 69% 54%);--ae-subgroup-input-color:hsl(27deg 99% 50%);--ae-placeholder-color:hsl(166deg 62% 33%);--ae-text-color:hsl(185deg 66% 85%);--ae-mobile-outside-gap-size:3px;--ae-mobile-inside-padding-size:3px;--ae-frame-bg-color:hsl(197deg 98% 16%);--ae-modal-bg-color:hsl(197deg 97% 14%);--ae-modal-icon-color:hsl(27deg 99% 50%);
extensions-builtin/sd_theme_editor/themes/tron.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(185deg 75% 3%);--ae-primary-color:hsl(182deg 95% 51%);--ae-input-bg-color:hsl(185deg 73% 3%);--ae-input-border-color:hsl(185deg 72% 25%);--ae-panel-bg-color:hsl(180deg 76% 5%);--ae-panel-border-color:hsl(185deg 72% 25%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(185deg 73% 3%);--ae-subgroup-input-bg-color:hsl(185deg 73% 3%);--ae-subgroup-input-border-color:hsl(185deg 72% 25%);--ae-subpanel-bg-color:hsl(185deg 73% 3%);--ae-subpanel-border-color:hsl(185deg 72% 25%);--ae-subpanel-border-radius:0px;--ae-textarea-focus-color:hsl(182deg 95% 51%);--ae-input-focus-color:hsl(182deg 95% 51%);--ae-outside-gap-size:2px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(182deg 95% 51%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(182deg 95% 51%);--ae-icon-hover-color:hsl(185deg 73% 3%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(185deg 73% 3%);--ae-nav-color:hsl(182deg 95% 75%);--ae-nav-hover-color:hsl(0deg 100% 50%);--ae-input-color:hsl(182deg 95% 75%);--ae-label-color:hsl(182deg 95% 75%);--ae-subgroup-input-color:hsl(182deg 95% 51%);--ae-placeholder-color:hsl(185deg 72% 25%);--ae-text-color:hsl(182deg 95% 75%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:8px;--ae-frame-bg-color:hsl(180deg 76% 5%);--ae-modal-bg-color:hsl(185deg 73% 3%);--ae-modal-icon-color:hsl(182deg 95% 51%);
extensions-builtin/sd_theme_editor/themes/tron2.css ADDED
@@ -0,0 +1 @@
 
 
1
+ --ae-main-bg-color:hsl(185deg 75% 3%);--ae-primary-color:hsl(182deg 95% 51%);--ae-input-bg-color:hsl(184deg 89% 7%);--ae-input-border-color:hsl(185deg 72% 25%);--ae-panel-bg-color:hsl(185deg 73% 3%);--ae-panel-border-color:hsl(185deg 72% 25%);--ae-panel-border-radius:0px;--ae-subgroup-bg-color:hsl(185deg 73% 3%);--ae-subgroup-input-bg-color:hsl(185deg 73% 3%);--ae-subgroup-input-border-color:hsl(185deg 72% 25%);--ae-subpanel-bg-color:hsl(185deg 73% 3%);--ae-subpanel-border-color:hsl(185deg 72% 25%);--ae-subpanel-border-radius:0px;--ae-textarea-focus-color:hsl(182deg 95% 51%);--ae-input-focus-color:hsl(182deg 95% 51%);--ae-outside-gap-size:2px;--ae-inside-padding-size:8px;--ae-tool-button-size:34px;--ae-tool-button-radius:16px;--ae-generate-button-height:70px;--ae-cancel-color:hsl(182deg 95% 51%);--ae-max-padding:max(var(--ae-outside-gap-size),var(--ae-inside-padding-size));--ae-icon-color:hsl(182deg 95% 51%);--ae-icon-hover-color:hsl(185deg 73% 3%);--ae-icon-size:22px;--ae-nav-bg-color:hsl(185deg 73% 3%);--ae-nav-color:hsl(182deg 95% 75%);--ae-nav-hover-color:hsl(0deg 100% 50%);--ae-input-color:hsl(182deg 95% 75%);--ae-label-color:hsl(182deg 95% 75%);--ae-subgroup-input-color:hsl(182deg 95% 51%);--ae-placeholder-color:hsl(185deg 72% 25%);--ae-text-color:hsl(182deg 95% 75%);--ae-mobile-outside-gap-size:2px;--ae-mobile-inside-padding-size:8px;--ae-frame-bg-color:hsl(185deg 73% 3%);
html/200w.webp ADDED
html/card-no-preview.png ADDED
html/extra-networks-card.html ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div class='card-container' >
2
+ <div class='card' onclick={card_clicked} data-name="{name}" {sort_keys}>
3
+ <img src={preview_image} loading="lazy">
4
+ {metadata_button}
5
+ {edit_button}
6
+ <div class='actions'>
7
+ <div class='additional'>
8
+ <ul>
9
+ <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
10
+ </ul>
11
+ <span style="display:none" class='search_term{search_only}'>{search_term}</span>
12
+ </div>
13
+ <span class='name'>{name}</span>
14
+ </div>
15
+ <span class='description'>{description}</span>
16
+ </div>
17
+ <i class="image-icon"></i>
18
+ </div>
html/extra-networks-no-cards.html ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <div class='nocards'>
2
+ <h1>Nothing here. Add some content to the following directories:</h1>
3
+
4
+ <ul>
5
+ {dirs}
6
+ </ul>
7
+ </div>
8
+
html/favicon.ico ADDED
html/footer.html ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div class="footer-wrapper">
2
+ <ul class="footer-links">
3
+ <li>
4
+ <a href="/docs" data-tooltip="Use via API" data-position="top" class="top">
5
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path fill="none" d="M0 0h24v24H0z"/><path d="M6.2,9.6c0.5-0.7,1-1.4,1.6-2c3.1-3.1,7.5-4,11.1-2.5c1.4,3.6,0.6,8-2.5,11.1c-0.6,0.6-1.3,1.1-2,1.6l0.1,2.3c0,0.2-0.1,0.4-0.3,0.4l-4,1c-0.2,0.1-0.5-0.1-0.5-0.3c0,0,0-0.1,0-0.1v-2.7c0-0.2-0.1-0.4-0.3-0.6l-3.2-3.2c-0.2-0.2-0.4-0.3-0.6-0.3H2.9c-0.2,0-0.4-0.2-0.4-0.4c0,0,0-0.1,0-0.1l1-4c0.1-0.2,0.2-0.3,0.4-0.3C3.9,9.5,6.2,9.6,6.2,9.6zM12.1,11.9c0.7,0.7,1.8,0.7,2.4,0c0.7-0.7,0.7-1.8,0-2.4c-0.7-0.7-1.8-0.7-2.4,0C11.4,10.2,11.4,11.3,12.1,11.9z"/></svg>
6
+ </a>
7
+ </li>
8
+ <li>
9
+ <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui" data-tooltip="Github" data-position="top" class="top">
10
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path fill="none" d="M0 0h24v24H0z"/><path d="M12 2C6.475 2 2 6.475 2 12a9.994 9.994 0 0 0 6.838 9.488c.5.087.687-.213.687-.476 0-.237-.013-1.024-.013-1.862-2.512.463-3.162-.612-3.362-1.175-.113-.288-.6-1.175-1.025-1.413-.35-.187-.85-.65-.013-.662.788-.013 1.35.725 1.538 1.025.9 1.512 2.338 1.087 2.912.825.088-.65.35-1.087.638-1.337-2.225-.25-4.55-1.113-4.55-4.938 0-1.088.387-1.987 1.025-2.688-.1-.25-.45-1.275.1-2.65 0 0 .837-.262 2.75 1.026a9.28 9.28 0 0 1 2.5-.338c.85 0 1.7.112 2.5.337 1.912-1.3 2.75-1.024 2.75-1.024.55 1.375.2 2.4.1 2.65.637.7 1.025 1.587 1.025 2.687 0 3.838-2.337 4.688-4.562 4.938.362.312.675.912.675 1.85 0 1.337-.013 2.412-.013 2.75 0 .262.188.574.688.474A10.016 10.016 0 0 0 22 12c0-5.525-4.475-10-10-10z"/></svg></a>
11
+ </li>
12
+ <li>
13
+ <a href="https://gradio.app" data-tooltip="Gradio" data-position="top" class="top">
14
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path d="M0,0H24V24H0V0Z" style="fill: none;"/><path d="M21.14,10.92h0s0-2.91,0-2.91L12.04,3.02v2.91l6.45,3.54-1.95,1.07-4.5-2.47v2.91l1.85,1.01-1.88,1.03-1.85-1.02,1.89-1.03v-2.93l-4.56,2.49-1.94-1.06,6.5-3.55V3L2.86,8.01h0s0,0,0,0v2.93l1.94,1.06-1.94,1.06h0s0,0,0,0v2.93l9.14,5.01,9.14-5.01v-.02h0s0-2.91,0-2.91l-1.94-1.06,1.93-1.06v-.02Zm-2.65,3.6l-6.49,3.56-6.46-3.54,1.94-1.06,4.53,2.48h0s4.55-2.5,4.55-2.5l1.93,1.06Z"/></svg>
15
+ </a>
16
+ </li>
17
+ <li>
18
+ <a href="https://github.com/anapnoe/stable-diffusion-webui-ux" data-tooltip="UI/UX design by anapnoe" data-position="top" class="top">
19
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24">
20
+ <polygon points="19.5,13.7 21.6,9.9 11.8,9.9 16.7,18.5 17.7,16.7 17.1,15.7 16.7,16.5 13.5,10.9 19.9,10.9 18.9,12.7 "/>
21
+ <polygon points="17.3,11.7 15.8,11.7 20.2,19.3 3.8,19.3 12,5.2 14.2,9 15.8,9 12,2.5 1.5,20.7 22.5,20.7 "/>
22
+ </svg>
23
+ </li>
24
+ <li>
25
+ <a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false" data-tooltip="Reload UI" data-position="top" class="top">
26
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path fill="none" d="M0 0h24v24H0z"/><path d="M12 22C6.477 22 2 17.523 2 12S6.477 2 12 2s10 4.477 10 10-4.477 10-10 10zm4.82-4.924a7 7 0 1 0-1.852 1.266l-.975-1.755A5 5 0 1 1 17 12h-3l2.82 5.076z"/></svg>
27
+ </a>
28
+ </li>
29
+ <li>
30
+ <div class="tooltip-html">
31
+ <i class="icon-info">
32
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path fill="none" d="M0 0h24v24H0z"/><path d="M13 18v2h4v2H7v-2h4v-2H2.992A.998.998 0 0 1 2 16.993V4.007C2 3.451 2.455 3 2.992 3h18.016c.548 0 .992.449.992 1.007v12.986c0 .556-.455 1.007-.992 1.007H13z"/></svg>
33
+ </i>
34
+ <div class="top center">
35
+ {versions}
36
+ <i></i>
37
+ </div>
38
+ </div>
39
+ </li>
40
+ <li class="coffee-circle">
41
+ <div class="tooltip-html">
42
+ <i class="coffee">
43
+ <a href="https://buymeacoffee.com/dayanbayah">
44
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24" height="24"><path fill="red" d="M12.001 4.529c2.349-2.109 5.979-2.039 8.242.228 2.262 2.268 2.34 5.88.236 8.236l-8.48 8.492-8.478-8.492c-2.104-2.356-2.025-5.974.236-8.236 2.265-2.264 5.888-2.34 8.244-.228z"/></svg>
45
+ </a>
46
+ </i>
47
+ <div class="top">
48
+ <img src="./file=html/200w.webp" width="160" height="90" alt="thanks for your support">
49
+ <i></i>
50
+ <p style="color:var(--ae-primary-color);">Your donation can help fund the continued development of Stable Diffusion web UI/UX project. Enjoy!</p>
51
+ </div>
52
+ </div>
53
+ </li>
54
+ </ul>
55
+ </div>