multimodalart HF staff commited on
Commit
90e5afe
1 Parent(s): c1a4c61

Delete scripts

Browse files
scripts/.DS_Store DELETED
Binary file (6.15 kB)
 
scripts/__init__.py DELETED
File without changes
scripts/demo/__init__.py DELETED
File without changes
scripts/demo/detect.py DELETED
@@ -1,156 +0,0 @@
1
- import argparse
2
-
3
- import cv2
4
- import numpy as np
5
-
6
- try:
7
- from imwatermark import WatermarkDecoder
8
- except ImportError as e:
9
- try:
10
- # Assume some of the other dependencies such as torch are not fulfilled
11
- # import file without loading unnecessary libraries.
12
- import importlib.util
13
- import sys
14
-
15
- spec = importlib.util.find_spec("imwatermark.maxDct")
16
- assert spec is not None
17
- maxDct = importlib.util.module_from_spec(spec)
18
- sys.modules["maxDct"] = maxDct
19
- spec.loader.exec_module(maxDct)
20
-
21
- class WatermarkDecoder(object):
22
- """A minimal version of
23
- https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py
24
- to only reconstruct bits using dwtDct"""
25
-
26
- def __init__(self, wm_type="bytes", length=0):
27
- assert wm_type == "bits", "Only bits defined in minimal import"
28
- self._wmType = wm_type
29
- self._wmLen = length
30
-
31
- def reconstruct(self, bits):
32
- if len(bits) != self._wmLen:
33
- raise RuntimeError("bits are not matched with watermark length")
34
-
35
- return bits
36
-
37
- def decode(self, cv2Image, method="dwtDct", **configs):
38
- (r, c, channels) = cv2Image.shape
39
- if r * c < 256 * 256:
40
- raise RuntimeError("image too small, should be larger than 256x256")
41
-
42
- bits = []
43
- assert method == "dwtDct"
44
- embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
45
- bits = embed.decode(cv2Image)
46
- return self.reconstruct(bits)
47
-
48
- except:
49
- raise e
50
-
51
-
52
- # A fixed 48-bit message that was choosen at random
53
- # WATERMARK_MESSAGE = 0xB3EC907BB19E
54
- WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
55
- # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
56
- WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
57
- MATCH_VALUES = [
58
- [27, "No watermark detected"],
59
- [33, "Partial watermark match. Cannot determine with certainty."],
60
- [
61
- 35,
62
- (
63
- "Likely watermarked. In our test 0.02% of real images were "
64
- 'falsely detected as "Likely watermarked"'
65
- ),
66
- ],
67
- [
68
- 49,
69
- (
70
- "Very likely watermarked. In our test no real images were "
71
- 'falsely detected as "Very likely watermarked"'
72
- ),
73
- ],
74
- ]
75
-
76
-
77
- class GetWatermarkMatch:
78
- def __init__(self, watermark):
79
- self.watermark = watermark
80
- self.num_bits = len(self.watermark)
81
- self.decoder = WatermarkDecoder("bits", self.num_bits)
82
-
83
- def __call__(self, x: np.ndarray) -> np.ndarray:
84
- """
85
- Detects the number of matching bits the predefined watermark with one
86
- or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
87
-
88
- Args:
89
- x: ([B], h w, c) in range [0, 255]
90
-
91
- Returns:
92
- number of matched bits ([B],)
93
- """
94
- squeeze = len(x.shape) == 3
95
- if squeeze:
96
- x = x[None, ...]
97
-
98
- bs = x.shape[0]
99
- detected = np.empty((bs, self.num_bits), dtype=bool)
100
- for k in range(bs):
101
- detected[k] = self.decoder.decode(x[k], "dwtDct")
102
- result = np.sum(detected == self.watermark, axis=-1)
103
- if squeeze:
104
- return result[0]
105
- else:
106
- return result
107
-
108
-
109
- get_watermark_match = GetWatermarkMatch(WATERMARK_BITS)
110
-
111
-
112
- if __name__ == "__main__":
113
- parser = argparse.ArgumentParser()
114
- parser.add_argument(
115
- "filename",
116
- nargs="+",
117
- type=str,
118
- help="Image files to check for watermarks",
119
- )
120
- opts = parser.parse_args()
121
-
122
- print(
123
- """
124
- This script tries to detect watermarked images. Please be aware of
125
- the following:
126
- - As the watermark is supposed to be invisible, there is the risk that
127
- watermarked images may not be detected.
128
- - To maximize the chance of detection make sure that the image has the same
129
- dimensions as when the watermark was applied (most likely 1024x1024
130
- or 512x512).
131
- - Specific image manipulation may drastically decrease the chance that
132
- watermarks can be detected.
133
- - There is also the chance that an image has the characteristics of the
134
- watermark by chance.
135
- - The watermark script is public, anybody may watermark any images, and
136
- could therefore claim it to be generated.
137
- - All numbers below are based on a test using 10,000 images without any
138
- modifications after applying the watermark.
139
- """
140
- )
141
-
142
- for fn in opts.filename:
143
- image = cv2.imread(fn)
144
- if image is None:
145
- print(f"Couldn't read {fn}. Skipping")
146
- continue
147
-
148
- num_bits = get_watermark_match(image)
149
- k = 0
150
- while num_bits > MATCH_VALUES[k][0]:
151
- k += 1
152
- print(
153
- f"{fn}: {MATCH_VALUES[k][1]}",
154
- f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n",
155
- sep="\n\t",
156
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/demo/discretization.py DELETED
@@ -1,59 +0,0 @@
1
- import torch
2
-
3
- from sgm.modules.diffusionmodules.discretizer import Discretization
4
-
5
-
6
- class Img2ImgDiscretizationWrapper:
7
- """
8
- wraps a discretizer, and prunes the sigmas
9
- params:
10
- strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
11
- """
12
-
13
- def __init__(self, discretization: Discretization, strength: float = 1.0):
14
- self.discretization = discretization
15
- self.strength = strength
16
- assert 0.0 <= self.strength <= 1.0
17
-
18
- def __call__(self, *args, **kwargs):
19
- # sigmas start large first, and decrease then
20
- sigmas = self.discretization(*args, **kwargs)
21
- print(f"sigmas after discretization, before pruning img2img: ", sigmas)
22
- sigmas = torch.flip(sigmas, (0,))
23
- sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
24
- print("prune index:", max(int(self.strength * len(sigmas)), 1))
25
- sigmas = torch.flip(sigmas, (0,))
26
- print(f"sigmas after pruning: ", sigmas)
27
- return sigmas
28
-
29
-
30
- class Txt2NoisyDiscretizationWrapper:
31
- """
32
- wraps a discretizer, and prunes the sigmas
33
- params:
34
- strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
35
- """
36
-
37
- def __init__(
38
- self, discretization: Discretization, strength: float = 0.0, original_steps=None
39
- ):
40
- self.discretization = discretization
41
- self.strength = strength
42
- self.original_steps = original_steps
43
- assert 0.0 <= self.strength <= 1.0
44
-
45
- def __call__(self, *args, **kwargs):
46
- # sigmas start large first, and decrease then
47
- sigmas = self.discretization(*args, **kwargs)
48
- print(f"sigmas after discretization, before pruning img2img: ", sigmas)
49
- sigmas = torch.flip(sigmas, (0,))
50
- if self.original_steps is None:
51
- steps = len(sigmas)
52
- else:
53
- steps = self.original_steps + 1
54
- prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
55
- sigmas = sigmas[prune_index:]
56
- print("prune index:", prune_index)
57
- sigmas = torch.flip(sigmas, (0,))
58
- print(f"sigmas after pruning: ", sigmas)
59
- return sigmas
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/demo/sampling.py DELETED
@@ -1,364 +0,0 @@
1
- from pytorch_lightning import seed_everything
2
-
3
- from scripts.demo.streamlit_helpers import *
4
-
5
- SAVE_PATH = "outputs/demo/txt2img/"
6
-
7
- SD_XL_BASE_RATIOS = {
8
- "0.5": (704, 1408),
9
- "0.52": (704, 1344),
10
- "0.57": (768, 1344),
11
- "0.6": (768, 1280),
12
- "0.68": (832, 1216),
13
- "0.72": (832, 1152),
14
- "0.78": (896, 1152),
15
- "0.82": (896, 1088),
16
- "0.88": (960, 1088),
17
- "0.94": (960, 1024),
18
- "1.0": (1024, 1024),
19
- "1.07": (1024, 960),
20
- "1.13": (1088, 960),
21
- "1.21": (1088, 896),
22
- "1.29": (1152, 896),
23
- "1.38": (1152, 832),
24
- "1.46": (1216, 832),
25
- "1.67": (1280, 768),
26
- "1.75": (1344, 768),
27
- "1.91": (1344, 704),
28
- "2.0": (1408, 704),
29
- "2.09": (1472, 704),
30
- "2.4": (1536, 640),
31
- "2.5": (1600, 640),
32
- "2.89": (1664, 576),
33
- "3.0": (1728, 576),
34
- }
35
-
36
- VERSION2SPECS = {
37
- "SDXL-base-1.0": {
38
- "H": 1024,
39
- "W": 1024,
40
- "C": 4,
41
- "f": 8,
42
- "is_legacy": False,
43
- "config": "configs/inference/sd_xl_base.yaml",
44
- "ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
45
- },
46
- "SDXL-base-0.9": {
47
- "H": 1024,
48
- "W": 1024,
49
- "C": 4,
50
- "f": 8,
51
- "is_legacy": False,
52
- "config": "configs/inference/sd_xl_base.yaml",
53
- "ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
54
- },
55
- "SD-2.1": {
56
- "H": 512,
57
- "W": 512,
58
- "C": 4,
59
- "f": 8,
60
- "is_legacy": True,
61
- "config": "configs/inference/sd_2_1.yaml",
62
- "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
63
- },
64
- "SD-2.1-768": {
65
- "H": 768,
66
- "W": 768,
67
- "C": 4,
68
- "f": 8,
69
- "is_legacy": True,
70
- "config": "configs/inference/sd_2_1_768.yaml",
71
- "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
72
- },
73
- "SDXL-refiner-0.9": {
74
- "H": 1024,
75
- "W": 1024,
76
- "C": 4,
77
- "f": 8,
78
- "is_legacy": True,
79
- "config": "configs/inference/sd_xl_refiner.yaml",
80
- "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
81
- },
82
- "SDXL-refiner-1.0": {
83
- "H": 1024,
84
- "W": 1024,
85
- "C": 4,
86
- "f": 8,
87
- "is_legacy": True,
88
- "config": "configs/inference/sd_xl_refiner.yaml",
89
- "ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
90
- },
91
- }
92
-
93
-
94
- def load_img(display=True, key=None, device="cuda"):
95
- image = get_interactive_image(key=key)
96
- if image is None:
97
- return None
98
- if display:
99
- st.image(image)
100
- w, h = image.size
101
- print(f"loaded input image of size ({w}, {h})")
102
- width, height = map(
103
- lambda x: x - x % 64, (w, h)
104
- ) # resize to integer multiple of 64
105
- image = image.resize((width, height))
106
- image = np.array(image.convert("RGB"))
107
- image = image[None].transpose(0, 3, 1, 2)
108
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
109
- return image.to(device)
110
-
111
-
112
- def run_txt2img(
113
- state,
114
- version,
115
- version_dict,
116
- is_legacy=False,
117
- return_latents=False,
118
- filter=None,
119
- stage2strength=None,
120
- ):
121
- if version.startswith("SDXL-base"):
122
- W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
123
- else:
124
- H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
125
- W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
126
- C = version_dict["C"]
127
- F = version_dict["f"]
128
-
129
- init_dict = {
130
- "orig_width": W,
131
- "orig_height": H,
132
- "target_width": W,
133
- "target_height": H,
134
- }
135
- value_dict = init_embedder_options(
136
- get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
137
- init_dict,
138
- prompt=prompt,
139
- negative_prompt=negative_prompt,
140
- )
141
- sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
142
- num_samples = num_rows * num_cols
143
-
144
- if st.button("Sample"):
145
- st.write(f"**Model I:** {version}")
146
- out = do_sample(
147
- state["model"],
148
- sampler,
149
- value_dict,
150
- num_samples,
151
- H,
152
- W,
153
- C,
154
- F,
155
- force_uc_zero_embeddings=["txt"] if not is_legacy else [],
156
- return_latents=return_latents,
157
- filter=filter,
158
- )
159
- return out
160
-
161
-
162
- def run_img2img(
163
- state,
164
- version_dict,
165
- is_legacy=False,
166
- return_latents=False,
167
- filter=None,
168
- stage2strength=None,
169
- ):
170
- img = load_img()
171
- if img is None:
172
- return None
173
- H, W = img.shape[2], img.shape[3]
174
-
175
- init_dict = {
176
- "orig_width": W,
177
- "orig_height": H,
178
- "target_width": W,
179
- "target_height": H,
180
- }
181
- value_dict = init_embedder_options(
182
- get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
183
- init_dict,
184
- prompt=prompt,
185
- negative_prompt=negative_prompt,
186
- )
187
- strength = st.number_input(
188
- "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
189
- )
190
- sampler, num_rows, num_cols = init_sampling(
191
- img2img_strength=strength,
192
- stage2strength=stage2strength,
193
- )
194
- num_samples = num_rows * num_cols
195
-
196
- if st.button("Sample"):
197
- out = do_img2img(
198
- repeat(img, "1 ... -> n ...", n=num_samples),
199
- state["model"],
200
- sampler,
201
- value_dict,
202
- num_samples,
203
- force_uc_zero_embeddings=["txt"] if not is_legacy else [],
204
- return_latents=return_latents,
205
- filter=filter,
206
- )
207
- return out
208
-
209
-
210
- def apply_refiner(
211
- input,
212
- state,
213
- sampler,
214
- num_samples,
215
- prompt,
216
- negative_prompt,
217
- filter=None,
218
- finish_denoising=False,
219
- ):
220
- init_dict = {
221
- "orig_width": input.shape[3] * 8,
222
- "orig_height": input.shape[2] * 8,
223
- "target_width": input.shape[3] * 8,
224
- "target_height": input.shape[2] * 8,
225
- }
226
-
227
- value_dict = init_dict
228
- value_dict["prompt"] = prompt
229
- value_dict["negative_prompt"] = negative_prompt
230
-
231
- value_dict["crop_coords_top"] = 0
232
- value_dict["crop_coords_left"] = 0
233
-
234
- value_dict["aesthetic_score"] = 6.0
235
- value_dict["negative_aesthetic_score"] = 2.5
236
-
237
- st.warning(f"refiner input shape: {input.shape}")
238
- samples = do_img2img(
239
- input,
240
- state["model"],
241
- sampler,
242
- value_dict,
243
- num_samples,
244
- skip_encode=True,
245
- filter=filter,
246
- add_noise=not finish_denoising,
247
- )
248
-
249
- return samples
250
-
251
-
252
- if __name__ == "__main__":
253
- st.title("Stable Diffusion")
254
- version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
255
- version_dict = VERSION2SPECS[version]
256
- if st.checkbox("Load Model"):
257
- mode = st.radio("Mode", ("txt2img", "img2img"), 0)
258
- else:
259
- mode = "skip"
260
- st.write("__________________________")
261
-
262
- set_lowvram_mode(st.checkbox("Low vram mode", True))
263
-
264
- if version.startswith("SDXL-base"):
265
- add_pipeline = st.checkbox("Load SDXL-refiner?", False)
266
- st.write("__________________________")
267
- else:
268
- add_pipeline = False
269
-
270
- seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
271
- seed_everything(seed)
272
-
273
- save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
274
-
275
- if mode != "skip":
276
- state = init_st(version_dict, load_filter=True)
277
- if state["msg"]:
278
- st.info(state["msg"])
279
- model = state["model"]
280
-
281
- is_legacy = version_dict["is_legacy"]
282
-
283
- prompt = st.text_input(
284
- "prompt",
285
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
286
- )
287
- if is_legacy:
288
- negative_prompt = st.text_input("negative prompt", "")
289
- else:
290
- negative_prompt = "" # which is unused
291
-
292
- stage2strength = None
293
- finish_denoising = False
294
-
295
- if add_pipeline:
296
- st.write("__________________________")
297
- version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
298
- st.warning(
299
- f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
300
- )
301
- st.write("**Refiner Options:**")
302
-
303
- version_dict2 = VERSION2SPECS[version2]
304
- state2 = init_st(version_dict2, load_filter=False)
305
- st.info(state2["msg"])
306
-
307
- stage2strength = st.number_input(
308
- "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
309
- )
310
-
311
- sampler2, *_ = init_sampling(
312
- key=2,
313
- img2img_strength=stage2strength,
314
- specify_num_samples=False,
315
- )
316
- st.write("__________________________")
317
- finish_denoising = st.checkbox("Finish denoising with refiner.", True)
318
- if not finish_denoising:
319
- stage2strength = None
320
-
321
- if mode == "txt2img":
322
- out = run_txt2img(
323
- state,
324
- version,
325
- version_dict,
326
- is_legacy=is_legacy,
327
- return_latents=add_pipeline,
328
- filter=state.get("filter"),
329
- stage2strength=stage2strength,
330
- )
331
- elif mode == "img2img":
332
- out = run_img2img(
333
- state,
334
- version_dict,
335
- is_legacy=is_legacy,
336
- return_latents=add_pipeline,
337
- filter=state.get("filter"),
338
- stage2strength=stage2strength,
339
- )
340
- elif mode == "skip":
341
- out = None
342
- else:
343
- raise ValueError(f"unknown mode {mode}")
344
- if isinstance(out, (tuple, list)):
345
- samples, samples_z = out
346
- else:
347
- samples = out
348
- samples_z = None
349
-
350
- if add_pipeline and samples_z is not None:
351
- st.write("**Running Refinement Stage**")
352
- samples = apply_refiner(
353
- samples_z,
354
- state2,
355
- sampler2,
356
- samples_z.shape[0],
357
- prompt=prompt,
358
- negative_prompt=negative_prompt if is_legacy else "",
359
- filter=state.get("filter"),
360
- finish_denoising=finish_denoising,
361
- )
362
-
363
- if save_locally and samples is not None:
364
- perform_save_locally(save_path, samples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/demo/streamlit_helpers.py DELETED
@@ -1,928 +0,0 @@
1
- import copy
2
- import math
3
- import os
4
- from glob import glob
5
- from typing import Dict, List, Optional, Tuple, Union
6
-
7
- import cv2
8
- import numpy as np
9
- import streamlit as st
10
- import torch
11
- import torch.nn as nn
12
- import torchvision.transforms as TT
13
- from einops import rearrange, repeat
14
- from imwatermark import WatermarkEncoder
15
- from omegaconf import ListConfig, OmegaConf
16
- from PIL import Image
17
- from safetensors.torch import load_file as load_safetensors
18
- from torch import autocast
19
- from torchvision import transforms
20
- from torchvision.utils import make_grid, save_image
21
-
22
- from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
23
- Txt2NoisyDiscretizationWrapper)
24
- from scripts.util.detection.nsfw_and_watermark_dectection import \
25
- DeepFloydDataFiltering
26
- from sgm.inference.helpers import embed_watermark
27
- from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
28
- VanillaCFG)
29
- from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
30
- DPMPP2SAncestralSampler,
31
- EulerAncestralSampler,
32
- EulerEDMSampler,
33
- HeunEDMSampler,
34
- LinearMultistepSampler)
35
- from sgm.util import append_dims, default, instantiate_from_config
36
-
37
-
38
- @st.cache_resource()
39
- def init_st(version_dict, load_ckpt=True, load_filter=True):
40
- state = dict()
41
- if not "model" in state:
42
- config = version_dict["config"]
43
- ckpt = version_dict["ckpt"]
44
-
45
- config = OmegaConf.load(config)
46
- model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
47
-
48
- state["msg"] = msg
49
- state["model"] = model
50
- state["ckpt"] = ckpt if load_ckpt else None
51
- state["config"] = config
52
- if load_filter:
53
- state["filter"] = DeepFloydDataFiltering(verbose=False)
54
- return state
55
-
56
-
57
- def load_model(model):
58
- model.cuda()
59
-
60
-
61
- lowvram_mode = False
62
-
63
-
64
- def set_lowvram_mode(mode):
65
- global lowvram_mode
66
- lowvram_mode = mode
67
-
68
-
69
- def initial_model_load(model):
70
- global lowvram_mode
71
- if lowvram_mode:
72
- model.model.half()
73
- else:
74
- model.cuda()
75
- return model
76
-
77
-
78
- def unload_model(model):
79
- global lowvram_mode
80
- if lowvram_mode:
81
- model.cpu()
82
- torch.cuda.empty_cache()
83
-
84
-
85
- def load_model_from_config(config, ckpt=None, verbose=True):
86
- model = instantiate_from_config(config.model)
87
-
88
- if ckpt is not None:
89
- print(f"Loading model from {ckpt}")
90
- if ckpt.endswith("ckpt"):
91
- pl_sd = torch.load(ckpt, map_location="cpu")
92
- if "global_step" in pl_sd:
93
- global_step = pl_sd["global_step"]
94
- st.info(f"loaded ckpt from global step {global_step}")
95
- print(f"Global Step: {pl_sd['global_step']}")
96
- sd = pl_sd["state_dict"]
97
- elif ckpt.endswith("safetensors"):
98
- sd = load_safetensors(ckpt)
99
- else:
100
- raise NotImplementedError
101
-
102
- msg = None
103
-
104
- m, u = model.load_state_dict(sd, strict=False)
105
-
106
- if len(m) > 0 and verbose:
107
- print("missing keys:")
108
- print(m)
109
- if len(u) > 0 and verbose:
110
- print("unexpected keys:")
111
- print(u)
112
- else:
113
- msg = None
114
-
115
- model = initial_model_load(model)
116
- model.eval()
117
- return model, msg
118
-
119
-
120
- def get_unique_embedder_keys_from_conditioner(conditioner):
121
- return list(set([x.input_key for x in conditioner.embedders]))
122
-
123
-
124
- def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
125
- # Hardcoded demo settings; might undergo some changes in the future
126
-
127
- value_dict = {}
128
- for key in keys:
129
- if key == "txt":
130
- if prompt is None:
131
- prompt = "A professional photograph of an astronaut riding a pig"
132
- if negative_prompt is None:
133
- negative_prompt = ""
134
-
135
- prompt = st.text_input("Prompt", prompt)
136
- negative_prompt = st.text_input("Negative prompt", negative_prompt)
137
-
138
- value_dict["prompt"] = prompt
139
- value_dict["negative_prompt"] = negative_prompt
140
-
141
- if key == "original_size_as_tuple":
142
- orig_width = st.number_input(
143
- "orig_width",
144
- value=init_dict["orig_width"],
145
- min_value=16,
146
- )
147
- orig_height = st.number_input(
148
- "orig_height",
149
- value=init_dict["orig_height"],
150
- min_value=16,
151
- )
152
-
153
- value_dict["orig_width"] = orig_width
154
- value_dict["orig_height"] = orig_height
155
-
156
- if key == "crop_coords_top_left":
157
- crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
158
- crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
159
-
160
- value_dict["crop_coords_top"] = crop_coord_top
161
- value_dict["crop_coords_left"] = crop_coord_left
162
-
163
- if key == "aesthetic_score":
164
- value_dict["aesthetic_score"] = 6.0
165
- value_dict["negative_aesthetic_score"] = 2.5
166
-
167
- if key == "target_size_as_tuple":
168
- value_dict["target_width"] = init_dict["target_width"]
169
- value_dict["target_height"] = init_dict["target_height"]
170
-
171
- if key in ["fps_id", "fps"]:
172
- fps = st.number_input("fps", value=6, min_value=1)
173
-
174
- value_dict["fps"] = fps
175
- value_dict["fps_id"] = fps - 1
176
-
177
- if key == "motion_bucket_id":
178
- mb_id = st.number_input("motion bucket id", 0, 511, value=127)
179
- value_dict["motion_bucket_id"] = mb_id
180
-
181
- if key == "pool_image":
182
- st.text("Image for pool conditioning")
183
- image = load_img(
184
- key="pool_image_input",
185
- size=224,
186
- center_crop=True,
187
- )
188
- if image is None:
189
- st.info("Need an image here")
190
- image = torch.zeros(1, 3, 224, 224)
191
- value_dict["pool_image"] = image
192
-
193
- return value_dict
194
-
195
-
196
- def perform_save_locally(save_path, samples):
197
- os.makedirs(os.path.join(save_path), exist_ok=True)
198
- base_count = len(os.listdir(os.path.join(save_path)))
199
- samples = embed_watermark(samples)
200
- for sample in samples:
201
- sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
202
- Image.fromarray(sample.astype(np.uint8)).save(
203
- os.path.join(save_path, f"{base_count:09}.png")
204
- )
205
- base_count += 1
206
-
207
-
208
- def init_save_locally(_dir, init_value: bool = False):
209
- save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
210
- if save_locally:
211
- save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
212
- else:
213
- save_path = None
214
-
215
- return save_locally, save_path
216
-
217
-
218
- def get_guider(options, key):
219
- guider = st.sidebar.selectbox(
220
- f"Discretization #{key}",
221
- [
222
- "VanillaCFG",
223
- "IdentityGuider",
224
- "LinearPredictionGuider",
225
- ],
226
- options.get("guider", 0),
227
- )
228
-
229
- additional_guider_kwargs = options.pop("additional_guider_kwargs", {})
230
-
231
- if guider == "IdentityGuider":
232
- guider_config = {
233
- "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
234
- }
235
- elif guider == "VanillaCFG":
236
- scale_schedule = st.sidebar.selectbox(
237
- f"Scale schedule #{key}",
238
- ["Identity", "Oscillating"],
239
- )
240
-
241
- if scale_schedule == "Identity":
242
- scale = st.number_input(
243
- f"cfg-scale #{key}",
244
- value=options.get("cfg", 5.0),
245
- min_value=0.0,
246
- )
247
-
248
- scale_schedule_config = {
249
- "target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule",
250
- "params": {"scale": scale},
251
- }
252
-
253
- elif scale_schedule == "Oscillating":
254
- small_scale = st.number_input(
255
- f"small cfg-scale #{key}",
256
- value=4.0,
257
- min_value=0.0,
258
- )
259
-
260
- large_scale = st.number_input(
261
- f"large cfg-scale #{key}",
262
- value=16.0,
263
- min_value=0.0,
264
- )
265
-
266
- sigma_cutoff = st.number_input(
267
- f"sigma cutoff #{key}",
268
- value=1.0,
269
- min_value=0.0,
270
- )
271
-
272
- scale_schedule_config = {
273
- "target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule",
274
- "params": {
275
- "small_scale": small_scale,
276
- "large_scale": large_scale,
277
- "sigma_cutoff": sigma_cutoff,
278
- },
279
- }
280
- else:
281
- raise NotImplementedError
282
-
283
- guider_config = {
284
- "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
285
- "params": {
286
- "scale_schedule_config": scale_schedule_config,
287
- **additional_guider_kwargs,
288
- },
289
- }
290
- elif guider == "LinearPredictionGuider":
291
- max_scale = st.number_input(
292
- f"max-cfg-scale #{key}",
293
- value=options.get("cfg", 1.5),
294
- min_value=1.0,
295
- )
296
- min_scale = st.number_input(
297
- f"min guidance scale",
298
- value=options.get("min_cfg", 1.0),
299
- min_value=1.0,
300
- max_value=10.0,
301
- )
302
-
303
- guider_config = {
304
- "target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
305
- "params": {
306
- "max_scale": max_scale,
307
- "min_scale": min_scale,
308
- "num_frames": options["num_frames"],
309
- **additional_guider_kwargs,
310
- },
311
- }
312
- else:
313
- raise NotImplementedError
314
- return guider_config
315
-
316
-
317
- def init_sampling(
318
- key=1,
319
- img2img_strength: Optional[float] = None,
320
- specify_num_samples: bool = True,
321
- stage2strength: Optional[float] = None,
322
- options: Optional[Dict[str, int]] = None,
323
- ):
324
- options = {} if options is None else options
325
-
326
- num_rows, num_cols = 1, 1
327
- if specify_num_samples:
328
- num_cols = st.number_input(
329
- f"num cols #{key}", value=num_cols, min_value=1, max_value=10
330
- )
331
-
332
- steps = st.sidebar.number_input(
333
- f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
334
- )
335
- sampler = st.sidebar.selectbox(
336
- f"Sampler #{key}",
337
- [
338
- "EulerEDMSampler",
339
- "HeunEDMSampler",
340
- "EulerAncestralSampler",
341
- "DPMPP2SAncestralSampler",
342
- "DPMPP2MSampler",
343
- "LinearMultistepSampler",
344
- ],
345
- options.get("sampler", 0),
346
- )
347
- discretization = st.sidebar.selectbox(
348
- f"Discretization #{key}",
349
- [
350
- "LegacyDDPMDiscretization",
351
- "EDMDiscretization",
352
- ],
353
- options.get("discretization", 0),
354
- )
355
-
356
- discretization_config = get_discretization(discretization, options=options, key=key)
357
-
358
- guider_config = get_guider(options=options, key=key)
359
-
360
- sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
361
- if img2img_strength is not None:
362
- st.warning(
363
- f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
364
- )
365
- sampler.discretization = Img2ImgDiscretizationWrapper(
366
- sampler.discretization, strength=img2img_strength
367
- )
368
- if stage2strength is not None:
369
- sampler.discretization = Txt2NoisyDiscretizationWrapper(
370
- sampler.discretization, strength=stage2strength, original_steps=steps
371
- )
372
- return sampler, num_rows, num_cols
373
-
374
-
375
- def get_discretization(discretization, options, key=1):
376
- if discretization == "LegacyDDPMDiscretization":
377
- discretization_config = {
378
- "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
379
- }
380
- elif discretization == "EDMDiscretization":
381
- sigma_min = st.number_input(
382
- f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
383
- ) # 0.0292
384
- sigma_max = st.number_input(
385
- f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
386
- ) # 14.6146
387
- rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
388
- discretization_config = {
389
- "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
390
- "params": {
391
- "sigma_min": sigma_min,
392
- "sigma_max": sigma_max,
393
- "rho": rho,
394
- },
395
- }
396
-
397
- return discretization_config
398
-
399
-
400
- def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
401
- if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
402
- s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
403
- s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
404
- s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
405
- s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
406
-
407
- if sampler_name == "EulerEDMSampler":
408
- sampler = EulerEDMSampler(
409
- num_steps=steps,
410
- discretization_config=discretization_config,
411
- guider_config=guider_config,
412
- s_churn=s_churn,
413
- s_tmin=s_tmin,
414
- s_tmax=s_tmax,
415
- s_noise=s_noise,
416
- verbose=True,
417
- )
418
- elif sampler_name == "HeunEDMSampler":
419
- sampler = HeunEDMSampler(
420
- num_steps=steps,
421
- discretization_config=discretization_config,
422
- guider_config=guider_config,
423
- s_churn=s_churn,
424
- s_tmin=s_tmin,
425
- s_tmax=s_tmax,
426
- s_noise=s_noise,
427
- verbose=True,
428
- )
429
- elif (
430
- sampler_name == "EulerAncestralSampler"
431
- or sampler_name == "DPMPP2SAncestralSampler"
432
- ):
433
- s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
434
- eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
435
-
436
- if sampler_name == "EulerAncestralSampler":
437
- sampler = EulerAncestralSampler(
438
- num_steps=steps,
439
- discretization_config=discretization_config,
440
- guider_config=guider_config,
441
- eta=eta,
442
- s_noise=s_noise,
443
- verbose=True,
444
- )
445
- elif sampler_name == "DPMPP2SAncestralSampler":
446
- sampler = DPMPP2SAncestralSampler(
447
- num_steps=steps,
448
- discretization_config=discretization_config,
449
- guider_config=guider_config,
450
- eta=eta,
451
- s_noise=s_noise,
452
- verbose=True,
453
- )
454
- elif sampler_name == "DPMPP2MSampler":
455
- sampler = DPMPP2MSampler(
456
- num_steps=steps,
457
- discretization_config=discretization_config,
458
- guider_config=guider_config,
459
- verbose=True,
460
- )
461
- elif sampler_name == "LinearMultistepSampler":
462
- order = st.sidebar.number_input("order", value=4, min_value=1)
463
- sampler = LinearMultistepSampler(
464
- num_steps=steps,
465
- discretization_config=discretization_config,
466
- guider_config=guider_config,
467
- order=order,
468
- verbose=True,
469
- )
470
- else:
471
- raise ValueError(f"unknown sampler {sampler_name}!")
472
-
473
- return sampler
474
-
475
-
476
- def get_interactive_image() -> Image.Image:
477
- image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
478
- if image is not None:
479
- image = Image.open(image)
480
- if not image.mode == "RGB":
481
- image = image.convert("RGB")
482
- return image
483
-
484
-
485
- def load_img(
486
- display: bool = True,
487
- size: Union[None, int, Tuple[int, int]] = None,
488
- center_crop: bool = False,
489
- ):
490
- image = get_interactive_image()
491
- if image is None:
492
- return None
493
- if display:
494
- st.image(image)
495
- w, h = image.size
496
- print(f"loaded input image of size ({w}, {h})")
497
-
498
- transform = []
499
- if size is not None:
500
- transform.append(transforms.Resize(size))
501
- if center_crop:
502
- transform.append(transforms.CenterCrop(size))
503
- transform.append(transforms.ToTensor())
504
- transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))
505
-
506
- transform = transforms.Compose(transform)
507
- img = transform(image)[None, ...]
508
- st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
509
- return img
510
-
511
-
512
- def get_init_img(batch_size=1, key=None):
513
- init_image = load_img(key=key).cuda()
514
- init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
515
- return init_image
516
-
517
-
518
- def do_sample(
519
- model,
520
- sampler,
521
- value_dict,
522
- num_samples,
523
- H,
524
- W,
525
- C,
526
- F,
527
- force_uc_zero_embeddings: Optional[List] = None,
528
- force_cond_zero_embeddings: Optional[List] = None,
529
- batch2model_input: List = None,
530
- return_latents=False,
531
- filter=None,
532
- T=None,
533
- additional_batch_uc_fields=None,
534
- decoding_t=None,
535
- ):
536
- force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
537
- batch2model_input = default(batch2model_input, [])
538
- additional_batch_uc_fields = default(additional_batch_uc_fields, [])
539
-
540
- st.text("Sampling")
541
-
542
- outputs = st.empty()
543
- precision_scope = autocast
544
- with torch.no_grad():
545
- with precision_scope("cuda"):
546
- with model.ema_scope():
547
- if T is not None:
548
- num_samples = [num_samples, T]
549
- else:
550
- num_samples = [num_samples]
551
-
552
- load_model(model.conditioner)
553
- batch, batch_uc = get_batch(
554
- get_unique_embedder_keys_from_conditioner(model.conditioner),
555
- value_dict,
556
- num_samples,
557
- T=T,
558
- additional_batch_uc_fields=additional_batch_uc_fields,
559
- )
560
-
561
- c, uc = model.conditioner.get_unconditional_conditioning(
562
- batch,
563
- batch_uc=batch_uc,
564
- force_uc_zero_embeddings=force_uc_zero_embeddings,
565
- force_cond_zero_embeddings=force_cond_zero_embeddings,
566
- )
567
- unload_model(model.conditioner)
568
-
569
- for k in c:
570
- if not k == "crossattn":
571
- c[k], uc[k] = map(
572
- lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
573
- )
574
- if k in ["crossattn", "concat"] and T is not None:
575
- uc[k] = repeat(uc[k], "b ... -> b t ...", t=T)
576
- uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T)
577
- c[k] = repeat(c[k], "b ... -> b t ...", t=T)
578
- c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T)
579
-
580
- additional_model_inputs = {}
581
- for k in batch2model_input:
582
- if k == "image_only_indicator":
583
- assert T is not None
584
-
585
- if isinstance(
586
- sampler.guider, (VanillaCFG, LinearPredictionGuider)
587
- ):
588
- additional_model_inputs[k] = torch.zeros(
589
- num_samples[0] * 2, num_samples[1]
590
- ).to("cuda")
591
- else:
592
- additional_model_inputs[k] = torch.zeros(num_samples).to(
593
- "cuda"
594
- )
595
- else:
596
- additional_model_inputs[k] = batch[k]
597
-
598
- shape = (math.prod(num_samples), C, H // F, W // F)
599
- randn = torch.randn(shape).to("cuda")
600
-
601
- def denoiser(input, sigma, c):
602
- return model.denoiser(
603
- model.model, input, sigma, c, **additional_model_inputs
604
- )
605
-
606
- load_model(model.denoiser)
607
- load_model(model.model)
608
- samples_z = sampler(denoiser, randn, cond=c, uc=uc)
609
- unload_model(model.model)
610
- unload_model(model.denoiser)
611
-
612
- load_model(model.first_stage_model)
613
- model.en_and_decode_n_samples_a_time = (
614
- decoding_t # Decode n frames at a time
615
- )
616
- samples_x = model.decode_first_stage(samples_z)
617
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
618
- unload_model(model.first_stage_model)
619
-
620
- if filter is not None:
621
- samples = filter(samples)
622
-
623
- if T is None:
624
- grid = torch.stack([samples])
625
- grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
626
- outputs.image(grid.cpu().numpy())
627
- else:
628
- as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T)
629
- for i, vid in enumerate(as_vids):
630
- grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c")
631
- st.image(
632
- grid.cpu().numpy(),
633
- f"Sample #{i} as image",
634
- )
635
-
636
- if return_latents:
637
- return samples, samples_z
638
- return samples
639
-
640
-
641
- def get_batch(
642
- keys,
643
- value_dict: dict,
644
- N: Union[List, ListConfig],
645
- device: str = "cuda",
646
- T: int = None,
647
- additional_batch_uc_fields: List[str] = [],
648
- ):
649
- # Hardcoded demo setups; might undergo some changes in the future
650
-
651
- batch = {}
652
- batch_uc = {}
653
-
654
- for key in keys:
655
- if key == "txt":
656
- batch["txt"] = [value_dict["prompt"]] * math.prod(N)
657
-
658
- batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
659
-
660
- elif key == "original_size_as_tuple":
661
- batch["original_size_as_tuple"] = (
662
- torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
663
- .to(device)
664
- .repeat(math.prod(N), 1)
665
- )
666
- elif key == "crop_coords_top_left":
667
- batch["crop_coords_top_left"] = (
668
- torch.tensor(
669
- [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
670
- )
671
- .to(device)
672
- .repeat(math.prod(N), 1)
673
- )
674
- elif key == "aesthetic_score":
675
- batch["aesthetic_score"] = (
676
- torch.tensor([value_dict["aesthetic_score"]])
677
- .to(device)
678
- .repeat(math.prod(N), 1)
679
- )
680
- batch_uc["aesthetic_score"] = (
681
- torch.tensor([value_dict["negative_aesthetic_score"]])
682
- .to(device)
683
- .repeat(math.prod(N), 1)
684
- )
685
-
686
- elif key == "target_size_as_tuple":
687
- batch["target_size_as_tuple"] = (
688
- torch.tensor([value_dict["target_height"], value_dict["target_width"]])
689
- .to(device)
690
- .repeat(math.prod(N), 1)
691
- )
692
- elif key == "fps":
693
- batch[key] = (
694
- torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
695
- )
696
- elif key == "fps_id":
697
- batch[key] = (
698
- torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
699
- )
700
- elif key == "motion_bucket_id":
701
- batch[key] = (
702
- torch.tensor([value_dict["motion_bucket_id"]])
703
- .to(device)
704
- .repeat(math.prod(N))
705
- )
706
- elif key == "pool_image":
707
- batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
708
- device, dtype=torch.half
709
- )
710
- elif key == "cond_aug":
711
- batch[key] = repeat(
712
- torch.tensor([value_dict["cond_aug"]]).to("cuda"),
713
- "1 -> b",
714
- b=math.prod(N),
715
- )
716
- elif key == "cond_frames":
717
- batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
718
- elif key == "cond_frames_without_noise":
719
- batch[key] = repeat(
720
- value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
721
- )
722
- else:
723
- batch[key] = value_dict[key]
724
-
725
- if T is not None:
726
- batch["num_video_frames"] = T
727
-
728
- for key in batch.keys():
729
- if key not in batch_uc and isinstance(batch[key], torch.Tensor):
730
- batch_uc[key] = torch.clone(batch[key])
731
- elif key in additional_batch_uc_fields and key not in batch_uc:
732
- batch_uc[key] = copy.copy(batch[key])
733
- return batch, batch_uc
734
-
735
-
736
- @torch.no_grad()
737
- def do_img2img(
738
- img,
739
- model,
740
- sampler,
741
- value_dict,
742
- num_samples,
743
- force_uc_zero_embeddings: Optional[List] = None,
744
- force_cond_zero_embeddings: Optional[List] = None,
745
- additional_kwargs={},
746
- offset_noise_level: int = 0.0,
747
- return_latents=False,
748
- skip_encode=False,
749
- filter=None,
750
- add_noise=True,
751
- ):
752
- st.text("Sampling")
753
-
754
- outputs = st.empty()
755
- precision_scope = autocast
756
- with torch.no_grad():
757
- with precision_scope("cuda"):
758
- with model.ema_scope():
759
- load_model(model.conditioner)
760
- batch, batch_uc = get_batch(
761
- get_unique_embedder_keys_from_conditioner(model.conditioner),
762
- value_dict,
763
- [num_samples],
764
- )
765
- c, uc = model.conditioner.get_unconditional_conditioning(
766
- batch,
767
- batch_uc=batch_uc,
768
- force_uc_zero_embeddings=force_uc_zero_embeddings,
769
- force_cond_zero_embeddings=force_cond_zero_embeddings,
770
- )
771
- unload_model(model.conditioner)
772
- for k in c:
773
- c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
774
-
775
- for k in additional_kwargs:
776
- c[k] = uc[k] = additional_kwargs[k]
777
- if skip_encode:
778
- z = img
779
- else:
780
- load_model(model.first_stage_model)
781
- z = model.encode_first_stage(img)
782
- unload_model(model.first_stage_model)
783
-
784
- noise = torch.randn_like(z)
785
-
786
- sigmas = sampler.discretization(sampler.num_steps).cuda()
787
- sigma = sigmas[0]
788
-
789
- st.info(f"all sigmas: {sigmas}")
790
- st.info(f"noising sigma: {sigma}")
791
- if offset_noise_level > 0.0:
792
- noise = noise + offset_noise_level * append_dims(
793
- torch.randn(z.shape[0], device=z.device), z.ndim
794
- )
795
- if add_noise:
796
- noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
797
- noised_z = noised_z / torch.sqrt(
798
- 1.0 + sigmas[0] ** 2.0
799
- ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
800
- else:
801
- noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
802
-
803
- def denoiser(x, sigma, c):
804
- return model.denoiser(model.model, x, sigma, c)
805
-
806
- load_model(model.denoiser)
807
- load_model(model.model)
808
- samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
809
- unload_model(model.model)
810
- unload_model(model.denoiser)
811
-
812
- load_model(model.first_stage_model)
813
- samples_x = model.decode_first_stage(samples_z)
814
- unload_model(model.first_stage_model)
815
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
816
-
817
- if filter is not None:
818
- samples = filter(samples)
819
-
820
- grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
821
- outputs.image(grid.cpu().numpy())
822
- if return_latents:
823
- return samples, samples_z
824
- return samples
825
-
826
-
827
- def get_resizing_factor(
828
- desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
829
- ) -> float:
830
- r_bound = desired_shape[1] / desired_shape[0]
831
- aspect_r = current_shape[1] / current_shape[0]
832
- if r_bound >= 1.0:
833
- if aspect_r >= r_bound:
834
- factor = min(desired_shape) / min(current_shape)
835
- else:
836
- if aspect_r < 1.0:
837
- factor = max(desired_shape) / min(current_shape)
838
- else:
839
- factor = max(desired_shape) / max(current_shape)
840
- else:
841
- if aspect_r <= r_bound:
842
- factor = min(desired_shape) / min(current_shape)
843
- else:
844
- if aspect_r > 1:
845
- factor = max(desired_shape) / min(current_shape)
846
- else:
847
- factor = max(desired_shape) / max(current_shape)
848
-
849
- return factor
850
-
851
-
852
- def get_interactive_image(key=None) -> Image.Image:
853
- image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
854
- if image is not None:
855
- image = Image.open(image)
856
- if not image.mode == "RGB":
857
- image = image.convert("RGB")
858
- return image
859
-
860
-
861
- def load_img_for_prediction(
862
- W: int, H: int, display=True, key=None, device="cuda"
863
- ) -> torch.Tensor:
864
- image = get_interactive_image(key=key)
865
- if image is None:
866
- return None
867
- if display:
868
- st.image(image)
869
- w, h = image.size
870
-
871
- image = np.array(image).transpose(2, 0, 1)
872
- image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
873
- image = image.unsqueeze(0)
874
-
875
- rfs = get_resizing_factor((H, W), (h, w))
876
- resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
877
- top = (resize_size[0] - H) // 2
878
- left = (resize_size[1] - W) // 2
879
-
880
- image = torch.nn.functional.interpolate(
881
- image, resize_size, mode="area", antialias=False
882
- )
883
- image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
884
-
885
- if display:
886
- numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))
887
- pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))
888
- st.image(pil_image)
889
- return image.to(device) * 2.0 - 1.0
890
-
891
-
892
- def save_video_as_grid_and_mp4(
893
- video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
894
- ):
895
- os.makedirs(save_path, exist_ok=True)
896
- base_count = len(glob(os.path.join(save_path, "*.mp4")))
897
-
898
- video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
899
- video_batch = embed_watermark(video_batch)
900
- for vid in video_batch:
901
- save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
902
-
903
- video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
904
-
905
- writer = cv2.VideoWriter(
906
- video_path,
907
- cv2.VideoWriter_fourcc(*"MP4V"),
908
- fps,
909
- (vid.shape[-1], vid.shape[-2]),
910
- )
911
-
912
- vid = (
913
- (rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
914
- )
915
- for frame in vid:
916
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
917
- writer.write(frame)
918
-
919
- writer.release()
920
-
921
- video_path_h264 = video_path[:-4] + "_h264.mp4"
922
- os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
923
-
924
- with open(video_path_h264, "rb") as f:
925
- video_bytes = f.read()
926
- st.video(video_bytes)
927
-
928
- base_count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/demo/video_sampling.py DELETED
@@ -1,200 +0,0 @@
1
- import os
2
-
3
- from pytorch_lightning import seed_everything
4
-
5
- from scripts.demo.streamlit_helpers import *
6
-
7
- SAVE_PATH = "outputs/demo/vid/"
8
-
9
- VERSION2SPECS = {
10
- "svd": {
11
- "T": 14,
12
- "H": 576,
13
- "W": 1024,
14
- "C": 4,
15
- "f": 8,
16
- "config": "configs/inference/svd.yaml",
17
- "ckpt": "checkpoints/svd.safetensors",
18
- "options": {
19
- "discretization": 1,
20
- "cfg": 2.5,
21
- "sigma_min": 0.002,
22
- "sigma_max": 700.0,
23
- "rho": 7.0,
24
- "guider": 2,
25
- "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
26
- "num_steps": 25,
27
- },
28
- },
29
- "svd_image_decoder": {
30
- "T": 14,
31
- "H": 576,
32
- "W": 1024,
33
- "C": 4,
34
- "f": 8,
35
- "config": "configs/inference/svd_image_decoder.yaml",
36
- "ckpt": "checkpoints/svd_image_decoder.safetensors",
37
- "options": {
38
- "discretization": 1,
39
- "cfg": 2.5,
40
- "sigma_min": 0.002,
41
- "sigma_max": 700.0,
42
- "rho": 7.0,
43
- "guider": 2,
44
- "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
45
- "num_steps": 25,
46
- },
47
- },
48
- "svd_xt": {
49
- "T": 25,
50
- "H": 576,
51
- "W": 1024,
52
- "C": 4,
53
- "f": 8,
54
- "config": "configs/inference/svd.yaml",
55
- "ckpt": "checkpoints/svd_xt.safetensors",
56
- "options": {
57
- "discretization": 1,
58
- "cfg": 3.0,
59
- "min_cfg": 1.5,
60
- "sigma_min": 0.002,
61
- "sigma_max": 700.0,
62
- "rho": 7.0,
63
- "guider": 2,
64
- "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
65
- "num_steps": 30,
66
- "decoding_t": 14,
67
- },
68
- },
69
- "svd_xt_image_decoder": {
70
- "T": 25,
71
- "H": 576,
72
- "W": 1024,
73
- "C": 4,
74
- "f": 8,
75
- "config": "configs/inference/svd_image_decoder.yaml",
76
- "ckpt": "checkpoints/svd_xt_image_decoder.safetensors",
77
- "options": {
78
- "discretization": 1,
79
- "cfg": 3.0,
80
- "min_cfg": 1.5,
81
- "sigma_min": 0.002,
82
- "sigma_max": 700.0,
83
- "rho": 7.0,
84
- "guider": 2,
85
- "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
86
- "num_steps": 30,
87
- "decoding_t": 14,
88
- },
89
- },
90
- }
91
-
92
-
93
- if __name__ == "__main__":
94
- st.title("Stable Video Diffusion")
95
- version = st.selectbox(
96
- "Model Version",
97
- [k for k in VERSION2SPECS.keys()],
98
- 0,
99
- )
100
- version_dict = VERSION2SPECS[version]
101
- if st.checkbox("Load Model"):
102
- mode = "img2vid"
103
- else:
104
- mode = "skip"
105
-
106
- H = st.sidebar.number_input(
107
- "H", value=version_dict["H"], min_value=64, max_value=2048
108
- )
109
- W = st.sidebar.number_input(
110
- "W", value=version_dict["W"], min_value=64, max_value=2048
111
- )
112
- T = st.sidebar.number_input(
113
- "T", value=version_dict["T"], min_value=0, max_value=128
114
- )
115
- C = version_dict["C"]
116
- F = version_dict["f"]
117
- options = version_dict["options"]
118
-
119
- if mode != "skip":
120
- state = init_st(version_dict, load_filter=True)
121
- if state["msg"]:
122
- st.info(state["msg"])
123
- model = state["model"]
124
-
125
- ukeys = set(
126
- get_unique_embedder_keys_from_conditioner(state["model"].conditioner)
127
- )
128
-
129
- value_dict = init_embedder_options(
130
- ukeys,
131
- {},
132
- )
133
-
134
- value_dict["image_only_indicator"] = 0
135
-
136
- if mode == "img2vid":
137
- img = load_img_for_prediction(W, H)
138
- cond_aug = st.number_input(
139
- "Conditioning augmentation:", value=0.02, min_value=0.0
140
- )
141
- value_dict["cond_frames_without_noise"] = img
142
- value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
143
- value_dict["cond_aug"] = cond_aug
144
-
145
- seed = st.sidebar.number_input(
146
- "seed", value=23, min_value=0, max_value=int(1e9)
147
- )
148
- seed_everything(seed)
149
-
150
- save_locally, save_path = init_save_locally(
151
- os.path.join(SAVE_PATH, version), init_value=True
152
- )
153
-
154
- options["num_frames"] = T
155
-
156
- sampler, num_rows, num_cols = init_sampling(options=options)
157
- num_samples = num_rows * num_cols
158
-
159
- decoding_t = st.number_input(
160
- "Decode t frames at a time (set small if you are low on VRAM)",
161
- value=options.get("decoding_t", T),
162
- min_value=1,
163
- max_value=int(1e9),
164
- )
165
-
166
- if st.checkbox("Overwrite fps in mp4 generator", False):
167
- saving_fps = st.number_input(
168
- f"saving video at fps:", value=value_dict["fps"], min_value=1
169
- )
170
- else:
171
- saving_fps = value_dict["fps"]
172
-
173
- if st.button("Sample"):
174
- out = do_sample(
175
- model,
176
- sampler,
177
- value_dict,
178
- num_samples,
179
- H,
180
- W,
181
- C,
182
- F,
183
- T=T,
184
- batch2model_input=["num_video_frames", "image_only_indicator"],
185
- force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
186
- force_cond_zero_embeddings=options.get(
187
- "force_cond_zero_embeddings", None
188
- ),
189
- return_latents=False,
190
- decoding_t=decoding_t,
191
- )
192
-
193
- if isinstance(out, (tuple, list)):
194
- samples, samples_z = out
195
- else:
196
- samples = out
197
- samples_z = None
198
-
199
- if save_locally:
200
- save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/sampling/configs/svd.yaml DELETED
@@ -1,146 +0,0 @@
1
- model:
2
- target: sgm.models.diffusion.DiffusionEngine
3
- params:
4
- scale_factor: 0.18215
5
- disable_first_stage_autocast: True
6
- ckpt_path: checkpoints/svd.safetensors
7
-
8
- denoiser_config:
9
- target: sgm.modules.diffusionmodules.denoiser.Denoiser
10
- params:
11
- scaling_config:
12
- target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13
-
14
- network_config:
15
- target: sgm.modules.diffusionmodules.video_model.VideoUNet
16
- params:
17
- adm_in_channels: 768
18
- num_classes: sequential
19
- use_checkpoint: True
20
- in_channels: 8
21
- out_channels: 4
22
- model_channels: 320
23
- attention_resolutions: [4, 2, 1]
24
- num_res_blocks: 2
25
- channel_mult: [1, 2, 4, 4]
26
- num_head_channels: 64
27
- use_linear_in_transformer: True
28
- transformer_depth: 1
29
- context_dim: 1024
30
- spatial_transformer_attn_type: softmax-xformers
31
- extra_ff_mix_layer: True
32
- use_spatial_context: True
33
- merge_strategy: learned_with_images
34
- video_kernel_size: [3, 1, 1]
35
-
36
- conditioner_config:
37
- target: sgm.modules.GeneralConditioner
38
- params:
39
- emb_models:
40
- - is_trainable: False
41
- input_key: cond_frames_without_noise
42
- target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43
- params:
44
- n_cond_frames: 1
45
- n_copies: 1
46
- open_clip_embedding_config:
47
- target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48
- params:
49
- freeze: True
50
-
51
- - input_key: fps_id
52
- is_trainable: False
53
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
54
- params:
55
- outdim: 256
56
-
57
- - input_key: motion_bucket_id
58
- is_trainable: False
59
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
- params:
61
- outdim: 256
62
-
63
- - input_key: cond_frames
64
- is_trainable: False
65
- target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
66
- params:
67
- disable_encoder_autocast: True
68
- n_cond_frames: 1
69
- n_copies: 1
70
- is_ae: True
71
- encoder_config:
72
- target: sgm.models.autoencoder.AutoencoderKLModeOnly
73
- params:
74
- embed_dim: 4
75
- monitor: val/rec_loss
76
- ddconfig:
77
- attn_type: vanilla-xformers
78
- double_z: True
79
- z_channels: 4
80
- resolution: 256
81
- in_channels: 3
82
- out_ch: 3
83
- ch: 128
84
- ch_mult: [1, 2, 4, 4]
85
- num_res_blocks: 2
86
- attn_resolutions: []
87
- dropout: 0.0
88
- lossconfig:
89
- target: torch.nn.Identity
90
-
91
- - input_key: cond_aug
92
- is_trainable: False
93
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94
- params:
95
- outdim: 256
96
-
97
- first_stage_config:
98
- target: sgm.models.autoencoder.AutoencodingEngine
99
- params:
100
- loss_config:
101
- target: torch.nn.Identity
102
- regularizer_config:
103
- target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
104
- encoder_config:
105
- target: sgm.modules.diffusionmodules.model.Encoder
106
- params:
107
- attn_type: vanilla
108
- double_z: True
109
- z_channels: 4
110
- resolution: 256
111
- in_channels: 3
112
- out_ch: 3
113
- ch: 128
114
- ch_mult: [1, 2, 4, 4]
115
- num_res_blocks: 2
116
- attn_resolutions: []
117
- dropout: 0.0
118
- decoder_config:
119
- target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
120
- params:
121
- attn_type: vanilla
122
- double_z: True
123
- z_channels: 4
124
- resolution: 256
125
- in_channels: 3
126
- out_ch: 3
127
- ch: 128
128
- ch_mult: [1, 2, 4, 4]
129
- num_res_blocks: 2
130
- attn_resolutions: []
131
- dropout: 0.0
132
- video_kernel_size: [3, 1, 1]
133
-
134
- sampler_config:
135
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
136
- params:
137
- discretization_config:
138
- target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
139
- params:
140
- sigma_max: 700.0
141
-
142
- guider_config:
143
- target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
144
- params:
145
- max_scale: 2.5
146
- min_scale: 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/sampling/configs/svd_image_decoder.yaml DELETED
@@ -1,129 +0,0 @@
1
- model:
2
- target: sgm.models.diffusion.DiffusionEngine
3
- params:
4
- scale_factor: 0.18215
5
- disable_first_stage_autocast: True
6
- ckpt_path: checkpoints/svd_image_decoder.safetensors
7
-
8
- denoiser_config:
9
- target: sgm.modules.diffusionmodules.denoiser.Denoiser
10
- params:
11
- scaling_config:
12
- target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13
-
14
- network_config:
15
- target: sgm.modules.diffusionmodules.video_model.VideoUNet
16
- params:
17
- adm_in_channels: 768
18
- num_classes: sequential
19
- use_checkpoint: True
20
- in_channels: 8
21
- out_channels: 4
22
- model_channels: 320
23
- attention_resolutions: [4, 2, 1]
24
- num_res_blocks: 2
25
- channel_mult: [1, 2, 4, 4]
26
- num_head_channels: 64
27
- use_linear_in_transformer: True
28
- transformer_depth: 1
29
- context_dim: 1024
30
- spatial_transformer_attn_type: softmax-xformers
31
- extra_ff_mix_layer: True
32
- use_spatial_context: True
33
- merge_strategy: learned_with_images
34
- video_kernel_size: [3, 1, 1]
35
-
36
- conditioner_config:
37
- target: sgm.modules.GeneralConditioner
38
- params:
39
- emb_models:
40
- - is_trainable: False
41
- input_key: cond_frames_without_noise
42
- target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43
- params:
44
- n_cond_frames: 1
45
- n_copies: 1
46
- open_clip_embedding_config:
47
- target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48
- params:
49
- freeze: True
50
-
51
- - input_key: fps_id
52
- is_trainable: False
53
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
54
- params:
55
- outdim: 256
56
-
57
- - input_key: motion_bucket_id
58
- is_trainable: False
59
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
- params:
61
- outdim: 256
62
-
63
- - input_key: cond_frames
64
- is_trainable: False
65
- target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
66
- params:
67
- disable_encoder_autocast: True
68
- n_cond_frames: 1
69
- n_copies: 1
70
- is_ae: True
71
- encoder_config:
72
- target: sgm.models.autoencoder.AutoencoderKLModeOnly
73
- params:
74
- embed_dim: 4
75
- monitor: val/rec_loss
76
- ddconfig:
77
- attn_type: vanilla-xformers
78
- double_z: True
79
- z_channels: 4
80
- resolution: 256
81
- in_channels: 3
82
- out_ch: 3
83
- ch: 128
84
- ch_mult: [1, 2, 4, 4]
85
- num_res_blocks: 2
86
- attn_resolutions: []
87
- dropout: 0.0
88
- lossconfig:
89
- target: torch.nn.Identity
90
-
91
- - input_key: cond_aug
92
- is_trainable: False
93
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94
- params:
95
- outdim: 256
96
-
97
- first_stage_config:
98
- target: sgm.models.autoencoder.AutoencoderKL
99
- params:
100
- embed_dim: 4
101
- monitor: val/rec_loss
102
- ddconfig:
103
- attn_type: vanilla-xformers
104
- double_z: True
105
- z_channels: 4
106
- resolution: 256
107
- in_channels: 3
108
- out_ch: 3
109
- ch: 128
110
- ch_mult: [1, 2, 4, 4]
111
- num_res_blocks: 2
112
- attn_resolutions: []
113
- dropout: 0.0
114
- lossconfig:
115
- target: torch.nn.Identity
116
-
117
- sampler_config:
118
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
119
- params:
120
- discretization_config:
121
- target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
122
- params:
123
- sigma_max: 700.0
124
-
125
- guider_config:
126
- target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
127
- params:
128
- max_scale: 2.5
129
- min_scale: 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/sampling/configs/svd_xt.yaml DELETED
@@ -1,146 +0,0 @@
1
- model:
2
- target: sgm.models.diffusion.DiffusionEngine
3
- params:
4
- scale_factor: 0.18215
5
- disable_first_stage_autocast: True
6
- ckpt_path: checkpoints/svd_xt.safetensors
7
-
8
- denoiser_config:
9
- target: sgm.modules.diffusionmodules.denoiser.Denoiser
10
- params:
11
- scaling_config:
12
- target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13
-
14
- network_config:
15
- target: sgm.modules.diffusionmodules.video_model.VideoUNet
16
- params:
17
- adm_in_channels: 768
18
- num_classes: sequential
19
- use_checkpoint: True
20
- in_channels: 8
21
- out_channels: 4
22
- model_channels: 320
23
- attention_resolutions: [4, 2, 1]
24
- num_res_blocks: 2
25
- channel_mult: [1, 2, 4, 4]
26
- num_head_channels: 64
27
- use_linear_in_transformer: True
28
- transformer_depth: 1
29
- context_dim: 1024
30
- spatial_transformer_attn_type: softmax-xformers
31
- extra_ff_mix_layer: True
32
- use_spatial_context: True
33
- merge_strategy: learned_with_images
34
- video_kernel_size: [3, 1, 1]
35
-
36
- conditioner_config:
37
- target: sgm.modules.GeneralConditioner
38
- params:
39
- emb_models:
40
- - is_trainable: False
41
- input_key: cond_frames_without_noise
42
- target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43
- params:
44
- n_cond_frames: 1
45
- n_copies: 1
46
- open_clip_embedding_config:
47
- target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48
- params:
49
- freeze: True
50
-
51
- - input_key: fps_id
52
- is_trainable: False
53
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
54
- params:
55
- outdim: 256
56
-
57
- - input_key: motion_bucket_id
58
- is_trainable: False
59
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
- params:
61
- outdim: 256
62
-
63
- - input_key: cond_frames
64
- is_trainable: False
65
- target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
66
- params:
67
- disable_encoder_autocast: True
68
- n_cond_frames: 1
69
- n_copies: 1
70
- is_ae: True
71
- encoder_config:
72
- target: sgm.models.autoencoder.AutoencoderKLModeOnly
73
- params:
74
- embed_dim: 4
75
- monitor: val/rec_loss
76
- ddconfig:
77
- attn_type: vanilla-xformers
78
- double_z: True
79
- z_channels: 4
80
- resolution: 256
81
- in_channels: 3
82
- out_ch: 3
83
- ch: 128
84
- ch_mult: [1, 2, 4, 4]
85
- num_res_blocks: 2
86
- attn_resolutions: []
87
- dropout: 0.0
88
- lossconfig:
89
- target: torch.nn.Identity
90
-
91
- - input_key: cond_aug
92
- is_trainable: False
93
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94
- params:
95
- outdim: 256
96
-
97
- first_stage_config:
98
- target: sgm.models.autoencoder.AutoencodingEngine
99
- params:
100
- loss_config:
101
- target: torch.nn.Identity
102
- regularizer_config:
103
- target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
104
- encoder_config:
105
- target: sgm.modules.diffusionmodules.model.Encoder
106
- params:
107
- attn_type: vanilla
108
- double_z: True
109
- z_channels: 4
110
- resolution: 256
111
- in_channels: 3
112
- out_ch: 3
113
- ch: 128
114
- ch_mult: [1, 2, 4, 4]
115
- num_res_blocks: 2
116
- attn_resolutions: []
117
- dropout: 0.0
118
- decoder_config:
119
- target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
120
- params:
121
- attn_type: vanilla
122
- double_z: True
123
- z_channels: 4
124
- resolution: 256
125
- in_channels: 3
126
- out_ch: 3
127
- ch: 128
128
- ch_mult: [1, 2, 4, 4]
129
- num_res_blocks: 2
130
- attn_resolutions: []
131
- dropout: 0.0
132
- video_kernel_size: [3, 1, 1]
133
-
134
- sampler_config:
135
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
136
- params:
137
- discretization_config:
138
- target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
139
- params:
140
- sigma_max: 700.0
141
-
142
- guider_config:
143
- target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
144
- params:
145
- max_scale: 3.0
146
- min_scale: 1.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/sampling/configs/svd_xt_image_decoder.yaml DELETED
@@ -1,129 +0,0 @@
1
- model:
2
- target: sgm.models.diffusion.DiffusionEngine
3
- params:
4
- scale_factor: 0.18215
5
- disable_first_stage_autocast: True
6
- ckpt_path: checkpoints/svd_xt_image_decoder.safetensors
7
-
8
- denoiser_config:
9
- target: sgm.modules.diffusionmodules.denoiser.Denoiser
10
- params:
11
- scaling_config:
12
- target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13
-
14
- network_config:
15
- target: sgm.modules.diffusionmodules.video_model.VideoUNet
16
- params:
17
- adm_in_channels: 768
18
- num_classes: sequential
19
- use_checkpoint: True
20
- in_channels: 8
21
- out_channels: 4
22
- model_channels: 320
23
- attention_resolutions: [4, 2, 1]
24
- num_res_blocks: 2
25
- channel_mult: [1, 2, 4, 4]
26
- num_head_channels: 64
27
- use_linear_in_transformer: True
28
- transformer_depth: 1
29
- context_dim: 1024
30
- spatial_transformer_attn_type: softmax-xformers
31
- extra_ff_mix_layer: True
32
- use_spatial_context: True
33
- merge_strategy: learned_with_images
34
- video_kernel_size: [3, 1, 1]
35
-
36
- conditioner_config:
37
- target: sgm.modules.GeneralConditioner
38
- params:
39
- emb_models:
40
- - is_trainable: False
41
- input_key: cond_frames_without_noise
42
- target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43
- params:
44
- n_cond_frames: 1
45
- n_copies: 1
46
- open_clip_embedding_config:
47
- target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48
- params:
49
- freeze: True
50
-
51
- - input_key: fps_id
52
- is_trainable: False
53
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
54
- params:
55
- outdim: 256
56
-
57
- - input_key: motion_bucket_id
58
- is_trainable: False
59
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
- params:
61
- outdim: 256
62
-
63
- - input_key: cond_frames
64
- is_trainable: False
65
- target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
66
- params:
67
- disable_encoder_autocast: True
68
- n_cond_frames: 1
69
- n_copies: 1
70
- is_ae: True
71
- encoder_config:
72
- target: sgm.models.autoencoder.AutoencoderKLModeOnly
73
- params:
74
- embed_dim: 4
75
- monitor: val/rec_loss
76
- ddconfig:
77
- attn_type: vanilla-xformers
78
- double_z: True
79
- z_channels: 4
80
- resolution: 256
81
- in_channels: 3
82
- out_ch: 3
83
- ch: 128
84
- ch_mult: [1, 2, 4, 4]
85
- num_res_blocks: 2
86
- attn_resolutions: []
87
- dropout: 0.0
88
- lossconfig:
89
- target: torch.nn.Identity
90
-
91
- - input_key: cond_aug
92
- is_trainable: False
93
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94
- params:
95
- outdim: 256
96
-
97
- first_stage_config:
98
- target: sgm.models.autoencoder.AutoencoderKL
99
- params:
100
- embed_dim: 4
101
- monitor: val/rec_loss
102
- ddconfig:
103
- attn_type: vanilla-xformers
104
- double_z: True
105
- z_channels: 4
106
- resolution: 256
107
- in_channels: 3
108
- out_ch: 3
109
- ch: 128
110
- ch_mult: [1, 2, 4, 4]
111
- num_res_blocks: 2
112
- attn_resolutions: []
113
- dropout: 0.0
114
- lossconfig:
115
- target: torch.nn.Identity
116
-
117
- sampler_config:
118
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
119
- params:
120
- discretization_config:
121
- target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
122
- params:
123
- sigma_max: 700.0
124
-
125
- guider_config:
126
- target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
127
- params:
128
- max_scale: 3.0
129
- min_scale: 1.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/sampling/simple_video_sample.py DELETED
@@ -1,278 +0,0 @@
1
- import math
2
- import os
3
- from glob import glob
4
- from pathlib import Path
5
- from typing import Optional
6
-
7
- import cv2
8
- import numpy as np
9
- import torch
10
- from einops import rearrange, repeat
11
- from fire import Fire
12
- from omegaconf import OmegaConf
13
- from PIL import Image
14
- from torchvision.transforms import ToTensor
15
-
16
- from scripts.util.detection.nsfw_and_watermark_dectection import \
17
- DeepFloydDataFiltering
18
- from sgm.inference.helpers import embed_watermark
19
- from sgm.util import default, instantiate_from_config
20
-
21
-
22
- def sample(
23
- input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
24
- num_frames: Optional[int] = None,
25
- num_steps: Optional[int] = None,
26
- version: str = "svd",
27
- fps_id: int = 6,
28
- motion_bucket_id: int = 127,
29
- cond_aug: float = 0.02,
30
- seed: int = 23,
31
- decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
32
- device: str = "cuda",
33
- output_folder: Optional[str] = None,
34
- ):
35
- """
36
- Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
37
- image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
38
- """
39
-
40
- if version == "svd":
41
- num_frames = default(num_frames, 14)
42
- num_steps = default(num_steps, 25)
43
- output_folder = default(output_folder, "outputs/simple_video_sample/svd/")
44
- model_config = "scripts/sampling/configs/svd.yaml"
45
- elif version == "svd_xt":
46
- num_frames = default(num_frames, 25)
47
- num_steps = default(num_steps, 30)
48
- output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/")
49
- model_config = "scripts/sampling/configs/svd_xt.yaml"
50
- elif version == "svd_image_decoder":
51
- num_frames = default(num_frames, 14)
52
- num_steps = default(num_steps, 25)
53
- output_folder = default(
54
- output_folder, "outputs/simple_video_sample/svd_image_decoder/"
55
- )
56
- model_config = "scripts/sampling/configs/svd_image_decoder.yaml"
57
- elif version == "svd_xt_image_decoder":
58
- num_frames = default(num_frames, 25)
59
- num_steps = default(num_steps, 30)
60
- output_folder = default(
61
- output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
62
- )
63
- model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
64
- else:
65
- raise ValueError(f"Version {version} does not exist.")
66
-
67
- model, filter = load_model(
68
- model_config,
69
- device,
70
- num_frames,
71
- num_steps,
72
- )
73
- torch.manual_seed(seed)
74
-
75
- path = Path(input_path)
76
- all_img_paths = []
77
- if path.is_file():
78
- if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
79
- all_img_paths = [input_path]
80
- else:
81
- raise ValueError("Path is not valid image file.")
82
- elif path.is_dir():
83
- all_img_paths = sorted(
84
- [
85
- f
86
- for f in path.iterdir()
87
- if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
88
- ]
89
- )
90
- if len(all_img_paths) == 0:
91
- raise ValueError("Folder does not contain any images.")
92
- else:
93
- raise ValueError
94
-
95
- for input_img_path in all_img_paths:
96
- with Image.open(input_img_path) as image:
97
- if image.mode == "RGBA":
98
- image = image.convert("RGB")
99
- w, h = image.size
100
-
101
- if h % 64 != 0 or w % 64 != 0:
102
- width, height = map(lambda x: x - x % 64, (w, h))
103
- image = image.resize((width, height))
104
- print(
105
- f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
106
- )
107
-
108
- image = ToTensor()(image)
109
- image = image * 2.0 - 1.0
110
-
111
- image = image.unsqueeze(0).to(device)
112
- H, W = image.shape[2:]
113
- assert image.shape[1] == 3
114
- F = 8
115
- C = 4
116
- shape = (num_frames, C, H // F, W // F)
117
- if (H, W) != (576, 1024):
118
- print(
119
- "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
120
- )
121
- if motion_bucket_id > 255:
122
- print(
123
- "WARNING: High motion bucket! This may lead to suboptimal performance."
124
- )
125
-
126
- if fps_id < 5:
127
- print("WARNING: Small fps value! This may lead to suboptimal performance.")
128
-
129
- if fps_id > 30:
130
- print("WARNING: Large fps value! This may lead to suboptimal performance.")
131
-
132
- value_dict = {}
133
- value_dict["motion_bucket_id"] = motion_bucket_id
134
- value_dict["fps_id"] = fps_id
135
- value_dict["cond_aug"] = cond_aug
136
- value_dict["cond_frames_without_noise"] = image
137
- value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
138
- value_dict["cond_aug"] = cond_aug
139
-
140
- with torch.no_grad():
141
- with torch.autocast(device):
142
- batch, batch_uc = get_batch(
143
- get_unique_embedder_keys_from_conditioner(model.conditioner),
144
- value_dict,
145
- [1, num_frames],
146
- T=num_frames,
147
- device=device,
148
- )
149
- c, uc = model.conditioner.get_unconditional_conditioning(
150
- batch,
151
- batch_uc=batch_uc,
152
- force_uc_zero_embeddings=[
153
- "cond_frames",
154
- "cond_frames_without_noise",
155
- ],
156
- )
157
-
158
- for k in ["crossattn", "concat"]:
159
- uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
160
- uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
161
- c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
162
- c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
163
-
164
- randn = torch.randn(shape, device=device)
165
-
166
- additional_model_inputs = {}
167
- additional_model_inputs["image_only_indicator"] = torch.zeros(
168
- 2, num_frames
169
- ).to(device)
170
- additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
171
-
172
- def denoiser(input, sigma, c):
173
- return model.denoiser(
174
- model.model, input, sigma, c, **additional_model_inputs
175
- )
176
-
177
- samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
178
- model.en_and_decode_n_samples_a_time = decoding_t
179
- samples_x = model.decode_first_stage(samples_z)
180
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
181
-
182
- os.makedirs(output_folder, exist_ok=True)
183
- base_count = len(glob(os.path.join(output_folder, "*.mp4")))
184
- video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
185
- writer = cv2.VideoWriter(
186
- video_path,
187
- cv2.VideoWriter_fourcc(*"MP4V"),
188
- fps_id + 1,
189
- (samples.shape[-1], samples.shape[-2]),
190
- )
191
-
192
- samples = embed_watermark(samples)
193
- samples = filter(samples)
194
- vid = (
195
- (rearrange(samples, "t c h w -> t h w c") * 255)
196
- .cpu()
197
- .numpy()
198
- .astype(np.uint8)
199
- )
200
- for frame in vid:
201
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
202
- writer.write(frame)
203
- writer.release()
204
-
205
-
206
- def get_unique_embedder_keys_from_conditioner(conditioner):
207
- return list(set([x.input_key for x in conditioner.embedders]))
208
-
209
-
210
- def get_batch(keys, value_dict, N, T, device):
211
- batch = {}
212
- batch_uc = {}
213
-
214
- for key in keys:
215
- if key == "fps_id":
216
- batch[key] = (
217
- torch.tensor([value_dict["fps_id"]])
218
- .to(device)
219
- .repeat(int(math.prod(N)))
220
- )
221
- elif key == "motion_bucket_id":
222
- batch[key] = (
223
- torch.tensor([value_dict["motion_bucket_id"]])
224
- .to(device)
225
- .repeat(int(math.prod(N)))
226
- )
227
- elif key == "cond_aug":
228
- batch[key] = repeat(
229
- torch.tensor([value_dict["cond_aug"]]).to(device),
230
- "1 -> b",
231
- b=math.prod(N),
232
- )
233
- elif key == "cond_frames":
234
- batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
235
- elif key == "cond_frames_without_noise":
236
- batch[key] = repeat(
237
- value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
238
- )
239
- else:
240
- batch[key] = value_dict[key]
241
-
242
- if T is not None:
243
- batch["num_video_frames"] = T
244
-
245
- for key in batch.keys():
246
- if key not in batch_uc and isinstance(batch[key], torch.Tensor):
247
- batch_uc[key] = torch.clone(batch[key])
248
- return batch, batch_uc
249
-
250
-
251
- def load_model(
252
- config: str,
253
- device: str,
254
- num_frames: int,
255
- num_steps: int,
256
- ):
257
- config = OmegaConf.load(config)
258
- if device == "cuda":
259
- config.model.params.conditioner_config.params.emb_models[
260
- 0
261
- ].params.open_clip_embedding_config.params.init_device = device
262
-
263
- config.model.params.sampler_config.params.num_steps = num_steps
264
- config.model.params.sampler_config.params.guider_config.params.num_frames = (
265
- num_frames
266
- )
267
- if device == "cuda":
268
- with torch.device(device):
269
- model = instantiate_from_config(config.model).to(device).eval()
270
- else:
271
- model = instantiate_from_config(config.model).to(device).eval()
272
-
273
- filter = DeepFloydDataFiltering(verbose=False, device=device)
274
- return model, filter
275
-
276
-
277
- if __name__ == "__main__":
278
- Fire(sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/tests/attention.py DELETED
@@ -1,319 +0,0 @@
1
- import einops
2
- import torch
3
- import torch.nn.functional as F
4
- import torch.utils.benchmark as benchmark
5
- from torch.backends.cuda import SDPBackend
6
-
7
- from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
8
-
9
-
10
- def benchmark_attn():
11
- # Lets define a helpful benchmarking function:
12
- # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
-
15
- def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
16
- t0 = benchmark.Timer(
17
- stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
18
- )
19
- return t0.blocked_autorange().mean * 1e6
20
-
21
- # Lets define the hyper-parameters of our input
22
- batch_size = 32
23
- max_sequence_len = 1024
24
- num_heads = 32
25
- embed_dimension = 32
26
-
27
- dtype = torch.float16
28
-
29
- query = torch.rand(
30
- batch_size,
31
- num_heads,
32
- max_sequence_len,
33
- embed_dimension,
34
- device=device,
35
- dtype=dtype,
36
- )
37
- key = torch.rand(
38
- batch_size,
39
- num_heads,
40
- max_sequence_len,
41
- embed_dimension,
42
- device=device,
43
- dtype=dtype,
44
- )
45
- value = torch.rand(
46
- batch_size,
47
- num_heads,
48
- max_sequence_len,
49
- embed_dimension,
50
- device=device,
51
- dtype=dtype,
52
- )
53
-
54
- print(f"q/k/v shape:", query.shape, key.shape, value.shape)
55
-
56
- # Lets explore the speed of each of the 3 implementations
57
- from torch.backends.cuda import SDPBackend, sdp_kernel
58
-
59
- # Helpful arguments mapper
60
- backend_map = {
61
- SDPBackend.MATH: {
62
- "enable_math": True,
63
- "enable_flash": False,
64
- "enable_mem_efficient": False,
65
- },
66
- SDPBackend.FLASH_ATTENTION: {
67
- "enable_math": False,
68
- "enable_flash": True,
69
- "enable_mem_efficient": False,
70
- },
71
- SDPBackend.EFFICIENT_ATTENTION: {
72
- "enable_math": False,
73
- "enable_flash": False,
74
- "enable_mem_efficient": True,
75
- },
76
- }
77
-
78
- from torch.profiler import ProfilerActivity, profile, record_function
79
-
80
- activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
81
-
82
- print(
83
- f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
84
- )
85
- with profile(
86
- activities=activities, record_shapes=False, profile_memory=True
87
- ) as prof:
88
- with record_function("Default detailed stats"):
89
- for _ in range(25):
90
- o = F.scaled_dot_product_attention(query, key, value)
91
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
92
-
93
- print(
94
- f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
95
- )
96
- with sdp_kernel(**backend_map[SDPBackend.MATH]):
97
- with profile(
98
- activities=activities, record_shapes=False, profile_memory=True
99
- ) as prof:
100
- with record_function("Math implmentation stats"):
101
- for _ in range(25):
102
- o = F.scaled_dot_product_attention(query, key, value)
103
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
104
-
105
- with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
106
- try:
107
- print(
108
- f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
109
- )
110
- except RuntimeError:
111
- print("FlashAttention is not supported. See warnings for reasons.")
112
- with profile(
113
- activities=activities, record_shapes=False, profile_memory=True
114
- ) as prof:
115
- with record_function("FlashAttention stats"):
116
- for _ in range(25):
117
- o = F.scaled_dot_product_attention(query, key, value)
118
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
119
-
120
- with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
121
- try:
122
- print(
123
- f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
124
- )
125
- except RuntimeError:
126
- print("EfficientAttention is not supported. See warnings for reasons.")
127
- with profile(
128
- activities=activities, record_shapes=False, profile_memory=True
129
- ) as prof:
130
- with record_function("EfficientAttention stats"):
131
- for _ in range(25):
132
- o = F.scaled_dot_product_attention(query, key, value)
133
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
134
-
135
-
136
- def run_model(model, x, context):
137
- return model(x, context)
138
-
139
-
140
- def benchmark_transformer_blocks():
141
- device = "cuda" if torch.cuda.is_available() else "cpu"
142
- import torch.utils.benchmark as benchmark
143
-
144
- def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
145
- t0 = benchmark.Timer(
146
- stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
147
- )
148
- return t0.blocked_autorange().mean * 1e6
149
-
150
- checkpoint = True
151
- compile = False
152
-
153
- batch_size = 32
154
- h, w = 64, 64
155
- context_len = 77
156
- embed_dimension = 1024
157
- context_dim = 1024
158
- d_head = 64
159
-
160
- transformer_depth = 4
161
-
162
- n_heads = embed_dimension // d_head
163
-
164
- dtype = torch.float16
165
-
166
- model_native = SpatialTransformer(
167
- embed_dimension,
168
- n_heads,
169
- d_head,
170
- context_dim=context_dim,
171
- use_linear=True,
172
- use_checkpoint=checkpoint,
173
- attn_type="softmax",
174
- depth=transformer_depth,
175
- sdp_backend=SDPBackend.FLASH_ATTENTION,
176
- ).to(device)
177
- model_efficient_attn = SpatialTransformer(
178
- embed_dimension,
179
- n_heads,
180
- d_head,
181
- context_dim=context_dim,
182
- use_linear=True,
183
- depth=transformer_depth,
184
- use_checkpoint=checkpoint,
185
- attn_type="softmax-xformers",
186
- ).to(device)
187
- if not checkpoint and compile:
188
- print("compiling models")
189
- model_native = torch.compile(model_native)
190
- model_efficient_attn = torch.compile(model_efficient_attn)
191
-
192
- x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
193
- c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
194
-
195
- from torch.profiler import ProfilerActivity, profile, record_function
196
-
197
- activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
198
-
199
- with torch.autocast("cuda"):
200
- print(
201
- f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
202
- )
203
- print(
204
- f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
205
- )
206
-
207
- print(75 * "+")
208
- print("NATIVE")
209
- print(75 * "+")
210
- torch.cuda.reset_peak_memory_stats()
211
- with profile(
212
- activities=activities, record_shapes=False, profile_memory=True
213
- ) as prof:
214
- with record_function("NativeAttention stats"):
215
- for _ in range(25):
216
- model_native(x, c)
217
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
218
- print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
219
-
220
- print(75 * "+")
221
- print("Xformers")
222
- print(75 * "+")
223
- torch.cuda.reset_peak_memory_stats()
224
- with profile(
225
- activities=activities, record_shapes=False, profile_memory=True
226
- ) as prof:
227
- with record_function("xformers stats"):
228
- for _ in range(25):
229
- model_efficient_attn(x, c)
230
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
231
- print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
232
-
233
-
234
- def test01():
235
- # conv1x1 vs linear
236
- from sgm.util import count_params
237
-
238
- conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
239
- print(count_params(conv))
240
- linear = torch.nn.Linear(3, 32).cuda()
241
- print(count_params(linear))
242
-
243
- print(conv.weight.shape)
244
-
245
- # use same initialization
246
- linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
247
- linear.bias = torch.nn.Parameter(conv.bias)
248
-
249
- print(linear.weight.shape)
250
-
251
- x = torch.randn(11, 3, 64, 64).cuda()
252
-
253
- xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
254
- print(xr.shape)
255
- out_linear = linear(xr)
256
- print(out_linear.mean(), out_linear.shape)
257
-
258
- out_conv = conv(x)
259
- print(out_conv.mean(), out_conv.shape)
260
- print("done with test01.\n")
261
-
262
-
263
- def test02():
264
- # try cosine flash attention
265
- import time
266
-
267
- torch.backends.cuda.matmul.allow_tf32 = True
268
- torch.backends.cudnn.allow_tf32 = True
269
- torch.backends.cudnn.benchmark = True
270
- print("testing cosine flash attention...")
271
- DIM = 1024
272
- SEQLEN = 4096
273
- BS = 16
274
-
275
- print(" softmax (vanilla) first...")
276
- model = BasicTransformerBlock(
277
- dim=DIM,
278
- n_heads=16,
279
- d_head=64,
280
- dropout=0.0,
281
- context_dim=None,
282
- attn_mode="softmax",
283
- ).cuda()
284
- try:
285
- x = torch.randn(BS, SEQLEN, DIM).cuda()
286
- tic = time.time()
287
- y = model(x)
288
- toc = time.time()
289
- print(y.shape, toc - tic)
290
- except RuntimeError as e:
291
- # likely oom
292
- print(str(e))
293
-
294
- print("\n now flash-cosine...")
295
- model = BasicTransformerBlock(
296
- dim=DIM,
297
- n_heads=16,
298
- d_head=64,
299
- dropout=0.0,
300
- context_dim=None,
301
- attn_mode="flash-cosine",
302
- ).cuda()
303
- x = torch.randn(BS, SEQLEN, DIM).cuda()
304
- tic = time.time()
305
- y = model(x)
306
- toc = time.time()
307
- print(y.shape, toc - tic)
308
- print("done with test02.\n")
309
-
310
-
311
- if __name__ == "__main__":
312
- # test01()
313
- # test02()
314
- # test03()
315
-
316
- # benchmark_attn()
317
- benchmark_transformer_blocks()
318
-
319
- print("done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/util/__init__.py DELETED
File without changes
scripts/util/detection/__init__.py DELETED
File without changes
scripts/util/detection/nsfw_and_watermark_dectection.py DELETED
@@ -1,110 +0,0 @@
1
- import os
2
-
3
- import clip
4
- import numpy as np
5
- import torch
6
- import torchvision.transforms as T
7
- from PIL import Image
8
-
9
- RESOURCES_ROOT = "scripts/util/detection/"
10
-
11
-
12
- def predict_proba(X, weights, biases):
13
- logits = X @ weights.T + biases
14
- proba = np.where(
15
- logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))
16
- )
17
- return proba.T
18
-
19
-
20
- def load_model_weights(path: str):
21
- model_weights = np.load(path)
22
- return model_weights["weights"], model_weights["biases"]
23
-
24
-
25
- def clip_process_images(images: torch.Tensor) -> torch.Tensor:
26
- min_size = min(images.shape[-2:])
27
- return T.Compose(
28
- [
29
- T.CenterCrop(min_size), # TODO: this might affect the watermark, check this
30
- T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
31
- T.Normalize(
32
- (0.48145466, 0.4578275, 0.40821073),
33
- (0.26862954, 0.26130258, 0.27577711),
34
- ),
35
- ]
36
- )(images)
37
-
38
-
39
- class DeepFloydDataFiltering(object):
40
- def __init__(
41
- self, verbose: bool = False, device: torch.device = torch.device("cpu")
42
- ):
43
- super().__init__()
44
- self.verbose = verbose
45
- self._device = None
46
- self.clip_model, _ = clip.load("ViT-L/14", device=device)
47
- self.clip_model.eval()
48
-
49
- self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
50
- os.path.join(RESOURCES_ROOT, "w_head_v1.npz")
51
- )
52
- self.cpu_p_weights, self.cpu_p_biases = load_model_weights(
53
- os.path.join(RESOURCES_ROOT, "p_head_v1.npz")
54
- )
55
- self.w_threshold, self.p_threshold = 0.5, 0.5
56
-
57
- @torch.inference_mode()
58
- def __call__(self, images: torch.Tensor) -> torch.Tensor:
59
- imgs = clip_process_images(images)
60
- if self._device is None:
61
- self._device = next(p for p in self.clip_model.parameters()).device
62
- image_features = self.clip_model.encode_image(imgs.to(self._device))
63
- image_features = image_features.detach().cpu().numpy().astype(np.float16)
64
- p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
65
- w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
66
- print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None
67
- query = p_pred > self.p_threshold
68
- if query.sum() > 0:
69
- print(f"Hit for p_threshold: {p_pred}") if self.verbose else None
70
- images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
71
- query = w_pred > self.w_threshold
72
- if query.sum() > 0:
73
- print(f"Hit for w_threshold: {w_pred}") if self.verbose else None
74
- images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
75
- return images
76
-
77
-
78
- def load_img(path: str) -> torch.Tensor:
79
- image = Image.open(path)
80
- if not image.mode == "RGB":
81
- image = image.convert("RGB")
82
- image_transforms = T.Compose(
83
- [
84
- T.ToTensor(),
85
- ]
86
- )
87
- return image_transforms(image)[None, ...]
88
-
89
-
90
- def test(root):
91
- from einops import rearrange
92
-
93
- filter = DeepFloydDataFiltering(verbose=True)
94
- for p in os.listdir((root)):
95
- print(f"running on {p}...")
96
- img = load_img(os.path.join(root, p))
97
- filtered_img = filter(img)
98
- filtered_img = rearrange(
99
- 255.0 * (filtered_img.numpy())[0], "c h w -> h w c"
100
- ).astype(np.uint8)
101
- Image.fromarray(filtered_img).save(
102
- os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg")
103
- )
104
-
105
-
106
- if __name__ == "__main__":
107
- import fire
108
-
109
- fire.Fire(test)
110
- print("done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/util/detection/p_head_v1.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b4653a64d5f85d8d4c5f6c5ec175f1c5c5e37db8f38d39b2ed8b5979da7fdc76
3
- size 3588
 
 
 
 
scripts/util/detection/w_head_v1.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b6af23687aa347073e692025f405ccc48c14aadc5dbe775b3312041006d496d1
3
- size 3588