multimodalart HF staff commited on
Commit
9ac31b8
β€’
1 Parent(s): 52a8e24

diffusers-backend (#6)

Browse files

- Delete assets (537ebb9bdae4dc86cc0b00b854f1eae9b32cd7bb)
- Delete configs (adcff078e9123ce669443f2e1d28516c05b7c455)
- Delete data (79751902793a42fb8f0329a80fd132d7d6f85ec8)
- Delete model_licenses (67be9bc7732e83a8b4dde21fa0311b8c2f787185)
- Delete requirements (c1a4c611db4c97cd499e036fa3111c785cd9f4ba)
- Delete scripts (90e5afe600048480e16ac285d298378b9c3ac9b1)
- Delete sgm (097567a9522fbd31df1e8f4ec7a360c02947f0d3)
- Delete tests (fa8cc685cd27c368611079175724d895c067742c)
- Delete .gitattributes (4d403cfb820b8f14ef17e6b7e97d16a8a4fa850d)
- Delete CODEOWNERS (5ff14e288cc2336b802e82a12f2f5a9b42bb4acf)
- Delete LICENSE-CODE (ef217859b8e2bdc2bcd8c757020356e51a1e1254)
- Delete main.py (422da78c99f88a7edbeb6e277a3c8cb0bb46e691)
- Delete pyproject.toml (9cb3ce47320995317454fb907a5a3ef9e47e8049)
- Delete pytest.ini (f3f18e3224b316f8894e505b90f05f05aeaad782)
- Delete simple_video_sample.py (a7729a11a558073e33666989f79521bcb4c59c5a)
- Update app.py (9ca6c3055abf1a8603c1a274aa9c0fdab2910f61)
- Update requirements.txt (0b4d79b2785491369787f19453b24b0d25bcd283)
- Update app.py (6e3b09a7cfa77adde712e954cebf68001facc1b4)
- Update app.py (e2530d1596d372176fe7707e1ff50d7c002dcda4)
- Update app.py (f7d4d47c674580f25d5f99f584dc415d81428cc7)
- Upload 10 files (0629c99de8869dc16cb9bfff11a6a93bb9fddb6f)
- Update app.py (2beda72ac79f28c2400a8e81dd9d7606d3c95ef9)

This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitattributes +0 -38
  2. CODEOWNERS +0 -1
  3. LICENSE-CODE +0 -21
  4. app.py +45 -211
  5. assets/000.jpg +0 -0
  6. assets/001_with_eval.png +0 -3
  7. assets/test_image.png +0 -0
  8. assets/tile.gif +0 -3
  9. configs/.DS_Store +0 -0
  10. configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml +0 -104
  11. configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml +0 -105
  12. configs/example_training/imagenet-f8_cond.yaml +0 -185
  13. configs/example_training/toy/cifar10_cond.yaml +0 -98
  14. configs/example_training/toy/mnist.yaml +0 -79
  15. configs/example_training/toy/mnist_cond.yaml +0 -98
  16. configs/example_training/toy/mnist_cond_discrete_eps.yaml +0 -103
  17. configs/example_training/toy/mnist_cond_l1_loss.yaml +0 -99
  18. configs/example_training/toy/mnist_cond_with_ema.yaml +0 -100
  19. configs/example_training/txt2img-clipl-legacy-ucg-training.yaml +0 -182
  20. configs/example_training/txt2img-clipl.yaml +0 -184
  21. configs/inference/sd_2_1.yaml +0 -60
  22. configs/inference/sd_2_1_768.yaml +0 -60
  23. configs/inference/sd_xl_base.yaml +0 -93
  24. configs/inference/sd_xl_refiner.yaml +0 -86
  25. configs/inference/svd.yaml +0 -131
  26. configs/inference/svd_image_decoder.yaml +0 -114
  27. data/DejaVuSans.ttf +0 -0
  28. images/blink_meme.png +0 -0
  29. images/confused2_meme.png +0 -0
  30. images/confused_meme.png +0 -0
  31. images/disaster_meme.png +0 -0
  32. images/distracted_meme.png +0 -0
  33. images/hide_meme.png +0 -0
  34. images/nazare_meme.png +0 -0
  35. images/success_meme.png +0 -0
  36. images/willy_meme.png +0 -0
  37. images/wink_meme.png +0 -0
  38. main.py +0 -943
  39. model_licenses/LICENSE-SDV +0 -31
  40. model_licenses/LICENSE-SDXL0.9 +0 -75
  41. model_licenses/LICENSE-SDXL1.0 +0 -175
  42. pyproject.toml +0 -48
  43. pytest.ini +0 -3
  44. requirements.txt +5 -40
  45. requirements/pt13.txt +0 -40
  46. requirements/pt2.txt +0 -39
  47. scripts/.DS_Store +0 -0
  48. scripts/__init__.py +0 -0
  49. scripts/demo/__init__.py +0 -0
  50. scripts/demo/detect.py +0 -156
.gitattributes DELETED
@@ -1,38 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- assets/001_with_eval.png filter=lfs diff=lfs merge=lfs -text
37
- assets/tile.gif filter=lfs diff=lfs merge=lfs -text
38
- outputs/000004.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
CODEOWNERS DELETED
@@ -1 +0,0 @@
1
- .github @Stability-AI/infrastructure
 
 
LICENSE-CODE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2023 Stability AI
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,243 +1,59 @@
1
- import math
 
 
2
  import os
3
  from glob import glob
4
  from pathlib import Path
5
  from typing import Optional
6
 
7
- import cv2
8
- import numpy as np
9
- import torch
10
- from einops import rearrange, repeat
11
- from fire import Fire
12
- from omegaconf import OmegaConf
13
  from PIL import Image
14
- from torchvision.transforms import ToTensor
15
 
16
- from scripts.util.detection.nsfw_and_watermark_dectection import \
17
- DeepFloydDataFiltering
18
- from sgm.inference.helpers import embed_watermark
19
- from sgm.util import default, instantiate_from_config
20
-
21
- import gradio as gr
22
  import uuid
23
  import random
24
  from huggingface_hub import hf_hub_download
25
 
26
- hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints")
27
-
28
- version = "svd_xt"
29
- device = "cuda"
30
- max_64_bit_int = 2**63 - 1
31
-
32
- def load_model(
33
- config: str,
34
- device: str,
35
- num_frames: int,
36
- num_steps: int,
37
- ):
38
- config = OmegaConf.load(config)
39
- if device == "cuda":
40
- config.model.params.conditioner_config.params.emb_models[
41
- 0
42
- ].params.open_clip_embedding_config.params.init_device = device
43
-
44
- config.model.params.sampler_config.params.num_steps = num_steps
45
- config.model.params.sampler_config.params.guider_config.params.num_frames = (
46
- num_frames
47
- )
48
- if device == "cuda":
49
- with torch.device(device):
50
- model = instantiate_from_config(config.model).to(device).eval()
51
- else:
52
- model = instantiate_from_config(config.model).to(device).eval()
53
-
54
- filter = DeepFloydDataFiltering(verbose=False, device=device)
55
- return model, filter
56
-
57
- if version == "svd_xt":
58
- num_frames = 25
59
- num_steps = 30
60
- model_config = "scripts/sampling/configs/svd_xt.yaml"
61
- else:
62
- raise ValueError(f"Version {version} does not exist.")
63
 
64
- model, filter = load_model(
65
- model_config,
66
- device,
67
- num_frames,
68
- num_steps,
69
  )
 
 
 
 
 
70
 
71
  def sample(
72
  image: Image,
73
- seed: Optional[int] = None,
74
  randomize_seed: bool = True,
75
  motion_bucket_id: int = 127,
76
  fps_id: int = 6,
77
  version: str = "svd_xt",
78
  cond_aug: float = 0.02,
79
- decoding_t: int = 5, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
80
  device: str = "cuda",
81
  output_folder: str = "outputs",
82
- progress=gr.Progress(track_tqdm=True)
83
  ):
 
 
 
84
  if(randomize_seed):
85
  seed = random.randint(0, max_64_bit_int)
86
-
87
- torch.manual_seed(seed)
88
 
89
- if image.mode == "RGBA":
90
- image = image.convert("RGB")
91
- w, h = image.size
92
 
93
- if h % 64 != 0 or w % 64 != 0:
94
- width, height = map(lambda x: x - x % 64, (w, h))
95
- image = image.resize((width, height))
96
- print(
97
- f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
98
- )
99
-
100
- image = ToTensor()(image)
101
- image = image * 2.0 - 1.0
102
- image = image.unsqueeze(0).to(device)
103
- H, W = image.shape[2:]
104
- assert image.shape[1] == 3
105
- F = 8
106
- C = 4
107
- shape = (num_frames, C, H // F, W // F)
108
- if (H, W) != (576, 1024):
109
- print(
110
- "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
111
- )
112
- if motion_bucket_id > 255:
113
- print(
114
- "WARNING: High motion bucket! This may lead to suboptimal performance."
115
- )
116
-
117
- if fps_id < 5:
118
- print("WARNING: Small fps value! This may lead to suboptimal performance.")
119
-
120
- if fps_id > 30:
121
- print("WARNING: Large fps value! This may lead to suboptimal performance.")
122
-
123
- value_dict = {}
124
- value_dict["motion_bucket_id"] = motion_bucket_id
125
- value_dict["fps_id"] = fps_id
126
- value_dict["cond_aug"] = cond_aug
127
- value_dict["cond_frames_without_noise"] = image
128
- value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
129
- value_dict["cond_aug"] = cond_aug
130
-
131
- with torch.no_grad():
132
- with torch.autocast(device):
133
- batch, batch_uc = get_batch(
134
- get_unique_embedder_keys_from_conditioner(model.conditioner),
135
- value_dict,
136
- [1, num_frames],
137
- T=num_frames,
138
- device=device,
139
- )
140
- c, uc = model.conditioner.get_unconditional_conditioning(
141
- batch,
142
- batch_uc=batch_uc,
143
- force_uc_zero_embeddings=[
144
- "cond_frames",
145
- "cond_frames_without_noise",
146
- ],
147
- )
148
-
149
- for k in ["crossattn", "concat"]:
150
- uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
151
- uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
152
- c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
153
- c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
154
-
155
- randn = torch.randn(shape, device=device)
156
-
157
- additional_model_inputs = {}
158
- additional_model_inputs["image_only_indicator"] = torch.zeros(
159
- 2, num_frames
160
- ).to(device)
161
- additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
162
-
163
- def denoiser(input, sigma, c):
164
- return model.denoiser(
165
- model.model, input, sigma, c, **additional_model_inputs
166
- )
167
-
168
- samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
169
- model.en_and_decode_n_samples_a_time = decoding_t
170
- samples_x = model.decode_first_stage(samples_z)
171
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
172
-
173
- os.makedirs(output_folder, exist_ok=True)
174
- base_count = len(glob(os.path.join(output_folder, "*.mp4")))
175
- video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
176
- writer = cv2.VideoWriter(
177
- video_path,
178
- cv2.VideoWriter_fourcc(*"mp4v"),
179
- fps_id + 1,
180
- (samples.shape[-1], samples.shape[-2]),
181
- )
182
-
183
- samples = embed_watermark(samples)
184
- samples = filter(samples)
185
- vid = (
186
- (rearrange(samples, "t c h w -> t h w c") * 255)
187
- .cpu()
188
- .numpy()
189
- .astype(np.uint8)
190
- )
191
- for frame in vid:
192
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
193
- writer.write(frame)
194
- writer.release()
195
  return video_path, seed
196
 
197
- def get_unique_embedder_keys_from_conditioner(conditioner):
198
- return list(set([x.input_key for x in conditioner.embedders]))
199
-
200
-
201
- def get_batch(keys, value_dict, N, T, device):
202
- batch = {}
203
- batch_uc = {}
204
-
205
- for key in keys:
206
- if key == "fps_id":
207
- batch[key] = (
208
- torch.tensor([value_dict["fps_id"]])
209
- .to(device)
210
- .repeat(int(math.prod(N)))
211
- )
212
- elif key == "motion_bucket_id":
213
- batch[key] = (
214
- torch.tensor([value_dict["motion_bucket_id"]])
215
- .to(device)
216
- .repeat(int(math.prod(N)))
217
- )
218
- elif key == "cond_aug":
219
- batch[key] = repeat(
220
- torch.tensor([value_dict["cond_aug"]]).to(device),
221
- "1 -> b",
222
- b=math.prod(N),
223
- )
224
- elif key == "cond_frames":
225
- batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
226
- elif key == "cond_frames_without_noise":
227
- batch[key] = repeat(
228
- value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
229
- )
230
- else:
231
- batch[key] = value_dict[key]
232
-
233
- if T is not None:
234
- batch["num_video_frames"] = T
235
-
236
- for key in batch.keys():
237
- if key not in batch_uc and isinstance(batch[key], torch.Tensor):
238
- batch_uc[key] = torch.clone(batch[key])
239
- return batch, batch_uc
240
-
241
  def resize_image(image, output_size=(1024, 576)):
242
  # Calculate aspect ratios
243
  target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
@@ -286,7 +102,25 @@ with gr.Blocks() as demo:
286
 
287
  image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
288
  generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")
289
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  if __name__ == "__main__":
291
  demo.queue(max_size=20)
292
  demo.launch(share=True)
 
1
+ import gradio as gr
2
+ import gradio.helpers
3
+ import torch
4
  import os
5
  from glob import glob
6
  from pathlib import Path
7
  from typing import Optional
8
 
9
+ from diffusers import StableVideoDiffusionPipeline
10
+ from diffusers.utils import load_image, export_to_video
 
 
 
 
11
  from PIL import Image
 
12
 
 
 
 
 
 
 
13
  import uuid
14
  import random
15
  from huggingface_hub import hf_hub_download
16
 
17
+ gradio.helpers.CACHED_FOLDER = '/data/cache'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ pipe = StableVideoDiffusionPipeline.from_pretrained(
20
+ "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
 
 
 
21
  )
22
+ pipe.to("cuda")
23
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
24
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
25
+
26
+ max_64_bit_int = 2**63 - 1
27
 
28
  def sample(
29
  image: Image,
30
+ seed: Optional[int] = 42,
31
  randomize_seed: bool = True,
32
  motion_bucket_id: int = 127,
33
  fps_id: int = 6,
34
  version: str = "svd_xt",
35
  cond_aug: float = 0.02,
36
+ decoding_t: int = 3, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
37
  device: str = "cuda",
38
  output_folder: str = "outputs",
 
39
  ):
40
+ if image.mode == "RGBA":
41
+ image = image.convert("RGB")
42
+
43
  if(randomize_seed):
44
  seed = random.randint(0, max_64_bit_int)
45
+ generator = torch.manual_seed(seed)
 
46
 
47
+ os.makedirs(output_folder, exist_ok=True)
48
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
49
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
50
 
51
+ frames = pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=0.1).frames[0]
52
+ export_to_video(frames, video_path, fps=fps_id)
53
+ torch.manual_seed(seed)
54
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  return video_path, seed
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def resize_image(image, output_size=(1024, 576)):
58
  # Calculate aspect ratios
59
  target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
 
102
 
103
  image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
104
  generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")
105
+ gr.Examples(
106
+ examples=[
107
+ "images/blink_meme.png",
108
+ "images/confused2_meme.png",
109
+ "images/confused_meme.png",
110
+ "images/disaster_meme.png",
111
+ "images/distracted_meme.png",
112
+ "images/hide_meme.png",
113
+ "images/nazare_meme.png",
114
+ "images/success_meme.png",
115
+ "images/willy_meme.png",
116
+ "images/wink_meme.png"
117
+ ],
118
+ inputs=image,
119
+ outputs=[video, seed],
120
+ fn=sample,
121
+ cache_examples=True,
122
+ )
123
+
124
  if __name__ == "__main__":
125
  demo.queue(max_size=20)
126
  demo.launch(share=True)
assets/000.jpg DELETED
Binary file (728 kB)
 
assets/001_with_eval.png DELETED

Git LFS Details

  • SHA256: 026fa14e30098729064a00fb7fcec41bb57dcddb33b36b548d553f601bc53634
  • Pointer size: 132 Bytes
  • Size of remote file: 4.19 MB
assets/test_image.png DELETED
Binary file (494 kB)
 
assets/tile.gif DELETED

Git LFS Details

  • SHA256: 2340a9809e36fa9634633c7cc5fd256737c620ba47151726c85173512dc5c8ff
  • Pointer size: 133 Bytes
  • Size of remote file: 18.6 MB
configs/.DS_Store DELETED
Binary file (6.15 kB)
 
configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml DELETED
@@ -1,104 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-6
3
- target: sgm.models.autoencoder.AutoencodingEngine
4
- params:
5
- input_key: jpg
6
- monitor: val/rec_loss
7
-
8
- loss_config:
9
- target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
10
- params:
11
- perceptual_weight: 0.25
12
- disc_start: 20001
13
- disc_weight: 0.5
14
- learn_logvar: True
15
-
16
- regularization_weights:
17
- kl_loss: 1.0
18
-
19
- regularizer_config:
20
- target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
21
-
22
- encoder_config:
23
- target: sgm.modules.diffusionmodules.model.Encoder
24
- params:
25
- attn_type: none
26
- double_z: True
27
- z_channels: 4
28
- resolution: 256
29
- in_channels: 3
30
- out_ch: 3
31
- ch: 128
32
- ch_mult: [1, 2, 4]
33
- num_res_blocks: 4
34
- attn_resolutions: []
35
- dropout: 0.0
36
-
37
- decoder_config:
38
- target: sgm.modules.diffusionmodules.model.Decoder
39
- params: ${model.params.encoder_config.params}
40
-
41
- data:
42
- target: sgm.data.dataset.StableDataModuleFromConfig
43
- params:
44
- train:
45
- datapipeline:
46
- urls:
47
- - DATA-PATH
48
- pipeline_config:
49
- shardshuffle: 10000
50
- sample_shuffle: 10000
51
-
52
- decoders:
53
- - pil
54
-
55
- postprocessors:
56
- - target: sdata.mappers.TorchVisionImageTransforms
57
- params:
58
- key: jpg
59
- transforms:
60
- - target: torchvision.transforms.Resize
61
- params:
62
- size: 256
63
- interpolation: 3
64
- - target: torchvision.transforms.ToTensor
65
- - target: sdata.mappers.Rescaler
66
- - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
67
- params:
68
- h_key: height
69
- w_key: width
70
-
71
- loader:
72
- batch_size: 8
73
- num_workers: 4
74
-
75
-
76
- lightning:
77
- strategy:
78
- target: pytorch_lightning.strategies.DDPStrategy
79
- params:
80
- find_unused_parameters: True
81
-
82
- modelcheckpoint:
83
- params:
84
- every_n_train_steps: 5000
85
-
86
- callbacks:
87
- metrics_over_trainsteps_checkpoint:
88
- params:
89
- every_n_train_steps: 50000
90
-
91
- image_logger:
92
- target: main.ImageLogger
93
- params:
94
- enable_autocast: False
95
- batch_frequency: 1000
96
- max_images: 8
97
- increase_log_steps: True
98
-
99
- trainer:
100
- devices: 0,
101
- limit_val_batches: 50
102
- benchmark: True
103
- accumulate_grad_batches: 1
104
- val_check_interval: 10000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml DELETED
@@ -1,105 +0,0 @@
1
- model:
2
- base_learning_rate: 4.5e-6
3
- target: sgm.models.autoencoder.AutoencodingEngine
4
- params:
5
- input_key: jpg
6
- monitor: val/loss/rec
7
- disc_start_iter: 0
8
-
9
- encoder_config:
10
- target: sgm.modules.diffusionmodules.model.Encoder
11
- params:
12
- attn_type: vanilla-xformers
13
- double_z: true
14
- z_channels: 8
15
- resolution: 256
16
- in_channels: 3
17
- out_ch: 3
18
- ch: 128
19
- ch_mult: [1, 2, 4, 4]
20
- num_res_blocks: 2
21
- attn_resolutions: []
22
- dropout: 0.0
23
-
24
- decoder_config:
25
- target: sgm.modules.diffusionmodules.model.Decoder
26
- params: ${model.params.encoder_config.params}
27
-
28
- regularizer_config:
29
- target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
30
-
31
- loss_config:
32
- target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
33
- params:
34
- perceptual_weight: 0.25
35
- disc_start: 20001
36
- disc_weight: 0.5
37
- learn_logvar: True
38
-
39
- regularization_weights:
40
- kl_loss: 1.0
41
-
42
- data:
43
- target: sgm.data.dataset.StableDataModuleFromConfig
44
- params:
45
- train:
46
- datapipeline:
47
- urls:
48
- - DATA-PATH
49
- pipeline_config:
50
- shardshuffle: 10000
51
- sample_shuffle: 10000
52
-
53
- decoders:
54
- - pil
55
-
56
- postprocessors:
57
- - target: sdata.mappers.TorchVisionImageTransforms
58
- params:
59
- key: jpg
60
- transforms:
61
- - target: torchvision.transforms.Resize
62
- params:
63
- size: 256
64
- interpolation: 3
65
- - target: torchvision.transforms.ToTensor
66
- - target: sdata.mappers.Rescaler
67
- - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
68
- params:
69
- h_key: height
70
- w_key: width
71
-
72
- loader:
73
- batch_size: 8
74
- num_workers: 4
75
-
76
-
77
- lightning:
78
- strategy:
79
- target: pytorch_lightning.strategies.DDPStrategy
80
- params:
81
- find_unused_parameters: True
82
-
83
- modelcheckpoint:
84
- params:
85
- every_n_train_steps: 5000
86
-
87
- callbacks:
88
- metrics_over_trainsteps_checkpoint:
89
- params:
90
- every_n_train_steps: 50000
91
-
92
- image_logger:
93
- target: main.ImageLogger
94
- params:
95
- enable_autocast: False
96
- batch_frequency: 1000
97
- max_images: 8
98
- increase_log_steps: True
99
-
100
- trainer:
101
- devices: 0,
102
- limit_val_batches: 50
103
- benchmark: True
104
- accumulate_grad_batches: 1
105
- val_check_interval: 10000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/example_training/imagenet-f8_cond.yaml DELETED
@@ -1,185 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: sgm.models.diffusion.DiffusionEngine
4
- params:
5
- scale_factor: 0.13025
6
- disable_first_stage_autocast: True
7
- log_keys:
8
- - cls
9
-
10
- scheduler_config:
11
- target: sgm.lr_scheduler.LambdaLinearScheduler
12
- params:
13
- warm_up_steps: [10000]
14
- cycle_lengths: [10000000000000]
15
- f_start: [1.e-6]
16
- f_max: [1.]
17
- f_min: [1.]
18
-
19
- denoiser_config:
20
- target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
- params:
22
- num_idx: 1000
23
-
24
- scaling_config:
25
- target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
- discretization_config:
27
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
-
29
- network_config:
30
- target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
- params:
32
- use_checkpoint: True
33
- in_channels: 4
34
- out_channels: 4
35
- model_channels: 256
36
- attention_resolutions: [1, 2, 4]
37
- num_res_blocks: 2
38
- channel_mult: [1, 2, 4]
39
- num_head_channels: 64
40
- num_classes: sequential
41
- adm_in_channels: 1024
42
- transformer_depth: 1
43
- context_dim: 1024
44
- spatial_transformer_attn_type: softmax-xformers
45
-
46
- conditioner_config:
47
- target: sgm.modules.GeneralConditioner
48
- params:
49
- emb_models:
50
- - is_trainable: True
51
- input_key: cls
52
- ucg_rate: 0.2
53
- target: sgm.modules.encoders.modules.ClassEmbedder
54
- params:
55
- add_sequence_dim: True
56
- embed_dim: 1024
57
- n_classes: 1000
58
-
59
- - is_trainable: False
60
- ucg_rate: 0.2
61
- input_key: original_size_as_tuple
62
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
- params:
64
- outdim: 256
65
-
66
- - is_trainable: False
67
- input_key: crop_coords_top_left
68
- ucg_rate: 0.2
69
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
- params:
71
- outdim: 256
72
-
73
- first_stage_config:
74
- target: sgm.models.autoencoder.AutoencoderKL
75
- params:
76
- ckpt_path: CKPT_PATH
77
- embed_dim: 4
78
- monitor: val/rec_loss
79
- ddconfig:
80
- attn_type: vanilla-xformers
81
- double_z: true
82
- z_channels: 4
83
- resolution: 256
84
- in_channels: 3
85
- out_ch: 3
86
- ch: 128
87
- ch_mult: [1, 2, 4, 4]
88
- num_res_blocks: 2
89
- attn_resolutions: []
90
- dropout: 0.0
91
- lossconfig:
92
- target: torch.nn.Identity
93
-
94
- loss_fn_config:
95
- target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
- params:
97
- loss_weighting_config:
98
- target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
99
- sigma_sampler_config:
100
- target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
101
- params:
102
- num_idx: 1000
103
-
104
- discretization_config:
105
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
106
-
107
- sampler_config:
108
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
109
- params:
110
- num_steps: 50
111
-
112
- discretization_config:
113
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
114
-
115
- guider_config:
116
- target: sgm.modules.diffusionmodules.guiders.VanillaCFG
117
- params:
118
- scale: 5.0
119
-
120
- data:
121
- target: sgm.data.dataset.StableDataModuleFromConfig
122
- params:
123
- train:
124
- datapipeline:
125
- urls:
126
- # USER: adapt this path the root of your custom dataset
127
- - DATA_PATH
128
- pipeline_config:
129
- shardshuffle: 10000
130
- sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
131
-
132
- decoders:
133
- - pil
134
-
135
- postprocessors:
136
- - target: sdata.mappers.TorchVisionImageTransforms
137
- params:
138
- key: jpg # USER: you might wanna adapt this for your custom dataset
139
- transforms:
140
- - target: torchvision.transforms.Resize
141
- params:
142
- size: 256
143
- interpolation: 3
144
- - target: torchvision.transforms.ToTensor
145
- - target: sdata.mappers.Rescaler
146
-
147
- - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
148
- params:
149
- h_key: height # USER: you might wanna adapt this for your custom dataset
150
- w_key: width # USER: you might wanna adapt this for your custom dataset
151
-
152
- loader:
153
- batch_size: 64
154
- num_workers: 6
155
-
156
- lightning:
157
- modelcheckpoint:
158
- params:
159
- every_n_train_steps: 5000
160
-
161
- callbacks:
162
- metrics_over_trainsteps_checkpoint:
163
- params:
164
- every_n_train_steps: 25000
165
-
166
- image_logger:
167
- target: main.ImageLogger
168
- params:
169
- disabled: False
170
- enable_autocast: False
171
- batch_frequency: 1000
172
- max_images: 8
173
- increase_log_steps: True
174
- log_first_step: False
175
- log_images_kwargs:
176
- use_ema_scope: False
177
- N: 8
178
- n_rows: 2
179
-
180
- trainer:
181
- devices: 0,
182
- benchmark: True
183
- num_sanity_val_steps: 0
184
- accumulate_grad_batches: 1
185
- max_epochs: 1000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/example_training/toy/cifar10_cond.yaml DELETED
@@ -1,98 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: sgm.models.diffusion.DiffusionEngine
4
- params:
5
- denoiser_config:
6
- target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
- params:
8
- scaling_config:
9
- target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
- params:
11
- sigma_data: 1.0
12
-
13
- network_config:
14
- target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
- params:
16
- in_channels: 3
17
- out_channels: 3
18
- model_channels: 32
19
- attention_resolutions: []
20
- num_res_blocks: 4
21
- channel_mult: [1, 2, 2]
22
- num_head_channels: 32
23
- num_classes: sequential
24
- adm_in_channels: 128
25
-
26
- conditioner_config:
27
- target: sgm.modules.GeneralConditioner
28
- params:
29
- emb_models:
30
- - is_trainable: True
31
- input_key: cls
32
- ucg_rate: 0.2
33
- target: sgm.modules.encoders.modules.ClassEmbedder
34
- params:
35
- embed_dim: 128
36
- n_classes: 10
37
-
38
- first_stage_config:
39
- target: sgm.models.autoencoder.IdentityFirstStage
40
-
41
- loss_fn_config:
42
- target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
- params:
44
- loss_weighting_config:
45
- target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
46
- params:
47
- sigma_data: 1.0
48
- sigma_sampler_config:
49
- target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
50
-
51
- sampler_config:
52
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
53
- params:
54
- num_steps: 50
55
-
56
- discretization_config:
57
- target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
58
-
59
- guider_config:
60
- target: sgm.modules.diffusionmodules.guiders.VanillaCFG
61
- params:
62
- scale: 3.0
63
-
64
- data:
65
- target: sgm.data.cifar10.CIFAR10Loader
66
- params:
67
- batch_size: 512
68
- num_workers: 1
69
-
70
- lightning:
71
- modelcheckpoint:
72
- params:
73
- every_n_train_steps: 5000
74
-
75
- callbacks:
76
- metrics_over_trainsteps_checkpoint:
77
- params:
78
- every_n_train_steps: 25000
79
-
80
- image_logger:
81
- target: main.ImageLogger
82
- params:
83
- disabled: False
84
- batch_frequency: 1000
85
- max_images: 64
86
- increase_log_steps: True
87
- log_first_step: False
88
- log_images_kwargs:
89
- use_ema_scope: False
90
- N: 64
91
- n_rows: 8
92
-
93
- trainer:
94
- devices: 0,
95
- benchmark: True
96
- num_sanity_val_steps: 0
97
- accumulate_grad_batches: 1
98
- max_epochs: 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/example_training/toy/mnist.yaml DELETED
@@ -1,79 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: sgm.models.diffusion.DiffusionEngine
4
- params:
5
- denoiser_config:
6
- target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
- params:
8
- scaling_config:
9
- target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
- params:
11
- sigma_data: 1.0
12
-
13
- network_config:
14
- target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
- params:
16
- in_channels: 1
17
- out_channels: 1
18
- model_channels: 32
19
- attention_resolutions: []
20
- num_res_blocks: 4
21
- channel_mult: [1, 2, 2]
22
- num_head_channels: 32
23
-
24
- first_stage_config:
25
- target: sgm.models.autoencoder.IdentityFirstStage
26
-
27
- loss_fn_config:
28
- target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
29
- params:
30
- loss_weighting_config:
31
- target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
32
- params:
33
- sigma_data: 1.0
34
- sigma_sampler_config:
35
- target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
36
-
37
- sampler_config:
38
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
39
- params:
40
- num_steps: 50
41
-
42
- discretization_config:
43
- target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
44
-
45
- data:
46
- target: sgm.data.mnist.MNISTLoader
47
- params:
48
- batch_size: 512
49
- num_workers: 1
50
-
51
- lightning:
52
- modelcheckpoint:
53
- params:
54
- every_n_train_steps: 5000
55
-
56
- callbacks:
57
- metrics_over_trainsteps_checkpoint:
58
- params:
59
- every_n_train_steps: 25000
60
-
61
- image_logger:
62
- target: main.ImageLogger
63
- params:
64
- disabled: False
65
- batch_frequency: 1000
66
- max_images: 64
67
- increase_log_steps: False
68
- log_first_step: False
69
- log_images_kwargs:
70
- use_ema_scope: False
71
- N: 64
72
- n_rows: 8
73
-
74
- trainer:
75
- devices: 0,
76
- benchmark: True
77
- num_sanity_val_steps: 0
78
- accumulate_grad_batches: 1
79
- max_epochs: 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/example_training/toy/mnist_cond.yaml DELETED
@@ -1,98 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: sgm.models.diffusion.DiffusionEngine
4
- params:
5
- denoiser_config:
6
- target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
- params:
8
- scaling_config:
9
- target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
- params:
11
- sigma_data: 1.0
12
-
13
- network_config:
14
- target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
- params:
16
- in_channels: 1
17
- out_channels: 1
18
- model_channels: 32
19
- attention_resolutions: []
20
- num_res_blocks: 4
21
- channel_mult: [1, 2, 2]
22
- num_head_channels: 32
23
- num_classes: sequential
24
- adm_in_channels: 128
25
-
26
- conditioner_config:
27
- target: sgm.modules.GeneralConditioner
28
- params:
29
- emb_models:
30
- - is_trainable: True
31
- input_key: cls
32
- ucg_rate: 0.2
33
- target: sgm.modules.encoders.modules.ClassEmbedder
34
- params:
35
- embed_dim: 128
36
- n_classes: 10
37
-
38
- first_stage_config:
39
- target: sgm.models.autoencoder.IdentityFirstStage
40
-
41
- loss_fn_config:
42
- target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
- params:
44
- loss_weighting_config:
45
- target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
46
- params:
47
- sigma_data: 1.0
48
- sigma_sampler_config:
49
- target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
50
-
51
- sampler_config:
52
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
53
- params:
54
- num_steps: 50
55
-
56
- discretization_config:
57
- target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
58
-
59
- guider_config:
60
- target: sgm.modules.diffusionmodules.guiders.VanillaCFG
61
- params:
62
- scale: 3.0
63
-
64
- data:
65
- target: sgm.data.mnist.MNISTLoader
66
- params:
67
- batch_size: 512
68
- num_workers: 1
69
-
70
- lightning:
71
- modelcheckpoint:
72
- params:
73
- every_n_train_steps: 5000
74
-
75
- callbacks:
76
- metrics_over_trainsteps_checkpoint:
77
- params:
78
- every_n_train_steps: 25000
79
-
80
- image_logger:
81
- target: main.ImageLogger
82
- params:
83
- disabled: False
84
- batch_frequency: 1000
85
- max_images: 16
86
- increase_log_steps: True
87
- log_first_step: False
88
- log_images_kwargs:
89
- use_ema_scope: False
90
- N: 16
91
- n_rows: 4
92
-
93
- trainer:
94
- devices: 0,
95
- benchmark: True
96
- num_sanity_val_steps: 0
97
- accumulate_grad_batches: 1
98
- max_epochs: 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/example_training/toy/mnist_cond_discrete_eps.yaml DELETED
@@ -1,103 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: sgm.models.diffusion.DiffusionEngine
4
- params:
5
- denoiser_config:
6
- target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
7
- params:
8
- num_idx: 1000
9
-
10
- scaling_config:
11
- target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
12
- discretization_config:
13
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
14
-
15
- network_config:
16
- target: sgm.modules.diffusionmodules.openaimodel.UNetModel
17
- params:
18
- in_channels: 1
19
- out_channels: 1
20
- model_channels: 32
21
- attention_resolutions: []
22
- num_res_blocks: 4
23
- channel_mult: [1, 2, 2]
24
- num_head_channels: 32
25
- num_classes: sequential
26
- adm_in_channels: 128
27
-
28
- conditioner_config:
29
- target: sgm.modules.GeneralConditioner
30
- params:
31
- emb_models:
32
- - is_trainable: True
33
- input_key: cls
34
- ucg_rate: 0.2
35
- target: sgm.modules.encoders.modules.ClassEmbedder
36
- params:
37
- embed_dim: 128
38
- n_classes: 10
39
-
40
- first_stage_config:
41
- target: sgm.models.autoencoder.IdentityFirstStage
42
-
43
- loss_fn_config:
44
- target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
45
- params:
46
- loss_weighting_config:
47
- target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
48
- sigma_sampler_config:
49
- target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
50
- params:
51
- num_idx: 1000
52
-
53
- discretization_config:
54
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
55
-
56
- sampler_config:
57
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
58
- params:
59
- num_steps: 50
60
-
61
- discretization_config:
62
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
63
-
64
- guider_config:
65
- target: sgm.modules.diffusionmodules.guiders.VanillaCFG
66
- params:
67
- scale: 5.0
68
-
69
- data:
70
- target: sgm.data.mnist.MNISTLoader
71
- params:
72
- batch_size: 512
73
- num_workers: 1
74
-
75
- lightning:
76
- modelcheckpoint:
77
- params:
78
- every_n_train_steps: 5000
79
-
80
- callbacks:
81
- metrics_over_trainsteps_checkpoint:
82
- params:
83
- every_n_train_steps: 25000
84
-
85
- image_logger:
86
- target: main.ImageLogger
87
- params:
88
- disabled: False
89
- batch_frequency: 1000
90
- max_images: 16
91
- increase_log_steps: True
92
- log_first_step: False
93
- log_images_kwargs:
94
- use_ema_scope: False
95
- N: 16
96
- n_rows: 4
97
-
98
- trainer:
99
- devices: 0,
100
- benchmark: True
101
- num_sanity_val_steps: 0
102
- accumulate_grad_batches: 1
103
- max_epochs: 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/example_training/toy/mnist_cond_l1_loss.yaml DELETED
@@ -1,99 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: sgm.models.diffusion.DiffusionEngine
4
- params:
5
- denoiser_config:
6
- target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
- params:
8
- scaling_config:
9
- target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
- params:
11
- sigma_data: 1.0
12
-
13
- network_config:
14
- target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
- params:
16
- in_channels: 1
17
- out_channels: 1
18
- model_channels: 32
19
- attention_resolutions: []
20
- num_res_blocks: 4
21
- channel_mult: [1, 2, 2]
22
- num_head_channels: 32
23
- num_classes: sequential
24
- adm_in_channels: 128
25
-
26
- conditioner_config:
27
- target: sgm.modules.GeneralConditioner
28
- params:
29
- emb_models:
30
- - is_trainable: True
31
- input_key: cls
32
- ucg_rate: 0.2
33
- target: sgm.modules.encoders.modules.ClassEmbedder
34
- params:
35
- embed_dim: 128
36
- n_classes: 10
37
-
38
- first_stage_config:
39
- target: sgm.models.autoencoder.IdentityFirstStage
40
-
41
- loss_fn_config:
42
- target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
- params:
44
- loss_type: l1
45
- loss_weighting_config:
46
- target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
47
- params:
48
- sigma_data: 1.0
49
- sigma_sampler_config:
50
- target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
51
-
52
- sampler_config:
53
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
54
- params:
55
- num_steps: 50
56
-
57
- discretization_config:
58
- target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
59
-
60
- guider_config:
61
- target: sgm.modules.diffusionmodules.guiders.VanillaCFG
62
- params:
63
- scale: 3.0
64
-
65
- data:
66
- target: sgm.data.mnist.MNISTLoader
67
- params:
68
- batch_size: 512
69
- num_workers: 1
70
-
71
- lightning:
72
- modelcheckpoint:
73
- params:
74
- every_n_train_steps: 5000
75
-
76
- callbacks:
77
- metrics_over_trainsteps_checkpoint:
78
- params:
79
- every_n_train_steps: 25000
80
-
81
- image_logger:
82
- target: main.ImageLogger
83
- params:
84
- disabled: False
85
- batch_frequency: 1000
86
- max_images: 64
87
- increase_log_steps: True
88
- log_first_step: False
89
- log_images_kwargs:
90
- use_ema_scope: False
91
- N: 64
92
- n_rows: 8
93
-
94
- trainer:
95
- devices: 0,
96
- benchmark: True
97
- num_sanity_val_steps: 0
98
- accumulate_grad_batches: 1
99
- max_epochs: 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/example_training/toy/mnist_cond_with_ema.yaml DELETED
@@ -1,100 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: sgm.models.diffusion.DiffusionEngine
4
- params:
5
- use_ema: True
6
-
7
- denoiser_config:
8
- target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
- params:
10
- scaling_config:
11
- target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
12
- params:
13
- sigma_data: 1.0
14
-
15
- network_config:
16
- target: sgm.modules.diffusionmodules.openaimodel.UNetModel
17
- params:
18
- in_channels: 1
19
- out_channels: 1
20
- model_channels: 32
21
- attention_resolutions: []
22
- num_res_blocks: 4
23
- channel_mult: [1, 2, 2]
24
- num_head_channels: 32
25
- num_classes: sequential
26
- adm_in_channels: 128
27
-
28
- conditioner_config:
29
- target: sgm.modules.GeneralConditioner
30
- params:
31
- emb_models:
32
- - is_trainable: True
33
- input_key: cls
34
- ucg_rate: 0.2
35
- target: sgm.modules.encoders.modules.ClassEmbedder
36
- params:
37
- embed_dim: 128
38
- n_classes: 10
39
-
40
- first_stage_config:
41
- target: sgm.models.autoencoder.IdentityFirstStage
42
-
43
- loss_fn_config:
44
- target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
45
- params:
46
- loss_weighting_config:
47
- target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
48
- params:
49
- sigma_data: 1.0
50
- sigma_sampler_config:
51
- target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
52
-
53
- sampler_config:
54
- target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
55
- params:
56
- num_steps: 50
57
-
58
- discretization_config:
59
- target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
60
-
61
- guider_config:
62
- target: sgm.modules.diffusionmodules.guiders.VanillaCFG
63
- params:
64
- scale: 3.0
65
-
66
- data:
67
- target: sgm.data.mnist.MNISTLoader
68
- params:
69
- batch_size: 512
70
- num_workers: 1
71
-
72
- lightning:
73
- modelcheckpoint:
74
- params:
75
- every_n_train_steps: 5000
76
-
77
- callbacks:
78
- metrics_over_trainsteps_checkpoint:
79
- params:
80
- every_n_train_steps: 25000
81
-
82
- image_logger:
83
- target: main.ImageLogger
84
- params:
85
- disabled: False
86
- batch_frequency: 1000
87
- max_images: 64
88
- increase_log_steps: True
89
- log_first_step: False
90
- log_images_kwargs:
91
- use_ema_scope: False
92
- N: 64
93
- n_rows: 8
94
-
95
- trainer:
96
- devices: 0,
97
- benchmark: True
98
- num_sanity_val_steps: 0
99
- accumulate_grad_batches: 1
100
- max_epochs: 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/example_training/txt2img-clipl-legacy-ucg-training.yaml DELETED
@@ -1,182 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: sgm.models.diffusion.DiffusionEngine
4
- params:
5
- scale_factor: 0.13025
6
- disable_first_stage_autocast: True
7
- log_keys:
8
- - txt
9
-
10
- scheduler_config:
11
- target: sgm.lr_scheduler.LambdaLinearScheduler
12
- params:
13
- warm_up_steps: [10000]
14
- cycle_lengths: [10000000000000]
15
- f_start: [1.e-6]
16
- f_max: [1.]
17
- f_min: [1.]
18
-
19
- denoiser_config:
20
- target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
- params:
22
- num_idx: 1000
23
-
24
- scaling_config:
25
- target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
- discretization_config:
27
- target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
-
29
- network_config:
30
- target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
- params:
32
- use_checkpoint: True
33
- in_channels: 4
34
- out_channels: 4
35
- model_channels: 320
36
- attention_resolutions: [1, 2, 4]
37
- num_res_blocks: 2
38
- channel_mult: [1, 2, 4, 4]
39
- num_head_channels: 64
40
- num_classes: sequential
41
- adm_in_channels: 1792
42
- num_heads: 1
43
- transformer_depth: 1
44
- context_dim: 768
45
- spatial_transformer_attn_type: softmax-xformers
46
-
47
- conditioner_config:
48
- target: sgm.modules.GeneralConditioner
49
- params:
50
- emb_models:
51
- - is_trainable: True
52
- input_key: txt
53
- ucg_rate: 0.1
54
- legacy_ucg_value: ""
55
- target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
56
- params:
57
- always_return_pooled: True
58
-
59
- - is_trainable: False
60
- ucg_rate: 0.1
61
- input_key: original_size_as_tuple
62
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
- params:
64
- outdim: 256
65
-
66
- - is_trainable: False
67
- input_key: crop_coords_top_left
68
- ucg_rate: 0.1
69
- target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
- params:
71
- outdim: 256
72
-
73
- first_stage_config:
74
- target: sgm.models.autoencoder.AutoencoderKL
75
- params:
76
- ckpt_path: CKPT_PATH
77
- embed_dim: 4
78
- monitor: val/rec_loss
79
- ddconfig:
80
- attn_type: vanilla-xformers
81
- double_z: true
82
- z_channels: 4
83
- resolution: 256
84
- in_channels: 3
85
- out_ch: 3
86
- ch: 128
87
- ch_mult: [ 1, 2, 4, 4 ]
88
- num_res_blocks: 2
89
- attn_resolutions: [ ]
90
- dropout: 0.0
91
- lossconfig:
92
- target: torch.nn.Identity
93
-
94
- loss_fn_config:
95
- target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
- params:
97
-