toto10 commited on
Commit
2462654
1 Parent(s): 2c93839

7bd1acf8301721733c956a62ee480dc281b7213caae348c6d8690ec8224ac24f

Browse files
Files changed (50) hide show
  1. repositories/stable-diffusion-stability-ai/ldm/modules/karlo/kakao/template.py +141 -0
  2. repositories/stable-diffusion-stability-ai/ldm/modules/midas/__init__.py +0 -0
  3. repositories/stable-diffusion-stability-ai/ldm/modules/midas/__pycache__/__init__.cpython-310.pyc +0 -0
  4. repositories/stable-diffusion-stability-ai/ldm/modules/midas/__pycache__/api.cpython-310.pyc +0 -0
  5. repositories/stable-diffusion-stability-ai/ldm/modules/midas/api.py +170 -0
  6. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__init__.py +0 -0
  7. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/__init__.cpython-310.pyc +0 -0
  8. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/base_model.cpython-310.pyc +0 -0
  9. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/blocks.cpython-310.pyc +0 -0
  10. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/dpt_depth.cpython-310.pyc +0 -0
  11. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/midas_net.cpython-310.pyc +0 -0
  12. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc +0 -0
  13. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/transforms.cpython-310.pyc +0 -0
  14. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/vit.cpython-310.pyc +0 -0
  15. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/base_model.py +16 -0
  16. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/blocks.py +342 -0
  17. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/dpt_depth.py +109 -0
  18. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net.py +76 -0
  19. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net_custom.py +128 -0
  20. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/transforms.py +234 -0
  21. repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/vit.py +491 -0
  22. repositories/stable-diffusion-stability-ai/ldm/modules/midas/utils.py +189 -0
  23. repositories/stable-diffusion-stability-ai/ldm/util.py +207 -0
  24. repositories/stable-diffusion-stability-ai/modelcard.md +153 -0
  25. repositories/stable-diffusion-stability-ai/requirements.txt +19 -0
  26. repositories/stable-diffusion-stability-ai/scripts/gradio/depth2img.py +184 -0
  27. repositories/stable-diffusion-stability-ai/scripts/gradio/inpainting.py +195 -0
  28. repositories/stable-diffusion-stability-ai/scripts/gradio/superresolution.py +197 -0
  29. repositories/stable-diffusion-stability-ai/scripts/img2img.py +279 -0
  30. repositories/stable-diffusion-stability-ai/scripts/streamlit/depth2img.py +157 -0
  31. repositories/stable-diffusion-stability-ai/scripts/streamlit/inpainting.py +195 -0
  32. repositories/stable-diffusion-stability-ai/scripts/streamlit/stableunclip.py +416 -0
  33. repositories/stable-diffusion-stability-ai/scripts/streamlit/superresolution.py +170 -0
  34. repositories/stable-diffusion-stability-ai/scripts/tests/test_watermark.py +18 -0
  35. repositories/stable-diffusion-stability-ai/scripts/txt2img.py +388 -0
  36. repositories/stable-diffusion-stability-ai/setup.py +13 -0
  37. requirements-test.txt +3 -0
  38. requirements.txt +33 -0
  39. requirements_versions.txt +31 -0
  40. screenshot.png +0 -0
  41. script.js +163 -0
  42. scripts/__pycache__/custom_code.cpython-310.pyc +0 -0
  43. scripts/__pycache__/img2imgalt.cpython-310.pyc +0 -0
  44. scripts/__pycache__/loopback.cpython-310.pyc +0 -0
  45. scripts/__pycache__/outpainting_mk_2.cpython-310.pyc +0 -0
  46. scripts/__pycache__/poor_mans_outpainting.cpython-310.pyc +0 -0
  47. scripts/__pycache__/postprocessing_codeformer.cpython-310.pyc +0 -0
  48. scripts/__pycache__/postprocessing_gfpgan.cpython-310.pyc +0 -0
  49. scripts/__pycache__/postprocessing_upscale.cpython-310.pyc +0 -0
  50. scripts/__pycache__/prompt_matrix.cpython-310.pyc +0 -0
repositories/stable-diffusion-stability-ai/ldm/modules/karlo/kakao/template.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Karlo-v1.0.alpha
3
+ # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------------
5
+
6
+ import os
7
+ import logging
8
+ import torch
9
+
10
+ from omegaconf import OmegaConf
11
+
12
+ from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer
13
+ from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel
14
+ from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel
15
+ from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
16
+
17
+
18
+ SAMPLING_CONF = {
19
+ "default": {
20
+ "prior_sm": "25",
21
+ "prior_n_samples": 1,
22
+ "prior_cf_scale": 4.0,
23
+ "decoder_sm": "50",
24
+ "decoder_cf_scale": 8.0,
25
+ "sr_sm": "7",
26
+ },
27
+ "fast": {
28
+ "prior_sm": "25",
29
+ "prior_n_samples": 1,
30
+ "prior_cf_scale": 4.0,
31
+ "decoder_sm": "25",
32
+ "decoder_cf_scale": 8.0,
33
+ "sr_sm": "7",
34
+ },
35
+ }
36
+
37
+ CKPT_PATH = {
38
+ "prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
39
+ "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
40
+ "sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
41
+ }
42
+
43
+
44
+ class BaseSampler:
45
+ _PRIOR_CLASS = PriorDiffusionModel
46
+ _DECODER_CLASS = Text2ImProgressiveModel
47
+ _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
48
+
49
+ def __init__(
50
+ self,
51
+ root_dir: str,
52
+ sampling_type: str = "fast",
53
+ ):
54
+ self._root_dir = root_dir
55
+
56
+ sampling_type = SAMPLING_CONF[sampling_type]
57
+ self._prior_sm = sampling_type["prior_sm"]
58
+ self._prior_n_samples = sampling_type["prior_n_samples"]
59
+ self._prior_cf_scale = sampling_type["prior_cf_scale"]
60
+
61
+ assert self._prior_n_samples == 1
62
+
63
+ self._decoder_sm = sampling_type["decoder_sm"]
64
+ self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
65
+
66
+ self._sr_sm = sampling_type["sr_sm"]
67
+
68
+ def __repr__(self):
69
+ line = ""
70
+ line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
71
+ line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
72
+ line += f"SR(64->256), sampling method: {self._sr_sm}"
73
+
74
+ return line
75
+
76
+ def load_clip(self, clip_path: str):
77
+ clip = CustomizedCLIP.load_from_checkpoint(
78
+ os.path.join(self._root_dir, clip_path)
79
+ )
80
+ clip = torch.jit.script(clip)
81
+ clip.cuda()
82
+ clip.eval()
83
+
84
+ self._clip = clip
85
+ self._tokenizer = CustomizedTokenizer()
86
+
87
+ def load_prior(
88
+ self,
89
+ ckpt_path: str,
90
+ clip_stat_path: str,
91
+ prior_config: str = "configs/prior_1B_vit_l.yaml"
92
+ ):
93
+ logging.info(f"Loading prior: {ckpt_path}")
94
+
95
+ config = OmegaConf.load(prior_config)
96
+ clip_mean, clip_std = torch.load(
97
+ os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
98
+ )
99
+
100
+ prior = self._PRIOR_CLASS.load_from_checkpoint(
101
+ config,
102
+ self._tokenizer,
103
+ clip_mean,
104
+ clip_std,
105
+ os.path.join(self._root_dir, ckpt_path),
106
+ strict=True,
107
+ )
108
+ prior.cuda()
109
+ prior.eval()
110
+ logging.info("done.")
111
+
112
+ self._prior = prior
113
+
114
+ def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"):
115
+ logging.info(f"Loading decoder: {ckpt_path}")
116
+
117
+ config = OmegaConf.load(decoder_config)
118
+ decoder = self._DECODER_CLASS.load_from_checkpoint(
119
+ config,
120
+ self._tokenizer,
121
+ os.path.join(self._root_dir, ckpt_path),
122
+ strict=True,
123
+ )
124
+ decoder.cuda()
125
+ decoder.eval()
126
+ logging.info("done.")
127
+
128
+ self._decoder = decoder
129
+
130
+ def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"):
131
+ logging.info(f"Loading SR(64->256): {ckpt_path}")
132
+
133
+ config = OmegaConf.load(sr_config)
134
+ sr = self._SR256_CLASS.load_from_checkpoint(
135
+ config, os.path.join(self._root_dir, ckpt_path), strict=True
136
+ )
137
+ sr.cuda()
138
+ sr.eval()
139
+ logging.info("done.")
140
+
141
+ self._sr_64_256 = sr
repositories/stable-diffusion-stability-ai/ldm/modules/midas/__init__.py ADDED
File without changes
repositories/stable-diffusion-stability-ai/ldm/modules/midas/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (189 Bytes). View file
 
repositories/stable-diffusion-stability-ai/ldm/modules/midas/__pycache__/api.cpython-310.pyc ADDED
Binary file (3.63 kB). View file
 
repositories/stable-diffusion-stability-ai/ldm/modules/midas/api.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/isl-org/MiDaS
2
+
3
+ import cv2
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision.transforms import Compose
7
+
8
+ from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
9
+ from ldm.modules.midas.midas.midas_net import MidasNet
10
+ from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
11
+ from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12
+
13
+
14
+ ISL_PATHS = {
15
+ "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16
+ "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17
+ "midas_v21": "",
18
+ "midas_v21_small": "",
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ def load_midas_transform(model_type):
29
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
30
+ # load transform only
31
+ if model_type == "dpt_large": # DPT-Large
32
+ net_w, net_h = 384, 384
33
+ resize_mode = "minimal"
34
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35
+
36
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
37
+ net_w, net_h = 384, 384
38
+ resize_mode = "minimal"
39
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40
+
41
+ elif model_type == "midas_v21":
42
+ net_w, net_h = 384, 384
43
+ resize_mode = "upper_bound"
44
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45
+
46
+ elif model_type == "midas_v21_small":
47
+ net_w, net_h = 256, 256
48
+ resize_mode = "upper_bound"
49
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
+
51
+ else:
52
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53
+
54
+ transform = Compose(
55
+ [
56
+ Resize(
57
+ net_w,
58
+ net_h,
59
+ resize_target=None,
60
+ keep_aspect_ratio=True,
61
+ ensure_multiple_of=32,
62
+ resize_method=resize_mode,
63
+ image_interpolation_method=cv2.INTER_CUBIC,
64
+ ),
65
+ normalization,
66
+ PrepareForNet(),
67
+ ]
68
+ )
69
+
70
+ return transform
71
+
72
+
73
+ def load_model(model_type):
74
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
75
+ # load network
76
+ model_path = ISL_PATHS[model_type]
77
+ if model_type == "dpt_large": # DPT-Large
78
+ model = DPTDepthModel(
79
+ path=model_path,
80
+ backbone="vitl16_384",
81
+ non_negative=True,
82
+ )
83
+ net_w, net_h = 384, 384
84
+ resize_mode = "minimal"
85
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
86
+
87
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
88
+ model = DPTDepthModel(
89
+ path=model_path,
90
+ backbone="vitb_rn50_384",
91
+ non_negative=True,
92
+ )
93
+ net_w, net_h = 384, 384
94
+ resize_mode = "minimal"
95
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96
+
97
+ elif model_type == "midas_v21":
98
+ model = MidasNet(model_path, non_negative=True)
99
+ net_w, net_h = 384, 384
100
+ resize_mode = "upper_bound"
101
+ normalization = NormalizeImage(
102
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
103
+ )
104
+
105
+ elif model_type == "midas_v21_small":
106
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
107
+ non_negative=True, blocks={'expand': True})
108
+ net_w, net_h = 256, 256
109
+ resize_mode = "upper_bound"
110
+ normalization = NormalizeImage(
111
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112
+ )
113
+
114
+ else:
115
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
116
+ assert False
117
+
118
+ transform = Compose(
119
+ [
120
+ Resize(
121
+ net_w,
122
+ net_h,
123
+ resize_target=None,
124
+ keep_aspect_ratio=True,
125
+ ensure_multiple_of=32,
126
+ resize_method=resize_mode,
127
+ image_interpolation_method=cv2.INTER_CUBIC,
128
+ ),
129
+ normalization,
130
+ PrepareForNet(),
131
+ ]
132
+ )
133
+
134
+ return model.eval(), transform
135
+
136
+
137
+ class MiDaSInference(nn.Module):
138
+ MODEL_TYPES_TORCH_HUB = [
139
+ "DPT_Large",
140
+ "DPT_Hybrid",
141
+ "MiDaS_small"
142
+ ]
143
+ MODEL_TYPES_ISL = [
144
+ "dpt_large",
145
+ "dpt_hybrid",
146
+ "midas_v21",
147
+ "midas_v21_small",
148
+ ]
149
+
150
+ def __init__(self, model_type):
151
+ super().__init__()
152
+ assert (model_type in self.MODEL_TYPES_ISL)
153
+ model, _ = load_model(model_type)
154
+ self.model = model
155
+ self.model.train = disabled_train
156
+
157
+ def forward(self, x):
158
+ # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
159
+ # NOTE: we expect that the correct transform has been called during dataloading.
160
+ with torch.no_grad():
161
+ prediction = self.model(x)
162
+ prediction = torch.nn.functional.interpolate(
163
+ prediction.unsqueeze(1),
164
+ size=x.shape[2:],
165
+ mode="bicubic",
166
+ align_corners=False,
167
+ )
168
+ assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
169
+ return prediction
170
+
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__init__.py ADDED
File without changes
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (195 Bytes). View file
 
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/base_model.cpython-310.pyc ADDED
Binary file (723 Bytes). View file
 
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (7.24 kB). View file
 
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/dpt_depth.cpython-310.pyc ADDED
Binary file (2.95 kB). View file
 
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/midas_net.cpython-310.pyc ADDED
Binary file (2.63 kB). View file
 
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc ADDED
Binary file (3.75 kB). View file
 
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (5.71 kB). View file
 
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__pycache__/vit.cpython-310.pyc ADDED
Binary file (9.4 kB). View file
 
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/blocks.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
+ if backbone == "vitl16_384":
13
+ pretrained = _make_pretrained_vitl16_384(
14
+ use_pretrained, hooks=hooks, use_readout=use_readout
15
+ )
16
+ scratch = _make_scratch(
17
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
19
+ elif backbone == "vitb_rn50_384":
20
+ pretrained = _make_pretrained_vitb_rn50_384(
21
+ use_pretrained,
22
+ hooks=hooks,
23
+ use_vit_only=use_vit_only,
24
+ use_readout=use_readout,
25
+ )
26
+ scratch = _make_scratch(
27
+ [256, 512, 768, 768], features, groups=groups, expand=expand
28
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
29
+ elif backbone == "vitb16_384":
30
+ pretrained = _make_pretrained_vitb16_384(
31
+ use_pretrained, hooks=hooks, use_readout=use_readout
32
+ )
33
+ scratch = _make_scratch(
34
+ [96, 192, 384, 768], features, groups=groups, expand=expand
35
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
36
+ elif backbone == "resnext101_wsl":
37
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
+ elif backbone == "efficientnet_lite3":
40
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
+ else:
43
+ print(f"Backbone '{backbone}' not implemented")
44
+ assert False
45
+
46
+ return pretrained, scratch
47
+
48
+
49
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
+ scratch = nn.Module()
51
+
52
+ out_shape1 = out_shape
53
+ out_shape2 = out_shape
54
+ out_shape3 = out_shape
55
+ out_shape4 = out_shape
56
+ if expand==True:
57
+ out_shape1 = out_shape
58
+ out_shape2 = out_shape*2
59
+ out_shape3 = out_shape*4
60
+ out_shape4 = out_shape*8
61
+
62
+ scratch.layer1_rn = nn.Conv2d(
63
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
+ )
65
+ scratch.layer2_rn = nn.Conv2d(
66
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
+ )
68
+ scratch.layer3_rn = nn.Conv2d(
69
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
+ )
71
+ scratch.layer4_rn = nn.Conv2d(
72
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
+ )
74
+
75
+ return scratch
76
+
77
+
78
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
+ efficientnet = torch.hub.load(
80
+ "rwightman/gen-efficientnet-pytorch",
81
+ "tf_efficientnet_lite3",
82
+ pretrained=use_pretrained,
83
+ exportable=exportable
84
+ )
85
+ return _make_efficientnet_backbone(efficientnet)
86
+
87
+
88
+ def _make_efficientnet_backbone(effnet):
89
+ pretrained = nn.Module()
90
+
91
+ pretrained.layer1 = nn.Sequential(
92
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
+ )
94
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
+
98
+ return pretrained
99
+
100
+
101
+ def _make_resnet_backbone(resnet):
102
+ pretrained = nn.Module()
103
+ pretrained.layer1 = nn.Sequential(
104
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
+ )
106
+
107
+ pretrained.layer2 = resnet.layer2
108
+ pretrained.layer3 = resnet.layer3
109
+ pretrained.layer4 = resnet.layer4
110
+
111
+ return pretrained
112
+
113
+
114
+ def _make_pretrained_resnext101_wsl(use_pretrained):
115
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
+ return _make_resnet_backbone(resnet)
117
+
118
+
119
+
120
+ class Interpolate(nn.Module):
121
+ """Interpolation module.
122
+ """
123
+
124
+ def __init__(self, scale_factor, mode, align_corners=False):
125
+ """Init.
126
+
127
+ Args:
128
+ scale_factor (float): scaling
129
+ mode (str): interpolation mode
130
+ """
131
+ super(Interpolate, self).__init__()
132
+
133
+ self.interp = nn.functional.interpolate
134
+ self.scale_factor = scale_factor
135
+ self.mode = mode
136
+ self.align_corners = align_corners
137
+
138
+ def forward(self, x):
139
+ """Forward pass.
140
+
141
+ Args:
142
+ x (tensor): input
143
+
144
+ Returns:
145
+ tensor: interpolated data
146
+ """
147
+
148
+ x = self.interp(
149
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
+ )
151
+
152
+ return x
153
+
154
+
155
+ class ResidualConvUnit(nn.Module):
156
+ """Residual convolution module.
157
+ """
158
+
159
+ def __init__(self, features):
160
+ """Init.
161
+
162
+ Args:
163
+ features (int): number of features
164
+ """
165
+ super().__init__()
166
+
167
+ self.conv1 = nn.Conv2d(
168
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
169
+ )
170
+
171
+ self.conv2 = nn.Conv2d(
172
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
173
+ )
174
+
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ def forward(self, x):
178
+ """Forward pass.
179
+
180
+ Args:
181
+ x (tensor): input
182
+
183
+ Returns:
184
+ tensor: output
185
+ """
186
+ out = self.relu(x)
187
+ out = self.conv1(out)
188
+ out = self.relu(out)
189
+ out = self.conv2(out)
190
+
191
+ return out + x
192
+
193
+
194
+ class FeatureFusionBlock(nn.Module):
195
+ """Feature fusion block.
196
+ """
197
+
198
+ def __init__(self, features):
199
+ """Init.
200
+
201
+ Args:
202
+ features (int): number of features
203
+ """
204
+ super(FeatureFusionBlock, self).__init__()
205
+
206
+ self.resConfUnit1 = ResidualConvUnit(features)
207
+ self.resConfUnit2 = ResidualConvUnit(features)
208
+
209
+ def forward(self, *xs):
210
+ """Forward pass.
211
+
212
+ Returns:
213
+ tensor: output
214
+ """
215
+ output = xs[0]
216
+
217
+ if len(xs) == 2:
218
+ output += self.resConfUnit1(xs[1])
219
+
220
+ output = self.resConfUnit2(output)
221
+
222
+ output = nn.functional.interpolate(
223
+ output, scale_factor=2, mode="bilinear", align_corners=True
224
+ )
225
+
226
+ return output
227
+
228
+
229
+
230
+
231
+ class ResidualConvUnit_custom(nn.Module):
232
+ """Residual convolution module.
233
+ """
234
+
235
+ def __init__(self, features, activation, bn):
236
+ """Init.
237
+
238
+ Args:
239
+ features (int): number of features
240
+ """
241
+ super().__init__()
242
+
243
+ self.bn = bn
244
+
245
+ self.groups=1
246
+
247
+ self.conv1 = nn.Conv2d(
248
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
+ )
250
+
251
+ self.conv2 = nn.Conv2d(
252
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
+ )
254
+
255
+ if self.bn==True:
256
+ self.bn1 = nn.BatchNorm2d(features)
257
+ self.bn2 = nn.BatchNorm2d(features)
258
+
259
+ self.activation = activation
260
+
261
+ self.skip_add = nn.quantized.FloatFunctional()
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+
266
+ Args:
267
+ x (tensor): input
268
+
269
+ Returns:
270
+ tensor: output
271
+ """
272
+
273
+ out = self.activation(x)
274
+ out = self.conv1(out)
275
+ if self.bn==True:
276
+ out = self.bn1(out)
277
+
278
+ out = self.activation(out)
279
+ out = self.conv2(out)
280
+ if self.bn==True:
281
+ out = self.bn2(out)
282
+
283
+ if self.groups > 1:
284
+ out = self.conv_merge(out)
285
+
286
+ return self.skip_add.add(out, x)
287
+
288
+ # return out + x
289
+
290
+
291
+ class FeatureFusionBlock_custom(nn.Module):
292
+ """Feature fusion block.
293
+ """
294
+
295
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
+ """Init.
297
+
298
+ Args:
299
+ features (int): number of features
300
+ """
301
+ super(FeatureFusionBlock_custom, self).__init__()
302
+
303
+ self.deconv = deconv
304
+ self.align_corners = align_corners
305
+
306
+ self.groups=1
307
+
308
+ self.expand = expand
309
+ out_features = features
310
+ if self.expand==True:
311
+ out_features = features//2
312
+
313
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
+
315
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
+
318
+ self.skip_add = nn.quantized.FloatFunctional()
319
+
320
+ def forward(self, *xs):
321
+ """Forward pass.
322
+
323
+ Returns:
324
+ tensor: output
325
+ """
326
+ output = xs[0]
327
+
328
+ if len(xs) == 2:
329
+ res = self.resConfUnit1(xs[1])
330
+ output = self.skip_add.add(output, res)
331
+ # output += res
332
+
333
+ output = self.resConfUnit2(output)
334
+
335
+ output = nn.functional.interpolate(
336
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
+ )
338
+
339
+ output = self.out_conv(output)
340
+
341
+ return output
342
+
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/dpt_depth.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ ):
36
+
37
+ super(DPT, self).__init__()
38
+
39
+ self.channels_last = channels_last
40
+
41
+ hooks = {
42
+ "vitb_rn50_384": [0, 1, 8, 11],
43
+ "vitb16_384": [2, 5, 8, 11],
44
+ "vitl16_384": [5, 11, 17, 23],
45
+ }
46
+
47
+ # Instantiate backbone and reassemble blocks
48
+ self.pretrained, self.scratch = _make_encoder(
49
+ backbone,
50
+ features,
51
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
52
+ groups=1,
53
+ expand=False,
54
+ exportable=False,
55
+ hooks=hooks[backbone],
56
+ use_readout=readout,
57
+ )
58
+
59
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
+
64
+ self.scratch.output_conv = head
65
+
66
+
67
+ def forward(self, x):
68
+ if self.channels_last == True:
69
+ x.contiguous(memory_format=torch.channels_last)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
+
73
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
74
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
75
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
76
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
77
+
78
+ path_4 = self.scratch.refinenet4(layer_4_rn)
79
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
+
83
+ out = self.scratch.output_conv(path_1)
84
+
85
+ return out
86
+
87
+
88
+ class DPTDepthModel(DPT):
89
+ def __init__(self, path=None, non_negative=True, **kwargs):
90
+ features = kwargs["features"] if "features" in kwargs else 256
91
+
92
+ head = nn.Sequential(
93
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
+ nn.ReLU(True),
97
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
+ nn.ReLU(True) if non_negative else nn.Identity(),
99
+ nn.Identity(),
100
+ )
101
+
102
+ super().__init__(head, **kwargs)
103
+
104
+ if path is not None:
105
+ self.load(path)
106
+
107
+ def forward(self, x):
108
+ return super().forward(x).squeeze(dim=1)
109
+
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/vit.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Slice(nn.Module):
10
+ def __init__(self, start_index=1):
11
+ super(Slice, self).__init__()
12
+ self.start_index = start_index
13
+
14
+ def forward(self, x):
15
+ return x[:, self.start_index :]
16
+
17
+
18
+ class AddReadout(nn.Module):
19
+ def __init__(self, start_index=1):
20
+ super(AddReadout, self).__init__()
21
+ self.start_index = start_index
22
+
23
+ def forward(self, x):
24
+ if self.start_index == 2:
25
+ readout = (x[:, 0] + x[:, 1]) / 2
26
+ else:
27
+ readout = x[:, 0]
28
+ return x[:, self.start_index :] + readout.unsqueeze(1)
29
+
30
+
31
+ class ProjectReadout(nn.Module):
32
+ def __init__(self, in_features, start_index=1):
33
+ super(ProjectReadout, self).__init__()
34
+ self.start_index = start_index
35
+
36
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
+
38
+ def forward(self, x):
39
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
+ features = torch.cat((x[:, self.start_index :], readout), -1)
41
+
42
+ return self.project(features)
43
+
44
+
45
+ class Transpose(nn.Module):
46
+ def __init__(self, dim0, dim1):
47
+ super(Transpose, self).__init__()
48
+ self.dim0 = dim0
49
+ self.dim1 = dim1
50
+
51
+ def forward(self, x):
52
+ x = x.transpose(self.dim0, self.dim1)
53
+ return x
54
+
55
+
56
+ def forward_vit(pretrained, x):
57
+ b, c, h, w = x.shape
58
+
59
+ glob = pretrained.model.forward_flex(x)
60
+
61
+ layer_1 = pretrained.activations["1"]
62
+ layer_2 = pretrained.activations["2"]
63
+ layer_3 = pretrained.activations["3"]
64
+ layer_4 = pretrained.activations["4"]
65
+
66
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70
+
71
+ unflatten = nn.Sequential(
72
+ nn.Unflatten(
73
+ 2,
74
+ torch.Size(
75
+ [
76
+ h // pretrained.model.patch_size[1],
77
+ w // pretrained.model.patch_size[0],
78
+ ]
79
+ ),
80
+ )
81
+ )
82
+
83
+ if layer_1.ndim == 3:
84
+ layer_1 = unflatten(layer_1)
85
+ if layer_2.ndim == 3:
86
+ layer_2 = unflatten(layer_2)
87
+ if layer_3.ndim == 3:
88
+ layer_3 = unflatten(layer_3)
89
+ if layer_4.ndim == 3:
90
+ layer_4 = unflatten(layer_4)
91
+
92
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96
+
97
+ return layer_1, layer_2, layer_3, layer_4
98
+
99
+
100
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
101
+ posemb_tok, posemb_grid = (
102
+ posemb[:, : self.start_index],
103
+ posemb[0, self.start_index :],
104
+ )
105
+
106
+ gs_old = int(math.sqrt(len(posemb_grid)))
107
+
108
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111
+
112
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113
+
114
+ return posemb
115
+
116
+
117
+ def forward_flex(self, x):
118
+ b, c, h, w = x.shape
119
+
120
+ pos_embed = self._resize_pos_embed(
121
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122
+ )
123
+
124
+ B = x.shape[0]
125
+
126
+ if hasattr(self.patch_embed, "backbone"):
127
+ x = self.patch_embed.backbone(x)
128
+ if isinstance(x, (list, tuple)):
129
+ x = x[-1] # last feature if backbone outputs list/tuple of features
130
+
131
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132
+
133
+ if getattr(self, "dist_token", None) is not None:
134
+ cls_tokens = self.cls_token.expand(
135
+ B, -1, -1
136
+ ) # stole cls_tokens impl from Phil Wang, thanks
137
+ dist_token = self.dist_token.expand(B, -1, -1)
138
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
139
+ else:
140
+ cls_tokens = self.cls_token.expand(
141
+ B, -1, -1
142
+ ) # stole cls_tokens impl from Phil Wang, thanks
143
+ x = torch.cat((cls_tokens, x), dim=1)
144
+
145
+ x = x + pos_embed
146
+ x = self.pos_drop(x)
147
+
148
+ for blk in self.blocks:
149
+ x = blk(x)
150
+
151
+ x = self.norm(x)
152
+
153
+ return x
154
+
155
+
156
+ activations = {}
157
+
158
+
159
+ def get_activation(name):
160
+ def hook(model, input, output):
161
+ activations[name] = output
162
+
163
+ return hook
164
+
165
+
166
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
167
+ if use_readout == "ignore":
168
+ readout_oper = [Slice(start_index)] * len(features)
169
+ elif use_readout == "add":
170
+ readout_oper = [AddReadout(start_index)] * len(features)
171
+ elif use_readout == "project":
172
+ readout_oper = [
173
+ ProjectReadout(vit_features, start_index) for out_feat in features
174
+ ]
175
+ else:
176
+ assert (
177
+ False
178
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179
+
180
+ return readout_oper
181
+
182
+
183
+ def _make_vit_b16_backbone(
184
+ model,
185
+ features=[96, 192, 384, 768],
186
+ size=[384, 384],
187
+ hooks=[2, 5, 8, 11],
188
+ vit_features=768,
189
+ use_readout="ignore",
190
+ start_index=1,
191
+ ):
192
+ pretrained = nn.Module()
193
+
194
+ pretrained.model = model
195
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199
+
200
+ pretrained.activations = activations
201
+
202
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203
+
204
+ # 32, 48, 136, 384
205
+ pretrained.act_postprocess1 = nn.Sequential(
206
+ readout_oper[0],
207
+ Transpose(1, 2),
208
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209
+ nn.Conv2d(
210
+ in_channels=vit_features,
211
+ out_channels=features[0],
212
+ kernel_size=1,
213
+ stride=1,
214
+ padding=0,
215
+ ),
216
+ nn.ConvTranspose2d(
217
+ in_channels=features[0],
218
+ out_channels=features[0],
219
+ kernel_size=4,
220
+ stride=4,
221
+ padding=0,
222
+ bias=True,
223
+ dilation=1,
224
+ groups=1,
225
+ ),
226
+ )
227
+
228
+ pretrained.act_postprocess2 = nn.Sequential(
229
+ readout_oper[1],
230
+ Transpose(1, 2),
231
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232
+ nn.Conv2d(
233
+ in_channels=vit_features,
234
+ out_channels=features[1],
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0,
238
+ ),
239
+ nn.ConvTranspose2d(
240
+ in_channels=features[1],
241
+ out_channels=features[1],
242
+ kernel_size=2,
243
+ stride=2,
244
+ padding=0,
245
+ bias=True,
246
+ dilation=1,
247
+ groups=1,
248
+ ),
249
+ )
250
+
251
+ pretrained.act_postprocess3 = nn.Sequential(
252
+ readout_oper[2],
253
+ Transpose(1, 2),
254
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255
+ nn.Conv2d(
256
+ in_channels=vit_features,
257
+ out_channels=features[2],
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0,
261
+ ),
262
+ )
263
+
264
+ pretrained.act_postprocess4 = nn.Sequential(
265
+ readout_oper[3],
266
+ Transpose(1, 2),
267
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268
+ nn.Conv2d(
269
+ in_channels=vit_features,
270
+ out_channels=features[3],
271
+ kernel_size=1,
272
+ stride=1,
273
+ padding=0,
274
+ ),
275
+ nn.Conv2d(
276
+ in_channels=features[3],
277
+ out_channels=features[3],
278
+ kernel_size=3,
279
+ stride=2,
280
+ padding=1,
281
+ ),
282
+ )
283
+
284
+ pretrained.model.start_index = start_index
285
+ pretrained.model.patch_size = [16, 16]
286
+
287
+ # We inject this function into the VisionTransformer instances so that
288
+ # we can use it with interpolated position embeddings without modifying the library source.
289
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290
+ pretrained.model._resize_pos_embed = types.MethodType(
291
+ _resize_pos_embed, pretrained.model
292
+ )
293
+
294
+ return pretrained
295
+
296
+
297
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299
+
300
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
301
+ return _make_vit_b16_backbone(
302
+ model,
303
+ features=[256, 512, 1024, 1024],
304
+ hooks=hooks,
305
+ vit_features=1024,
306
+ use_readout=use_readout,
307
+ )
308
+
309
+
310
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312
+
313
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
314
+ return _make_vit_b16_backbone(
315
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316
+ )
317
+
318
+
319
+ def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321
+
322
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
323
+ return _make_vit_b16_backbone(
324
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325
+ )
326
+
327
+
328
+ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329
+ model = timm.create_model(
330
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331
+ )
332
+
333
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
334
+ return _make_vit_b16_backbone(
335
+ model,
336
+ features=[96, 192, 384, 768],
337
+ hooks=hooks,
338
+ use_readout=use_readout,
339
+ start_index=2,
340
+ )
341
+
342
+
343
+ def _make_vit_b_rn50_backbone(
344
+ model,
345
+ features=[256, 512, 768, 768],
346
+ size=[384, 384],
347
+ hooks=[0, 1, 8, 11],
348
+ vit_features=768,
349
+ use_vit_only=False,
350
+ use_readout="ignore",
351
+ start_index=1,
352
+ ):
353
+ pretrained = nn.Module()
354
+
355
+ pretrained.model = model
356
+
357
+ if use_vit_only == True:
358
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360
+ else:
361
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362
+ get_activation("1")
363
+ )
364
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365
+ get_activation("2")
366
+ )
367
+
368
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370
+
371
+ pretrained.activations = activations
372
+
373
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374
+
375
+ if use_vit_only == True:
376
+ pretrained.act_postprocess1 = nn.Sequential(
377
+ readout_oper[0],
378
+ Transpose(1, 2),
379
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380
+ nn.Conv2d(
381
+ in_channels=vit_features,
382
+ out_channels=features[0],
383
+ kernel_size=1,
384
+ stride=1,
385
+ padding=0,
386
+ ),
387
+ nn.ConvTranspose2d(
388
+ in_channels=features[0],
389
+ out_channels=features[0],
390
+ kernel_size=4,
391
+ stride=4,
392
+ padding=0,
393
+ bias=True,
394
+ dilation=1,
395
+ groups=1,
396
+ ),
397
+ )
398
+
399
+ pretrained.act_postprocess2 = nn.Sequential(
400
+ readout_oper[1],
401
+ Transpose(1, 2),
402
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403
+ nn.Conv2d(
404
+ in_channels=vit_features,
405
+ out_channels=features[1],
406
+ kernel_size=1,
407
+ stride=1,
408
+ padding=0,
409
+ ),
410
+ nn.ConvTranspose2d(
411
+ in_channels=features[1],
412
+ out_channels=features[1],
413
+ kernel_size=2,
414
+ stride=2,
415
+ padding=0,
416
+ bias=True,
417
+ dilation=1,
418
+ groups=1,
419
+ ),
420
+ )
421
+ else:
422
+ pretrained.act_postprocess1 = nn.Sequential(
423
+ nn.Identity(), nn.Identity(), nn.Identity()
424
+ )
425
+ pretrained.act_postprocess2 = nn.Sequential(
426
+ nn.Identity(), nn.Identity(), nn.Identity()
427
+ )
428
+
429
+ pretrained.act_postprocess3 = nn.Sequential(
430
+ readout_oper[2],
431
+ Transpose(1, 2),
432
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433
+ nn.Conv2d(
434
+ in_channels=vit_features,
435
+ out_channels=features[2],
436
+ kernel_size=1,
437
+ stride=1,
438
+ padding=0,
439
+ ),
440
+ )
441
+
442
+ pretrained.act_postprocess4 = nn.Sequential(
443
+ readout_oper[3],
444
+ Transpose(1, 2),
445
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446
+ nn.Conv2d(
447
+ in_channels=vit_features,
448
+ out_channels=features[3],
449
+ kernel_size=1,
450
+ stride=1,
451
+ padding=0,
452
+ ),
453
+ nn.Conv2d(
454
+ in_channels=features[3],
455
+ out_channels=features[3],
456
+ kernel_size=3,
457
+ stride=2,
458
+ padding=1,
459
+ ),
460
+ )
461
+
462
+ pretrained.model.start_index = start_index
463
+ pretrained.model.patch_size = [16, 16]
464
+
465
+ # We inject this function into the VisionTransformer instances so that
466
+ # we can use it with interpolated position embeddings without modifying the library source.
467
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468
+
469
+ # We inject this function into the VisionTransformer instances so that
470
+ # we can use it with interpolated position embeddings without modifying the library source.
471
+ pretrained.model._resize_pos_embed = types.MethodType(
472
+ _resize_pos_embed, pretrained.model
473
+ )
474
+
475
+ return pretrained
476
+
477
+
478
+ def _make_pretrained_vitb_rn50_384(
479
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480
+ ):
481
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482
+
483
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
484
+ return _make_vit_b_rn50_backbone(
485
+ model,
486
+ features=[256, 512, 768, 768],
487
+ size=[384, 384],
488
+ hooks=hooks,
489
+ use_vit_only=use_vit_only,
490
+ use_readout=use_readout,
491
+ )
repositories/stable-diffusion-stability-ai/ldm/modules/midas/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for monoDepth."""
2
+ import sys
3
+ import re
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+
8
+
9
+ def read_pfm(path):
10
+ """Read pfm file.
11
+
12
+ Args:
13
+ path (str): path to file
14
+
15
+ Returns:
16
+ tuple: (data, scale)
17
+ """
18
+ with open(path, "rb") as file:
19
+
20
+ color = None
21
+ width = None
22
+ height = None
23
+ scale = None
24
+ endian = None
25
+
26
+ header = file.readline().rstrip()
27
+ if header.decode("ascii") == "PF":
28
+ color = True
29
+ elif header.decode("ascii") == "Pf":
30
+ color = False
31
+ else:
32
+ raise Exception("Not a PFM file: " + path)
33
+
34
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35
+ if dim_match:
36
+ width, height = list(map(int, dim_match.groups()))
37
+ else:
38
+ raise Exception("Malformed PFM header.")
39
+
40
+ scale = float(file.readline().decode("ascii").rstrip())
41
+ if scale < 0:
42
+ # little-endian
43
+ endian = "<"
44
+ scale = -scale
45
+ else:
46
+ # big-endian
47
+ endian = ">"
48
+
49
+ data = np.fromfile(file, endian + "f")
50
+ shape = (height, width, 3) if color else (height, width)
51
+
52
+ data = np.reshape(data, shape)
53
+ data = np.flipud(data)
54
+
55
+ return data, scale
56
+
57
+
58
+ def write_pfm(path, image, scale=1):
59
+ """Write pfm file.
60
+
61
+ Args:
62
+ path (str): pathto file
63
+ image (array): data
64
+ scale (int, optional): Scale. Defaults to 1.
65
+ """
66
+
67
+ with open(path, "wb") as file:
68
+ color = None
69
+
70
+ if image.dtype.name != "float32":
71
+ raise Exception("Image dtype must be float32.")
72
+
73
+ image = np.flipud(image)
74
+
75
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
76
+ color = True
77
+ elif (
78
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79
+ ): # greyscale
80
+ color = False
81
+ else:
82
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83
+
84
+ file.write("PF\n" if color else "Pf\n".encode())
85
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86
+
87
+ endian = image.dtype.byteorder
88
+
89
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
90
+ scale = -scale
91
+
92
+ file.write("%f\n".encode() % scale)
93
+
94
+ image.tofile(file)
95
+
96
+
97
+ def read_image(path):
98
+ """Read image and output RGB image (0-1).
99
+
100
+ Args:
101
+ path (str): path to file
102
+
103
+ Returns:
104
+ array: RGB image (0-1)
105
+ """
106
+ img = cv2.imread(path)
107
+
108
+ if img.ndim == 2:
109
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110
+
111
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112
+
113
+ return img
114
+
115
+
116
+ def resize_image(img):
117
+ """Resize image and make it fit for network.
118
+
119
+ Args:
120
+ img (array): image
121
+
122
+ Returns:
123
+ tensor: data ready for network
124
+ """
125
+ height_orig = img.shape[0]
126
+ width_orig = img.shape[1]
127
+
128
+ if width_orig > height_orig:
129
+ scale = width_orig / 384
130
+ else:
131
+ scale = height_orig / 384
132
+
133
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135
+
136
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137
+
138
+ img_resized = (
139
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140
+ )
141
+ img_resized = img_resized.unsqueeze(0)
142
+
143
+ return img_resized
144
+
145
+
146
+ def resize_depth(depth, width, height):
147
+ """Resize depth map and bring to CPU (numpy).
148
+
149
+ Args:
150
+ depth (tensor): depth
151
+ width (int): image width
152
+ height (int): image height
153
+
154
+ Returns:
155
+ array: processed depth
156
+ """
157
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158
+
159
+ depth_resized = cv2.resize(
160
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161
+ )
162
+
163
+ return depth_resized
164
+
165
+ def write_depth(path, depth, bits=1):
166
+ """Write depth map to pfm and png file.
167
+
168
+ Args:
169
+ path (str): filepath without extension
170
+ depth (array): depth
171
+ """
172
+ write_pfm(path + ".pfm", depth.astype(np.float32))
173
+
174
+ depth_min = depth.min()
175
+ depth_max = depth.max()
176
+
177
+ max_val = (2**(8*bits))-1
178
+
179
+ if depth_max - depth_min > np.finfo("float").eps:
180
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
181
+ else:
182
+ out = np.zeros(depth.shape, dtype=depth.type)
183
+
184
+ if bits == 1:
185
+ cv2.imwrite(path + ".png", out.astype("uint8"))
186
+ elif bits == 2:
187
+ cv2.imwrite(path + ".png", out.astype("uint16"))
188
+
189
+ return
repositories/stable-diffusion-stability-ai/ldm/util.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ from torch import optim
5
+ import numpy as np
6
+
7
+ from inspect import isfunction
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+
11
+ def autocast(f):
12
+ def do_autocast(*args, **kwargs):
13
+ with torch.cuda.amp.autocast(enabled=True,
14
+ dtype=torch.get_autocast_gpu_dtype(),
15
+ cache_enabled=torch.is_autocast_cache_enabled()):
16
+ return f(*args, **kwargs)
17
+
18
+ return do_autocast
19
+
20
+
21
+ def log_txt_as_img(wh, xc, size=10):
22
+ # wh a tuple of (width, height)
23
+ # xc a list of captions to plot
24
+ b = len(xc)
25
+ txts = list()
26
+ for bi in range(b):
27
+ txt = Image.new("RGB", wh, color="white")
28
+ draw = ImageDraw.Draw(txt)
29
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
30
+ nc = int(40 * (wh[0] / 256))
31
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
32
+
33
+ try:
34
+ draw.text((0, 0), lines, fill="black", font=font)
35
+ except UnicodeEncodeError:
36
+ print("Cant encode string for logging. Skipping.")
37
+
38
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
39
+ txts.append(txt)
40
+ txts = np.stack(txts)
41
+ txts = torch.tensor(txts)
42
+ return txts
43
+
44
+
45
+ def ismap(x):
46
+ if not isinstance(x, torch.Tensor):
47
+ return False
48
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
49
+
50
+
51
+ def isimage(x):
52
+ if not isinstance(x,torch.Tensor):
53
+ return False
54
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
55
+
56
+
57
+ def exists(x):
58
+ return x is not None
59
+
60
+
61
+ def default(val, d):
62
+ if exists(val):
63
+ return val
64
+ return d() if isfunction(d) else d
65
+
66
+
67
+ def mean_flat(tensor):
68
+ """
69
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
70
+ Take the mean over all non-batch dimensions.
71
+ """
72
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
73
+
74
+
75
+ def count_params(model, verbose=False):
76
+ total_params = sum(p.numel() for p in model.parameters())
77
+ if verbose:
78
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
79
+ return total_params
80
+
81
+
82
+ def instantiate_from_config(config):
83
+ if not "target" in config:
84
+ if config == '__is_first_stage__':
85
+ return None
86
+ elif config == "__is_unconditional__":
87
+ return None
88
+ raise KeyError("Expected key `target` to instantiate.")
89
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
90
+
91
+
92
+ def get_obj_from_str(string, reload=False):
93
+ module, cls = string.rsplit(".", 1)
94
+ if reload:
95
+ module_imp = importlib.import_module(module)
96
+ importlib.reload(module_imp)
97
+ return getattr(importlib.import_module(module, package=None), cls)
98
+
99
+
100
+ class AdamWwithEMAandWings(optim.Optimizer):
101
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
102
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
103
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
104
+ ema_power=1., param_names=()):
105
+ """AdamW that saves EMA versions of the parameters."""
106
+ if not 0.0 <= lr:
107
+ raise ValueError("Invalid learning rate: {}".format(lr))
108
+ if not 0.0 <= eps:
109
+ raise ValueError("Invalid epsilon value: {}".format(eps))
110
+ if not 0.0 <= betas[0] < 1.0:
111
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
112
+ if not 0.0 <= betas[1] < 1.0:
113
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
114
+ if not 0.0 <= weight_decay:
115
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
116
+ if not 0.0 <= ema_decay <= 1.0:
117
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
118
+ defaults = dict(lr=lr, betas=betas, eps=eps,
119
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
120
+ ema_power=ema_power, param_names=param_names)
121
+ super().__init__(params, defaults)
122
+
123
+ def __setstate__(self, state):
124
+ super().__setstate__(state)
125
+ for group in self.param_groups:
126
+ group.setdefault('amsgrad', False)
127
+
128
+ @torch.no_grad()
129
+ def step(self, closure=None):
130
+ """Performs a single optimization step.
131
+ Args:
132
+ closure (callable, optional): A closure that reevaluates the model
133
+ and returns the loss.
134
+ """
135
+ loss = None
136
+ if closure is not None:
137
+ with torch.enable_grad():
138
+ loss = closure()
139
+
140
+ for group in self.param_groups:
141
+ params_with_grad = []
142
+ grads = []
143
+ exp_avgs = []
144
+ exp_avg_sqs = []
145
+ ema_params_with_grad = []
146
+ state_sums = []
147
+ max_exp_avg_sqs = []
148
+ state_steps = []
149
+ amsgrad = group['amsgrad']
150
+ beta1, beta2 = group['betas']
151
+ ema_decay = group['ema_decay']
152
+ ema_power = group['ema_power']
153
+
154
+ for p in group['params']:
155
+ if p.grad is None:
156
+ continue
157
+ params_with_grad.append(p)
158
+ if p.grad.is_sparse:
159
+ raise RuntimeError('AdamW does not support sparse gradients')
160
+ grads.append(p.grad)
161
+
162
+ state = self.state[p]
163
+
164
+ # State initialization
165
+ if len(state) == 0:
166
+ state['step'] = 0
167
+ # Exponential moving average of gradient values
168
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
169
+ # Exponential moving average of squared gradient values
170
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
171
+ if amsgrad:
172
+ # Maintains max of all exp. moving avg. of sq. grad. values
173
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
174
+ # Exponential moving average of parameter values
175
+ state['param_exp_avg'] = p.detach().float().clone()
176
+
177
+ exp_avgs.append(state['exp_avg'])
178
+ exp_avg_sqs.append(state['exp_avg_sq'])
179
+ ema_params_with_grad.append(state['param_exp_avg'])
180
+
181
+ if amsgrad:
182
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
183
+
184
+ # update the steps for each param group update
185
+ state['step'] += 1
186
+ # record the step after step update
187
+ state_steps.append(state['step'])
188
+
189
+ optim._functional.adamw(params_with_grad,
190
+ grads,
191
+ exp_avgs,
192
+ exp_avg_sqs,
193
+ max_exp_avg_sqs,
194
+ state_steps,
195
+ amsgrad=amsgrad,
196
+ beta1=beta1,
197
+ beta2=beta2,
198
+ lr=group['lr'],
199
+ weight_decay=group['weight_decay'],
200
+ eps=group['eps'],
201
+ maximize=False)
202
+
203
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
204
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
205
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
206
+
207
+ return loss
repositories/stable-diffusion-stability-ai/modelcard.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stable Diffusion v2 Model Card
2
+ This model card focuses on the models associated with the Stable Diffusion v2, available [here](https://github.com/Stability-AI/stablediffusion/).
3
+
4
+ ## Model Details
5
+ - **Developed by:** Robin Rombach, Patrick Esser
6
+ - **Model type:** Diffusion-based text-to-image generation model
7
+ - **Language(s):** English
8
+ - **License:** CreativeML Open RAIL++-M License
9
+ - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([OpenCLIP-ViT/H](https://github.com/mlfoundations/open_clip)).
10
+ - **Resources for more information:** [GitHub Repository](https://github.com/Stability-AI/).
11
+ - **Cite as:**
12
+
13
+ @InProceedings{Rombach_2022_CVPR,
14
+ author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
15
+ title = {High-Resolution Image Synthesis With Latent Diffusion Models},
16
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
17
+ month = {June},
18
+ year = {2022},
19
+ pages = {10684-10695}
20
+ }
21
+
22
+ # Uses
23
+
24
+ ## Direct Use
25
+ The model is intended for research purposes only. Possible research areas and tasks include
26
+
27
+ - Safe deployment of models which have the potential to generate harmful content.
28
+ - Probing and understanding the limitations and biases of generative models.
29
+ - Generation of artworks and use in design and other artistic processes.
30
+ - Applications in educational or creative tools.
31
+ - Research on generative models.
32
+
33
+ Excluded uses are described below.
34
+
35
+ ### Misuse, Malicious Use, and Out-of-Scope Use
36
+ _Note: This section is originally taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), was used for Stable Diffusion v1, but applies in the same way to Stable Diffusion v2_.
37
+
38
+ The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
39
+
40
+ #### Out-of-Scope Use
41
+ The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
42
+
43
+ #### Misuse and Malicious Use
44
+ Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
45
+
46
+ - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
47
+ - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
48
+ - Impersonating individuals without their consent.
49
+ - Sexual content without consent of the people who might see it.
50
+ - Mis- and disinformation
51
+ - Representations of egregious violence and gore
52
+ - Sharing of copyrighted or licensed material in violation of its terms of use.
53
+ - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
54
+
55
+ ## Limitations and Bias
56
+
57
+ ### Limitations
58
+
59
+ - The model does not achieve perfect photorealism
60
+ - The model cannot render legible text
61
+ - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
62
+ - Faces and people in general may not be generated properly.
63
+ - The model was trained mainly with English captions and will not work as well in other languages.
64
+ - The autoencoding part of the model is lossy
65
+ - The model was trained on a subset of the large-scale dataset
66
+ [LAION-5B](https://laion.ai/blog/laion-5b/), which contains adult, violent and sexual content. To partially mitigate this, we have filtered the dataset using LAION's NFSW detector (see Training section).
67
+
68
+ ### Bias
69
+ While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
70
+ Stable Diffusion vw was primarily trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
71
+ which consists of images that are limited to English descriptions.
72
+ Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
73
+ This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
74
+ ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
75
+ Stable Diffusion v2 mirrors and exacerbates biases to such a degree that viewer discretion must be advised irrespective of the input or its intent.
76
+
77
+
78
+ ## Training
79
+
80
+ **Training Data**
81
+ The model developers used the following dataset for training the model:
82
+
83
+ - LAION-5B and subsets (details below). The training data is further filtered using LAION's NSFW detector. For more details, please refer to LAION-5B's [NeurIPS 2022](https://openreview.net/forum?id=M3Y74vmsMcY) paper and reviewer discussions on the topic.
84
+
85
+ **Training Procedure**
86
+ Stable Diffusion v2 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
87
+
88
+ - Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
89
+ - Text prompts are encoded through the OpenCLIP-ViT/H text-encoder.
90
+ - The output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
91
+ - The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet. We also use the so-called _v-objective_, see https://arxiv.org/abs/2202.00512.
92
+
93
+ We currently provide the following checkpoints, for various versions:
94
+
95
+ ### Version 2.1
96
+
97
+ - `512-base-ema.ckpt`: Fine-tuned on `512-base-ema.ckpt` 2.0 with 220k extra steps taken, with `punsafe=0.98` on the same dataset.
98
+ - `768-v-ema.ckpt`: Resumed from `768-v-ema.ckpt` 2.0 with an additional 55k steps on the same dataset (`punsafe=0.1`), and then fine-tuned for another 155k extra steps with `punsafe=0.98`.
99
+
100
+ **SD-unCLIP 2.1** is a finetuned version of Stable Diffusion 2.1, modified to accept (noisy) CLIP image embedding in addition to the text prompt, and can be used to create image variations ([Examples](https://github.com/Stability-AI/stablediffusion/blob/main/doc/UNCLIP.MD)) or can be chained with text-to-image CLIP priors. The amount of noise added to the image embedding can be specified via the `noise_level` (0 means no noise, 1000 full noise).
101
+
102
+ If you plan on building applications on top of the model that the general public may use, you are responsible for adding the guardrails to minimize or prevent misuse of the application, especially for use-cases highlighted in the earlier section, Misuse, Malicious Use, and Out-of-Scope Use.
103
+
104
+ A public demo of SD-unCLIP is already available at [clipdrop.co/stable-diffusion-reimagine](https://clipdrop.co/stable-diffusion-reimagine)
105
+
106
+ ### Version 2.0
107
+
108
+ - `512-base-ema.ckpt`: 550k steps at resolution `256x256` on a subset of [LAION-5B](https://laion.ai/blog/laion-5b/) filtered for explicit pornographic material, using the [LAION-NSFW classifier](https://github.com/LAION-AI/CLIP-based-NSFW-Detector) with `punsafe=0.1` and an [aesthetic score](https://github.com/christophschuhmann/improved-aesthetic-predictor) >= `4.5`.
109
+ 850k steps at resolution `512x512` on the same dataset with resolution `>= 512x512`.
110
+ - `768-v-ema.ckpt`: Resumed from `512-base-ema.ckpt` and trained for 150k steps using a [v-objective](https://arxiv.org/abs/2202.00512) on the same dataset. Resumed for another 140k steps on a `768x768` subset of our dataset.
111
+ - `512-depth-ema.ckpt`: Resumed from `512-base-ema.ckpt` and finetuned for 200k steps. Added an extra input channel to process the (relative) depth prediction produced by [MiDaS](https://github.com/isl-org/MiDaS) (`dpt_hybrid`) which is used as an additional conditioning.
112
+ The additional input channels of the U-Net which process this extra information were zero-initialized.
113
+ - `512-inpainting-ema.ckpt`: Resumed from `512-base-ema.ckpt` and trained for another 200k steps. Follows the mask-generation strategy presented in [LAMA](https://github.com/saic-mdal/lama) which, in combination with the latent VAE representations of the masked image, are used as an additional conditioning.
114
+ The additional input channels of the U-Net which process this extra information were zero-initialized. The same strategy was used to train the [1.5-inpainting checkpoint](https://github.com/saic-mdal/lama).
115
+ - `x4-upscaling-ema.ckpt`: Trained for 1.25M steps on a 10M subset of LAION containing images `>2048x2048`. The model was trained on crops of size `512x512` and is a text-guided [latent upscaling diffusion model](https://arxiv.org/abs/2112.10752).
116
+ In addition to the textual input, it receives a `noise_level` as an input parameter, which can be used to add noise to the low-resolution input according to a [predefined diffusion schedule](configs/stable-diffusion/x4-upscaling.yaml).
117
+
118
+ - **Hardware:** 32 x 8 x A100 GPUs
119
+ - **Optimizer:** AdamW
120
+ - **Gradient Accumulations**: 1
121
+ - **Batch:** 32 x 8 x 2 x 4 = 2048
122
+ - **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
123
+
124
+ ## Evaluation Results
125
+ Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
126
+ 5.0, 6.0, 7.0, 8.0) and 50 steps DDIM sampling steps show the relative improvements of the checkpoints:
127
+
128
+ ![pareto](assets/model-variants.jpg)
129
+
130
+ Evaluated using 50 DDIM steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
131
+
132
+ ## Environmental Impact
133
+
134
+ **Stable Diffusion v1** **Estimated Emissions**
135
+ Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
136
+
137
+ - **Hardware Type:** A100 PCIe 40GB
138
+ - **Hours used:** 200000
139
+ - **Cloud Provider:** AWS
140
+ - **Compute Region:** US-east
141
+ - **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 15000 kg CO2 eq.
142
+
143
+ ## Citation
144
+ @InProceedings{Rombach_2022_CVPR,
145
+ author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
146
+ title = {High-Resolution Image Synthesis With Latent Diffusion Models},
147
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
148
+ month = {June},
149
+ year = {2022},
150
+ pages = {10684-10695}
151
+ }
152
+
153
+ *This model card was written by: Robin Rombach, Patrick Esser and David Ha and is based on the [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion/blob/main/Stable_Diffusion_v1_Model_Card.md) and [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
repositories/stable-diffusion-stability-ai/requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==0.4.3
2
+ opencv-python
3
+ pudb==2019.2
4
+ imageio==2.9.0
5
+ imageio-ffmpeg==0.4.2
6
+ pytorch-lightning==1.4.2
7
+ torchmetrics==0.6
8
+ omegaconf==2.1.1
9
+ test-tube>=0.7.5
10
+ streamlit>=0.73.1
11
+ einops==0.3.0
12
+ transformers==4.19.2
13
+ webdataset==0.2.5
14
+ open-clip-torch==2.7.0
15
+ gradio==3.13.2
16
+ kornia==0.6
17
+ invisible-watermark>=0.1.5
18
+ streamlit-drawable-canvas==0.8.0
19
+ -e .
repositories/stable-diffusion-stability-ai/scripts/gradio/depth2img.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from omegaconf import OmegaConf
7
+ from einops import repeat, rearrange
8
+ from pytorch_lightning import seed_everything
9
+ from imwatermark import WatermarkEncoder
10
+
11
+ from scripts.txt2img import put_watermark
12
+ from ldm.util import instantiate_from_config
13
+ from ldm.models.diffusion.ddim import DDIMSampler
14
+ from ldm.data.util import AddMiDaS
15
+
16
+ torch.set_grad_enabled(False)
17
+
18
+
19
+ def initialize_model(config, ckpt):
20
+ config = OmegaConf.load(config)
21
+ model = instantiate_from_config(config.model)
22
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
23
+
24
+ device = torch.device(
25
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
26
+ model = model.to(device)
27
+ sampler = DDIMSampler(model)
28
+ return sampler
29
+
30
+
31
+ def make_batch_sd(
32
+ image,
33
+ txt,
34
+ device,
35
+ num_samples=1,
36
+ model_type="dpt_hybrid"
37
+ ):
38
+ image = np.array(image.convert("RGB"))
39
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
40
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
41
+ midas_trafo = AddMiDaS(model_type=model_type)
42
+ batch = {
43
+ "jpg": image,
44
+ "txt": num_samples * [txt],
45
+ }
46
+ batch = midas_trafo(batch)
47
+ batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w')
48
+ batch["jpg"] = repeat(batch["jpg"].to(device=device),
49
+ "1 ... -> n ...", n=num_samples)
50
+ batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to(
51
+ device=device), "1 ... -> n ...", n=num_samples)
52
+ return batch
53
+
54
+
55
+ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
56
+ do_full_sample=False):
57
+ device = torch.device(
58
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
59
+ model = sampler.model
60
+ seed_everything(seed)
61
+
62
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
63
+ wm = "SDV2"
64
+ wm_encoder = WatermarkEncoder()
65
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
66
+
67
+ with torch.no_grad(),\
68
+ torch.autocast("cuda"):
69
+ batch = make_batch_sd(
70
+ image, txt=prompt, device=device, num_samples=num_samples)
71
+ z = model.get_first_stage_encoding(model.encode_first_stage(
72
+ batch[model.first_stage_key])) # move to latent space
73
+ c = model.cond_stage_model.encode(batch["txt"])
74
+ c_cat = list()
75
+ for ck in model.concat_keys:
76
+ cc = batch[ck]
77
+ cc = model.depth_model(cc)
78
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
79
+ keepdim=True)
80
+ display_depth = (cc - depth_min) / (depth_max - depth_min)
81
+ depth_image = Image.fromarray(
82
+ (display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8))
83
+ cc = torch.nn.functional.interpolate(
84
+ cc,
85
+ size=z.shape[2:],
86
+ mode="bicubic",
87
+ align_corners=False,
88
+ )
89
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
90
+ keepdim=True)
91
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1.
92
+ c_cat.append(cc)
93
+ c_cat = torch.cat(c_cat, dim=1)
94
+ # cond
95
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
96
+
97
+ # uncond cond
98
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
99
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
100
+ if not do_full_sample:
101
+ # encode (scaled latent)
102
+ z_enc = sampler.stochastic_encode(
103
+ z, torch.tensor([t_enc] * num_samples).to(model.device))
104
+ else:
105
+ z_enc = torch.randn_like(z)
106
+ # decode it
107
+ samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale,
108
+ unconditional_conditioning=uc_full, callback=callback)
109
+ x_samples_ddim = model.decode_first_stage(samples)
110
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
111
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
112
+ return [depth_image] + [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
113
+
114
+
115
+ def pad_image(input_image):
116
+ pad_w, pad_h = np.max(((2, 2), np.ceil(
117
+ np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
118
+ im_padded = Image.fromarray(
119
+ np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
120
+ return im_padded
121
+
122
+
123
+ def predict(input_image, prompt, steps, num_samples, scale, seed, eta, strength):
124
+ init_image = input_image.convert("RGB")
125
+ image = pad_image(init_image) # resize to integer multiple of 32
126
+
127
+ sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
128
+ assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
129
+ do_full_sample = strength == 1.
130
+ t_enc = min(int(strength * steps), steps-1)
131
+ result = paint(
132
+ sampler=sampler,
133
+ image=image,
134
+ prompt=prompt,
135
+ t_enc=t_enc,
136
+ seed=seed,
137
+ scale=scale,
138
+ num_samples=num_samples,
139
+ callback=None,
140
+ do_full_sample=do_full_sample
141
+ )
142
+ return result
143
+
144
+
145
+ sampler = initialize_model(sys.argv[1], sys.argv[2])
146
+
147
+ block = gr.Blocks().queue()
148
+ with block:
149
+ with gr.Row():
150
+ gr.Markdown("## Stable Diffusion Depth2Img")
151
+
152
+ with gr.Row():
153
+ with gr.Column():
154
+ input_image = gr.Image(source='upload', type="pil")
155
+ prompt = gr.Textbox(label="Prompt")
156
+ run_button = gr.Button(label="Run")
157
+ with gr.Accordion("Advanced options", open=False):
158
+ num_samples = gr.Slider(
159
+ label="Images", minimum=1, maximum=4, value=1, step=1)
160
+ ddim_steps = gr.Slider(label="Steps", minimum=1,
161
+ maximum=50, value=50, step=1)
162
+ scale = gr.Slider(
163
+ label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1
164
+ )
165
+ strength = gr.Slider(
166
+ label="Strength", minimum=0.0, maximum=1.0, value=0.9, step=0.01
167
+ )
168
+ seed = gr.Slider(
169
+ label="Seed",
170
+ minimum=0,
171
+ maximum=2147483647,
172
+ step=1,
173
+ randomize=True,
174
+ )
175
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
176
+ with gr.Column():
177
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(
178
+ grid=[2], height="auto")
179
+
180
+ run_button.click(fn=predict, inputs=[
181
+ input_image, prompt, ddim_steps, num_samples, scale, seed, eta, strength], outputs=[gallery])
182
+
183
+
184
+ block.launch()
repositories/stable-diffusion-stability-ai/scripts/gradio/inpainting.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from omegaconf import OmegaConf
8
+ from einops import repeat
9
+ from imwatermark import WatermarkEncoder
10
+ from pathlib import Path
11
+
12
+ from ldm.models.diffusion.ddim import DDIMSampler
13
+ from ldm.util import instantiate_from_config
14
+
15
+
16
+ torch.set_grad_enabled(False)
17
+
18
+
19
+ def put_watermark(img, wm_encoder=None):
20
+ if wm_encoder is not None:
21
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
22
+ img = wm_encoder.encode(img, 'dwtDct')
23
+ img = Image.fromarray(img[:, :, ::-1])
24
+ return img
25
+
26
+
27
+ def initialize_model(config, ckpt):
28
+ config = OmegaConf.load(config)
29
+ model = instantiate_from_config(config.model)
30
+
31
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
32
+
33
+ device = torch.device(
34
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
35
+ model = model.to(device)
36
+ sampler = DDIMSampler(model)
37
+
38
+ return sampler
39
+
40
+
41
+ def make_batch_sd(
42
+ image,
43
+ mask,
44
+ txt,
45
+ device,
46
+ num_samples=1):
47
+ image = np.array(image.convert("RGB"))
48
+ image = image[None].transpose(0, 3, 1, 2)
49
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
50
+
51
+ mask = np.array(mask.convert("L"))
52
+ mask = mask.astype(np.float32) / 255.0
53
+ mask = mask[None, None]
54
+ mask[mask < 0.5] = 0
55
+ mask[mask >= 0.5] = 1
56
+ mask = torch.from_numpy(mask)
57
+
58
+ masked_image = image * (mask < 0.5)
59
+
60
+ batch = {
61
+ "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
62
+ "txt": num_samples * [txt],
63
+ "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
64
+ "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
65
+ }
66
+ return batch
67
+
68
+
69
+ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
70
+ device = torch.device(
71
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
72
+ model = sampler.model
73
+
74
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
75
+ wm = "SDV2"
76
+ wm_encoder = WatermarkEncoder()
77
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
78
+
79
+ prng = np.random.RandomState(seed)
80
+ start_code = prng.randn(num_samples, 4, h // 8, w // 8)
81
+ start_code = torch.from_numpy(start_code).to(
82
+ device=device, dtype=torch.float32)
83
+
84
+ with torch.no_grad(), \
85
+ torch.autocast("cuda"):
86
+ batch = make_batch_sd(image, mask, txt=prompt,
87
+ device=device, num_samples=num_samples)
88
+
89
+ c = model.cond_stage_model.encode(batch["txt"])
90
+
91
+ c_cat = list()
92
+ for ck in model.concat_keys:
93
+ cc = batch[ck].float()
94
+ if ck != model.masked_image_key:
95
+ bchw = [num_samples, 4, h // 8, w // 8]
96
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
97
+ else:
98
+ cc = model.get_first_stage_encoding(
99
+ model.encode_first_stage(cc))
100
+ c_cat.append(cc)
101
+ c_cat = torch.cat(c_cat, dim=1)
102
+
103
+ # cond
104
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
105
+
106
+ # uncond cond
107
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
108
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
109
+
110
+ shape = [model.channels, h // 8, w // 8]
111
+ samples_cfg, intermediates = sampler.sample(
112
+ ddim_steps,
113
+ num_samples,
114
+ shape,
115
+ cond,
116
+ verbose=False,
117
+ eta=1.0,
118
+ unconditional_guidance_scale=scale,
119
+ unconditional_conditioning=uc_full,
120
+ x_T=start_code,
121
+ )
122
+ x_samples_ddim = model.decode_first_stage(samples_cfg)
123
+
124
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0,
125
+ min=0.0, max=1.0)
126
+
127
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
128
+ return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
129
+
130
+ def pad_image(input_image):
131
+ pad_w, pad_h = np.max(((2, 2), np.ceil(
132
+ np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
133
+ im_padded = Image.fromarray(
134
+ np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
135
+ return im_padded
136
+
137
+ def predict(input_image, prompt, ddim_steps, num_samples, scale, seed):
138
+ init_image = input_image["image"].convert("RGB")
139
+ init_mask = input_image["mask"].convert("RGB")
140
+ image = pad_image(init_image) # resize to integer multiple of 32
141
+ mask = pad_image(init_mask) # resize to integer multiple of 32
142
+ width, height = image.size
143
+ print("Inpainting...", width, height)
144
+
145
+ result = inpaint(
146
+ sampler=sampler,
147
+ image=image,
148
+ mask=mask,
149
+ prompt=prompt,
150
+ seed=seed,
151
+ scale=scale,
152
+ ddim_steps=ddim_steps,
153
+ num_samples=num_samples,
154
+ h=height, w=width
155
+ )
156
+
157
+ return result
158
+
159
+
160
+ sampler = initialize_model(sys.argv[1], sys.argv[2])
161
+
162
+ block = gr.Blocks().queue()
163
+ with block:
164
+ with gr.Row():
165
+ gr.Markdown("## Stable Diffusion Inpainting")
166
+
167
+ with gr.Row():
168
+ with gr.Column():
169
+ input_image = gr.Image(source='upload', tool='sketch', type="pil")
170
+ prompt = gr.Textbox(label="Prompt")
171
+ run_button = gr.Button(label="Run")
172
+ with gr.Accordion("Advanced options", open=False):
173
+ num_samples = gr.Slider(
174
+ label="Images", minimum=1, maximum=4, value=4, step=1)
175
+ ddim_steps = gr.Slider(label="Steps", minimum=1,
176
+ maximum=50, value=45, step=1)
177
+ scale = gr.Slider(
178
+ label="Guidance Scale", minimum=0.1, maximum=30.0, value=10, step=0.1
179
+ )
180
+ seed = gr.Slider(
181
+ label="Seed",
182
+ minimum=0,
183
+ maximum=2147483647,
184
+ step=1,
185
+ randomize=True,
186
+ )
187
+ with gr.Column():
188
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(
189
+ grid=[2], height="auto")
190
+
191
+ run_button.click(fn=predict, inputs=[
192
+ input_image, prompt, ddim_steps, num_samples, scale, seed], outputs=[gallery])
193
+
194
+
195
+ block.launch()
repositories/stable-diffusion-stability-ai/scripts/gradio/superresolution.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from omegaconf import OmegaConf
7
+ from einops import repeat, rearrange
8
+ from pytorch_lightning import seed_everything
9
+ from imwatermark import WatermarkEncoder
10
+
11
+ from scripts.txt2img import put_watermark
12
+ from ldm.models.diffusion.ddim import DDIMSampler
13
+ from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentUpscaleFinetuneDiffusion
14
+ from ldm.util import exists, instantiate_from_config
15
+
16
+
17
+ torch.set_grad_enabled(False)
18
+
19
+
20
+ def initialize_model(config, ckpt):
21
+ config = OmegaConf.load(config)
22
+ model = instantiate_from_config(config.model)
23
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
24
+
25
+ device = torch.device(
26
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
27
+ model = model.to(device)
28
+ sampler = DDIMSampler(model)
29
+ return sampler
30
+
31
+
32
+ def make_batch_sd(
33
+ image,
34
+ txt,
35
+ device,
36
+ num_samples=1,
37
+ ):
38
+ image = np.array(image.convert("RGB"))
39
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
40
+ batch = {
41
+ "lr": rearrange(image, 'h w c -> 1 c h w'),
42
+ "txt": num_samples * [txt],
43
+ }
44
+ batch["lr"] = repeat(batch["lr"].to(device=device),
45
+ "1 ... -> n ...", n=num_samples)
46
+ return batch
47
+
48
+
49
+ def make_noise_augmentation(model, batch, noise_level=None):
50
+ x_low = batch[model.low_scale_key]
51
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
52
+ x_aug, noise_level = model.low_scale_model(x_low, noise_level)
53
+ return x_aug, noise_level
54
+
55
+
56
+ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
57
+ device = torch.device(
58
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
59
+ model = sampler.model
60
+ seed_everything(seed)
61
+ prng = np.random.RandomState(seed)
62
+ start_code = prng.randn(num_samples, model.channels, h, w)
63
+ start_code = torch.from_numpy(start_code).to(
64
+ device=device, dtype=torch.float32)
65
+
66
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
67
+ wm = "SDV2"
68
+ wm_encoder = WatermarkEncoder()
69
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
70
+ with torch.no_grad(),\
71
+ torch.autocast("cuda"):
72
+ batch = make_batch_sd(
73
+ image, txt=prompt, device=device, num_samples=num_samples)
74
+ c = model.cond_stage_model.encode(batch["txt"])
75
+ c_cat = list()
76
+ if isinstance(model, LatentUpscaleFinetuneDiffusion):
77
+ for ck in model.concat_keys:
78
+ cc = batch[ck]
79
+ if exists(model.reshuffle_patch_size):
80
+ assert isinstance(model.reshuffle_patch_size, int)
81
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
82
+ p1=model.reshuffle_patch_size, p2=model.reshuffle_patch_size)
83
+ c_cat.append(cc)
84
+ c_cat = torch.cat(c_cat, dim=1)
85
+ # cond
86
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
87
+ # uncond cond
88
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
89
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
90
+ elif isinstance(model, LatentUpscaleDiffusion):
91
+ x_augment, noise_level = make_noise_augmentation(
92
+ model, batch, noise_level)
93
+ cond = {"c_concat": [x_augment],
94
+ "c_crossattn": [c], "c_adm": noise_level}
95
+ # uncond cond
96
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
97
+ uc_full = {"c_concat": [x_augment], "c_crossattn": [
98
+ uc_cross], "c_adm": noise_level}
99
+ else:
100
+ raise NotImplementedError()
101
+
102
+ shape = [model.channels, h, w]
103
+ samples, intermediates = sampler.sample(
104
+ steps,
105
+ num_samples,
106
+ shape,
107
+ cond,
108
+ verbose=False,
109
+ eta=eta,
110
+ unconditional_guidance_scale=scale,
111
+ unconditional_conditioning=uc_full,
112
+ x_T=start_code,
113
+ callback=callback
114
+ )
115
+ with torch.no_grad():
116
+ x_samples_ddim = model.decode_first_stage(samples)
117
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
118
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
119
+ return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
120
+
121
+
122
+ def pad_image(input_image):
123
+ pad_w, pad_h = np.max(((2, 2), np.ceil(
124
+ np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
125
+ im_padded = Image.fromarray(
126
+ np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
127
+ return im_padded
128
+
129
+
130
+ def predict(input_image, prompt, steps, num_samples, scale, seed, eta, noise_level):
131
+ init_image = input_image.convert("RGB")
132
+ image = pad_image(init_image) # resize to integer multiple of 32
133
+ width, height = image.size
134
+
135
+ noise_level = torch.Tensor(
136
+ num_samples * [noise_level]).to(sampler.model.device).long()
137
+ sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
138
+ result = paint(
139
+ sampler=sampler,
140
+ image=image,
141
+ prompt=prompt,
142
+ seed=seed,
143
+ scale=scale,
144
+ h=height, w=width, steps=steps,
145
+ num_samples=num_samples,
146
+ callback=None,
147
+ noise_level=noise_level
148
+ )
149
+ return result
150
+
151
+
152
+ sampler = initialize_model(sys.argv[1], sys.argv[2])
153
+
154
+ block = gr.Blocks().queue()
155
+ with block:
156
+ with gr.Row():
157
+ gr.Markdown("## Stable Diffusion Upscaling")
158
+
159
+ with gr.Row():
160
+ with gr.Column():
161
+ input_image = gr.Image(source='upload', type="pil")
162
+ gr.Markdown(
163
+ "Tip: Add a description of the object that should be upscaled, e.g.: 'a professional photograph of a cat")
164
+ prompt = gr.Textbox(label="Prompt")
165
+ run_button = gr.Button(label="Run")
166
+ with gr.Accordion("Advanced options", open=False):
167
+ num_samples = gr.Slider(
168
+ label="Number of Samples", minimum=1, maximum=4, value=1, step=1)
169
+ steps = gr.Slider(label="DDIM Steps", minimum=2,
170
+ maximum=200, value=75, step=1)
171
+ scale = gr.Slider(
172
+ label="Scale", minimum=0.1, maximum=30.0, value=10, step=0.1
173
+ )
174
+ seed = gr.Slider(
175
+ label="Seed",
176
+ minimum=0,
177
+ maximum=2147483647,
178
+ step=1,
179
+ randomize=True,
180
+ )
181
+ eta = gr.Number(label="eta (DDIM)",
182
+ value=0.0, min=0.0, max=1.0)
183
+ noise_level = None
184
+ if isinstance(sampler.model, LatentUpscaleDiffusion):
185
+ # TODO: make this work for all models
186
+ noise_level = gr.Number(
187
+ label="Noise Augmentation", min=0, max=350, value=20, step=1)
188
+
189
+ with gr.Column():
190
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(
191
+ grid=[2], height="auto")
192
+
193
+ run_button.click(fn=predict, inputs=[
194
+ input_image, prompt, steps, num_samples, scale, seed, eta, noise_level], outputs=[gallery])
195
+
196
+
197
+ block.launch()
repositories/stable-diffusion-stability-ai/scripts/img2img.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """make variations of input image"""
2
+
3
+ import argparse, os
4
+ import PIL
5
+ import torch
6
+ import numpy as np
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ from tqdm import tqdm, trange
10
+ from itertools import islice
11
+ from einops import rearrange, repeat
12
+ from torchvision.utils import make_grid
13
+ from torch import autocast
14
+ from contextlib import nullcontext
15
+ from pytorch_lightning import seed_everything
16
+ from imwatermark import WatermarkEncoder
17
+
18
+
19
+ from scripts.txt2img import put_watermark
20
+ from ldm.util import instantiate_from_config
21
+ from ldm.models.diffusion.ddim import DDIMSampler
22
+
23
+
24
+ def chunk(it, size):
25
+ it = iter(it)
26
+ return iter(lambda: tuple(islice(it, size)), ())
27
+
28
+
29
+ def load_model_from_config(config, ckpt, verbose=False):
30
+ print(f"Loading model from {ckpt}")
31
+ pl_sd = torch.load(ckpt, map_location="cpu")
32
+ if "global_step" in pl_sd:
33
+ print(f"Global Step: {pl_sd['global_step']}")
34
+ sd = pl_sd["state_dict"]
35
+ model = instantiate_from_config(config.model)
36
+ m, u = model.load_state_dict(sd, strict=False)
37
+ if len(m) > 0 and verbose:
38
+ print("missing keys:")
39
+ print(m)
40
+ if len(u) > 0 and verbose:
41
+ print("unexpected keys:")
42
+ print(u)
43
+
44
+ model.cuda()
45
+ model.eval()
46
+ return model
47
+
48
+
49
+ def load_img(path):
50
+ image = Image.open(path).convert("RGB")
51
+ w, h = image.size
52
+ print(f"loaded input image of size ({w}, {h}) from {path}")
53
+ w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
54
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
55
+ image = np.array(image).astype(np.float32) / 255.0
56
+ image = image[None].transpose(0, 3, 1, 2)
57
+ image = torch.from_numpy(image)
58
+ return 2. * image - 1.
59
+
60
+
61
+ def main():
62
+ parser = argparse.ArgumentParser()
63
+
64
+ parser.add_argument(
65
+ "--prompt",
66
+ type=str,
67
+ nargs="?",
68
+ default="a painting of a virus monster playing guitar",
69
+ help="the prompt to render"
70
+ )
71
+
72
+ parser.add_argument(
73
+ "--init-img",
74
+ type=str,
75
+ nargs="?",
76
+ help="path to the input image"
77
+ )
78
+
79
+ parser.add_argument(
80
+ "--outdir",
81
+ type=str,
82
+ nargs="?",
83
+ help="dir to write results to",
84
+ default="outputs/img2img-samples"
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--ddim_steps",
89
+ type=int,
90
+ default=50,
91
+ help="number of ddim sampling steps",
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--fixed_code",
96
+ action='store_true',
97
+ help="if enabled, uses the same starting code across all samples ",
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--ddim_eta",
102
+ type=float,
103
+ default=0.0,
104
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
105
+ )
106
+ parser.add_argument(
107
+ "--n_iter",
108
+ type=int,
109
+ default=1,
110
+ help="sample this often",
111
+ )
112
+
113
+ parser.add_argument(
114
+ "--C",
115
+ type=int,
116
+ default=4,
117
+ help="latent channels",
118
+ )
119
+ parser.add_argument(
120
+ "--f",
121
+ type=int,
122
+ default=8,
123
+ help="downsampling factor, most often 8 or 16",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--n_samples",
128
+ type=int,
129
+ default=2,
130
+ help="how many samples to produce for each given prompt. A.k.a batch size",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--n_rows",
135
+ type=int,
136
+ default=0,
137
+ help="rows in the grid (default: n_samples)",
138
+ )
139
+
140
+ parser.add_argument(
141
+ "--scale",
142
+ type=float,
143
+ default=9.0,
144
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--strength",
149
+ type=float,
150
+ default=0.8,
151
+ help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--from-file",
156
+ type=str,
157
+ help="if specified, load prompts from this file",
158
+ )
159
+ parser.add_argument(
160
+ "--config",
161
+ type=str,
162
+ default="configs/stable-diffusion/v2-inference.yaml",
163
+ help="path to config which constructs model",
164
+ )
165
+ parser.add_argument(
166
+ "--ckpt",
167
+ type=str,
168
+ help="path to checkpoint of model",
169
+ )
170
+ parser.add_argument(
171
+ "--seed",
172
+ type=int,
173
+ default=42,
174
+ help="the seed (for reproducible sampling)",
175
+ )
176
+ parser.add_argument(
177
+ "--precision",
178
+ type=str,
179
+ help="evaluate at this precision",
180
+ choices=["full", "autocast"],
181
+ default="autocast"
182
+ )
183
+
184
+ opt = parser.parse_args()
185
+ seed_everything(opt.seed)
186
+
187
+ config = OmegaConf.load(f"{opt.config}")
188
+ model = load_model_from_config(config, f"{opt.ckpt}")
189
+
190
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
191
+ model = model.to(device)
192
+
193
+ sampler = DDIMSampler(model)
194
+
195
+ os.makedirs(opt.outdir, exist_ok=True)
196
+ outpath = opt.outdir
197
+
198
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
199
+ wm = "SDV2"
200
+ wm_encoder = WatermarkEncoder()
201
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
202
+
203
+ batch_size = opt.n_samples
204
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
205
+ if not opt.from_file:
206
+ prompt = opt.prompt
207
+ assert prompt is not None
208
+ data = [batch_size * [prompt]]
209
+
210
+ else:
211
+ print(f"reading prompts from {opt.from_file}")
212
+ with open(opt.from_file, "r") as f:
213
+ data = f.read().splitlines()
214
+ data = list(chunk(data, batch_size))
215
+
216
+ sample_path = os.path.join(outpath, "samples")
217
+ os.makedirs(sample_path, exist_ok=True)
218
+ base_count = len(os.listdir(sample_path))
219
+ grid_count = len(os.listdir(outpath)) - 1
220
+
221
+ assert os.path.isfile(opt.init_img)
222
+ init_image = load_img(opt.init_img).to(device)
223
+ init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
224
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
225
+
226
+ sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
227
+
228
+ assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
229
+ t_enc = int(opt.strength * opt.ddim_steps)
230
+ print(f"target t_enc is {t_enc} steps")
231
+
232
+ precision_scope = autocast if opt.precision == "autocast" else nullcontext
233
+ with torch.no_grad():
234
+ with precision_scope("cuda"):
235
+ with model.ema_scope():
236
+ all_samples = list()
237
+ for n in trange(opt.n_iter, desc="Sampling"):
238
+ for prompts in tqdm(data, desc="data"):
239
+ uc = None
240
+ if opt.scale != 1.0:
241
+ uc = model.get_learned_conditioning(batch_size * [""])
242
+ if isinstance(prompts, tuple):
243
+ prompts = list(prompts)
244
+ c = model.get_learned_conditioning(prompts)
245
+
246
+ # encode (scaled latent)
247
+ z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
248
+ # decode it
249
+ samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
250
+ unconditional_conditioning=uc, )
251
+
252
+ x_samples = model.decode_first_stage(samples)
253
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
254
+
255
+ for x_sample in x_samples:
256
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
257
+ img = Image.fromarray(x_sample.astype(np.uint8))
258
+ img = put_watermark(img, wm_encoder)
259
+ img.save(os.path.join(sample_path, f"{base_count:05}.png"))
260
+ base_count += 1
261
+ all_samples.append(x_samples)
262
+
263
+ # additionally, save as grid
264
+ grid = torch.stack(all_samples, 0)
265
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
266
+ grid = make_grid(grid, nrow=n_rows)
267
+
268
+ # to image
269
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
270
+ grid = Image.fromarray(grid.astype(np.uint8))
271
+ grid = put_watermark(grid, wm_encoder)
272
+ grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
273
+ grid_count += 1
274
+
275
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()
repositories/stable-diffusion-stability-ai/scripts/streamlit/depth2img.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import numpy as np
4
+ import streamlit as st
5
+ from PIL import Image
6
+ from omegaconf import OmegaConf
7
+ from einops import repeat, rearrange
8
+ from pytorch_lightning import seed_everything
9
+ from imwatermark import WatermarkEncoder
10
+
11
+ from scripts.txt2img import put_watermark
12
+ from ldm.util import instantiate_from_config
13
+ from ldm.models.diffusion.ddim import DDIMSampler
14
+ from ldm.data.util import AddMiDaS
15
+
16
+ torch.set_grad_enabled(False)
17
+
18
+
19
+ @st.cache(allow_output_mutation=True)
20
+ def initialize_model(config, ckpt):
21
+ config = OmegaConf.load(config)
22
+ model = instantiate_from_config(config.model)
23
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
24
+
25
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
26
+ model = model.to(device)
27
+ sampler = DDIMSampler(model)
28
+ return sampler
29
+
30
+
31
+ def make_batch_sd(
32
+ image,
33
+ txt,
34
+ device,
35
+ num_samples=1,
36
+ model_type="dpt_hybrid"
37
+ ):
38
+ image = np.array(image.convert("RGB"))
39
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
40
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
41
+ midas_trafo = AddMiDaS(model_type=model_type)
42
+ batch = {
43
+ "jpg": image,
44
+ "txt": num_samples * [txt],
45
+ }
46
+ batch = midas_trafo(batch)
47
+ batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w')
48
+ batch["jpg"] = repeat(batch["jpg"].to(device=device), "1 ... -> n ...", n=num_samples)
49
+ batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to(device=device), "1 ... -> n ...", n=num_samples)
50
+ return batch
51
+
52
+
53
+ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
54
+ do_full_sample=False):
55
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
56
+ model = sampler.model
57
+ seed_everything(seed)
58
+
59
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
60
+ wm = "SDV2"
61
+ wm_encoder = WatermarkEncoder()
62
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
63
+
64
+ with torch.no_grad(),\
65
+ torch.autocast("cuda"):
66
+ batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
67
+ z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
68
+ c = model.cond_stage_model.encode(batch["txt"])
69
+ c_cat = list()
70
+ for ck in model.concat_keys:
71
+ cc = batch[ck]
72
+ cc = model.depth_model(cc)
73
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
74
+ keepdim=True)
75
+ display_depth = (cc - depth_min) / (depth_max - depth_min)
76
+ st.image(Image.fromarray((display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8)))
77
+ cc = torch.nn.functional.interpolate(
78
+ cc,
79
+ size=z.shape[2:],
80
+ mode="bicubic",
81
+ align_corners=False,
82
+ )
83
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
84
+ keepdim=True)
85
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1.
86
+ c_cat.append(cc)
87
+ c_cat = torch.cat(c_cat, dim=1)
88
+ # cond
89
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
90
+
91
+ # uncond cond
92
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
93
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
94
+ if not do_full_sample:
95
+ # encode (scaled latent)
96
+ z_enc = sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
97
+ else:
98
+ z_enc = torch.randn_like(z)
99
+ # decode it
100
+ samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale,
101
+ unconditional_conditioning=uc_full, callback=callback)
102
+ x_samples_ddim = model.decode_first_stage(samples)
103
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
104
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
105
+ return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
106
+
107
+
108
+ def run():
109
+ st.title("Stable Diffusion Depth2Img")
110
+ # run via streamlit run scripts/demo/depth2img.py <path-tp-config> <path-to-ckpt>
111
+ sampler = initialize_model(sys.argv[1], sys.argv[2])
112
+
113
+ image = st.file_uploader("Image", ["jpg", "png"])
114
+ if image:
115
+ image = Image.open(image)
116
+ w, h = image.size
117
+ st.text(f"loaded input image of size ({w}, {h})")
118
+ width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
119
+ image = image.resize((width, height))
120
+ st.text(f"resized input image to size ({width}, {height} (w, h))")
121
+ st.image(image)
122
+
123
+ prompt = st.text_input("Prompt")
124
+
125
+ seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0)
126
+ num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1)
127
+ scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1)
128
+ steps = st.slider("DDIM Steps", min_value=0, max_value=50, value=50, step=1)
129
+ strength = st.slider("Strength", min_value=0., max_value=1., value=0.9)
130
+
131
+ t_progress = st.progress(0)
132
+ def t_callback(t):
133
+ t_progress.progress(min((t + 1) / t_enc, 1.))
134
+
135
+ assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
136
+ do_full_sample = strength == 1.
137
+ t_enc = min(int(strength * steps), steps-1)
138
+ sampler.make_schedule(steps, ddim_eta=0., verbose=True)
139
+ if st.button("Sample"):
140
+ result = paint(
141
+ sampler=sampler,
142
+ image=image,
143
+ prompt=prompt,
144
+ t_enc=t_enc,
145
+ seed=seed,
146
+ scale=scale,
147
+ num_samples=num_samples,
148
+ callback=t_callback,
149
+ do_full_sample=do_full_sample,
150
+ )
151
+ st.write("Result")
152
+ for image in result:
153
+ st.image(image, output_format='PNG')
154
+
155
+
156
+ if __name__ == "__main__":
157
+ run()
repositories/stable-diffusion-stability-ai/scripts/streamlit/inpainting.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import streamlit as st
6
+ from PIL import Image
7
+ from omegaconf import OmegaConf
8
+ from einops import repeat
9
+ from streamlit_drawable_canvas import st_canvas
10
+ from imwatermark import WatermarkEncoder
11
+
12
+ from ldm.models.diffusion.ddim import DDIMSampler
13
+ from ldm.util import instantiate_from_config
14
+
15
+
16
+ torch.set_grad_enabled(False)
17
+
18
+
19
+ def put_watermark(img, wm_encoder=None):
20
+ if wm_encoder is not None:
21
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
22
+ img = wm_encoder.encode(img, 'dwtDct')
23
+ img = Image.fromarray(img[:, :, ::-1])
24
+ return img
25
+
26
+
27
+ @st.cache(allow_output_mutation=True)
28
+ def initialize_model(config, ckpt):
29
+ config = OmegaConf.load(config)
30
+ model = instantiate_from_config(config.model)
31
+
32
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
33
+
34
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
35
+ model = model.to(device)
36
+ sampler = DDIMSampler(model)
37
+
38
+ return sampler
39
+
40
+
41
+ def make_batch_sd(
42
+ image,
43
+ mask,
44
+ txt,
45
+ device,
46
+ num_samples=1):
47
+ image = np.array(image.convert("RGB"))
48
+ image = image[None].transpose(0, 3, 1, 2)
49
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
50
+
51
+ mask = np.array(mask.convert("L"))
52
+ mask = mask.astype(np.float32) / 255.0
53
+ mask = mask[None, None]
54
+ mask[mask < 0.5] = 0
55
+ mask[mask >= 0.5] = 1
56
+ mask = torch.from_numpy(mask)
57
+
58
+ masked_image = image * (mask < 0.5)
59
+
60
+ batch = {
61
+ "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
62
+ "txt": num_samples * [txt],
63
+ "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
64
+ "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
65
+ }
66
+ return batch
67
+
68
+
69
+ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.):
70
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
71
+ model = sampler.model
72
+
73
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
74
+ wm = "SDV2"
75
+ wm_encoder = WatermarkEncoder()
76
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
77
+
78
+ prng = np.random.RandomState(seed)
79
+ start_code = prng.randn(num_samples, 4, h // 8, w // 8)
80
+ start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
81
+
82
+ with torch.no_grad(), \
83
+ torch.autocast("cuda"):
84
+ batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
85
+
86
+ c = model.cond_stage_model.encode(batch["txt"])
87
+
88
+ c_cat = list()
89
+ for ck in model.concat_keys:
90
+ cc = batch[ck].float()
91
+ if ck != model.masked_image_key:
92
+ bchw = [num_samples, 4, h // 8, w // 8]
93
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
94
+ else:
95
+ cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
96
+ c_cat.append(cc)
97
+ c_cat = torch.cat(c_cat, dim=1)
98
+
99
+ # cond
100
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
101
+
102
+ # uncond cond
103
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
104
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
105
+
106
+ shape = [model.channels, h // 8, w // 8]
107
+ samples_cfg, intermediates = sampler.sample(
108
+ ddim_steps,
109
+ num_samples,
110
+ shape,
111
+ cond,
112
+ verbose=False,
113
+ eta=eta,
114
+ unconditional_guidance_scale=scale,
115
+ unconditional_conditioning=uc_full,
116
+ x_T=start_code,
117
+ )
118
+ x_samples_ddim = model.decode_first_stage(samples_cfg)
119
+
120
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0,
121
+ min=0.0, max=1.0)
122
+
123
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
124
+ return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
125
+
126
+
127
+ def run():
128
+ st.title("Stable Diffusion Inpainting")
129
+
130
+ sampler = initialize_model(sys.argv[1], sys.argv[2])
131
+
132
+ image = st.file_uploader("Image", ["jpg", "png"])
133
+ if image:
134
+ image = Image.open(image)
135
+ w, h = image.size
136
+ print(f"loaded input image of size ({w}, {h})")
137
+ width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
138
+ image = image.resize((width, height))
139
+
140
+ prompt = st.text_input("Prompt")
141
+
142
+ seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0)
143
+ num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1)
144
+ scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=10., step=0.1)
145
+ ddim_steps = st.slider("DDIM Steps", min_value=0, max_value=50, value=50, step=1)
146
+ eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
147
+
148
+ fill_color = "rgba(255, 255, 255, 0.0)"
149
+ stroke_width = st.number_input("Brush Size",
150
+ value=64,
151
+ min_value=1,
152
+ max_value=100)
153
+ stroke_color = "rgba(255, 255, 255, 1.0)"
154
+ bg_color = "rgba(0, 0, 0, 1.0)"
155
+ drawing_mode = "freedraw"
156
+
157
+ st.write("Canvas")
158
+ st.caption(
159
+ "Draw a mask to inpaint, then click the 'Send to Streamlit' button (bottom left, with an arrow on it).")
160
+ canvas_result = st_canvas(
161
+ fill_color=fill_color,
162
+ stroke_width=stroke_width,
163
+ stroke_color=stroke_color,
164
+ background_color=bg_color,
165
+ background_image=image,
166
+ update_streamlit=False,
167
+ height=height,
168
+ width=width,
169
+ drawing_mode=drawing_mode,
170
+ key="canvas",
171
+ )
172
+ if canvas_result:
173
+ mask = canvas_result.image_data
174
+ mask = mask[:, :, -1] > 0
175
+ if mask.sum() > 0:
176
+ mask = Image.fromarray(mask)
177
+
178
+ result = inpaint(
179
+ sampler=sampler,
180
+ image=image,
181
+ mask=mask,
182
+ prompt=prompt,
183
+ seed=seed,
184
+ scale=scale,
185
+ ddim_steps=ddim_steps,
186
+ num_samples=num_samples,
187
+ h=height, w=width, eta=eta
188
+ )
189
+ st.write("Inpainted")
190
+ for image in result:
191
+ st.image(image, output_format='PNG')
192
+
193
+
194
+ if __name__ == "__main__":
195
+ run()
repositories/stable-diffusion-stability-ai/scripts/streamlit/stableunclip.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import streamlit as st
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ import PIL
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ from tqdm import trange
10
+ import io, os
11
+ from torch import autocast
12
+ from einops import rearrange, repeat
13
+ from torchvision.utils import make_grid
14
+ from pytorch_lightning import seed_everything
15
+ from contextlib import nullcontext
16
+
17
+ from ldm.models.diffusion.ddim import DDIMSampler
18
+ from ldm.models.diffusion.plms import PLMSSampler
19
+ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
20
+
21
+ torch.set_grad_enabled(False)
22
+
23
+ PROMPTS_ROOT = "scripts/prompts/"
24
+ SAVE_PATH = "outputs/demo/stable-unclip/"
25
+
26
+ VERSION2SPECS = {
27
+ "Stable unCLIP-L": {"H": 768, "W": 768, "C": 4, "f": 8},
28
+ "Stable unOpenCLIP-H": {"H": 768, "W": 768, "C": 4, "f": 8},
29
+ "Full Karlo": {}
30
+ }
31
+
32
+
33
+ def get_obj_from_str(string, reload=False):
34
+ module, cls = string.rsplit(".", 1)
35
+ importlib.invalidate_caches()
36
+ if reload:
37
+ module_imp = importlib.import_module(module)
38
+ importlib.reload(module_imp)
39
+ return getattr(importlib.import_module(module, package=None), cls)
40
+
41
+
42
+ def instantiate_from_config(config):
43
+ if not "target" in config:
44
+ raise KeyError("Expected key `target` to instantiate.")
45
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
46
+
47
+
48
+ def get_interactive_image(key=None):
49
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
50
+ if image is not None:
51
+ image = Image.open(image)
52
+ if not image.mode == "RGB":
53
+ image = image.convert("RGB")
54
+ return image
55
+
56
+
57
+ def load_img(display=True, key=None):
58
+ image = get_interactive_image(key=key)
59
+ if display:
60
+ st.image(image)
61
+ w, h = image.size
62
+ print(f"loaded input image of size ({w}, {h})")
63
+ w, h = map(lambda x: x - x % 64, (w, h))
64
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
65
+ image = np.array(image).astype(np.float32) / 255.0
66
+ image = image[None].transpose(0, 3, 1, 2)
67
+ image = torch.from_numpy(image)
68
+ return 2. * image - 1.
69
+
70
+
71
+ def get_init_img(batch_size=1, key=None):
72
+ init_image = load_img(key=key).cuda()
73
+ init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
74
+ return init_image
75
+
76
+
77
+ def sample(
78
+ model,
79
+ prompt,
80
+ n_runs=3,
81
+ n_samples=2,
82
+ H=512,
83
+ W=512,
84
+ C=4,
85
+ f=8,
86
+ scale=10.0,
87
+ ddim_steps=50,
88
+ ddim_eta=0.0,
89
+ callback=None,
90
+ skip_single_save=False,
91
+ save_grid=True,
92
+ ucg_schedule=None,
93
+ negative_prompt="",
94
+ adm_cond=None,
95
+ adm_uc=None,
96
+ use_full_precision=False,
97
+ only_adm_cond=False
98
+ ):
99
+ batch_size = n_samples
100
+ precision_scope = autocast if not use_full_precision else nullcontext
101
+ # decoderscope = autocast if not use_full_precision else nullcontext
102
+ if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
103
+ if isinstance(prompt, str):
104
+ prompt = [prompt]
105
+ prompts = batch_size * prompt
106
+
107
+ outputs = st.empty()
108
+
109
+ with precision_scope("cuda"):
110
+ with model.ema_scope():
111
+ all_samples = list()
112
+ for n in trange(n_runs, desc="Sampling"):
113
+ shape = [C, H // f, W // f]
114
+ if not only_adm_cond:
115
+ uc = None
116
+ if scale != 1.0:
117
+ uc = model.get_learned_conditioning(batch_size * [negative_prompt])
118
+ if isinstance(prompts, tuple):
119
+ prompts = list(prompts)
120
+ c = model.get_learned_conditioning(prompts)
121
+
122
+ if adm_cond is not None:
123
+ if adm_cond.shape[0] == 1:
124
+ adm_cond = repeat(adm_cond, '1 ... -> b ...', b=batch_size)
125
+ if adm_uc is None:
126
+ st.warning("Not guiding via c_adm")
127
+ adm_uc = adm_cond
128
+ else:
129
+ if adm_uc.shape[0] == 1:
130
+ adm_uc = repeat(adm_uc, '1 ... -> b ...', b=batch_size)
131
+ if not only_adm_cond:
132
+ c = {"c_crossattn": [c], "c_adm": adm_cond}
133
+ uc = {"c_crossattn": [uc], "c_adm": adm_uc}
134
+ else:
135
+ c = adm_cond
136
+ uc = adm_uc
137
+ samples_ddim, _ = sampler.sample(S=ddim_steps,
138
+ conditioning=c,
139
+ batch_size=batch_size,
140
+ shape=shape,
141
+ verbose=False,
142
+ unconditional_guidance_scale=scale,
143
+ unconditional_conditioning=uc,
144
+ eta=ddim_eta,
145
+ x_T=None,
146
+ callback=callback,
147
+ ucg_schedule=ucg_schedule
148
+ )
149
+ x_samples = model.decode_first_stage(samples_ddim)
150
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
151
+
152
+ if not skip_single_save:
153
+ base_count = len(os.listdir(os.path.join(SAVE_PATH, "samples")))
154
+ for x_sample in x_samples:
155
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
156
+ Image.fromarray(x_sample.astype(np.uint8)).save(
157
+ os.path.join(SAVE_PATH, "samples", f"{base_count:09}.png"))
158
+ base_count += 1
159
+
160
+ all_samples.append(x_samples)
161
+
162
+ # get grid of all samples
163
+ grid = torch.stack(all_samples, 0)
164
+ grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
165
+ outputs.image(grid.cpu().numpy())
166
+
167
+ # additionally, save grid
168
+ grid = Image.fromarray((255. * grid.cpu().numpy()).astype(np.uint8))
169
+ if save_grid:
170
+ grid_count = len(os.listdir(SAVE_PATH)) - 1
171
+ grid.save(os.path.join(SAVE_PATH, f'grid-{grid_count:06}.png'))
172
+
173
+ return x_samples
174
+
175
+
176
+ def make_oscillating_guidance_schedule(num_steps, max_weight=15., min_weight=1.):
177
+ schedule = list()
178
+ for i in range(num_steps):
179
+ if float(i / num_steps) < 0.1:
180
+ schedule.append(max_weight)
181
+ elif i % 2 == 0:
182
+ schedule.append(min_weight)
183
+ else:
184
+ schedule.append(max_weight)
185
+ print(f"OSCILLATING GUIDANCE SCHEDULE: \n {schedule}")
186
+ return schedule
187
+
188
+
189
+ def torch2np(x):
190
+ x = ((x + 1.0) * 127.5).clamp(0, 255).to(dtype=torch.uint8)
191
+ x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
192
+ return x
193
+
194
+
195
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
196
+ def init(version="Stable unCLIP-L", load_karlo_prior=False):
197
+ state = dict()
198
+ if not "model" in state:
199
+ if version == "Stable unCLIP-L":
200
+ config = "configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml"
201
+ ckpt = "checkpoints/sd21-unclip-l.ckpt"
202
+
203
+ elif version == "Stable unOpenCLIP-H":
204
+ config = "configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml"
205
+ ckpt = "checkpoints/sd21-unclip-h.ckpt"
206
+
207
+ elif version == "Full Karlo":
208
+ from ldm.modules.karlo.kakao.sampler import T2ISampler
209
+ st.info("Loading full KARLO..")
210
+ karlo = T2ISampler.from_pretrained(
211
+ root_dir="checkpoints/karlo_models",
212
+ clip_model_path="ViT-L-14.pt",
213
+ clip_stat_path="ViT-L-14_stats.th",
214
+ sampling_type="default",
215
+ )
216
+ state["karlo_prior"] = karlo
217
+ state["msg"] = "loaded full Karlo"
218
+ return state
219
+ else:
220
+ raise ValueError(f"version {version} unknown!")
221
+
222
+ config = OmegaConf.load(config)
223
+ model, msg = load_model_from_config(config, ckpt, vae_sd=None)
224
+ state["msg"] = msg
225
+
226
+ if load_karlo_prior:
227
+ from ldm.modules.karlo.kakao.sampler import PriorSampler
228
+ st.info("Loading KARLO CLIP prior...")
229
+ karlo_prior = PriorSampler.from_pretrained(
230
+ root_dir="checkpoints/karlo_models",
231
+ clip_model_path="ViT-L-14.pt",
232
+ clip_stat_path="ViT-L-14_stats.th",
233
+ sampling_type="default",
234
+ )
235
+ state["karlo_prior"] = karlo_prior
236
+ state["model"] = model
237
+ state["ckpt"] = ckpt
238
+ state["config"] = config
239
+ return state
240
+
241
+
242
+ def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
243
+ print(f"Loading model from {ckpt}")
244
+ pl_sd = torch.load(ckpt, map_location="cpu")
245
+ msg = None
246
+ if "global_step" in pl_sd:
247
+ msg = f"This is global step {pl_sd['global_step']}. "
248
+ if "model_ema.num_updates" in pl_sd["state_dict"]:
249
+ msg += f"And we got {pl_sd['state_dict']['model_ema.num_updates']} EMA updates."
250
+ global_step = pl_sd.get("global_step", "?")
251
+ sd = pl_sd["state_dict"]
252
+ if vae_sd is not None:
253
+ for k in sd.keys():
254
+ if "first_stage" in k:
255
+ sd[k] = vae_sd[k[len("first_stage_model."):]]
256
+
257
+ model = instantiate_from_config(config.model)
258
+ m, u = model.load_state_dict(sd, strict=False)
259
+ if len(m) > 0 and verbose:
260
+ print("missing keys:")
261
+ print(m)
262
+ if len(u) > 0 and verbose:
263
+ print("unexpected keys:")
264
+ print(u)
265
+
266
+ model.cuda()
267
+ model.eval()
268
+ print(f"Loaded global step {global_step}")
269
+ return model, msg
270
+
271
+
272
+ if __name__ == "__main__":
273
+ st.title("Stable unCLIP")
274
+ mode = "txt2img"
275
+ version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
276
+ use_karlo_prior = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
277
+ state = init(version=version, load_karlo_prior=use_karlo_prior)
278
+ prompt = st.text_input("Prompt", "a professional photograph")
279
+ negative_prompt = st.text_input("Negative Prompt", "")
280
+ scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
281
+ number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
282
+ number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
283
+ steps = st.sidebar.number_input("steps", value=20, min_value=1, max_value=1000)
284
+ eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
285
+ force_full_precision = st.sidebar.checkbox("Force FP32", False) # TODO: check if/where things break.
286
+ if version != "Full Karlo":
287
+ H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048)
288
+ W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048)
289
+ C = VERSION2SPECS[version]["C"]
290
+ f = VERSION2SPECS[version]["f"]
291
+
292
+ SAVE_PATH = os.path.join(SAVE_PATH, version)
293
+ os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True)
294
+
295
+ seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
296
+ seed_everything(seed)
297
+
298
+ ucg_schedule = None
299
+ sampler = st.sidebar.selectbox("Sampler", ["DDIM", "DPM"], 0)
300
+ if version == "Full Karlo":
301
+ pass
302
+ else:
303
+ if sampler == "DPM":
304
+ sampler = DPMSolverSampler(state["model"])
305
+ elif sampler == "DDIM":
306
+ sampler = DDIMSampler(state["model"])
307
+ else:
308
+ raise ValueError(f"unknown sampler {sampler}!")
309
+
310
+ adm_cond, adm_uc = None, None
311
+ if use_karlo_prior:
312
+ # uses the prior
313
+ karlo_sampler = state["karlo_prior"]
314
+ noise_level = None
315
+ if state["model"].noise_augmentor is not None:
316
+ noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
317
+ max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
318
+ with torch.no_grad():
319
+ karlo_prediction = iter(
320
+ karlo_sampler(
321
+ prompt=prompt,
322
+ bsz=number_cols,
323
+ progressive_mode="final",
324
+ )
325
+ ).__next__()
326
+ adm_cond = karlo_prediction
327
+ if noise_level is not None:
328
+ c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
329
+ torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
330
+ adm_cond = torch.cat((c_adm, noise_level_emb), 1)
331
+ adm_uc = torch.zeros_like(adm_cond)
332
+ elif version == "Full Karlo":
333
+ pass
334
+ else:
335
+ num_inputs = st.number_input("Number of Input Images", 1)
336
+
337
+
338
+ def make_conditionings_from_input(num=1, key=None):
339
+ init_img = get_init_img(batch_size=number_cols, key=key)
340
+ with torch.no_grad():
341
+ adm_cond = state["model"].embedder(init_img)
342
+ weight = st.slider(f"Weight for Input {num}", min_value=-10., max_value=10., value=1.)
343
+ if state["model"].noise_augmentor is not None:
344
+ noise_level = st.number_input(f"Noise Augmentation for CLIP embedding of input #{num}", min_value=0,
345
+ max_value=state["model"].noise_augmentor.max_noise_level - 1,
346
+ value=0, )
347
+ c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
348
+ torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
349
+ adm_cond = torch.cat((c_adm, noise_level_emb), 1) * weight
350
+ adm_uc = torch.zeros_like(adm_cond)
351
+ return adm_cond, adm_uc, weight
352
+
353
+
354
+ adm_inputs = list()
355
+ weights = list()
356
+ for n in range(num_inputs):
357
+ adm_cond, adm_uc, w = make_conditionings_from_input(num=n + 1, key=n)
358
+ weights.append(w)
359
+ adm_inputs.append(adm_cond)
360
+ adm_cond = torch.stack(adm_inputs).sum(0) / sum(weights)
361
+ if num_inputs > 1:
362
+ if st.checkbox("Apply Noise to Embedding Mix", True):
363
+ noise_level = st.number_input(f"Noise Augmentation for averaged CLIP embeddings", min_value=0,
364
+ max_value=state["model"].noise_augmentor.max_noise_level - 1, value=50, )
365
+ c_adm, noise_level_emb = state["model"].noise_augmentor(
366
+ adm_cond[:, :state["model"].noise_augmentor.time_embed.dim],
367
+ noise_level=repeat(
368
+ torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
369
+ adm_cond = torch.cat((c_adm, noise_level_emb), 1)
370
+
371
+ if st.button("Sample"):
372
+ print("running prompt:", prompt)
373
+ st.text("Sampling")
374
+ t_progress = st.progress(0)
375
+ result = st.empty()
376
+
377
+
378
+ def t_callback(t):
379
+ t_progress.progress(min((t + 1) / steps, 1.))
380
+
381
+
382
+ if version == "Full Karlo":
383
+ outputs = st.empty()
384
+ karlo_sampler = state["karlo_prior"]
385
+ all_samples = list()
386
+ with torch.no_grad():
387
+ for _ in range(number_rows):
388
+ karlo_prediction = iter(
389
+ karlo_sampler(
390
+ prompt=prompt,
391
+ bsz=number_cols,
392
+ progressive_mode="final",
393
+ )
394
+ ).__next__()
395
+ all_samples.append(karlo_prediction)
396
+ grid = torch.stack(all_samples, 0)
397
+ grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
398
+ outputs.image(grid.cpu().numpy())
399
+
400
+ else:
401
+ samples = sample(
402
+ state["model"],
403
+ prompt,
404
+ n_runs=number_rows,
405
+ n_samples=number_cols,
406
+ H=H, W=W, C=C, f=f,
407
+ scale=scale,
408
+ ddim_steps=steps,
409
+ ddim_eta=eta,
410
+ callback=t_callback,
411
+ ucg_schedule=ucg_schedule,
412
+ negative_prompt=negative_prompt,
413
+ adm_cond=adm_cond, adm_uc=adm_uc,
414
+ use_full_precision=force_full_precision,
415
+ only_adm_cond=False
416
+ )
repositories/stable-diffusion-stability-ai/scripts/streamlit/superresolution.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import numpy as np
4
+ import streamlit as st
5
+ from PIL import Image
6
+ from omegaconf import OmegaConf
7
+ from einops import repeat, rearrange
8
+ from pytorch_lightning import seed_everything
9
+ from imwatermark import WatermarkEncoder
10
+
11
+ from scripts.txt2img import put_watermark
12
+ from ldm.models.diffusion.ddim import DDIMSampler
13
+ from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentUpscaleFinetuneDiffusion
14
+ from ldm.util import exists, instantiate_from_config
15
+
16
+
17
+ torch.set_grad_enabled(False)
18
+
19
+
20
+ @st.cache(allow_output_mutation=True)
21
+ def initialize_model(config, ckpt):
22
+ config = OmegaConf.load(config)
23
+ model = instantiate_from_config(config.model)
24
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
25
+
26
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
27
+ model = model.to(device)
28
+ sampler = DDIMSampler(model)
29
+ return sampler
30
+
31
+
32
+ def make_batch_sd(
33
+ image,
34
+ txt,
35
+ device,
36
+ num_samples=1,
37
+ ):
38
+ image = np.array(image.convert("RGB"))
39
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
40
+ batch = {
41
+ "lr": rearrange(image, 'h w c -> 1 c h w'),
42
+ "txt": num_samples * [txt],
43
+ }
44
+ batch["lr"] = repeat(batch["lr"].to(device=device), "1 ... -> n ...", n=num_samples)
45
+ return batch
46
+
47
+
48
+ def make_noise_augmentation(model, batch, noise_level=None):
49
+ x_low = batch[model.low_scale_key]
50
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
51
+ x_aug, noise_level = model.low_scale_model(x_low, noise_level)
52
+ return x_aug, noise_level
53
+
54
+
55
+ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
56
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
57
+ model = sampler.model
58
+ seed_everything(seed)
59
+ prng = np.random.RandomState(seed)
60
+ start_code = prng.randn(num_samples, model.channels, h , w)
61
+ start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
62
+
63
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
64
+ wm = "SDV2"
65
+ wm_encoder = WatermarkEncoder()
66
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
67
+ with torch.no_grad(),\
68
+ torch.autocast("cuda"):
69
+ batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
70
+ c = model.cond_stage_model.encode(batch["txt"])
71
+ c_cat = list()
72
+ if isinstance(model, LatentUpscaleFinetuneDiffusion):
73
+ for ck in model.concat_keys:
74
+ cc = batch[ck]
75
+ if exists(model.reshuffle_patch_size):
76
+ assert isinstance(model.reshuffle_patch_size, int)
77
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
78
+ p1=model.reshuffle_patch_size, p2=model.reshuffle_patch_size)
79
+ c_cat.append(cc)
80
+ c_cat = torch.cat(c_cat, dim=1)
81
+ # cond
82
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
83
+ # uncond cond
84
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
85
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
86
+ elif isinstance(model, LatentUpscaleDiffusion):
87
+ x_augment, noise_level = make_noise_augmentation(model, batch, noise_level)
88
+ cond = {"c_concat": [x_augment], "c_crossattn": [c], "c_adm": noise_level}
89
+ # uncond cond
90
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
91
+ uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
92
+ else:
93
+ raise NotImplementedError()
94
+
95
+ shape = [model.channels, h, w]
96
+ samples, intermediates = sampler.sample(
97
+ steps,
98
+ num_samples,
99
+ shape,
100
+ cond,
101
+ verbose=False,
102
+ eta=eta,
103
+ unconditional_guidance_scale=scale,
104
+ unconditional_conditioning=uc_full,
105
+ x_T=start_code,
106
+ callback=callback
107
+ )
108
+ with torch.no_grad():
109
+ x_samples_ddim = model.decode_first_stage(samples)
110
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
111
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
112
+ st.text(f"upscaled image shape: {result.shape}")
113
+ return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
114
+
115
+
116
+ def run():
117
+ st.title("Stable Diffusion Upscaling")
118
+ # run via streamlit run scripts/demo/depth2img.py <path-tp-config> <path-to-ckpt>
119
+ sampler = initialize_model(sys.argv[1], sys.argv[2])
120
+
121
+ image = st.file_uploader("Image", ["jpg", "png"])
122
+ if image:
123
+ image = Image.open(image)
124
+ w, h = image.size
125
+ st.text(f"loaded input image of size ({w}, {h})")
126
+ width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
127
+ image = image.resize((width, height))
128
+ st.text(f"resized input image to size ({width}, {height} (w, h))")
129
+ st.image(image)
130
+
131
+ st.write(f"\n Tip: Add a description of the object that should be upscaled, e.g.: 'a professional photograph of a cat'")
132
+ prompt = st.text_input("Prompt", "a high quality professional photograph")
133
+
134
+ seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0)
135
+ num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1)
136
+ scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1)
137
+ steps = st.slider("DDIM Steps", min_value=2, max_value=250, value=50, step=1)
138
+ eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
139
+
140
+ noise_level = None
141
+ if isinstance(sampler.model, LatentUpscaleDiffusion):
142
+ # TODO: make this work for all models
143
+ noise_level = st.sidebar.number_input("Noise Augmentation", min_value=0, max_value=350, value=20)
144
+ noise_level = torch.Tensor(num_samples * [noise_level]).to(sampler.model.device).long()
145
+
146
+ t_progress = st.progress(0)
147
+ def t_callback(t):
148
+ t_progress.progress(min((t + 1) / steps, 1.))
149
+
150
+ sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
151
+ if st.button("Sample"):
152
+ result = paint(
153
+ sampler=sampler,
154
+ image=image,
155
+ prompt=prompt,
156
+ seed=seed,
157
+ scale=scale,
158
+ h=height, w=width, steps=steps,
159
+ num_samples=num_samples,
160
+ callback=t_callback,
161
+ noise_level=noise_level,
162
+ eta=eta
163
+ )
164
+ st.write("Result")
165
+ for image in result:
166
+ st.image(image, output_format='PNG')
167
+
168
+
169
+ if __name__ == "__main__":
170
+ run()
repositories/stable-diffusion-stability-ai/scripts/tests/test_watermark.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import fire
3
+ from imwatermark import WatermarkDecoder
4
+
5
+
6
+ def testit(img_path):
7
+ bgr = cv2.imread(img_path)
8
+ decoder = WatermarkDecoder('bytes', 136)
9
+ watermark = decoder.decode(bgr, 'dwtDct')
10
+ try:
11
+ dec = watermark.decode('utf-8')
12
+ except:
13
+ dec = "null"
14
+ print(dec)
15
+
16
+
17
+ if __name__ == "__main__":
18
+ fire.Fire(testit)
repositories/stable-diffusion-stability-ai/scripts/txt2img.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from omegaconf import OmegaConf
6
+ from PIL import Image
7
+ from tqdm import tqdm, trange
8
+ from itertools import islice
9
+ from einops import rearrange
10
+ from torchvision.utils import make_grid
11
+ from pytorch_lightning import seed_everything
12
+ from torch import autocast
13
+ from contextlib import nullcontext
14
+ from imwatermark import WatermarkEncoder
15
+
16
+ from ldm.util import instantiate_from_config
17
+ from ldm.models.diffusion.ddim import DDIMSampler
18
+ from ldm.models.diffusion.plms import PLMSSampler
19
+ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
20
+
21
+ torch.set_grad_enabled(False)
22
+
23
+ def chunk(it, size):
24
+ it = iter(it)
25
+ return iter(lambda: tuple(islice(it, size)), ())
26
+
27
+
28
+ def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False):
29
+ print(f"Loading model from {ckpt}")
30
+ pl_sd = torch.load(ckpt, map_location="cpu")
31
+ if "global_step" in pl_sd:
32
+ print(f"Global Step: {pl_sd['global_step']}")
33
+ sd = pl_sd["state_dict"]
34
+ model = instantiate_from_config(config.model)
35
+ m, u = model.load_state_dict(sd, strict=False)
36
+ if len(m) > 0 and verbose:
37
+ print("missing keys:")
38
+ print(m)
39
+ if len(u) > 0 and verbose:
40
+ print("unexpected keys:")
41
+ print(u)
42
+
43
+ if device == torch.device("cuda"):
44
+ model.cuda()
45
+ elif device == torch.device("cpu"):
46
+ model.cpu()
47
+ model.cond_stage_model.device = "cpu"
48
+ else:
49
+ raise ValueError(f"Incorrect device name. Received: {device}")
50
+ model.eval()
51
+ return model
52
+
53
+
54
+ def parse_args():
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument(
57
+ "--prompt",
58
+ type=str,
59
+ nargs="?",
60
+ default="a professional photograph of an astronaut riding a triceratops",
61
+ help="the prompt to render"
62
+ )
63
+ parser.add_argument(
64
+ "--outdir",
65
+ type=str,
66
+ nargs="?",
67
+ help="dir to write results to",
68
+ default="outputs/txt2img-samples"
69
+ )
70
+ parser.add_argument(
71
+ "--steps",
72
+ type=int,
73
+ default=50,
74
+ help="number of ddim sampling steps",
75
+ )
76
+ parser.add_argument(
77
+ "--plms",
78
+ action='store_true',
79
+ help="use plms sampling",
80
+ )
81
+ parser.add_argument(
82
+ "--dpm",
83
+ action='store_true',
84
+ help="use DPM (2) sampler",
85
+ )
86
+ parser.add_argument(
87
+ "--fixed_code",
88
+ action='store_true',
89
+ help="if enabled, uses the same starting code across all samples ",
90
+ )
91
+ parser.add_argument(
92
+ "--ddim_eta",
93
+ type=float,
94
+ default=0.0,
95
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
96
+ )
97
+ parser.add_argument(
98
+ "--n_iter",
99
+ type=int,
100
+ default=3,
101
+ help="sample this often",
102
+ )
103
+ parser.add_argument(
104
+ "--H",
105
+ type=int,
106
+ default=512,
107
+ help="image height, in pixel space",
108
+ )
109
+ parser.add_argument(
110
+ "--W",
111
+ type=int,
112
+ default=512,
113
+ help="image width, in pixel space",
114
+ )
115
+ parser.add_argument(
116
+ "--C",
117
+ type=int,
118
+ default=4,
119
+ help="latent channels",
120
+ )
121
+ parser.add_argument(
122
+ "--f",
123
+ type=int,
124
+ default=8,
125
+ help="downsampling factor, most often 8 or 16",
126
+ )
127
+ parser.add_argument(
128
+ "--n_samples",
129
+ type=int,
130
+ default=3,
131
+ help="how many samples to produce for each given prompt. A.k.a batch size",
132
+ )
133
+ parser.add_argument(
134
+ "--n_rows",
135
+ type=int,
136
+ default=0,
137
+ help="rows in the grid (default: n_samples)",
138
+ )
139
+ parser.add_argument(
140
+ "--scale",
141
+ type=float,
142
+ default=9.0,
143
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
144
+ )
145
+ parser.add_argument(
146
+ "--from-file",
147
+ type=str,
148
+ help="if specified, load prompts from this file, separated by newlines",
149
+ )
150
+ parser.add_argument(
151
+ "--config",
152
+ type=str,
153
+ default="configs/stable-diffusion/v2-inference.yaml",
154
+ help="path to config which constructs model",
155
+ )
156
+ parser.add_argument(
157
+ "--ckpt",
158
+ type=str,
159
+ help="path to checkpoint of model",
160
+ )
161
+ parser.add_argument(
162
+ "--seed",
163
+ type=int,
164
+ default=42,
165
+ help="the seed (for reproducible sampling)",
166
+ )
167
+ parser.add_argument(
168
+ "--precision",
169
+ type=str,
170
+ help="evaluate at this precision",
171
+ choices=["full", "autocast"],
172
+ default="autocast"
173
+ )
174
+ parser.add_argument(
175
+ "--repeat",
176
+ type=int,
177
+ default=1,
178
+ help="repeat each prompt in file this often",
179
+ )
180
+ parser.add_argument(
181
+ "--device",
182
+ type=str,
183
+ help="Device on which Stable Diffusion will be run",
184
+ choices=["cpu", "cuda"],
185
+ default="cpu"
186
+ )
187
+ parser.add_argument(
188
+ "--torchscript",
189
+ action='store_true',
190
+ help="Use TorchScript",
191
+ )
192
+ parser.add_argument(
193
+ "--ipex",
194
+ action='store_true',
195
+ help="Use Intel® Extension for PyTorch*",
196
+ )
197
+ parser.add_argument(
198
+ "--bf16",
199
+ action='store_true',
200
+ help="Use bfloat16",
201
+ )
202
+ opt = parser.parse_args()
203
+ return opt
204
+
205
+
206
+ def put_watermark(img, wm_encoder=None):
207
+ if wm_encoder is not None:
208
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
209
+ img = wm_encoder.encode(img, 'dwtDct')
210
+ img = Image.fromarray(img[:, :, ::-1])
211
+ return img
212
+
213
+
214
+ def main(opt):
215
+ seed_everything(opt.seed)
216
+
217
+ config = OmegaConf.load(f"{opt.config}")
218
+ device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
219
+ model = load_model_from_config(config, f"{opt.ckpt}", device)
220
+
221
+ if opt.plms:
222
+ sampler = PLMSSampler(model, device=device)
223
+ elif opt.dpm:
224
+ sampler = DPMSolverSampler(model, device=device)
225
+ else:
226
+ sampler = DDIMSampler(model, device=device)
227
+
228
+ os.makedirs(opt.outdir, exist_ok=True)
229
+ outpath = opt.outdir
230
+
231
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
232
+ wm = "SDV2"
233
+ wm_encoder = WatermarkEncoder()
234
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
235
+
236
+ batch_size = opt.n_samples
237
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
238
+ if not opt.from_file:
239
+ prompt = opt.prompt
240
+ assert prompt is not None
241
+ data = [batch_size * [prompt]]
242
+
243
+ else:
244
+ print(f"reading prompts from {opt.from_file}")
245
+ with open(opt.from_file, "r") as f:
246
+ data = f.read().splitlines()
247
+ data = [p for p in data for i in range(opt.repeat)]
248
+ data = list(chunk(data, batch_size))
249
+
250
+ sample_path = os.path.join(outpath, "samples")
251
+ os.makedirs(sample_path, exist_ok=True)
252
+ sample_count = 0
253
+ base_count = len(os.listdir(sample_path))
254
+ grid_count = len(os.listdir(outpath)) - 1
255
+
256
+ start_code = None
257
+ if opt.fixed_code:
258
+ start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
259
+
260
+ if opt.torchscript or opt.ipex:
261
+ transformer = model.cond_stage_model.model
262
+ unet = model.model.diffusion_model
263
+ decoder = model.first_stage_model.decoder
264
+ additional_context = torch.cpu.amp.autocast() if opt.bf16 else nullcontext()
265
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
266
+
267
+ if opt.bf16 and not opt.torchscript and not opt.ipex:
268
+ raise ValueError('Bfloat16 is supported only for torchscript+ipex')
269
+ if opt.bf16 and unet.dtype != torch.bfloat16:
270
+ raise ValueError("Use configs/stable-diffusion/intel/ configs with bf16 enabled if " +
271
+ "you'd like to use bfloat16 with CPU.")
272
+ if unet.dtype == torch.float16 and device == torch.device("cpu"):
273
+ raise ValueError("Use configs/stable-diffusion/intel/ configs for your model if you'd like to run it on CPU.")
274
+
275
+ if opt.ipex:
276
+ import intel_extension_for_pytorch as ipex
277
+ bf16_dtype = torch.bfloat16 if opt.bf16 else None
278
+ transformer = transformer.to(memory_format=torch.channels_last)
279
+ transformer = ipex.optimize(transformer, level="O1", inplace=True)
280
+
281
+ unet = unet.to(memory_format=torch.channels_last)
282
+ unet = ipex.optimize(unet, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
283
+
284
+ decoder = decoder.to(memory_format=torch.channels_last)
285
+ decoder = ipex.optimize(decoder, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
286
+
287
+ if opt.torchscript:
288
+ with torch.no_grad(), additional_context:
289
+ # get UNET scripted
290
+ if unet.use_checkpoint:
291
+ raise ValueError("Gradient checkpoint won't work with tracing. " +
292
+ "Use configs/stable-diffusion/intel/ configs for your model or disable checkpoint in your config.")
293
+
294
+ img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32)
295
+ t_in = torch.ones(2, dtype=torch.int64)
296
+ context = torch.ones(2, 77, 1024, dtype=torch.float32)
297
+ scripted_unet = torch.jit.trace(unet, (img_in, t_in, context))
298
+ scripted_unet = torch.jit.optimize_for_inference(scripted_unet)
299
+ print(type(scripted_unet))
300
+ model.model.scripted_diffusion_model = scripted_unet
301
+
302
+ # get Decoder for first stage model scripted
303
+ samples_ddim = torch.ones(1, 4, 96, 96, dtype=torch.float32)
304
+ scripted_decoder = torch.jit.trace(decoder, (samples_ddim))
305
+ scripted_decoder = torch.jit.optimize_for_inference(scripted_decoder)
306
+ print(type(scripted_decoder))
307
+ model.first_stage_model.decoder = scripted_decoder
308
+
309
+ prompts = data[0]
310
+ print("Running a forward pass to initialize optimizations")
311
+ uc = None
312
+ if opt.scale != 1.0:
313
+ uc = model.get_learned_conditioning(batch_size * [""])
314
+ if isinstance(prompts, tuple):
315
+ prompts = list(prompts)
316
+
317
+ with torch.no_grad(), additional_context:
318
+ for _ in range(3):
319
+ c = model.get_learned_conditioning(prompts)
320
+ samples_ddim, _ = sampler.sample(S=5,
321
+ conditioning=c,
322
+ batch_size=batch_size,
323
+ shape=shape,
324
+ verbose=False,
325
+ unconditional_guidance_scale=opt.scale,
326
+ unconditional_conditioning=uc,
327
+ eta=opt.ddim_eta,
328
+ x_T=start_code)
329
+ print("Running a forward pass for decoder")
330
+ for _ in range(3):
331
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
332
+
333
+ precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext
334
+ with torch.no_grad(), \
335
+ precision_scope(opt.device), \
336
+ model.ema_scope():
337
+ all_samples = list()
338
+ for n in trange(opt.n_iter, desc="Sampling"):
339
+ for prompts in tqdm(data, desc="data"):
340
+ uc = None
341
+ if opt.scale != 1.0:
342
+ uc = model.get_learned_conditioning(batch_size * [""])
343
+ if isinstance(prompts, tuple):
344
+ prompts = list(prompts)
345
+ c = model.get_learned_conditioning(prompts)
346
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
347
+ samples, _ = sampler.sample(S=opt.steps,
348
+ conditioning=c,
349
+ batch_size=opt.n_samples,
350
+ shape=shape,
351
+ verbose=False,
352
+ unconditional_guidance_scale=opt.scale,
353
+ unconditional_conditioning=uc,
354
+ eta=opt.ddim_eta,
355
+ x_T=start_code)
356
+
357
+ x_samples = model.decode_first_stage(samples)
358
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
359
+
360
+ for x_sample in x_samples:
361
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
362
+ img = Image.fromarray(x_sample.astype(np.uint8))
363
+ img = put_watermark(img, wm_encoder)
364
+ img.save(os.path.join(sample_path, f"{base_count:05}.png"))
365
+ base_count += 1
366
+ sample_count += 1
367
+
368
+ all_samples.append(x_samples)
369
+
370
+ # additionally, save as grid
371
+ grid = torch.stack(all_samples, 0)
372
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
373
+ grid = make_grid(grid, nrow=n_rows)
374
+
375
+ # to image
376
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
377
+ grid = Image.fromarray(grid.astype(np.uint8))
378
+ grid = put_watermark(grid, wm_encoder)
379
+ grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
380
+ grid_count += 1
381
+
382
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
383
+ f" \nEnjoy.")
384
+
385
+
386
+ if __name__ == "__main__":
387
+ opt = parse_args()
388
+ main(opt)
repositories/stable-diffusion-stability-ai/setup.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='stable-diffusion',
5
+ version='0.0.1',
6
+ description='',
7
+ packages=find_packages(),
8
+ install_requires=[
9
+ 'torch',
10
+ 'numpy',
11
+ 'tqdm',
12
+ ],
13
+ )
requirements-test.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pytest-base-url~=2.0
2
+ pytest-cov~=4.0
3
+ pytest~=7.3
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GitPython
2
+ Pillow
3
+ accelerate
4
+
5
+ basicsr
6
+ blendmodes
7
+ clean-fid
8
+ einops
9
+ gfpgan
10
+ gradio==3.32.0
11
+ inflection
12
+ jsonmerge
13
+ kornia
14
+ lark
15
+ numpy
16
+ omegaconf
17
+ open-clip-torch
18
+
19
+ piexif
20
+ psutil
21
+ pytorch_lightning
22
+ realesrgan
23
+ requests
24
+ resize-right
25
+
26
+ safetensors
27
+ scikit-image>=0.19
28
+ timm
29
+ tomesd
30
+ torch
31
+ torchdiffeq
32
+ torchsde
33
+ transformers==4.25.1
requirements_versions.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GitPython==3.1.30
2
+ Pillow==9.5.0
3
+ accelerate==0.18.0
4
+ basicsr==1.4.2
5
+ blendmodes==2022
6
+ clean-fid==0.1.35
7
+ einops==0.4.1
8
+ fastapi==0.94.0
9
+ gfpgan==1.3.8
10
+ gradio==3.32.0
11
+ httpcore==0.15
12
+ inflection==0.5.1
13
+ jsonmerge==1.8.0
14
+ kornia==0.6.7
15
+ lark==1.1.2
16
+ numpy==1.23.5
17
+ omegaconf==2.2.3
18
+ open-clip-torch==2.20.0
19
+ piexif==1.1.3
20
+ psutil==5.9.5
21
+ pytorch_lightning==1.9.4
22
+ realesrgan==0.3.0
23
+ resize-right==0.0.2
24
+ safetensors==0.3.1
25
+ scikit-image==0.20.0
26
+ timm==0.6.7
27
+ tomesd==0.1.2
28
+ torch
29
+ torchdiffeq==0.2.3
30
+ torchsde==0.2.5
31
+ transformers==4.25.1
screenshot.png ADDED
script.js ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function gradioApp() {
2
+ const elems = document.getElementsByTagName('gradio-app');
3
+ const elem = elems.length == 0 ? document : elems[0];
4
+
5
+ if (elem !== document) {
6
+ elem.getElementById = function(id) {
7
+ return document.getElementById(id);
8
+ };
9
+ }
10
+ return elem.shadowRoot ? elem.shadowRoot : elem;
11
+ }
12
+
13
+ /**
14
+ * Get the currently selected top-level UI tab button (e.g. the button that says "Extras").
15
+ */
16
+ function get_uiCurrentTab() {
17
+ return gradioApp().querySelector('#tabs > .tab-nav > button.selected');
18
+ }
19
+
20
+ /**
21
+ * Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI).
22
+ */
23
+ function get_uiCurrentTabContent() {
24
+ return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])');
25
+ }
26
+
27
+ var uiUpdateCallbacks = [];
28
+ var uiAfterUpdateCallbacks = [];
29
+ var uiLoadedCallbacks = [];
30
+ var uiTabChangeCallbacks = [];
31
+ var optionsChangedCallbacks = [];
32
+ var uiAfterUpdateTimeout = null;
33
+ var uiCurrentTab = null;
34
+
35
+ /**
36
+ * Register callback to be called at each UI update.
37
+ * The callback receives an array of MutationRecords as an argument.
38
+ */
39
+ function onUiUpdate(callback) {
40
+ uiUpdateCallbacks.push(callback);
41
+ }
42
+
43
+ /**
44
+ * Register callback to be called soon after UI updates.
45
+ * The callback receives no arguments.
46
+ *
47
+ * This is preferred over `onUiUpdate` if you don't need
48
+ * access to the MutationRecords, as your function will
49
+ * not be called quite as often.
50
+ */
51
+ function onAfterUiUpdate(callback) {
52
+ uiAfterUpdateCallbacks.push(callback);
53
+ }
54
+
55
+ /**
56
+ * Register callback to be called when the UI is loaded.
57
+ * The callback receives no arguments.
58
+ */
59
+ function onUiLoaded(callback) {
60
+ uiLoadedCallbacks.push(callback);
61
+ }
62
+
63
+ /**
64
+ * Register callback to be called when the UI tab is changed.
65
+ * The callback receives no arguments.
66
+ */
67
+ function onUiTabChange(callback) {
68
+ uiTabChangeCallbacks.push(callback);
69
+ }
70
+
71
+ /**
72
+ * Register callback to be called when the options are changed.
73
+ * The callback receives no arguments.
74
+ * @param callback
75
+ */
76
+ function onOptionsChanged(callback) {
77
+ optionsChangedCallbacks.push(callback);
78
+ }
79
+
80
+ function executeCallbacks(queue, arg) {
81
+ for (const callback of queue) {
82
+ try {
83
+ callback(arg);
84
+ } catch (e) {
85
+ console.error("error running callback", callback, ":", e);
86
+ }
87
+ }
88
+ }
89
+
90
+ /**
91
+ * Schedule the execution of the callbacks registered with onAfterUiUpdate.
92
+ * The callbacks are executed after a short while, unless another call to this function
93
+ * is made before that time. IOW, the callbacks are executed only once, even
94
+ * when there are multiple mutations observed.
95
+ */
96
+ function scheduleAfterUiUpdateCallbacks() {
97
+ clearTimeout(uiAfterUpdateTimeout);
98
+ uiAfterUpdateTimeout = setTimeout(function() {
99
+ executeCallbacks(uiAfterUpdateCallbacks);
100
+ }, 200);
101
+ }
102
+
103
+ var executedOnLoaded = false;
104
+
105
+ document.addEventListener("DOMContentLoaded", function() {
106
+ var mutationObserver = new MutationObserver(function(m) {
107
+ if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) {
108
+ executedOnLoaded = true;
109
+ executeCallbacks(uiLoadedCallbacks);
110
+ }
111
+
112
+ executeCallbacks(uiUpdateCallbacks, m);
113
+ scheduleAfterUiUpdateCallbacks();
114
+ const newTab = get_uiCurrentTab();
115
+ if (newTab && (newTab !== uiCurrentTab)) {
116
+ uiCurrentTab = newTab;
117
+ executeCallbacks(uiTabChangeCallbacks);
118
+ }
119
+ });
120
+ mutationObserver.observe(gradioApp(), {childList: true, subtree: true});
121
+ });
122
+
123
+ /**
124
+ * Add a ctrl+enter as a shortcut to start a generation
125
+ */
126
+ document.addEventListener('keydown', function(e) {
127
+ var handled = false;
128
+ if (e.key !== undefined) {
129
+ if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
130
+ } else if (e.keyCode !== undefined) {
131
+ if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
132
+ }
133
+ if (handled) {
134
+ var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
135
+ if (button) {
136
+ button.click();
137
+ }
138
+ e.preventDefault();
139
+ }
140
+ });
141
+
142
+ /**
143
+ * checks that a UI element is not in another hidden element or tab content
144
+ */
145
+ function uiElementIsVisible(el) {
146
+ if (el === document) {
147
+ return true;
148
+ }
149
+
150
+ const computedStyle = getComputedStyle(el);
151
+ const isVisible = computedStyle.display !== 'none';
152
+
153
+ if (!isVisible) return false;
154
+ return uiElementIsVisible(el.parentNode);
155
+ }
156
+
157
+ function uiElementInSight(el) {
158
+ const clRect = el.getBoundingClientRect();
159
+ const windowHeight = window.innerHeight;
160
+ const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight;
161
+
162
+ return isOnScreen;
163
+ }
scripts/__pycache__/custom_code.cpython-310.pyc ADDED
Binary file (2.73 kB). View file
 
scripts/__pycache__/img2imgalt.cpython-310.pyc ADDED
Binary file (6.37 kB). View file
 
scripts/__pycache__/loopback.cpython-310.pyc ADDED
Binary file (3.53 kB). View file
 
scripts/__pycache__/outpainting_mk_2.cpython-310.pyc ADDED
Binary file (8.31 kB). View file
 
scripts/__pycache__/poor_mans_outpainting.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
scripts/__pycache__/postprocessing_codeformer.cpython-310.pyc ADDED
Binary file (1.61 kB). View file
 
scripts/__pycache__/postprocessing_gfpgan.cpython-310.pyc ADDED
Binary file (1.39 kB). View file
 
scripts/__pycache__/postprocessing_upscale.cpython-310.pyc ADDED
Binary file (6.44 kB). View file
 
scripts/__pycache__/prompt_matrix.cpython-310.pyc ADDED
Binary file (4.2 kB). View file