Spaces:
Running
on
Zero
diffusers-backend (#6)
Browse files- Delete assets (537ebb9bdae4dc86cc0b00b854f1eae9b32cd7bb)
- Delete configs (adcff078e9123ce669443f2e1d28516c05b7c455)
- Delete data (79751902793a42fb8f0329a80fd132d7d6f85ec8)
- Delete model_licenses (67be9bc7732e83a8b4dde21fa0311b8c2f787185)
- Delete requirements (c1a4c611db4c97cd499e036fa3111c785cd9f4ba)
- Delete scripts (90e5afe600048480e16ac285d298378b9c3ac9b1)
- Delete sgm (097567a9522fbd31df1e8f4ec7a360c02947f0d3)
- Delete tests (fa8cc685cd27c368611079175724d895c067742c)
- Delete .gitattributes (4d403cfb820b8f14ef17e6b7e97d16a8a4fa850d)
- Delete CODEOWNERS (5ff14e288cc2336b802e82a12f2f5a9b42bb4acf)
- Delete LICENSE-CODE (ef217859b8e2bdc2bcd8c757020356e51a1e1254)
- Delete main.py (422da78c99f88a7edbeb6e277a3c8cb0bb46e691)
- Delete pyproject.toml (9cb3ce47320995317454fb907a5a3ef9e47e8049)
- Delete pytest.ini (f3f18e3224b316f8894e505b90f05f05aeaad782)
- Delete simple_video_sample.py (a7729a11a558073e33666989f79521bcb4c59c5a)
- Update app.py (9ca6c3055abf1a8603c1a274aa9c0fdab2910f61)
- Update requirements.txt (0b4d79b2785491369787f19453b24b0d25bcd283)
- Update app.py (6e3b09a7cfa77adde712e954cebf68001facc1b4)
- Update app.py (e2530d1596d372176fe7707e1ff50d7c002dcda4)
- Update app.py (f7d4d47c674580f25d5f99f584dc415d81428cc7)
- Upload 10 files (0629c99de8869dc16cb9bfff11a6a93bb9fddb6f)
- Update app.py (2beda72ac79f28c2400a8e81dd9d7606d3c95ef9)
- .gitattributes +0 -38
- CODEOWNERS +0 -1
- LICENSE-CODE +0 -21
- app.py +45 -211
- assets/000.jpg +0 -0
- assets/001_with_eval.png +0 -3
- assets/test_image.png +0 -0
- assets/tile.gif +0 -3
- configs/.DS_Store +0 -0
- configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml +0 -104
- configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml +0 -105
- configs/example_training/imagenet-f8_cond.yaml +0 -185
- configs/example_training/toy/cifar10_cond.yaml +0 -98
- configs/example_training/toy/mnist.yaml +0 -79
- configs/example_training/toy/mnist_cond.yaml +0 -98
- configs/example_training/toy/mnist_cond_discrete_eps.yaml +0 -103
- configs/example_training/toy/mnist_cond_l1_loss.yaml +0 -99
- configs/example_training/toy/mnist_cond_with_ema.yaml +0 -100
- configs/example_training/txt2img-clipl-legacy-ucg-training.yaml +0 -182
- configs/example_training/txt2img-clipl.yaml +0 -184
- configs/inference/sd_2_1.yaml +0 -60
- configs/inference/sd_2_1_768.yaml +0 -60
- configs/inference/sd_xl_base.yaml +0 -93
- configs/inference/sd_xl_refiner.yaml +0 -86
- configs/inference/svd.yaml +0 -131
- configs/inference/svd_image_decoder.yaml +0 -114
- data/DejaVuSans.ttf +0 -0
- images/blink_meme.png +0 -0
- images/confused2_meme.png +0 -0
- images/confused_meme.png +0 -0
- images/disaster_meme.png +0 -0
- images/distracted_meme.png +0 -0
- images/hide_meme.png +0 -0
- images/nazare_meme.png +0 -0
- images/success_meme.png +0 -0
- images/willy_meme.png +0 -0
- images/wink_meme.png +0 -0
- main.py +0 -943
- model_licenses/LICENSE-SDV +0 -31
- model_licenses/LICENSE-SDXL0.9 +0 -75
- model_licenses/LICENSE-SDXL1.0 +0 -175
- pyproject.toml +0 -48
- pytest.ini +0 -3
- requirements.txt +5 -40
- requirements/pt13.txt +0 -40
- requirements/pt2.txt +0 -39
- scripts/.DS_Store +0 -0
- scripts/__init__.py +0 -0
- scripts/demo/__init__.py +0 -0
- scripts/demo/detect.py +0 -156
@@ -1,38 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
assets/001_with_eval.png filter=lfs diff=lfs merge=lfs -text
|
37 |
-
assets/tile.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
-
outputs/000004.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1 +0,0 @@
|
|
1 |
-
.github @Stability-AI/infrastructure
|
|
|
|
@@ -1,21 +0,0 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
Copyright (c) 2023 Stability AI
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,243 +1,59 @@
|
|
1 |
-
import
|
|
|
|
|
2 |
import os
|
3 |
from glob import glob
|
4 |
from pathlib import Path
|
5 |
from typing import Optional
|
6 |
|
7 |
-
import
|
8 |
-
import
|
9 |
-
import torch
|
10 |
-
from einops import rearrange, repeat
|
11 |
-
from fire import Fire
|
12 |
-
from omegaconf import OmegaConf
|
13 |
from PIL import Image
|
14 |
-
from torchvision.transforms import ToTensor
|
15 |
|
16 |
-
from scripts.util.detection.nsfw_and_watermark_dectection import \
|
17 |
-
DeepFloydDataFiltering
|
18 |
-
from sgm.inference.helpers import embed_watermark
|
19 |
-
from sgm.util import default, instantiate_from_config
|
20 |
-
|
21 |
-
import gradio as gr
|
22 |
import uuid
|
23 |
import random
|
24 |
from huggingface_hub import hf_hub_download
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
version = "svd_xt"
|
29 |
-
device = "cuda"
|
30 |
-
max_64_bit_int = 2**63 - 1
|
31 |
-
|
32 |
-
def load_model(
|
33 |
-
config: str,
|
34 |
-
device: str,
|
35 |
-
num_frames: int,
|
36 |
-
num_steps: int,
|
37 |
-
):
|
38 |
-
config = OmegaConf.load(config)
|
39 |
-
if device == "cuda":
|
40 |
-
config.model.params.conditioner_config.params.emb_models[
|
41 |
-
0
|
42 |
-
].params.open_clip_embedding_config.params.init_device = device
|
43 |
-
|
44 |
-
config.model.params.sampler_config.params.num_steps = num_steps
|
45 |
-
config.model.params.sampler_config.params.guider_config.params.num_frames = (
|
46 |
-
num_frames
|
47 |
-
)
|
48 |
-
if device == "cuda":
|
49 |
-
with torch.device(device):
|
50 |
-
model = instantiate_from_config(config.model).to(device).eval()
|
51 |
-
else:
|
52 |
-
model = instantiate_from_config(config.model).to(device).eval()
|
53 |
-
|
54 |
-
filter = DeepFloydDataFiltering(verbose=False, device=device)
|
55 |
-
return model, filter
|
56 |
-
|
57 |
-
if version == "svd_xt":
|
58 |
-
num_frames = 25
|
59 |
-
num_steps = 30
|
60 |
-
model_config = "scripts/sampling/configs/svd_xt.yaml"
|
61 |
-
else:
|
62 |
-
raise ValueError(f"Version {version} does not exist.")
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
device,
|
67 |
-
num_frames,
|
68 |
-
num_steps,
|
69 |
)
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
def sample(
|
72 |
image: Image,
|
73 |
-
seed: Optional[int] =
|
74 |
randomize_seed: bool = True,
|
75 |
motion_bucket_id: int = 127,
|
76 |
fps_id: int = 6,
|
77 |
version: str = "svd_xt",
|
78 |
cond_aug: float = 0.02,
|
79 |
-
decoding_t: int =
|
80 |
device: str = "cuda",
|
81 |
output_folder: str = "outputs",
|
82 |
-
progress=gr.Progress(track_tqdm=True)
|
83 |
):
|
|
|
|
|
|
|
84 |
if(randomize_seed):
|
85 |
seed = random.randint(0, max_64_bit_int)
|
86 |
-
|
87 |
-
torch.manual_seed(seed)
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
|
98 |
-
)
|
99 |
-
|
100 |
-
image = ToTensor()(image)
|
101 |
-
image = image * 2.0 - 1.0
|
102 |
-
image = image.unsqueeze(0).to(device)
|
103 |
-
H, W = image.shape[2:]
|
104 |
-
assert image.shape[1] == 3
|
105 |
-
F = 8
|
106 |
-
C = 4
|
107 |
-
shape = (num_frames, C, H // F, W // F)
|
108 |
-
if (H, W) != (576, 1024):
|
109 |
-
print(
|
110 |
-
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
|
111 |
-
)
|
112 |
-
if motion_bucket_id > 255:
|
113 |
-
print(
|
114 |
-
"WARNING: High motion bucket! This may lead to suboptimal performance."
|
115 |
-
)
|
116 |
-
|
117 |
-
if fps_id < 5:
|
118 |
-
print("WARNING: Small fps value! This may lead to suboptimal performance.")
|
119 |
-
|
120 |
-
if fps_id > 30:
|
121 |
-
print("WARNING: Large fps value! This may lead to suboptimal performance.")
|
122 |
-
|
123 |
-
value_dict = {}
|
124 |
-
value_dict["motion_bucket_id"] = motion_bucket_id
|
125 |
-
value_dict["fps_id"] = fps_id
|
126 |
-
value_dict["cond_aug"] = cond_aug
|
127 |
-
value_dict["cond_frames_without_noise"] = image
|
128 |
-
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
|
129 |
-
value_dict["cond_aug"] = cond_aug
|
130 |
-
|
131 |
-
with torch.no_grad():
|
132 |
-
with torch.autocast(device):
|
133 |
-
batch, batch_uc = get_batch(
|
134 |
-
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
135 |
-
value_dict,
|
136 |
-
[1, num_frames],
|
137 |
-
T=num_frames,
|
138 |
-
device=device,
|
139 |
-
)
|
140 |
-
c, uc = model.conditioner.get_unconditional_conditioning(
|
141 |
-
batch,
|
142 |
-
batch_uc=batch_uc,
|
143 |
-
force_uc_zero_embeddings=[
|
144 |
-
"cond_frames",
|
145 |
-
"cond_frames_without_noise",
|
146 |
-
],
|
147 |
-
)
|
148 |
-
|
149 |
-
for k in ["crossattn", "concat"]:
|
150 |
-
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
151 |
-
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
152 |
-
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
153 |
-
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
154 |
-
|
155 |
-
randn = torch.randn(shape, device=device)
|
156 |
-
|
157 |
-
additional_model_inputs = {}
|
158 |
-
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
159 |
-
2, num_frames
|
160 |
-
).to(device)
|
161 |
-
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
162 |
-
|
163 |
-
def denoiser(input, sigma, c):
|
164 |
-
return model.denoiser(
|
165 |
-
model.model, input, sigma, c, **additional_model_inputs
|
166 |
-
)
|
167 |
-
|
168 |
-
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
|
169 |
-
model.en_and_decode_n_samples_a_time = decoding_t
|
170 |
-
samples_x = model.decode_first_stage(samples_z)
|
171 |
-
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
172 |
-
|
173 |
-
os.makedirs(output_folder, exist_ok=True)
|
174 |
-
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
175 |
-
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
176 |
-
writer = cv2.VideoWriter(
|
177 |
-
video_path,
|
178 |
-
cv2.VideoWriter_fourcc(*"mp4v"),
|
179 |
-
fps_id + 1,
|
180 |
-
(samples.shape[-1], samples.shape[-2]),
|
181 |
-
)
|
182 |
-
|
183 |
-
samples = embed_watermark(samples)
|
184 |
-
samples = filter(samples)
|
185 |
-
vid = (
|
186 |
-
(rearrange(samples, "t c h w -> t h w c") * 255)
|
187 |
-
.cpu()
|
188 |
-
.numpy()
|
189 |
-
.astype(np.uint8)
|
190 |
-
)
|
191 |
-
for frame in vid:
|
192 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
193 |
-
writer.write(frame)
|
194 |
-
writer.release()
|
195 |
return video_path, seed
|
196 |
|
197 |
-
def get_unique_embedder_keys_from_conditioner(conditioner):
|
198 |
-
return list(set([x.input_key for x in conditioner.embedders]))
|
199 |
-
|
200 |
-
|
201 |
-
def get_batch(keys, value_dict, N, T, device):
|
202 |
-
batch = {}
|
203 |
-
batch_uc = {}
|
204 |
-
|
205 |
-
for key in keys:
|
206 |
-
if key == "fps_id":
|
207 |
-
batch[key] = (
|
208 |
-
torch.tensor([value_dict["fps_id"]])
|
209 |
-
.to(device)
|
210 |
-
.repeat(int(math.prod(N)))
|
211 |
-
)
|
212 |
-
elif key == "motion_bucket_id":
|
213 |
-
batch[key] = (
|
214 |
-
torch.tensor([value_dict["motion_bucket_id"]])
|
215 |
-
.to(device)
|
216 |
-
.repeat(int(math.prod(N)))
|
217 |
-
)
|
218 |
-
elif key == "cond_aug":
|
219 |
-
batch[key] = repeat(
|
220 |
-
torch.tensor([value_dict["cond_aug"]]).to(device),
|
221 |
-
"1 -> b",
|
222 |
-
b=math.prod(N),
|
223 |
-
)
|
224 |
-
elif key == "cond_frames":
|
225 |
-
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
226 |
-
elif key == "cond_frames_without_noise":
|
227 |
-
batch[key] = repeat(
|
228 |
-
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
229 |
-
)
|
230 |
-
else:
|
231 |
-
batch[key] = value_dict[key]
|
232 |
-
|
233 |
-
if T is not None:
|
234 |
-
batch["num_video_frames"] = T
|
235 |
-
|
236 |
-
for key in batch.keys():
|
237 |
-
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
238 |
-
batch_uc[key] = torch.clone(batch[key])
|
239 |
-
return batch, batch_uc
|
240 |
-
|
241 |
def resize_image(image, output_size=(1024, 576)):
|
242 |
# Calculate aspect ratios
|
243 |
target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
|
@@ -286,7 +102,25 @@ with gr.Blocks() as demo:
|
|
286 |
|
287 |
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
|
288 |
generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")
|
289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
if __name__ == "__main__":
|
291 |
demo.queue(max_size=20)
|
292 |
demo.launch(share=True)
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import gradio.helpers
|
3 |
+
import torch
|
4 |
import os
|
5 |
from glob import glob
|
6 |
from pathlib import Path
|
7 |
from typing import Optional
|
8 |
|
9 |
+
from diffusers import StableVideoDiffusionPipeline
|
10 |
+
from diffusers.utils import load_image, export_to_video
|
|
|
|
|
|
|
|
|
11 |
from PIL import Image
|
|
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
import uuid
|
14 |
import random
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
|
17 |
+
gradio.helpers.CACHED_FOLDER = '/data/cache'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained(
|
20 |
+
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
|
|
|
|
|
|
|
21 |
)
|
22 |
+
pipe.to("cuda")
|
23 |
+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
24 |
+
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
|
25 |
+
|
26 |
+
max_64_bit_int = 2**63 - 1
|
27 |
|
28 |
def sample(
|
29 |
image: Image,
|
30 |
+
seed: Optional[int] = 42,
|
31 |
randomize_seed: bool = True,
|
32 |
motion_bucket_id: int = 127,
|
33 |
fps_id: int = 6,
|
34 |
version: str = "svd_xt",
|
35 |
cond_aug: float = 0.02,
|
36 |
+
decoding_t: int = 3, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
37 |
device: str = "cuda",
|
38 |
output_folder: str = "outputs",
|
|
|
39 |
):
|
40 |
+
if image.mode == "RGBA":
|
41 |
+
image = image.convert("RGB")
|
42 |
+
|
43 |
if(randomize_seed):
|
44 |
seed = random.randint(0, max_64_bit_int)
|
45 |
+
generator = torch.manual_seed(seed)
|
|
|
46 |
|
47 |
+
os.makedirs(output_folder, exist_ok=True)
|
48 |
+
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
49 |
+
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
50 |
|
51 |
+
frames = pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=0.1).frames[0]
|
52 |
+
export_to_video(frames, video_path, fps=fps_id)
|
53 |
+
torch.manual_seed(seed)
|
54 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
return video_path, seed
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def resize_image(image, output_size=(1024, 576)):
|
58 |
# Calculate aspect ratios
|
59 |
target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
|
|
|
102 |
|
103 |
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
|
104 |
generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")
|
105 |
+
gr.Examples(
|
106 |
+
examples=[
|
107 |
+
"images/blink_meme.png",
|
108 |
+
"images/confused2_meme.png",
|
109 |
+
"images/confused_meme.png",
|
110 |
+
"images/disaster_meme.png",
|
111 |
+
"images/distracted_meme.png",
|
112 |
+
"images/hide_meme.png",
|
113 |
+
"images/nazare_meme.png",
|
114 |
+
"images/success_meme.png",
|
115 |
+
"images/willy_meme.png",
|
116 |
+
"images/wink_meme.png"
|
117 |
+
],
|
118 |
+
inputs=image,
|
119 |
+
outputs=[video, seed],
|
120 |
+
fn=sample,
|
121 |
+
cache_examples=True,
|
122 |
+
)
|
123 |
+
|
124 |
if __name__ == "__main__":
|
125 |
demo.queue(max_size=20)
|
126 |
demo.launch(share=True)
|
Binary file (728 kB)
|
|
Git LFS Details
|
Binary file (494 kB)
|
|
Git LFS Details
|
Binary file (6.15 kB)
|
|
@@ -1,104 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-6
|
3 |
-
target: sgm.models.autoencoder.AutoencodingEngine
|
4 |
-
params:
|
5 |
-
input_key: jpg
|
6 |
-
monitor: val/rec_loss
|
7 |
-
|
8 |
-
loss_config:
|
9 |
-
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
|
10 |
-
params:
|
11 |
-
perceptual_weight: 0.25
|
12 |
-
disc_start: 20001
|
13 |
-
disc_weight: 0.5
|
14 |
-
learn_logvar: True
|
15 |
-
|
16 |
-
regularization_weights:
|
17 |
-
kl_loss: 1.0
|
18 |
-
|
19 |
-
regularizer_config:
|
20 |
-
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
21 |
-
|
22 |
-
encoder_config:
|
23 |
-
target: sgm.modules.diffusionmodules.model.Encoder
|
24 |
-
params:
|
25 |
-
attn_type: none
|
26 |
-
double_z: True
|
27 |
-
z_channels: 4
|
28 |
-
resolution: 256
|
29 |
-
in_channels: 3
|
30 |
-
out_ch: 3
|
31 |
-
ch: 128
|
32 |
-
ch_mult: [1, 2, 4]
|
33 |
-
num_res_blocks: 4
|
34 |
-
attn_resolutions: []
|
35 |
-
dropout: 0.0
|
36 |
-
|
37 |
-
decoder_config:
|
38 |
-
target: sgm.modules.diffusionmodules.model.Decoder
|
39 |
-
params: ${model.params.encoder_config.params}
|
40 |
-
|
41 |
-
data:
|
42 |
-
target: sgm.data.dataset.StableDataModuleFromConfig
|
43 |
-
params:
|
44 |
-
train:
|
45 |
-
datapipeline:
|
46 |
-
urls:
|
47 |
-
- DATA-PATH
|
48 |
-
pipeline_config:
|
49 |
-
shardshuffle: 10000
|
50 |
-
sample_shuffle: 10000
|
51 |
-
|
52 |
-
decoders:
|
53 |
-
- pil
|
54 |
-
|
55 |
-
postprocessors:
|
56 |
-
- target: sdata.mappers.TorchVisionImageTransforms
|
57 |
-
params:
|
58 |
-
key: jpg
|
59 |
-
transforms:
|
60 |
-
- target: torchvision.transforms.Resize
|
61 |
-
params:
|
62 |
-
size: 256
|
63 |
-
interpolation: 3
|
64 |
-
- target: torchvision.transforms.ToTensor
|
65 |
-
- target: sdata.mappers.Rescaler
|
66 |
-
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
67 |
-
params:
|
68 |
-
h_key: height
|
69 |
-
w_key: width
|
70 |
-
|
71 |
-
loader:
|
72 |
-
batch_size: 8
|
73 |
-
num_workers: 4
|
74 |
-
|
75 |
-
|
76 |
-
lightning:
|
77 |
-
strategy:
|
78 |
-
target: pytorch_lightning.strategies.DDPStrategy
|
79 |
-
params:
|
80 |
-
find_unused_parameters: True
|
81 |
-
|
82 |
-
modelcheckpoint:
|
83 |
-
params:
|
84 |
-
every_n_train_steps: 5000
|
85 |
-
|
86 |
-
callbacks:
|
87 |
-
metrics_over_trainsteps_checkpoint:
|
88 |
-
params:
|
89 |
-
every_n_train_steps: 50000
|
90 |
-
|
91 |
-
image_logger:
|
92 |
-
target: main.ImageLogger
|
93 |
-
params:
|
94 |
-
enable_autocast: False
|
95 |
-
batch_frequency: 1000
|
96 |
-
max_images: 8
|
97 |
-
increase_log_steps: True
|
98 |
-
|
99 |
-
trainer:
|
100 |
-
devices: 0,
|
101 |
-
limit_val_batches: 50
|
102 |
-
benchmark: True
|
103 |
-
accumulate_grad_batches: 1
|
104 |
-
val_check_interval: 10000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,105 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 4.5e-6
|
3 |
-
target: sgm.models.autoencoder.AutoencodingEngine
|
4 |
-
params:
|
5 |
-
input_key: jpg
|
6 |
-
monitor: val/loss/rec
|
7 |
-
disc_start_iter: 0
|
8 |
-
|
9 |
-
encoder_config:
|
10 |
-
target: sgm.modules.diffusionmodules.model.Encoder
|
11 |
-
params:
|
12 |
-
attn_type: vanilla-xformers
|
13 |
-
double_z: true
|
14 |
-
z_channels: 8
|
15 |
-
resolution: 256
|
16 |
-
in_channels: 3
|
17 |
-
out_ch: 3
|
18 |
-
ch: 128
|
19 |
-
ch_mult: [1, 2, 4, 4]
|
20 |
-
num_res_blocks: 2
|
21 |
-
attn_resolutions: []
|
22 |
-
dropout: 0.0
|
23 |
-
|
24 |
-
decoder_config:
|
25 |
-
target: sgm.modules.diffusionmodules.model.Decoder
|
26 |
-
params: ${model.params.encoder_config.params}
|
27 |
-
|
28 |
-
regularizer_config:
|
29 |
-
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
30 |
-
|
31 |
-
loss_config:
|
32 |
-
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
|
33 |
-
params:
|
34 |
-
perceptual_weight: 0.25
|
35 |
-
disc_start: 20001
|
36 |
-
disc_weight: 0.5
|
37 |
-
learn_logvar: True
|
38 |
-
|
39 |
-
regularization_weights:
|
40 |
-
kl_loss: 1.0
|
41 |
-
|
42 |
-
data:
|
43 |
-
target: sgm.data.dataset.StableDataModuleFromConfig
|
44 |
-
params:
|
45 |
-
train:
|
46 |
-
datapipeline:
|
47 |
-
urls:
|
48 |
-
- DATA-PATH
|
49 |
-
pipeline_config:
|
50 |
-
shardshuffle: 10000
|
51 |
-
sample_shuffle: 10000
|
52 |
-
|
53 |
-
decoders:
|
54 |
-
- pil
|
55 |
-
|
56 |
-
postprocessors:
|
57 |
-
- target: sdata.mappers.TorchVisionImageTransforms
|
58 |
-
params:
|
59 |
-
key: jpg
|
60 |
-
transforms:
|
61 |
-
- target: torchvision.transforms.Resize
|
62 |
-
params:
|
63 |
-
size: 256
|
64 |
-
interpolation: 3
|
65 |
-
- target: torchvision.transforms.ToTensor
|
66 |
-
- target: sdata.mappers.Rescaler
|
67 |
-
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
68 |
-
params:
|
69 |
-
h_key: height
|
70 |
-
w_key: width
|
71 |
-
|
72 |
-
loader:
|
73 |
-
batch_size: 8
|
74 |
-
num_workers: 4
|
75 |
-
|
76 |
-
|
77 |
-
lightning:
|
78 |
-
strategy:
|
79 |
-
target: pytorch_lightning.strategies.DDPStrategy
|
80 |
-
params:
|
81 |
-
find_unused_parameters: True
|
82 |
-
|
83 |
-
modelcheckpoint:
|
84 |
-
params:
|
85 |
-
every_n_train_steps: 5000
|
86 |
-
|
87 |
-
callbacks:
|
88 |
-
metrics_over_trainsteps_checkpoint:
|
89 |
-
params:
|
90 |
-
every_n_train_steps: 50000
|
91 |
-
|
92 |
-
image_logger:
|
93 |
-
target: main.ImageLogger
|
94 |
-
params:
|
95 |
-
enable_autocast: False
|
96 |
-
batch_frequency: 1000
|
97 |
-
max_images: 8
|
98 |
-
increase_log_steps: True
|
99 |
-
|
100 |
-
trainer:
|
101 |
-
devices: 0,
|
102 |
-
limit_val_batches: 50
|
103 |
-
benchmark: True
|
104 |
-
accumulate_grad_batches: 1
|
105 |
-
val_check_interval: 10000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,185 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 1.0e-4
|
3 |
-
target: sgm.models.diffusion.DiffusionEngine
|
4 |
-
params:
|
5 |
-
scale_factor: 0.13025
|
6 |
-
disable_first_stage_autocast: True
|
7 |
-
log_keys:
|
8 |
-
- cls
|
9 |
-
|
10 |
-
scheduler_config:
|
11 |
-
target: sgm.lr_scheduler.LambdaLinearScheduler
|
12 |
-
params:
|
13 |
-
warm_up_steps: [10000]
|
14 |
-
cycle_lengths: [10000000000000]
|
15 |
-
f_start: [1.e-6]
|
16 |
-
f_max: [1.]
|
17 |
-
f_min: [1.]
|
18 |
-
|
19 |
-
denoiser_config:
|
20 |
-
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
21 |
-
params:
|
22 |
-
num_idx: 1000
|
23 |
-
|
24 |
-
scaling_config:
|
25 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
26 |
-
discretization_config:
|
27 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
28 |
-
|
29 |
-
network_config:
|
30 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
-
params:
|
32 |
-
use_checkpoint: True
|
33 |
-
in_channels: 4
|
34 |
-
out_channels: 4
|
35 |
-
model_channels: 256
|
36 |
-
attention_resolutions: [1, 2, 4]
|
37 |
-
num_res_blocks: 2
|
38 |
-
channel_mult: [1, 2, 4]
|
39 |
-
num_head_channels: 64
|
40 |
-
num_classes: sequential
|
41 |
-
adm_in_channels: 1024
|
42 |
-
transformer_depth: 1
|
43 |
-
context_dim: 1024
|
44 |
-
spatial_transformer_attn_type: softmax-xformers
|
45 |
-
|
46 |
-
conditioner_config:
|
47 |
-
target: sgm.modules.GeneralConditioner
|
48 |
-
params:
|
49 |
-
emb_models:
|
50 |
-
- is_trainable: True
|
51 |
-
input_key: cls
|
52 |
-
ucg_rate: 0.2
|
53 |
-
target: sgm.modules.encoders.modules.ClassEmbedder
|
54 |
-
params:
|
55 |
-
add_sequence_dim: True
|
56 |
-
embed_dim: 1024
|
57 |
-
n_classes: 1000
|
58 |
-
|
59 |
-
- is_trainable: False
|
60 |
-
ucg_rate: 0.2
|
61 |
-
input_key: original_size_as_tuple
|
62 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
63 |
-
params:
|
64 |
-
outdim: 256
|
65 |
-
|
66 |
-
- is_trainable: False
|
67 |
-
input_key: crop_coords_top_left
|
68 |
-
ucg_rate: 0.2
|
69 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
70 |
-
params:
|
71 |
-
outdim: 256
|
72 |
-
|
73 |
-
first_stage_config:
|
74 |
-
target: sgm.models.autoencoder.AutoencoderKL
|
75 |
-
params:
|
76 |
-
ckpt_path: CKPT_PATH
|
77 |
-
embed_dim: 4
|
78 |
-
monitor: val/rec_loss
|
79 |
-
ddconfig:
|
80 |
-
attn_type: vanilla-xformers
|
81 |
-
double_z: true
|
82 |
-
z_channels: 4
|
83 |
-
resolution: 256
|
84 |
-
in_channels: 3
|
85 |
-
out_ch: 3
|
86 |
-
ch: 128
|
87 |
-
ch_mult: [1, 2, 4, 4]
|
88 |
-
num_res_blocks: 2
|
89 |
-
attn_resolutions: []
|
90 |
-
dropout: 0.0
|
91 |
-
lossconfig:
|
92 |
-
target: torch.nn.Identity
|
93 |
-
|
94 |
-
loss_fn_config:
|
95 |
-
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
96 |
-
params:
|
97 |
-
loss_weighting_config:
|
98 |
-
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
99 |
-
sigma_sampler_config:
|
100 |
-
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
101 |
-
params:
|
102 |
-
num_idx: 1000
|
103 |
-
|
104 |
-
discretization_config:
|
105 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
106 |
-
|
107 |
-
sampler_config:
|
108 |
-
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
109 |
-
params:
|
110 |
-
num_steps: 50
|
111 |
-
|
112 |
-
discretization_config:
|
113 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
114 |
-
|
115 |
-
guider_config:
|
116 |
-
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
117 |
-
params:
|
118 |
-
scale: 5.0
|
119 |
-
|
120 |
-
data:
|
121 |
-
target: sgm.data.dataset.StableDataModuleFromConfig
|
122 |
-
params:
|
123 |
-
train:
|
124 |
-
datapipeline:
|
125 |
-
urls:
|
126 |
-
# USER: adapt this path the root of your custom dataset
|
127 |
-
- DATA_PATH
|
128 |
-
pipeline_config:
|
129 |
-
shardshuffle: 10000
|
130 |
-
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
131 |
-
|
132 |
-
decoders:
|
133 |
-
- pil
|
134 |
-
|
135 |
-
postprocessors:
|
136 |
-
- target: sdata.mappers.TorchVisionImageTransforms
|
137 |
-
params:
|
138 |
-
key: jpg # USER: you might wanna adapt this for your custom dataset
|
139 |
-
transforms:
|
140 |
-
- target: torchvision.transforms.Resize
|
141 |
-
params:
|
142 |
-
size: 256
|
143 |
-
interpolation: 3
|
144 |
-
- target: torchvision.transforms.ToTensor
|
145 |
-
- target: sdata.mappers.Rescaler
|
146 |
-
|
147 |
-
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
148 |
-
params:
|
149 |
-
h_key: height # USER: you might wanna adapt this for your custom dataset
|
150 |
-
w_key: width # USER: you might wanna adapt this for your custom dataset
|
151 |
-
|
152 |
-
loader:
|
153 |
-
batch_size: 64
|
154 |
-
num_workers: 6
|
155 |
-
|
156 |
-
lightning:
|
157 |
-
modelcheckpoint:
|
158 |
-
params:
|
159 |
-
every_n_train_steps: 5000
|
160 |
-
|
161 |
-
callbacks:
|
162 |
-
metrics_over_trainsteps_checkpoint:
|
163 |
-
params:
|
164 |
-
every_n_train_steps: 25000
|
165 |
-
|
166 |
-
image_logger:
|
167 |
-
target: main.ImageLogger
|
168 |
-
params:
|
169 |
-
disabled: False
|
170 |
-
enable_autocast: False
|
171 |
-
batch_frequency: 1000
|
172 |
-
max_images: 8
|
173 |
-
increase_log_steps: True
|
174 |
-
log_first_step: False
|
175 |
-
log_images_kwargs:
|
176 |
-
use_ema_scope: False
|
177 |
-
N: 8
|
178 |
-
n_rows: 2
|
179 |
-
|
180 |
-
trainer:
|
181 |
-
devices: 0,
|
182 |
-
benchmark: True
|
183 |
-
num_sanity_val_steps: 0
|
184 |
-
accumulate_grad_batches: 1
|
185 |
-
max_epochs: 1000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,98 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 1.0e-4
|
3 |
-
target: sgm.models.diffusion.DiffusionEngine
|
4 |
-
params:
|
5 |
-
denoiser_config:
|
6 |
-
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
7 |
-
params:
|
8 |
-
scaling_config:
|
9 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
10 |
-
params:
|
11 |
-
sigma_data: 1.0
|
12 |
-
|
13 |
-
network_config:
|
14 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
15 |
-
params:
|
16 |
-
in_channels: 3
|
17 |
-
out_channels: 3
|
18 |
-
model_channels: 32
|
19 |
-
attention_resolutions: []
|
20 |
-
num_res_blocks: 4
|
21 |
-
channel_mult: [1, 2, 2]
|
22 |
-
num_head_channels: 32
|
23 |
-
num_classes: sequential
|
24 |
-
adm_in_channels: 128
|
25 |
-
|
26 |
-
conditioner_config:
|
27 |
-
target: sgm.modules.GeneralConditioner
|
28 |
-
params:
|
29 |
-
emb_models:
|
30 |
-
- is_trainable: True
|
31 |
-
input_key: cls
|
32 |
-
ucg_rate: 0.2
|
33 |
-
target: sgm.modules.encoders.modules.ClassEmbedder
|
34 |
-
params:
|
35 |
-
embed_dim: 128
|
36 |
-
n_classes: 10
|
37 |
-
|
38 |
-
first_stage_config:
|
39 |
-
target: sgm.models.autoencoder.IdentityFirstStage
|
40 |
-
|
41 |
-
loss_fn_config:
|
42 |
-
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
43 |
-
params:
|
44 |
-
loss_weighting_config:
|
45 |
-
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
46 |
-
params:
|
47 |
-
sigma_data: 1.0
|
48 |
-
sigma_sampler_config:
|
49 |
-
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
50 |
-
|
51 |
-
sampler_config:
|
52 |
-
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
53 |
-
params:
|
54 |
-
num_steps: 50
|
55 |
-
|
56 |
-
discretization_config:
|
57 |
-
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
58 |
-
|
59 |
-
guider_config:
|
60 |
-
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
61 |
-
params:
|
62 |
-
scale: 3.0
|
63 |
-
|
64 |
-
data:
|
65 |
-
target: sgm.data.cifar10.CIFAR10Loader
|
66 |
-
params:
|
67 |
-
batch_size: 512
|
68 |
-
num_workers: 1
|
69 |
-
|
70 |
-
lightning:
|
71 |
-
modelcheckpoint:
|
72 |
-
params:
|
73 |
-
every_n_train_steps: 5000
|
74 |
-
|
75 |
-
callbacks:
|
76 |
-
metrics_over_trainsteps_checkpoint:
|
77 |
-
params:
|
78 |
-
every_n_train_steps: 25000
|
79 |
-
|
80 |
-
image_logger:
|
81 |
-
target: main.ImageLogger
|
82 |
-
params:
|
83 |
-
disabled: False
|
84 |
-
batch_frequency: 1000
|
85 |
-
max_images: 64
|
86 |
-
increase_log_steps: True
|
87 |
-
log_first_step: False
|
88 |
-
log_images_kwargs:
|
89 |
-
use_ema_scope: False
|
90 |
-
N: 64
|
91 |
-
n_rows: 8
|
92 |
-
|
93 |
-
trainer:
|
94 |
-
devices: 0,
|
95 |
-
benchmark: True
|
96 |
-
num_sanity_val_steps: 0
|
97 |
-
accumulate_grad_batches: 1
|
98 |
-
max_epochs: 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,79 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 1.0e-4
|
3 |
-
target: sgm.models.diffusion.DiffusionEngine
|
4 |
-
params:
|
5 |
-
denoiser_config:
|
6 |
-
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
7 |
-
params:
|
8 |
-
scaling_config:
|
9 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
10 |
-
params:
|
11 |
-
sigma_data: 1.0
|
12 |
-
|
13 |
-
network_config:
|
14 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
15 |
-
params:
|
16 |
-
in_channels: 1
|
17 |
-
out_channels: 1
|
18 |
-
model_channels: 32
|
19 |
-
attention_resolutions: []
|
20 |
-
num_res_blocks: 4
|
21 |
-
channel_mult: [1, 2, 2]
|
22 |
-
num_head_channels: 32
|
23 |
-
|
24 |
-
first_stage_config:
|
25 |
-
target: sgm.models.autoencoder.IdentityFirstStage
|
26 |
-
|
27 |
-
loss_fn_config:
|
28 |
-
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
29 |
-
params:
|
30 |
-
loss_weighting_config:
|
31 |
-
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
32 |
-
params:
|
33 |
-
sigma_data: 1.0
|
34 |
-
sigma_sampler_config:
|
35 |
-
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
36 |
-
|
37 |
-
sampler_config:
|
38 |
-
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
39 |
-
params:
|
40 |
-
num_steps: 50
|
41 |
-
|
42 |
-
discretization_config:
|
43 |
-
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
44 |
-
|
45 |
-
data:
|
46 |
-
target: sgm.data.mnist.MNISTLoader
|
47 |
-
params:
|
48 |
-
batch_size: 512
|
49 |
-
num_workers: 1
|
50 |
-
|
51 |
-
lightning:
|
52 |
-
modelcheckpoint:
|
53 |
-
params:
|
54 |
-
every_n_train_steps: 5000
|
55 |
-
|
56 |
-
callbacks:
|
57 |
-
metrics_over_trainsteps_checkpoint:
|
58 |
-
params:
|
59 |
-
every_n_train_steps: 25000
|
60 |
-
|
61 |
-
image_logger:
|
62 |
-
target: main.ImageLogger
|
63 |
-
params:
|
64 |
-
disabled: False
|
65 |
-
batch_frequency: 1000
|
66 |
-
max_images: 64
|
67 |
-
increase_log_steps: False
|
68 |
-
log_first_step: False
|
69 |
-
log_images_kwargs:
|
70 |
-
use_ema_scope: False
|
71 |
-
N: 64
|
72 |
-
n_rows: 8
|
73 |
-
|
74 |
-
trainer:
|
75 |
-
devices: 0,
|
76 |
-
benchmark: True
|
77 |
-
num_sanity_val_steps: 0
|
78 |
-
accumulate_grad_batches: 1
|
79 |
-
max_epochs: 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,98 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 1.0e-4
|
3 |
-
target: sgm.models.diffusion.DiffusionEngine
|
4 |
-
params:
|
5 |
-
denoiser_config:
|
6 |
-
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
7 |
-
params:
|
8 |
-
scaling_config:
|
9 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
10 |
-
params:
|
11 |
-
sigma_data: 1.0
|
12 |
-
|
13 |
-
network_config:
|
14 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
15 |
-
params:
|
16 |
-
in_channels: 1
|
17 |
-
out_channels: 1
|
18 |
-
model_channels: 32
|
19 |
-
attention_resolutions: []
|
20 |
-
num_res_blocks: 4
|
21 |
-
channel_mult: [1, 2, 2]
|
22 |
-
num_head_channels: 32
|
23 |
-
num_classes: sequential
|
24 |
-
adm_in_channels: 128
|
25 |
-
|
26 |
-
conditioner_config:
|
27 |
-
target: sgm.modules.GeneralConditioner
|
28 |
-
params:
|
29 |
-
emb_models:
|
30 |
-
- is_trainable: True
|
31 |
-
input_key: cls
|
32 |
-
ucg_rate: 0.2
|
33 |
-
target: sgm.modules.encoders.modules.ClassEmbedder
|
34 |
-
params:
|
35 |
-
embed_dim: 128
|
36 |
-
n_classes: 10
|
37 |
-
|
38 |
-
first_stage_config:
|
39 |
-
target: sgm.models.autoencoder.IdentityFirstStage
|
40 |
-
|
41 |
-
loss_fn_config:
|
42 |
-
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
43 |
-
params:
|
44 |
-
loss_weighting_config:
|
45 |
-
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
46 |
-
params:
|
47 |
-
sigma_data: 1.0
|
48 |
-
sigma_sampler_config:
|
49 |
-
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
50 |
-
|
51 |
-
sampler_config:
|
52 |
-
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
53 |
-
params:
|
54 |
-
num_steps: 50
|
55 |
-
|
56 |
-
discretization_config:
|
57 |
-
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
58 |
-
|
59 |
-
guider_config:
|
60 |
-
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
61 |
-
params:
|
62 |
-
scale: 3.0
|
63 |
-
|
64 |
-
data:
|
65 |
-
target: sgm.data.mnist.MNISTLoader
|
66 |
-
params:
|
67 |
-
batch_size: 512
|
68 |
-
num_workers: 1
|
69 |
-
|
70 |
-
lightning:
|
71 |
-
modelcheckpoint:
|
72 |
-
params:
|
73 |
-
every_n_train_steps: 5000
|
74 |
-
|
75 |
-
callbacks:
|
76 |
-
metrics_over_trainsteps_checkpoint:
|
77 |
-
params:
|
78 |
-
every_n_train_steps: 25000
|
79 |
-
|
80 |
-
image_logger:
|
81 |
-
target: main.ImageLogger
|
82 |
-
params:
|
83 |
-
disabled: False
|
84 |
-
batch_frequency: 1000
|
85 |
-
max_images: 16
|
86 |
-
increase_log_steps: True
|
87 |
-
log_first_step: False
|
88 |
-
log_images_kwargs:
|
89 |
-
use_ema_scope: False
|
90 |
-
N: 16
|
91 |
-
n_rows: 4
|
92 |
-
|
93 |
-
trainer:
|
94 |
-
devices: 0,
|
95 |
-
benchmark: True
|
96 |
-
num_sanity_val_steps: 0
|
97 |
-
accumulate_grad_batches: 1
|
98 |
-
max_epochs: 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,103 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 1.0e-4
|
3 |
-
target: sgm.models.diffusion.DiffusionEngine
|
4 |
-
params:
|
5 |
-
denoiser_config:
|
6 |
-
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
7 |
-
params:
|
8 |
-
num_idx: 1000
|
9 |
-
|
10 |
-
scaling_config:
|
11 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
12 |
-
discretization_config:
|
13 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
14 |
-
|
15 |
-
network_config:
|
16 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
17 |
-
params:
|
18 |
-
in_channels: 1
|
19 |
-
out_channels: 1
|
20 |
-
model_channels: 32
|
21 |
-
attention_resolutions: []
|
22 |
-
num_res_blocks: 4
|
23 |
-
channel_mult: [1, 2, 2]
|
24 |
-
num_head_channels: 32
|
25 |
-
num_classes: sequential
|
26 |
-
adm_in_channels: 128
|
27 |
-
|
28 |
-
conditioner_config:
|
29 |
-
target: sgm.modules.GeneralConditioner
|
30 |
-
params:
|
31 |
-
emb_models:
|
32 |
-
- is_trainable: True
|
33 |
-
input_key: cls
|
34 |
-
ucg_rate: 0.2
|
35 |
-
target: sgm.modules.encoders.modules.ClassEmbedder
|
36 |
-
params:
|
37 |
-
embed_dim: 128
|
38 |
-
n_classes: 10
|
39 |
-
|
40 |
-
first_stage_config:
|
41 |
-
target: sgm.models.autoencoder.IdentityFirstStage
|
42 |
-
|
43 |
-
loss_fn_config:
|
44 |
-
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
45 |
-
params:
|
46 |
-
loss_weighting_config:
|
47 |
-
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
48 |
-
sigma_sampler_config:
|
49 |
-
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
50 |
-
params:
|
51 |
-
num_idx: 1000
|
52 |
-
|
53 |
-
discretization_config:
|
54 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
55 |
-
|
56 |
-
sampler_config:
|
57 |
-
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
58 |
-
params:
|
59 |
-
num_steps: 50
|
60 |
-
|
61 |
-
discretization_config:
|
62 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
63 |
-
|
64 |
-
guider_config:
|
65 |
-
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
66 |
-
params:
|
67 |
-
scale: 5.0
|
68 |
-
|
69 |
-
data:
|
70 |
-
target: sgm.data.mnist.MNISTLoader
|
71 |
-
params:
|
72 |
-
batch_size: 512
|
73 |
-
num_workers: 1
|
74 |
-
|
75 |
-
lightning:
|
76 |
-
modelcheckpoint:
|
77 |
-
params:
|
78 |
-
every_n_train_steps: 5000
|
79 |
-
|
80 |
-
callbacks:
|
81 |
-
metrics_over_trainsteps_checkpoint:
|
82 |
-
params:
|
83 |
-
every_n_train_steps: 25000
|
84 |
-
|
85 |
-
image_logger:
|
86 |
-
target: main.ImageLogger
|
87 |
-
params:
|
88 |
-
disabled: False
|
89 |
-
batch_frequency: 1000
|
90 |
-
max_images: 16
|
91 |
-
increase_log_steps: True
|
92 |
-
log_first_step: False
|
93 |
-
log_images_kwargs:
|
94 |
-
use_ema_scope: False
|
95 |
-
N: 16
|
96 |
-
n_rows: 4
|
97 |
-
|
98 |
-
trainer:
|
99 |
-
devices: 0,
|
100 |
-
benchmark: True
|
101 |
-
num_sanity_val_steps: 0
|
102 |
-
accumulate_grad_batches: 1
|
103 |
-
max_epochs: 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,99 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 1.0e-4
|
3 |
-
target: sgm.models.diffusion.DiffusionEngine
|
4 |
-
params:
|
5 |
-
denoiser_config:
|
6 |
-
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
7 |
-
params:
|
8 |
-
scaling_config:
|
9 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
10 |
-
params:
|
11 |
-
sigma_data: 1.0
|
12 |
-
|
13 |
-
network_config:
|
14 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
15 |
-
params:
|
16 |
-
in_channels: 1
|
17 |
-
out_channels: 1
|
18 |
-
model_channels: 32
|
19 |
-
attention_resolutions: []
|
20 |
-
num_res_blocks: 4
|
21 |
-
channel_mult: [1, 2, 2]
|
22 |
-
num_head_channels: 32
|
23 |
-
num_classes: sequential
|
24 |
-
adm_in_channels: 128
|
25 |
-
|
26 |
-
conditioner_config:
|
27 |
-
target: sgm.modules.GeneralConditioner
|
28 |
-
params:
|
29 |
-
emb_models:
|
30 |
-
- is_trainable: True
|
31 |
-
input_key: cls
|
32 |
-
ucg_rate: 0.2
|
33 |
-
target: sgm.modules.encoders.modules.ClassEmbedder
|
34 |
-
params:
|
35 |
-
embed_dim: 128
|
36 |
-
n_classes: 10
|
37 |
-
|
38 |
-
first_stage_config:
|
39 |
-
target: sgm.models.autoencoder.IdentityFirstStage
|
40 |
-
|
41 |
-
loss_fn_config:
|
42 |
-
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
43 |
-
params:
|
44 |
-
loss_type: l1
|
45 |
-
loss_weighting_config:
|
46 |
-
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
47 |
-
params:
|
48 |
-
sigma_data: 1.0
|
49 |
-
sigma_sampler_config:
|
50 |
-
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
51 |
-
|
52 |
-
sampler_config:
|
53 |
-
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
54 |
-
params:
|
55 |
-
num_steps: 50
|
56 |
-
|
57 |
-
discretization_config:
|
58 |
-
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
59 |
-
|
60 |
-
guider_config:
|
61 |
-
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
62 |
-
params:
|
63 |
-
scale: 3.0
|
64 |
-
|
65 |
-
data:
|
66 |
-
target: sgm.data.mnist.MNISTLoader
|
67 |
-
params:
|
68 |
-
batch_size: 512
|
69 |
-
num_workers: 1
|
70 |
-
|
71 |
-
lightning:
|
72 |
-
modelcheckpoint:
|
73 |
-
params:
|
74 |
-
every_n_train_steps: 5000
|
75 |
-
|
76 |
-
callbacks:
|
77 |
-
metrics_over_trainsteps_checkpoint:
|
78 |
-
params:
|
79 |
-
every_n_train_steps: 25000
|
80 |
-
|
81 |
-
image_logger:
|
82 |
-
target: main.ImageLogger
|
83 |
-
params:
|
84 |
-
disabled: False
|
85 |
-
batch_frequency: 1000
|
86 |
-
max_images: 64
|
87 |
-
increase_log_steps: True
|
88 |
-
log_first_step: False
|
89 |
-
log_images_kwargs:
|
90 |
-
use_ema_scope: False
|
91 |
-
N: 64
|
92 |
-
n_rows: 8
|
93 |
-
|
94 |
-
trainer:
|
95 |
-
devices: 0,
|
96 |
-
benchmark: True
|
97 |
-
num_sanity_val_steps: 0
|
98 |
-
accumulate_grad_batches: 1
|
99 |
-
max_epochs: 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,100 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 1.0e-4
|
3 |
-
target: sgm.models.diffusion.DiffusionEngine
|
4 |
-
params:
|
5 |
-
use_ema: True
|
6 |
-
|
7 |
-
denoiser_config:
|
8 |
-
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
9 |
-
params:
|
10 |
-
scaling_config:
|
11 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
12 |
-
params:
|
13 |
-
sigma_data: 1.0
|
14 |
-
|
15 |
-
network_config:
|
16 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
17 |
-
params:
|
18 |
-
in_channels: 1
|
19 |
-
out_channels: 1
|
20 |
-
model_channels: 32
|
21 |
-
attention_resolutions: []
|
22 |
-
num_res_blocks: 4
|
23 |
-
channel_mult: [1, 2, 2]
|
24 |
-
num_head_channels: 32
|
25 |
-
num_classes: sequential
|
26 |
-
adm_in_channels: 128
|
27 |
-
|
28 |
-
conditioner_config:
|
29 |
-
target: sgm.modules.GeneralConditioner
|
30 |
-
params:
|
31 |
-
emb_models:
|
32 |
-
- is_trainable: True
|
33 |
-
input_key: cls
|
34 |
-
ucg_rate: 0.2
|
35 |
-
target: sgm.modules.encoders.modules.ClassEmbedder
|
36 |
-
params:
|
37 |
-
embed_dim: 128
|
38 |
-
n_classes: 10
|
39 |
-
|
40 |
-
first_stage_config:
|
41 |
-
target: sgm.models.autoencoder.IdentityFirstStage
|
42 |
-
|
43 |
-
loss_fn_config:
|
44 |
-
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
45 |
-
params:
|
46 |
-
loss_weighting_config:
|
47 |
-
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
48 |
-
params:
|
49 |
-
sigma_data: 1.0
|
50 |
-
sigma_sampler_config:
|
51 |
-
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
52 |
-
|
53 |
-
sampler_config:
|
54 |
-
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
55 |
-
params:
|
56 |
-
num_steps: 50
|
57 |
-
|
58 |
-
discretization_config:
|
59 |
-
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
60 |
-
|
61 |
-
guider_config:
|
62 |
-
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
63 |
-
params:
|
64 |
-
scale: 3.0
|
65 |
-
|
66 |
-
data:
|
67 |
-
target: sgm.data.mnist.MNISTLoader
|
68 |
-
params:
|
69 |
-
batch_size: 512
|
70 |
-
num_workers: 1
|
71 |
-
|
72 |
-
lightning:
|
73 |
-
modelcheckpoint:
|
74 |
-
params:
|
75 |
-
every_n_train_steps: 5000
|
76 |
-
|
77 |
-
callbacks:
|
78 |
-
metrics_over_trainsteps_checkpoint:
|
79 |
-
params:
|
80 |
-
every_n_train_steps: 25000
|
81 |
-
|
82 |
-
image_logger:
|
83 |
-
target: main.ImageLogger
|
84 |
-
params:
|
85 |
-
disabled: False
|
86 |
-
batch_frequency: 1000
|
87 |
-
max_images: 64
|
88 |
-
increase_log_steps: True
|
89 |
-
log_first_step: False
|
90 |
-
log_images_kwargs:
|
91 |
-
use_ema_scope: False
|
92 |
-
N: 64
|
93 |
-
n_rows: 8
|
94 |
-
|
95 |
-
trainer:
|
96 |
-
devices: 0,
|
97 |
-
benchmark: True
|
98 |
-
num_sanity_val_steps: 0
|
99 |
-
accumulate_grad_batches: 1
|
100 |
-
max_epochs: 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,182 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 1.0e-4
|
3 |
-
target: sgm.models.diffusion.DiffusionEngine
|
4 |
-
params:
|
5 |
-
scale_factor: 0.13025
|
6 |
-
disable_first_stage_autocast: True
|
7 |
-
log_keys:
|
8 |
-
- txt
|
9 |
-
|
10 |
-
scheduler_config:
|
11 |
-
target: sgm.lr_scheduler.LambdaLinearScheduler
|
12 |
-
params:
|
13 |
-
warm_up_steps: [10000]
|
14 |
-
cycle_lengths: [10000000000000]
|
15 |
-
f_start: [1.e-6]
|
16 |
-
f_max: [1.]
|
17 |
-
f_min: [1.]
|
18 |
-
|
19 |
-
denoiser_config:
|
20 |
-
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
21 |
-
params:
|
22 |
-
num_idx: 1000
|
23 |
-
|
24 |
-
scaling_config:
|
25 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
26 |
-
discretization_config:
|
27 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
28 |
-
|
29 |
-
network_config:
|
30 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
-
params:
|
32 |
-
use_checkpoint: True
|
33 |
-
in_channels: 4
|
34 |
-
out_channels: 4
|
35 |
-
model_channels: 320
|
36 |
-
attention_resolutions: [1, 2, 4]
|
37 |
-
num_res_blocks: 2
|
38 |
-
channel_mult: [1, 2, 4, 4]
|
39 |
-
num_head_channels: 64
|
40 |
-
num_classes: sequential
|
41 |
-
adm_in_channels: 1792
|
42 |
-
num_heads: 1
|
43 |
-
transformer_depth: 1
|
44 |
-
context_dim: 768
|
45 |
-
spatial_transformer_attn_type: softmax-xformers
|
46 |
-
|
47 |
-
conditioner_config:
|
48 |
-
target: sgm.modules.GeneralConditioner
|
49 |
-
params:
|
50 |
-
emb_models:
|
51 |
-
- is_trainable: True
|
52 |
-
input_key: txt
|
53 |
-
ucg_rate: 0.1
|
54 |
-
legacy_ucg_value: ""
|
55 |
-
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
56 |
-
params:
|
57 |
-
always_return_pooled: True
|
58 |
-
|
59 |
-
- is_trainable: False
|
60 |
-
ucg_rate: 0.1
|
61 |
-
input_key: original_size_as_tuple
|
62 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
63 |
-
params:
|
64 |
-
outdim: 256
|
65 |
-
|
66 |
-
- is_trainable: False
|
67 |
-
input_key: crop_coords_top_left
|
68 |
-
ucg_rate: 0.1
|
69 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
70 |
-
params:
|
71 |
-
outdim: 256
|
72 |
-
|
73 |
-
first_stage_config:
|
74 |
-
target: sgm.models.autoencoder.AutoencoderKL
|
75 |
-
params:
|
76 |
-
ckpt_path: CKPT_PATH
|
77 |
-
embed_dim: 4
|
78 |
-
monitor: val/rec_loss
|
79 |
-
ddconfig:
|
80 |
-
attn_type: vanilla-xformers
|
81 |
-
double_z: true
|
82 |
-
z_channels: 4
|
83 |
-
resolution: 256
|
84 |
-
in_channels: 3
|
85 |
-
out_ch: 3
|
86 |
-
ch: 128
|
87 |
-
ch_mult: [ 1, 2, 4, 4 ]
|
88 |
-
num_res_blocks: 2
|
89 |
-
attn_resolutions: [ ]
|
90 |
-
dropout: 0.0
|
91 |
-
lossconfig:
|
92 |
-
target: torch.nn.Identity
|
93 |
-
|
94 |
-
loss_fn_config:
|
95 |
-
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
96 |
-
params:
|
97 |
-
loss_weighting_config:
|
98 |
-
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
99 |
-
sigma_sampler_config:
|
100 |
-
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
101 |
-
params:
|
102 |
-
num_idx: 1000
|
103 |
-
|
104 |
-
discretization_config:
|
105 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
106 |
-
|
107 |
-
sampler_config:
|
108 |
-
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
109 |
-
params:
|
110 |
-
num_steps: 50
|
111 |
-
|
112 |
-
discretization_config:
|
113 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
114 |
-
|
115 |
-
guider_config:
|
116 |
-
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
117 |
-
params:
|
118 |
-
scale: 7.5
|
119 |
-
|
120 |
-
data:
|
121 |
-
target: sgm.data.dataset.StableDataModuleFromConfig
|
122 |
-
params:
|
123 |
-
train:
|
124 |
-
datapipeline:
|
125 |
-
urls:
|
126 |
-
# USER: adapt this path the root of your custom dataset
|
127 |
-
- DATA_PATH
|
128 |
-
pipeline_config:
|
129 |
-
shardshuffle: 10000
|
130 |
-
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
131 |
-
|
132 |
-
decoders:
|
133 |
-
- pil
|
134 |
-
|
135 |
-
postprocessors:
|
136 |
-
- target: sdata.mappers.TorchVisionImageTransforms
|
137 |
-
params:
|
138 |
-
key: jpg # USER: you might wanna adapt this for your custom dataset
|
139 |
-
transforms:
|
140 |
-
- target: torchvision.transforms.Resize
|
141 |
-
params:
|
142 |
-
size: 256
|
143 |
-
interpolation: 3
|
144 |
-
- target: torchvision.transforms.ToTensor
|
145 |
-
- target: sdata.mappers.Rescaler
|
146 |
-
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
147 |
-
# USER: you might wanna use non-default parameters due to your custom dataset
|
148 |
-
|
149 |
-
loader:
|
150 |
-
batch_size: 64
|
151 |
-
num_workers: 6
|
152 |
-
|
153 |
-
lightning:
|
154 |
-
modelcheckpoint:
|
155 |
-
params:
|
156 |
-
every_n_train_steps: 5000
|
157 |
-
|
158 |
-
callbacks:
|
159 |
-
metrics_over_trainsteps_checkpoint:
|
160 |
-
params:
|
161 |
-
every_n_train_steps: 25000
|
162 |
-
|
163 |
-
image_logger:
|
164 |
-
target: main.ImageLogger
|
165 |
-
params:
|
166 |
-
disabled: False
|
167 |
-
enable_autocast: False
|
168 |
-
batch_frequency: 1000
|
169 |
-
max_images: 8
|
170 |
-
increase_log_steps: True
|
171 |
-
log_first_step: False
|
172 |
-
log_images_kwargs:
|
173 |
-
use_ema_scope: False
|
174 |
-
N: 8
|
175 |
-
n_rows: 2
|
176 |
-
|
177 |
-
trainer:
|
178 |
-
devices: 0,
|
179 |
-
benchmark: True
|
180 |
-
num_sanity_val_steps: 0
|
181 |
-
accumulate_grad_batches: 1
|
182 |
-
max_epochs: 1000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,184 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 1.0e-4
|
3 |
-
target: sgm.models.diffusion.DiffusionEngine
|
4 |
-
params:
|
5 |
-
scale_factor: 0.13025
|
6 |
-
disable_first_stage_autocast: True
|
7 |
-
log_keys:
|
8 |
-
- txt
|
9 |
-
|
10 |
-
scheduler_config:
|
11 |
-
target: sgm.lr_scheduler.LambdaLinearScheduler
|
12 |
-
params:
|
13 |
-
warm_up_steps: [10000]
|
14 |
-
cycle_lengths: [10000000000000]
|
15 |
-
f_start: [1.e-6]
|
16 |
-
f_max: [1.]
|
17 |
-
f_min: [1.]
|
18 |
-
|
19 |
-
denoiser_config:
|
20 |
-
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
21 |
-
params:
|
22 |
-
num_idx: 1000
|
23 |
-
|
24 |
-
scaling_config:
|
25 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
26 |
-
discretization_config:
|
27 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
28 |
-
|
29 |
-
network_config:
|
30 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
-
params:
|
32 |
-
use_checkpoint: True
|
33 |
-
in_channels: 4
|
34 |
-
out_channels: 4
|
35 |
-
model_channels: 320
|
36 |
-
attention_resolutions: [1, 2, 4]
|
37 |
-
num_res_blocks: 2
|
38 |
-
channel_mult: [1, 2, 4, 4]
|
39 |
-
num_head_channels: 64
|
40 |
-
num_classes: sequential
|
41 |
-
adm_in_channels: 1792
|
42 |
-
num_heads: 1
|
43 |
-
transformer_depth: 1
|
44 |
-
context_dim: 768
|
45 |
-
spatial_transformer_attn_type: softmax-xformers
|
46 |
-
|
47 |
-
conditioner_config:
|
48 |
-
target: sgm.modules.GeneralConditioner
|
49 |
-
params:
|
50 |
-
emb_models:
|
51 |
-
- is_trainable: True
|
52 |
-
input_key: txt
|
53 |
-
ucg_rate: 0.1
|
54 |
-
legacy_ucg_value: ""
|
55 |
-
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
56 |
-
params:
|
57 |
-
always_return_pooled: True
|
58 |
-
|
59 |
-
- is_trainable: False
|
60 |
-
ucg_rate: 0.1
|
61 |
-
input_key: original_size_as_tuple
|
62 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
63 |
-
params:
|
64 |
-
outdim: 256
|
65 |
-
|
66 |
-
- is_trainable: False
|
67 |
-
input_key: crop_coords_top_left
|
68 |
-
ucg_rate: 0.1
|
69 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
70 |
-
params:
|
71 |
-
outdim: 256
|
72 |
-
|
73 |
-
first_stage_config:
|
74 |
-
target: sgm.models.autoencoder.AutoencoderKL
|
75 |
-
params:
|
76 |
-
ckpt_path: CKPT_PATH
|
77 |
-
embed_dim: 4
|
78 |
-
monitor: val/rec_loss
|
79 |
-
ddconfig:
|
80 |
-
attn_type: vanilla-xformers
|
81 |
-
double_z: true
|
82 |
-
z_channels: 4
|
83 |
-
resolution: 256
|
84 |
-
in_channels: 3
|
85 |
-
out_ch: 3
|
86 |
-
ch: 128
|
87 |
-
ch_mult: [1, 2, 4, 4]
|
88 |
-
num_res_blocks: 2
|
89 |
-
attn_resolutions: []
|
90 |
-
dropout: 0.0
|
91 |
-
lossconfig:
|
92 |
-
target: torch.nn.Identity
|
93 |
-
|
94 |
-
loss_fn_config:
|
95 |
-
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
96 |
-
params:
|
97 |
-
loss_weighting_config:
|
98 |
-
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
99 |
-
sigma_sampler_config:
|
100 |
-
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
101 |
-
params:
|
102 |
-
num_idx: 1000
|
103 |
-
|
104 |
-
discretization_config:
|
105 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
106 |
-
|
107 |
-
sampler_config:
|
108 |
-
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
109 |
-
params:
|
110 |
-
num_steps: 50
|
111 |
-
|
112 |
-
discretization_config:
|
113 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
114 |
-
|
115 |
-
guider_config:
|
116 |
-
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
|
117 |
-
params:
|
118 |
-
scale: 7.5
|
119 |
-
|
120 |
-
data:
|
121 |
-
target: sgm.data.dataset.StableDataModuleFromConfig
|
122 |
-
params:
|
123 |
-
train:
|
124 |
-
datapipeline:
|
125 |
-
urls:
|
126 |
-
# USER: adapt this path the root of your custom dataset
|
127 |
-
- DATA_PATH
|
128 |
-
pipeline_config:
|
129 |
-
shardshuffle: 10000
|
130 |
-
sample_shuffle: 10000
|
131 |
-
|
132 |
-
|
133 |
-
decoders:
|
134 |
-
- pil
|
135 |
-
|
136 |
-
postprocessors:
|
137 |
-
- target: sdata.mappers.TorchVisionImageTransforms
|
138 |
-
params:
|
139 |
-
key: jpg # USER: you might wanna adapt this for your custom dataset
|
140 |
-
transforms:
|
141 |
-
- target: torchvision.transforms.Resize
|
142 |
-
params:
|
143 |
-
size: 256
|
144 |
-
interpolation: 3
|
145 |
-
- target: torchvision.transforms.ToTensor
|
146 |
-
- target: sdata.mappers.Rescaler
|
147 |
-
# USER: you might wanna use non-default parameters due to your custom dataset
|
148 |
-
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
149 |
-
# USER: you might wanna use non-default parameters due to your custom dataset
|
150 |
-
|
151 |
-
loader:
|
152 |
-
batch_size: 64
|
153 |
-
num_workers: 6
|
154 |
-
|
155 |
-
lightning:
|
156 |
-
modelcheckpoint:
|
157 |
-
params:
|
158 |
-
every_n_train_steps: 5000
|
159 |
-
|
160 |
-
callbacks:
|
161 |
-
metrics_over_trainsteps_checkpoint:
|
162 |
-
params:
|
163 |
-
every_n_train_steps: 25000
|
164 |
-
|
165 |
-
image_logger:
|
166 |
-
target: main.ImageLogger
|
167 |
-
params:
|
168 |
-
disabled: False
|
169 |
-
enable_autocast: False
|
170 |
-
batch_frequency: 1000
|
171 |
-
max_images: 8
|
172 |
-
increase_log_steps: True
|
173 |
-
log_first_step: False
|
174 |
-
log_images_kwargs:
|
175 |
-
use_ema_scope: False
|
176 |
-
N: 8
|
177 |
-
n_rows: 2
|
178 |
-
|
179 |
-
trainer:
|
180 |
-
devices: 0,
|
181 |
-
benchmark: True
|
182 |
-
num_sanity_val_steps: 0
|
183 |
-
accumulate_grad_batches: 1
|
184 |
-
max_epochs: 1000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,60 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
target: sgm.models.diffusion.DiffusionEngine
|
3 |
-
params:
|
4 |
-
scale_factor: 0.18215
|
5 |
-
disable_first_stage_autocast: True
|
6 |
-
|
7 |
-
denoiser_config:
|
8 |
-
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
9 |
-
params:
|
10 |
-
num_idx: 1000
|
11 |
-
|
12 |
-
scaling_config:
|
13 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
14 |
-
discretization_config:
|
15 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
16 |
-
|
17 |
-
network_config:
|
18 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
-
params:
|
20 |
-
use_checkpoint: True
|
21 |
-
in_channels: 4
|
22 |
-
out_channels: 4
|
23 |
-
model_channels: 320
|
24 |
-
attention_resolutions: [4, 2, 1]
|
25 |
-
num_res_blocks: 2
|
26 |
-
channel_mult: [1, 2, 4, 4]
|
27 |
-
num_head_channels: 64
|
28 |
-
use_linear_in_transformer: True
|
29 |
-
transformer_depth: 1
|
30 |
-
context_dim: 1024
|
31 |
-
|
32 |
-
conditioner_config:
|
33 |
-
target: sgm.modules.GeneralConditioner
|
34 |
-
params:
|
35 |
-
emb_models:
|
36 |
-
- is_trainable: False
|
37 |
-
input_key: txt
|
38 |
-
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
39 |
-
params:
|
40 |
-
freeze: true
|
41 |
-
layer: penultimate
|
42 |
-
|
43 |
-
first_stage_config:
|
44 |
-
target: sgm.models.autoencoder.AutoencoderKL
|
45 |
-
params:
|
46 |
-
embed_dim: 4
|
47 |
-
monitor: val/rec_loss
|
48 |
-
ddconfig:
|
49 |
-
double_z: true
|
50 |
-
z_channels: 4
|
51 |
-
resolution: 256
|
52 |
-
in_channels: 3
|
53 |
-
out_ch: 3
|
54 |
-
ch: 128
|
55 |
-
ch_mult: [1, 2, 4, 4]
|
56 |
-
num_res_blocks: 2
|
57 |
-
attn_resolutions: []
|
58 |
-
dropout: 0.0
|
59 |
-
lossconfig:
|
60 |
-
target: torch.nn.Identity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,60 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
target: sgm.models.diffusion.DiffusionEngine
|
3 |
-
params:
|
4 |
-
scale_factor: 0.18215
|
5 |
-
disable_first_stage_autocast: True
|
6 |
-
|
7 |
-
denoiser_config:
|
8 |
-
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
9 |
-
params:
|
10 |
-
num_idx: 1000
|
11 |
-
|
12 |
-
scaling_config:
|
13 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
|
14 |
-
discretization_config:
|
15 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
16 |
-
|
17 |
-
network_config:
|
18 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
-
params:
|
20 |
-
use_checkpoint: True
|
21 |
-
in_channels: 4
|
22 |
-
out_channels: 4
|
23 |
-
model_channels: 320
|
24 |
-
attention_resolutions: [4, 2, 1]
|
25 |
-
num_res_blocks: 2
|
26 |
-
channel_mult: [1, 2, 4, 4]
|
27 |
-
num_head_channels: 64
|
28 |
-
use_linear_in_transformer: True
|
29 |
-
transformer_depth: 1
|
30 |
-
context_dim: 1024
|
31 |
-
|
32 |
-
conditioner_config:
|
33 |
-
target: sgm.modules.GeneralConditioner
|
34 |
-
params:
|
35 |
-
emb_models:
|
36 |
-
- is_trainable: False
|
37 |
-
input_key: txt
|
38 |
-
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
39 |
-
params:
|
40 |
-
freeze: true
|
41 |
-
layer: penultimate
|
42 |
-
|
43 |
-
first_stage_config:
|
44 |
-
target: sgm.models.autoencoder.AutoencoderKL
|
45 |
-
params:
|
46 |
-
embed_dim: 4
|
47 |
-
monitor: val/rec_loss
|
48 |
-
ddconfig:
|
49 |
-
double_z: true
|
50 |
-
z_channels: 4
|
51 |
-
resolution: 256
|
52 |
-
in_channels: 3
|
53 |
-
out_ch: 3
|
54 |
-
ch: 128
|
55 |
-
ch_mult: [1, 2, 4, 4]
|
56 |
-
num_res_blocks: 2
|
57 |
-
attn_resolutions: []
|
58 |
-
dropout: 0.0
|
59 |
-
lossconfig:
|
60 |
-
target: torch.nn.Identity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,93 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
target: sgm.models.diffusion.DiffusionEngine
|
3 |
-
params:
|
4 |
-
scale_factor: 0.13025
|
5 |
-
disable_first_stage_autocast: True
|
6 |
-
|
7 |
-
denoiser_config:
|
8 |
-
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
9 |
-
params:
|
10 |
-
num_idx: 1000
|
11 |
-
|
12 |
-
scaling_config:
|
13 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
14 |
-
discretization_config:
|
15 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
16 |
-
|
17 |
-
network_config:
|
18 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
-
params:
|
20 |
-
adm_in_channels: 2816
|
21 |
-
num_classes: sequential
|
22 |
-
use_checkpoint: True
|
23 |
-
in_channels: 4
|
24 |
-
out_channels: 4
|
25 |
-
model_channels: 320
|
26 |
-
attention_resolutions: [4, 2]
|
27 |
-
num_res_blocks: 2
|
28 |
-
channel_mult: [1, 2, 4]
|
29 |
-
num_head_channels: 64
|
30 |
-
use_linear_in_transformer: True
|
31 |
-
transformer_depth: [1, 2, 10]
|
32 |
-
context_dim: 2048
|
33 |
-
spatial_transformer_attn_type: softmax-xformers
|
34 |
-
|
35 |
-
conditioner_config:
|
36 |
-
target: sgm.modules.GeneralConditioner
|
37 |
-
params:
|
38 |
-
emb_models:
|
39 |
-
- is_trainable: False
|
40 |
-
input_key: txt
|
41 |
-
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
42 |
-
params:
|
43 |
-
layer: hidden
|
44 |
-
layer_idx: 11
|
45 |
-
|
46 |
-
- is_trainable: False
|
47 |
-
input_key: txt
|
48 |
-
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
49 |
-
params:
|
50 |
-
arch: ViT-bigG-14
|
51 |
-
version: laion2b_s39b_b160k
|
52 |
-
freeze: True
|
53 |
-
layer: penultimate
|
54 |
-
always_return_pooled: True
|
55 |
-
legacy: False
|
56 |
-
|
57 |
-
- is_trainable: False
|
58 |
-
input_key: original_size_as_tuple
|
59 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
60 |
-
params:
|
61 |
-
outdim: 256
|
62 |
-
|
63 |
-
- is_trainable: False
|
64 |
-
input_key: crop_coords_top_left
|
65 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
66 |
-
params:
|
67 |
-
outdim: 256
|
68 |
-
|
69 |
-
- is_trainable: False
|
70 |
-
input_key: target_size_as_tuple
|
71 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
72 |
-
params:
|
73 |
-
outdim: 256
|
74 |
-
|
75 |
-
first_stage_config:
|
76 |
-
target: sgm.models.autoencoder.AutoencoderKL
|
77 |
-
params:
|
78 |
-
embed_dim: 4
|
79 |
-
monitor: val/rec_loss
|
80 |
-
ddconfig:
|
81 |
-
attn_type: vanilla-xformers
|
82 |
-
double_z: true
|
83 |
-
z_channels: 4
|
84 |
-
resolution: 256
|
85 |
-
in_channels: 3
|
86 |
-
out_ch: 3
|
87 |
-
ch: 128
|
88 |
-
ch_mult: [1, 2, 4, 4]
|
89 |
-
num_res_blocks: 2
|
90 |
-
attn_resolutions: []
|
91 |
-
dropout: 0.0
|
92 |
-
lossconfig:
|
93 |
-
target: torch.nn.Identity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,86 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
target: sgm.models.diffusion.DiffusionEngine
|
3 |
-
params:
|
4 |
-
scale_factor: 0.13025
|
5 |
-
disable_first_stage_autocast: True
|
6 |
-
|
7 |
-
denoiser_config:
|
8 |
-
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
9 |
-
params:
|
10 |
-
num_idx: 1000
|
11 |
-
|
12 |
-
scaling_config:
|
13 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
14 |
-
discretization_config:
|
15 |
-
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
16 |
-
|
17 |
-
network_config:
|
18 |
-
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
-
params:
|
20 |
-
adm_in_channels: 2560
|
21 |
-
num_classes: sequential
|
22 |
-
use_checkpoint: True
|
23 |
-
in_channels: 4
|
24 |
-
out_channels: 4
|
25 |
-
model_channels: 384
|
26 |
-
attention_resolutions: [4, 2]
|
27 |
-
num_res_blocks: 2
|
28 |
-
channel_mult: [1, 2, 4, 4]
|
29 |
-
num_head_channels: 64
|
30 |
-
use_linear_in_transformer: True
|
31 |
-
transformer_depth: 4
|
32 |
-
context_dim: [1280, 1280, 1280, 1280]
|
33 |
-
spatial_transformer_attn_type: softmax-xformers
|
34 |
-
|
35 |
-
conditioner_config:
|
36 |
-
target: sgm.modules.GeneralConditioner
|
37 |
-
params:
|
38 |
-
emb_models:
|
39 |
-
- is_trainable: False
|
40 |
-
input_key: txt
|
41 |
-
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
42 |
-
params:
|
43 |
-
arch: ViT-bigG-14
|
44 |
-
version: laion2b_s39b_b160k
|
45 |
-
legacy: False
|
46 |
-
freeze: True
|
47 |
-
layer: penultimate
|
48 |
-
always_return_pooled: True
|
49 |
-
|
50 |
-
- is_trainable: False
|
51 |
-
input_key: original_size_as_tuple
|
52 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
53 |
-
params:
|
54 |
-
outdim: 256
|
55 |
-
|
56 |
-
- is_trainable: False
|
57 |
-
input_key: crop_coords_top_left
|
58 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
59 |
-
params:
|
60 |
-
outdim: 256
|
61 |
-
|
62 |
-
- is_trainable: False
|
63 |
-
input_key: aesthetic_score
|
64 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
65 |
-
params:
|
66 |
-
outdim: 256
|
67 |
-
|
68 |
-
first_stage_config:
|
69 |
-
target: sgm.models.autoencoder.AutoencoderKL
|
70 |
-
params:
|
71 |
-
embed_dim: 4
|
72 |
-
monitor: val/rec_loss
|
73 |
-
ddconfig:
|
74 |
-
attn_type: vanilla-xformers
|
75 |
-
double_z: true
|
76 |
-
z_channels: 4
|
77 |
-
resolution: 256
|
78 |
-
in_channels: 3
|
79 |
-
out_ch: 3
|
80 |
-
ch: 128
|
81 |
-
ch_mult: [1, 2, 4, 4]
|
82 |
-
num_res_blocks: 2
|
83 |
-
attn_resolutions: []
|
84 |
-
dropout: 0.0
|
85 |
-
lossconfig:
|
86 |
-
target: torch.nn.Identity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,131 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
target: sgm.models.diffusion.DiffusionEngine
|
3 |
-
params:
|
4 |
-
scale_factor: 0.18215
|
5 |
-
disable_first_stage_autocast: True
|
6 |
-
|
7 |
-
denoiser_config:
|
8 |
-
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
9 |
-
params:
|
10 |
-
scaling_config:
|
11 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
12 |
-
|
13 |
-
network_config:
|
14 |
-
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
15 |
-
params:
|
16 |
-
adm_in_channels: 768
|
17 |
-
num_classes: sequential
|
18 |
-
use_checkpoint: True
|
19 |
-
in_channels: 8
|
20 |
-
out_channels: 4
|
21 |
-
model_channels: 320
|
22 |
-
attention_resolutions: [4, 2, 1]
|
23 |
-
num_res_blocks: 2
|
24 |
-
channel_mult: [1, 2, 4, 4]
|
25 |
-
num_head_channels: 64
|
26 |
-
use_linear_in_transformer: True
|
27 |
-
transformer_depth: 1
|
28 |
-
context_dim: 1024
|
29 |
-
spatial_transformer_attn_type: softmax-xformers
|
30 |
-
extra_ff_mix_layer: True
|
31 |
-
use_spatial_context: True
|
32 |
-
merge_strategy: learned_with_images
|
33 |
-
video_kernel_size: [3, 1, 1]
|
34 |
-
|
35 |
-
conditioner_config:
|
36 |
-
target: sgm.modules.GeneralConditioner
|
37 |
-
params:
|
38 |
-
emb_models:
|
39 |
-
- is_trainable: False
|
40 |
-
input_key: cond_frames_without_noise
|
41 |
-
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
42 |
-
params:
|
43 |
-
n_cond_frames: 1
|
44 |
-
n_copies: 1
|
45 |
-
open_clip_embedding_config:
|
46 |
-
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
47 |
-
params:
|
48 |
-
freeze: True
|
49 |
-
|
50 |
-
- input_key: fps_id
|
51 |
-
is_trainable: False
|
52 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
53 |
-
params:
|
54 |
-
outdim: 256
|
55 |
-
|
56 |
-
- input_key: motion_bucket_id
|
57 |
-
is_trainable: False
|
58 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
59 |
-
params:
|
60 |
-
outdim: 256
|
61 |
-
|
62 |
-
- input_key: cond_frames
|
63 |
-
is_trainable: False
|
64 |
-
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
65 |
-
params:
|
66 |
-
disable_encoder_autocast: True
|
67 |
-
n_cond_frames: 1
|
68 |
-
n_copies: 1
|
69 |
-
is_ae: True
|
70 |
-
encoder_config:
|
71 |
-
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
72 |
-
params:
|
73 |
-
embed_dim: 4
|
74 |
-
monitor: val/rec_loss
|
75 |
-
ddconfig:
|
76 |
-
attn_type: vanilla-xformers
|
77 |
-
double_z: True
|
78 |
-
z_channels: 4
|
79 |
-
resolution: 256
|
80 |
-
in_channels: 3
|
81 |
-
out_ch: 3
|
82 |
-
ch: 128
|
83 |
-
ch_mult: [1, 2, 4, 4]
|
84 |
-
num_res_blocks: 2
|
85 |
-
attn_resolutions: []
|
86 |
-
dropout: 0.0
|
87 |
-
lossconfig:
|
88 |
-
target: torch.nn.Identity
|
89 |
-
|
90 |
-
- input_key: cond_aug
|
91 |
-
is_trainable: False
|
92 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
93 |
-
params:
|
94 |
-
outdim: 256
|
95 |
-
|
96 |
-
first_stage_config:
|
97 |
-
target: sgm.models.autoencoder.AutoencodingEngine
|
98 |
-
params:
|
99 |
-
loss_config:
|
100 |
-
target: torch.nn.Identity
|
101 |
-
regularizer_config:
|
102 |
-
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
103 |
-
encoder_config:
|
104 |
-
target: sgm.modules.diffusionmodules.model.Encoder
|
105 |
-
params:
|
106 |
-
attn_type: vanilla
|
107 |
-
double_z: True
|
108 |
-
z_channels: 4
|
109 |
-
resolution: 256
|
110 |
-
in_channels: 3
|
111 |
-
out_ch: 3
|
112 |
-
ch: 128
|
113 |
-
ch_mult: [1, 2, 4, 4]
|
114 |
-
num_res_blocks: 2
|
115 |
-
attn_resolutions: []
|
116 |
-
dropout: 0.0
|
117 |
-
decoder_config:
|
118 |
-
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
119 |
-
params:
|
120 |
-
attn_type: vanilla
|
121 |
-
double_z: True
|
122 |
-
z_channels: 4
|
123 |
-
resolution: 256
|
124 |
-
in_channels: 3
|
125 |
-
out_ch: 3
|
126 |
-
ch: 128
|
127 |
-
ch_mult: [1, 2, 4, 4]
|
128 |
-
num_res_blocks: 2
|
129 |
-
attn_resolutions: []
|
130 |
-
dropout: 0.0
|
131 |
-
video_kernel_size: [3, 1, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,114 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
target: sgm.models.diffusion.DiffusionEngine
|
3 |
-
params:
|
4 |
-
scale_factor: 0.18215
|
5 |
-
disable_first_stage_autocast: True
|
6 |
-
|
7 |
-
denoiser_config:
|
8 |
-
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
9 |
-
params:
|
10 |
-
scaling_config:
|
11 |
-
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
12 |
-
|
13 |
-
network_config:
|
14 |
-
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
15 |
-
params:
|
16 |
-
adm_in_channels: 768
|
17 |
-
num_classes: sequential
|
18 |
-
use_checkpoint: True
|
19 |
-
in_channels: 8
|
20 |
-
out_channels: 4
|
21 |
-
model_channels: 320
|
22 |
-
attention_resolutions: [4, 2, 1]
|
23 |
-
num_res_blocks: 2
|
24 |
-
channel_mult: [1, 2, 4, 4]
|
25 |
-
num_head_channels: 64
|
26 |
-
use_linear_in_transformer: True
|
27 |
-
transformer_depth: 1
|
28 |
-
context_dim: 1024
|
29 |
-
spatial_transformer_attn_type: softmax-xformers
|
30 |
-
extra_ff_mix_layer: True
|
31 |
-
use_spatial_context: True
|
32 |
-
merge_strategy: learned_with_images
|
33 |
-
video_kernel_size: [3, 1, 1]
|
34 |
-
|
35 |
-
conditioner_config:
|
36 |
-
target: sgm.modules.GeneralConditioner
|
37 |
-
params:
|
38 |
-
emb_models:
|
39 |
-
- is_trainable: False
|
40 |
-
input_key: cond_frames_without_noise
|
41 |
-
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
42 |
-
params:
|
43 |
-
n_cond_frames: 1
|
44 |
-
n_copies: 1
|
45 |
-
open_clip_embedding_config:
|
46 |
-
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
47 |
-
params:
|
48 |
-
freeze: True
|
49 |
-
|
50 |
-
- input_key: fps_id
|
51 |
-
is_trainable: False
|
52 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
53 |
-
params:
|
54 |
-
outdim: 256
|
55 |
-
|
56 |
-
- input_key: motion_bucket_id
|
57 |
-
is_trainable: False
|
58 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
59 |
-
params:
|
60 |
-
outdim: 256
|
61 |
-
|
62 |
-
- input_key: cond_frames
|
63 |
-
is_trainable: False
|
64 |
-
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
65 |
-
params:
|
66 |
-
disable_encoder_autocast: True
|
67 |
-
n_cond_frames: 1
|
68 |
-
n_copies: 1
|
69 |
-
is_ae: True
|
70 |
-
encoder_config:
|
71 |
-
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
72 |
-
params:
|
73 |
-
embed_dim: 4
|
74 |
-
monitor: val/rec_loss
|
75 |
-
ddconfig:
|
76 |
-
attn_type: vanilla-xformers
|
77 |
-
double_z: True
|
78 |
-
z_channels: 4
|
79 |
-
resolution: 256
|
80 |
-
in_channels: 3
|
81 |
-
out_ch: 3
|
82 |
-
ch: 128
|
83 |
-
ch_mult: [1, 2, 4, 4]
|
84 |
-
num_res_blocks: 2
|
85 |
-
attn_resolutions: []
|
86 |
-
dropout: 0.0
|
87 |
-
lossconfig:
|
88 |
-
target: torch.nn.Identity
|
89 |
-
|
90 |
-
- input_key: cond_aug
|
91 |
-
is_trainable: False
|
92 |
-
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
93 |
-
params:
|
94 |
-
outdim: 256
|
95 |
-
|
96 |
-
first_stage_config:
|
97 |
-
target: sgm.models.autoencoder.AutoencoderKL
|
98 |
-
params:
|
99 |
-
embed_dim: 4
|
100 |
-
monitor: val/rec_loss
|
101 |
-
ddconfig:
|
102 |
-
attn_type: vanilla-xformers
|
103 |
-
double_z: True
|
104 |
-
z_channels: 4
|
105 |
-
resolution: 256
|
106 |
-
in_channels: 3
|
107 |
-
out_ch: 3
|
108 |
-
ch: 128
|
109 |
-
ch_mult: [1, 2, 4, 4]
|
110 |
-
num_res_blocks: 2
|
111 |
-
attn_resolutions: []
|
112 |
-
dropout: 0.0
|
113 |
-
lossconfig:
|
114 |
-
target: torch.nn.Identity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Binary file (757 kB)
|
|
@@ -1,943 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import datetime
|
3 |
-
import glob
|
4 |
-
import inspect
|
5 |
-
import os
|
6 |
-
import sys
|
7 |
-
from inspect import Parameter
|
8 |
-
from typing import Union
|
9 |
-
|
10 |
-
import numpy as np
|
11 |
-
import pytorch_lightning as pl
|
12 |
-
import torch
|
13 |
-
import torchvision
|
14 |
-
import wandb
|
15 |
-
from matplotlib import pyplot as plt
|
16 |
-
from natsort import natsorted
|
17 |
-
from omegaconf import OmegaConf
|
18 |
-
from packaging import version
|
19 |
-
from PIL import Image
|
20 |
-
from pytorch_lightning import seed_everything
|
21 |
-
from pytorch_lightning.callbacks import Callback
|
22 |
-
from pytorch_lightning.loggers import WandbLogger
|
23 |
-
from pytorch_lightning.trainer import Trainer
|
24 |
-
from pytorch_lightning.utilities import rank_zero_only
|
25 |
-
|
26 |
-
from sgm.util import exists, instantiate_from_config, isheatmap
|
27 |
-
|
28 |
-
MULTINODE_HACKS = True
|
29 |
-
|
30 |
-
|
31 |
-
def default_trainer_args():
|
32 |
-
argspec = dict(inspect.signature(Trainer.__init__).parameters)
|
33 |
-
argspec.pop("self")
|
34 |
-
default_args = {
|
35 |
-
param: argspec[param].default
|
36 |
-
for param in argspec
|
37 |
-
if argspec[param] != Parameter.empty
|
38 |
-
}
|
39 |
-
return default_args
|
40 |
-
|
41 |
-
|
42 |
-
def get_parser(**parser_kwargs):
|
43 |
-
def str2bool(v):
|
44 |
-
if isinstance(v, bool):
|
45 |
-
return v
|
46 |
-
if v.lower() in ("yes", "true", "t", "y", "1"):
|
47 |
-
return True
|
48 |
-
elif v.lower() in ("no", "false", "f", "n", "0"):
|
49 |
-
return False
|
50 |
-
else:
|
51 |
-
raise argparse.ArgumentTypeError("Boolean value expected.")
|
52 |
-
|
53 |
-
parser = argparse.ArgumentParser(**parser_kwargs)
|
54 |
-
parser.add_argument(
|
55 |
-
"-n",
|
56 |
-
"--name",
|
57 |
-
type=str,
|
58 |
-
const=True,
|
59 |
-
default="",
|
60 |
-
nargs="?",
|
61 |
-
help="postfix for logdir",
|
62 |
-
)
|
63 |
-
parser.add_argument(
|
64 |
-
"--no_date",
|
65 |
-
type=str2bool,
|
66 |
-
nargs="?",
|
67 |
-
const=True,
|
68 |
-
default=False,
|
69 |
-
help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)",
|
70 |
-
)
|
71 |
-
parser.add_argument(
|
72 |
-
"-r",
|
73 |
-
"--resume",
|
74 |
-
type=str,
|
75 |
-
const=True,
|
76 |
-
default="",
|
77 |
-
nargs="?",
|
78 |
-
help="resume from logdir or checkpoint in logdir",
|
79 |
-
)
|
80 |
-
parser.add_argument(
|
81 |
-
"-b",
|
82 |
-
"--base",
|
83 |
-
nargs="*",
|
84 |
-
metavar="base_config.yaml",
|
85 |
-
help="paths to base configs. Loaded from left-to-right. "
|
86 |
-
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
87 |
-
default=list(),
|
88 |
-
)
|
89 |
-
parser.add_argument(
|
90 |
-
"-t",
|
91 |
-
"--train",
|
92 |
-
type=str2bool,
|
93 |
-
const=True,
|
94 |
-
default=True,
|
95 |
-
nargs="?",
|
96 |
-
help="train",
|
97 |
-
)
|
98 |
-
parser.add_argument(
|
99 |
-
"--no-test",
|
100 |
-
type=str2bool,
|
101 |
-
const=True,
|
102 |
-
default=False,
|
103 |
-
nargs="?",
|
104 |
-
help="disable test",
|
105 |
-
)
|
106 |
-
parser.add_argument(
|
107 |
-
"-p", "--project", help="name of new or path to existing project"
|
108 |
-
)
|
109 |
-
parser.add_argument(
|
110 |
-
"-d",
|
111 |
-
"--debug",
|
112 |
-
type=str2bool,
|
113 |
-
nargs="?",
|
114 |
-
const=True,
|
115 |
-
default=False,
|
116 |
-
help="enable post-mortem debugging",
|
117 |
-
)
|
118 |
-
parser.add_argument(
|
119 |
-
"-s",
|
120 |
-
"--seed",
|
121 |
-
type=int,
|
122 |
-
default=23,
|
123 |
-
help="seed for seed_everything",
|
124 |
-
)
|
125 |
-
parser.add_argument(
|
126 |
-
"-f",
|
127 |
-
"--postfix",
|
128 |
-
type=str,
|
129 |
-
default="",
|
130 |
-
help="post-postfix for default name",
|
131 |
-
)
|
132 |
-
parser.add_argument(
|
133 |
-
"--projectname",
|
134 |
-
type=str,
|
135 |
-
default="stablediffusion",
|
136 |
-
)
|
137 |
-
parser.add_argument(
|
138 |
-
"-l",
|
139 |
-
"--logdir",
|
140 |
-
type=str,
|
141 |
-
default="logs",
|
142 |
-
help="directory for logging dat shit",
|
143 |
-
)
|
144 |
-
parser.add_argument(
|
145 |
-
"--scale_lr",
|
146 |
-
type=str2bool,
|
147 |
-
nargs="?",
|
148 |
-
const=True,
|
149 |
-
default=False,
|
150 |
-
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
151 |
-
)
|
152 |
-
parser.add_argument(
|
153 |
-
"--legacy_naming",
|
154 |
-
type=str2bool,
|
155 |
-
nargs="?",
|
156 |
-
const=True,
|
157 |
-
default=False,
|
158 |
-
help="name run based on config file name if true, else by whole path",
|
159 |
-
)
|
160 |
-
parser.add_argument(
|
161 |
-
"--enable_tf32",
|
162 |
-
type=str2bool,
|
163 |
-
nargs="?",
|
164 |
-
const=True,
|
165 |
-
default=False,
|
166 |
-
help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12",
|
167 |
-
)
|
168 |
-
parser.add_argument(
|
169 |
-
"--startup",
|
170 |
-
type=str,
|
171 |
-
default=None,
|
172 |
-
help="Startuptime from distributed script",
|
173 |
-
)
|
174 |
-
parser.add_argument(
|
175 |
-
"--wandb",
|
176 |
-
type=str2bool,
|
177 |
-
nargs="?",
|
178 |
-
const=True,
|
179 |
-
default=False, # TODO: later default to True
|
180 |
-
help="log to wandb",
|
181 |
-
)
|
182 |
-
parser.add_argument(
|
183 |
-
"--no_base_name",
|
184 |
-
type=str2bool,
|
185 |
-
nargs="?",
|
186 |
-
const=True,
|
187 |
-
default=False, # TODO: later default to True
|
188 |
-
help="log to wandb",
|
189 |
-
)
|
190 |
-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
191 |
-
parser.add_argument(
|
192 |
-
"--resume_from_checkpoint",
|
193 |
-
type=str,
|
194 |
-
default=None,
|
195 |
-
help="single checkpoint file to resume from",
|
196 |
-
)
|
197 |
-
default_args = default_trainer_args()
|
198 |
-
for key in default_args:
|
199 |
-
parser.add_argument("--" + key, default=default_args[key])
|
200 |
-
return parser
|
201 |
-
|
202 |
-
|
203 |
-
def get_checkpoint_name(logdir):
|
204 |
-
ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt")
|
205 |
-
ckpt = natsorted(glob.glob(ckpt))
|
206 |
-
print('available "last" checkpoints:')
|
207 |
-
print(ckpt)
|
208 |
-
if len(ckpt) > 1:
|
209 |
-
print("got most recent checkpoint")
|
210 |
-
ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1]
|
211 |
-
print(f"Most recent ckpt is {ckpt}")
|
212 |
-
with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f:
|
213 |
-
f.write(ckpt + "\n")
|
214 |
-
try:
|
215 |
-
version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0])
|
216 |
-
except Exception as e:
|
217 |
-
print("version confusion but not bad")
|
218 |
-
print(e)
|
219 |
-
version = 1
|
220 |
-
# version = last_version + 1
|
221 |
-
else:
|
222 |
-
# in this case, we only have one "last.ckpt"
|
223 |
-
ckpt = ckpt[0]
|
224 |
-
version = 1
|
225 |
-
melk_ckpt_name = f"last-v{version}.ckpt"
|
226 |
-
print(f"Current melk ckpt name: {melk_ckpt_name}")
|
227 |
-
return ckpt, melk_ckpt_name
|
228 |
-
|
229 |
-
|
230 |
-
class SetupCallback(Callback):
|
231 |
-
def __init__(
|
232 |
-
self,
|
233 |
-
resume,
|
234 |
-
now,
|
235 |
-
logdir,
|
236 |
-
ckptdir,
|
237 |
-
cfgdir,
|
238 |
-
config,
|
239 |
-
lightning_config,
|
240 |
-
debug,
|
241 |
-
ckpt_name=None,
|
242 |
-
):
|
243 |
-
super().__init__()
|
244 |
-
self.resume = resume
|
245 |
-
self.now = now
|
246 |
-
self.logdir = logdir
|
247 |
-
self.ckptdir = ckptdir
|
248 |
-
self.cfgdir = cfgdir
|
249 |
-
self.config = config
|
250 |
-
self.lightning_config = lightning_config
|
251 |
-
self.debug = debug
|
252 |
-
self.ckpt_name = ckpt_name
|
253 |
-
|
254 |
-
def on_exception(self, trainer: pl.Trainer, pl_module, exception):
|
255 |
-
if not self.debug and trainer.global_rank == 0:
|
256 |
-
print("Summoning checkpoint.")
|
257 |
-
if self.ckpt_name is None:
|
258 |
-
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
259 |
-
else:
|
260 |
-
ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
|
261 |
-
trainer.save_checkpoint(ckpt_path)
|
262 |
-
|
263 |
-
def on_fit_start(self, trainer, pl_module):
|
264 |
-
if trainer.global_rank == 0:
|
265 |
-
# Create logdirs and save configs
|
266 |
-
os.makedirs(self.logdir, exist_ok=True)
|
267 |
-
os.makedirs(self.ckptdir, exist_ok=True)
|
268 |
-
os.makedirs(self.cfgdir, exist_ok=True)
|
269 |
-
|
270 |
-
if "callbacks" in self.lightning_config:
|
271 |
-
if (
|
272 |
-
"metrics_over_trainsteps_checkpoint"
|
273 |
-
in self.lightning_config["callbacks"]
|
274 |
-
):
|
275 |
-
os.makedirs(
|
276 |
-
os.path.join(self.ckptdir, "trainstep_checkpoints"),
|
277 |
-
exist_ok=True,
|
278 |
-
)
|
279 |
-
print("Project config")
|
280 |
-
print(OmegaConf.to_yaml(self.config))
|
281 |
-
if MULTINODE_HACKS:
|
282 |
-
import time
|
283 |
-
|
284 |
-
time.sleep(5)
|
285 |
-
OmegaConf.save(
|
286 |
-
self.config,
|
287 |
-
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
|
288 |
-
)
|
289 |
-
|
290 |
-
print("Lightning config")
|
291 |
-
print(OmegaConf.to_yaml(self.lightning_config))
|
292 |
-
OmegaConf.save(
|
293 |
-
OmegaConf.create({"lightning": self.lightning_config}),
|
294 |
-
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
|
295 |
-
)
|
296 |
-
|
297 |
-
else:
|
298 |
-
# ModelCheckpoint callback created log directory --- remove it
|
299 |
-
if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
|
300 |
-
dst, name = os.path.split(self.logdir)
|
301 |
-
dst = os.path.join(dst, "child_runs", name)
|
302 |
-
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
303 |
-
try:
|
304 |
-
os.rename(self.logdir, dst)
|
305 |
-
except FileNotFoundError:
|
306 |
-
pass
|
307 |
-
|
308 |
-
|
309 |
-
class ImageLogger(Callback):
|
310 |
-
def __init__(
|
311 |
-
self,
|
312 |
-
batch_frequency,
|
313 |
-
max_images,
|
314 |
-
clamp=True,
|
315 |
-
increase_log_steps=True,
|
316 |
-
rescale=True,
|
317 |
-
disabled=False,
|
318 |
-
log_on_batch_idx=False,
|
319 |
-
log_first_step=False,
|
320 |
-
log_images_kwargs=None,
|
321 |
-
log_before_first_step=False,
|
322 |
-
enable_autocast=True,
|
323 |
-
):
|
324 |
-
super().__init__()
|
325 |
-
self.enable_autocast = enable_autocast
|
326 |
-
self.rescale = rescale
|
327 |
-
self.batch_freq = batch_frequency
|
328 |
-
self.max_images = max_images
|
329 |
-
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
330 |
-
if not increase_log_steps:
|
331 |
-
self.log_steps = [self.batch_freq]
|
332 |
-
self.clamp = clamp
|
333 |
-
self.disabled = disabled
|
334 |
-
self.log_on_batch_idx = log_on_batch_idx
|
335 |
-
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
336 |
-
self.log_first_step = log_first_step
|
337 |
-
self.log_before_first_step = log_before_first_step
|
338 |
-
|
339 |
-
@rank_zero_only
|
340 |
-
def log_local(
|
341 |
-
self,
|
342 |
-
save_dir,
|
343 |
-
split,
|
344 |
-
images,
|
345 |
-
global_step,
|
346 |
-
current_epoch,
|
347 |
-
batch_idx,
|
348 |
-
pl_module: Union[None, pl.LightningModule] = None,
|
349 |
-
):
|
350 |
-
root = os.path.join(save_dir, "images", split)
|
351 |
-
for k in images:
|
352 |
-
if isheatmap(images[k]):
|
353 |
-
fig, ax = plt.subplots()
|
354 |
-
ax = ax.matshow(
|
355 |
-
images[k].cpu().numpy(), cmap="hot", interpolation="lanczos"
|
356 |
-
)
|
357 |
-
plt.colorbar(ax)
|
358 |
-
plt.axis("off")
|
359 |
-
|
360 |
-
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
361 |
-
k, global_step, current_epoch, batch_idx
|
362 |
-
)
|
363 |
-
os.makedirs(root, exist_ok=True)
|
364 |
-
path = os.path.join(root, filename)
|
365 |
-
plt.savefig(path)
|
366 |
-
plt.close()
|
367 |
-
# TODO: support wandb
|
368 |
-
else:
|
369 |
-
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
370 |
-
if self.rescale:
|
371 |
-
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
372 |
-
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
373 |
-
grid = grid.numpy()
|
374 |
-
grid = (grid * 255).astype(np.uint8)
|
375 |
-
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
376 |
-
k, global_step, current_epoch, batch_idx
|
377 |
-
)
|
378 |
-
path = os.path.join(root, filename)
|
379 |
-
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
380 |
-
img = Image.fromarray(grid)
|
381 |
-
img.save(path)
|
382 |
-
if exists(pl_module):
|
383 |
-
assert isinstance(
|
384 |
-
pl_module.logger, WandbLogger
|
385 |
-
), "logger_log_image only supports WandbLogger currently"
|
386 |
-
pl_module.logger.log_image(
|
387 |
-
key=f"{split}/{k}",
|
388 |
-
images=[
|
389 |
-
img,
|
390 |
-
],
|
391 |
-
step=pl_module.global_step,
|
392 |
-
)
|
393 |
-
|
394 |
-
@rank_zero_only
|
395 |
-
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
396 |
-
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
397 |
-
if (
|
398 |
-
self.check_frequency(check_idx)
|
399 |
-
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
|
400 |
-
and callable(pl_module.log_images)
|
401 |
-
and
|
402 |
-
# batch_idx > 5 and
|
403 |
-
self.max_images > 0
|
404 |
-
):
|
405 |
-
logger = type(pl_module.logger)
|
406 |
-
is_train = pl_module.training
|
407 |
-
if is_train:
|
408 |
-
pl_module.eval()
|
409 |
-
|
410 |
-
gpu_autocast_kwargs = {
|
411 |
-
"enabled": self.enable_autocast, # torch.is_autocast_enabled(),
|
412 |
-
"dtype": torch.get_autocast_gpu_dtype(),
|
413 |
-
"cache_enabled": torch.is_autocast_cache_enabled(),
|
414 |
-
}
|
415 |
-
with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
|
416 |
-
images = pl_module.log_images(
|
417 |
-
batch, split=split, **self.log_images_kwargs
|
418 |
-
)
|
419 |
-
|
420 |
-
for k in images:
|
421 |
-
N = min(images[k].shape[0], self.max_images)
|
422 |
-
if not isheatmap(images[k]):
|
423 |
-
images[k] = images[k][:N]
|
424 |
-
if isinstance(images[k], torch.Tensor):
|
425 |
-
images[k] = images[k].detach().float().cpu()
|
426 |
-
if self.clamp and not isheatmap(images[k]):
|
427 |
-
images[k] = torch.clamp(images[k], -1.0, 1.0)
|
428 |
-
|
429 |
-
self.log_local(
|
430 |
-
pl_module.logger.save_dir,
|
431 |
-
split,
|
432 |
-
images,
|
433 |
-
pl_module.global_step,
|
434 |
-
pl_module.current_epoch,
|
435 |
-
batch_idx,
|
436 |
-
pl_module=pl_module
|
437 |
-
if isinstance(pl_module.logger, WandbLogger)
|
438 |
-
else None,
|
439 |
-
)
|
440 |
-
|
441 |
-
if is_train:
|
442 |
-
pl_module.train()
|
443 |
-
|
444 |
-
def check_frequency(self, check_idx):
|
445 |
-
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
446 |
-
check_idx > 0 or self.log_first_step
|
447 |
-
):
|
448 |
-
try:
|
449 |
-
self.log_steps.pop(0)
|
450 |
-
except IndexError as e:
|
451 |
-
print(e)
|
452 |
-
pass
|
453 |
-
return True
|
454 |
-
return False
|
455 |
-
|
456 |
-
@rank_zero_only
|
457 |
-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
458 |
-
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
459 |
-
self.log_img(pl_module, batch, batch_idx, split="train")
|
460 |
-
|
461 |
-
@rank_zero_only
|
462 |
-
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
463 |
-
if self.log_before_first_step and pl_module.global_step == 0:
|
464 |
-
print(f"{self.__class__.__name__}: logging before training")
|
465 |
-
self.log_img(pl_module, batch, batch_idx, split="train")
|
466 |
-
|
467 |
-
@rank_zero_only
|
468 |
-
def on_validation_batch_end(
|
469 |
-
self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs
|
470 |
-
):
|
471 |
-
if not self.disabled and pl_module.global_step > 0:
|
472 |
-
self.log_img(pl_module, batch, batch_idx, split="val")
|
473 |
-
if hasattr(pl_module, "calibrate_grad_norm"):
|
474 |
-
if (
|
475 |
-
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
|
476 |
-
) and batch_idx > 0:
|
477 |
-
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
478 |
-
|
479 |
-
|
480 |
-
@rank_zero_only
|
481 |
-
def init_wandb(save_dir, opt, config, group_name, name_str):
|
482 |
-
print(f"setting WANDB_DIR to {save_dir}")
|
483 |
-
os.makedirs(save_dir, exist_ok=True)
|
484 |
-
|
485 |
-
os.environ["WANDB_DIR"] = save_dir
|
486 |
-
if opt.debug:
|
487 |
-
wandb.init(project=opt.projectname, mode="offline", group=group_name)
|
488 |
-
else:
|
489 |
-
wandb.init(
|
490 |
-
project=opt.projectname,
|
491 |
-
config=config,
|
492 |
-
settings=wandb.Settings(code_dir="./sgm"),
|
493 |
-
group=group_name,
|
494 |
-
name=name_str,
|
495 |
-
)
|
496 |
-
|
497 |
-
|
498 |
-
if __name__ == "__main__":
|
499 |
-
# custom parser to specify config files, train, test and debug mode,
|
500 |
-
# postfix, resume.
|
501 |
-
# `--key value` arguments are interpreted as arguments to the trainer.
|
502 |
-
# `nested.key=value` arguments are interpreted as config parameters.
|
503 |
-
# configs are merged from left-to-right followed by command line parameters.
|
504 |
-
|
505 |
-
# model:
|
506 |
-
# base_learning_rate: float
|
507 |
-
# target: path to lightning module
|
508 |
-
# params:
|
509 |
-
# key: value
|
510 |
-
# data:
|
511 |
-
# target: main.DataModuleFromConfig
|
512 |
-
# params:
|
513 |
-
# batch_size: int
|
514 |
-
# wrap: bool
|
515 |
-
# train:
|
516 |
-
# target: path to train dataset
|
517 |
-
# params:
|
518 |
-
# key: value
|
519 |
-
# validation:
|
520 |
-
# target: path to validation dataset
|
521 |
-
# params:
|
522 |
-
# key: value
|
523 |
-
# test:
|
524 |
-
# target: path to test dataset
|
525 |
-
# params:
|
526 |
-
# key: value
|
527 |
-
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
528 |
-
# trainer:
|
529 |
-
# additional arguments to trainer
|
530 |
-
# logger:
|
531 |
-
# logger to instantiate
|
532 |
-
# modelcheckpoint:
|
533 |
-
# modelcheckpoint to instantiate
|
534 |
-
# callbacks:
|
535 |
-
# callback1:
|
536 |
-
# target: importpath
|
537 |
-
# params:
|
538 |
-
# key: value
|
539 |
-
|
540 |
-
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
541 |
-
|
542 |
-
# add cwd for convenience and to make classes in this file available when
|
543 |
-
# running as `python main.py`
|
544 |
-
# (in particular `main.DataModuleFromConfig`)
|
545 |
-
sys.path.append(os.getcwd())
|
546 |
-
|
547 |
-
parser = get_parser()
|
548 |
-
|
549 |
-
opt, unknown = parser.parse_known_args()
|
550 |
-
|
551 |
-
if opt.name and opt.resume:
|
552 |
-
raise ValueError(
|
553 |
-
"-n/--name and -r/--resume cannot be specified both."
|
554 |
-
"If you want to resume training in a new log folder, "
|
555 |
-
"use -n/--name in combination with --resume_from_checkpoint"
|
556 |
-
)
|
557 |
-
melk_ckpt_name = None
|
558 |
-
name = None
|
559 |
-
if opt.resume:
|
560 |
-
if not os.path.exists(opt.resume):
|
561 |
-
raise ValueError("Cannot find {}".format(opt.resume))
|
562 |
-
if os.path.isfile(opt.resume):
|
563 |
-
paths = opt.resume.split("/")
|
564 |
-
# idx = len(paths)-paths[::-1].index("logs")+1
|
565 |
-
# logdir = "/".join(paths[:idx])
|
566 |
-
logdir = "/".join(paths[:-2])
|
567 |
-
ckpt = opt.resume
|
568 |
-
_, melk_ckpt_name = get_checkpoint_name(logdir)
|
569 |
-
else:
|
570 |
-
assert os.path.isdir(opt.resume), opt.resume
|
571 |
-
logdir = opt.resume.rstrip("/")
|
572 |
-
ckpt, melk_ckpt_name = get_checkpoint_name(logdir)
|
573 |
-
|
574 |
-
print("#" * 100)
|
575 |
-
print(f'Resuming from checkpoint "{ckpt}"')
|
576 |
-
print("#" * 100)
|
577 |
-
|
578 |
-
opt.resume_from_checkpoint = ckpt
|
579 |
-
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
580 |
-
opt.base = base_configs + opt.base
|
581 |
-
_tmp = logdir.split("/")
|
582 |
-
nowname = _tmp[-1]
|
583 |
-
else:
|
584 |
-
if opt.name:
|
585 |
-
name = "_" + opt.name
|
586 |
-
elif opt.base:
|
587 |
-
if opt.no_base_name:
|
588 |
-
name = ""
|
589 |
-
else:
|
590 |
-
if opt.legacy_naming:
|
591 |
-
cfg_fname = os.path.split(opt.base[0])[-1]
|
592 |
-
cfg_name = os.path.splitext(cfg_fname)[0]
|
593 |
-
else:
|
594 |
-
assert "configs" in os.path.split(opt.base[0])[0], os.path.split(
|
595 |
-
opt.base[0]
|
596 |
-
)[0]
|
597 |
-
cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[
|
598 |
-
os.path.split(opt.base[0])[0].split(os.sep).index("configs")
|
599 |
-
+ 1 :
|
600 |
-
] # cut away the first one (we assert all configs are in "configs")
|
601 |
-
cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0]
|
602 |
-
cfg_name = "-".join(cfg_path) + f"-{cfg_name}"
|
603 |
-
name = "_" + cfg_name
|
604 |
-
else:
|
605 |
-
name = ""
|
606 |
-
if not opt.no_date:
|
607 |
-
nowname = now + name + opt.postfix
|
608 |
-
else:
|
609 |
-
nowname = name + opt.postfix
|
610 |
-
if nowname.startswith("_"):
|
611 |
-
nowname = nowname[1:]
|
612 |
-
logdir = os.path.join(opt.logdir, nowname)
|
613 |
-
print(f"LOGDIR: {logdir}")
|
614 |
-
|
615 |
-
ckptdir = os.path.join(logdir, "checkpoints")
|
616 |
-
cfgdir = os.path.join(logdir, "configs")
|
617 |
-
seed_everything(opt.seed, workers=True)
|
618 |
-
|
619 |
-
# move before model init, in case a torch.compile(...) is called somewhere
|
620 |
-
if opt.enable_tf32:
|
621 |
-
# pt_version = version.parse(torch.__version__)
|
622 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
623 |
-
torch.backends.cudnn.allow_tf32 = True
|
624 |
-
print(f"Enabling TF32 for PyTorch {torch.__version__}")
|
625 |
-
else:
|
626 |
-
print(f"Using default TF32 settings for PyTorch {torch.__version__}:")
|
627 |
-
print(
|
628 |
-
f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}"
|
629 |
-
)
|
630 |
-
print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}")
|
631 |
-
|
632 |
-
try:
|
633 |
-
# init and save configs
|
634 |
-
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
635 |
-
cli = OmegaConf.from_dotlist(unknown)
|
636 |
-
config = OmegaConf.merge(*configs, cli)
|
637 |
-
lightning_config = config.pop("lightning", OmegaConf.create())
|
638 |
-
# merge trainer cli with config
|
639 |
-
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
640 |
-
|
641 |
-
# default to gpu
|
642 |
-
trainer_config["accelerator"] = "gpu"
|
643 |
-
#
|
644 |
-
standard_args = default_trainer_args()
|
645 |
-
for k in standard_args:
|
646 |
-
if getattr(opt, k) != standard_args[k]:
|
647 |
-
trainer_config[k] = getattr(opt, k)
|
648 |
-
|
649 |
-
ckpt_resume_path = opt.resume_from_checkpoint
|
650 |
-
|
651 |
-
if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu":
|
652 |
-
del trainer_config["accelerator"]
|
653 |
-
cpu = True
|
654 |
-
else:
|
655 |
-
gpuinfo = trainer_config["devices"]
|
656 |
-
print(f"Running on GPUs {gpuinfo}")
|
657 |
-
cpu = False
|
658 |
-
trainer_opt = argparse.Namespace(**trainer_config)
|
659 |
-
lightning_config.trainer = trainer_config
|
660 |
-
|
661 |
-
# model
|
662 |
-
model = instantiate_from_config(config.model)
|
663 |
-
|
664 |
-
# trainer and callbacks
|
665 |
-
trainer_kwargs = dict()
|
666 |
-
|
667 |
-
# default logger configs
|
668 |
-
default_logger_cfgs = {
|
669 |
-
"wandb": {
|
670 |
-
"target": "pytorch_lightning.loggers.WandbLogger",
|
671 |
-
"params": {
|
672 |
-
"name": nowname,
|
673 |
-
# "save_dir": logdir,
|
674 |
-
"offline": opt.debug,
|
675 |
-
"id": nowname,
|
676 |
-
"project": opt.projectname,
|
677 |
-
"log_model": False,
|
678 |
-
# "dir": logdir,
|
679 |
-
},
|
680 |
-
},
|
681 |
-
"csv": {
|
682 |
-
"target": "pytorch_lightning.loggers.CSVLogger",
|
683 |
-
"params": {
|
684 |
-
"name": "testtube", # hack for sbord fanatics
|
685 |
-
"save_dir": logdir,
|
686 |
-
},
|
687 |
-
},
|
688 |
-
}
|
689 |
-
default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"]
|
690 |
-
if opt.wandb:
|
691 |
-
# TODO change once leaving "swiffer" config directory
|
692 |
-
try:
|
693 |
-
group_name = nowname.split(now)[-1].split("-")[1]
|
694 |
-
except:
|
695 |
-
group_name = nowname
|
696 |
-
default_logger_cfg["params"]["group"] = group_name
|
697 |
-
init_wandb(
|
698 |
-
os.path.join(os.getcwd(), logdir),
|
699 |
-
opt=opt,
|
700 |
-
group_name=group_name,
|
701 |
-
config=config,
|
702 |
-
name_str=nowname,
|
703 |
-
)
|
704 |
-
if "logger" in lightning_config:
|
705 |
-
logger_cfg = lightning_config.logger
|
706 |
-
else:
|
707 |
-
logger_cfg = OmegaConf.create()
|
708 |
-
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
709 |
-
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
710 |
-
|
711 |
-
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
712 |
-
# specify which metric is used to determine best models
|
713 |
-
default_modelckpt_cfg = {
|
714 |
-
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
715 |
-
"params": {
|
716 |
-
"dirpath": ckptdir,
|
717 |
-
"filename": "{epoch:06}",
|
718 |
-
"verbose": True,
|
719 |
-
"save_last": True,
|
720 |
-
},
|
721 |
-
}
|
722 |
-
if hasattr(model, "monitor"):
|
723 |
-
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
724 |
-
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
725 |
-
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
726 |
-
|
727 |
-
if "modelcheckpoint" in lightning_config:
|
728 |
-
modelckpt_cfg = lightning_config.modelcheckpoint
|
729 |
-
else:
|
730 |
-
modelckpt_cfg = OmegaConf.create()
|
731 |
-
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
732 |
-
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
733 |
-
|
734 |
-
# https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html
|
735 |
-
# default to ddp if not further specified
|
736 |
-
default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"}
|
737 |
-
|
738 |
-
if "strategy" in lightning_config:
|
739 |
-
strategy_cfg = lightning_config.strategy
|
740 |
-
else:
|
741 |
-
strategy_cfg = OmegaConf.create()
|
742 |
-
default_strategy_config["params"] = {
|
743 |
-
"find_unused_parameters": False,
|
744 |
-
# "static_graph": True,
|
745 |
-
# "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded
|
746 |
-
}
|
747 |
-
strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg)
|
748 |
-
print(
|
749 |
-
f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ "
|
750 |
-
)
|
751 |
-
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
752 |
-
|
753 |
-
# add callback which sets up log directory
|
754 |
-
default_callbacks_cfg = {
|
755 |
-
"setup_callback": {
|
756 |
-
"target": "main.SetupCallback",
|
757 |
-
"params": {
|
758 |
-
"resume": opt.resume,
|
759 |
-
"now": now,
|
760 |
-
"logdir": logdir,
|
761 |
-
"ckptdir": ckptdir,
|
762 |
-
"cfgdir": cfgdir,
|
763 |
-
"config": config,
|
764 |
-
"lightning_config": lightning_config,
|
765 |
-
"debug": opt.debug,
|
766 |
-
"ckpt_name": melk_ckpt_name,
|
767 |
-
},
|
768 |
-
},
|
769 |
-
"image_logger": {
|
770 |
-
"target": "main.ImageLogger",
|
771 |
-
"params": {"batch_frequency": 1000, "max_images": 4, "clamp": True},
|
772 |
-
},
|
773 |
-
"learning_rate_logger": {
|
774 |
-
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
|
775 |
-
"params": {
|
776 |
-
"logging_interval": "step",
|
777 |
-
# "log_momentum": True
|
778 |
-
},
|
779 |
-
},
|
780 |
-
}
|
781 |
-
if version.parse(pl.__version__) >= version.parse("1.4.0"):
|
782 |
-
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
|
783 |
-
|
784 |
-
if "callbacks" in lightning_config:
|
785 |
-
callbacks_cfg = lightning_config.callbacks
|
786 |
-
else:
|
787 |
-
callbacks_cfg = OmegaConf.create()
|
788 |
-
|
789 |
-
if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
|
790 |
-
print(
|
791 |
-
"Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
|
792 |
-
)
|
793 |
-
default_metrics_over_trainsteps_ckpt_dict = {
|
794 |
-
"metrics_over_trainsteps_checkpoint": {
|
795 |
-
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
796 |
-
"params": {
|
797 |
-
"dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
|
798 |
-
"filename": "{epoch:06}-{step:09}",
|
799 |
-
"verbose": True,
|
800 |
-
"save_top_k": -1,
|
801 |
-
"every_n_train_steps": 10000,
|
802 |
-
"save_weights_only": True,
|
803 |
-
},
|
804 |
-
}
|
805 |
-
}
|
806 |
-
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
807 |
-
|
808 |
-
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
809 |
-
if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None:
|
810 |
-
callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path
|
811 |
-
elif "ignore_keys_callback" in callbacks_cfg:
|
812 |
-
del callbacks_cfg["ignore_keys_callback"]
|
813 |
-
|
814 |
-
trainer_kwargs["callbacks"] = [
|
815 |
-
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
|
816 |
-
]
|
817 |
-
if not "plugins" in trainer_kwargs:
|
818 |
-
trainer_kwargs["plugins"] = list()
|
819 |
-
|
820 |
-
# cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)
|
821 |
-
trainer_opt = vars(trainer_opt)
|
822 |
-
trainer_kwargs = {
|
823 |
-
key: val for key, val in trainer_kwargs.items() if key not in trainer_opt
|
824 |
-
}
|
825 |
-
trainer = Trainer(**trainer_opt, **trainer_kwargs)
|
826 |
-
|
827 |
-
trainer.logdir = logdir ###
|
828 |
-
|
829 |
-
# data
|
830 |
-
data = instantiate_from_config(config.data)
|
831 |
-
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
832 |
-
# calling these ourselves should not be necessary but it is.
|
833 |
-
# lightning still takes care of proper multiprocessing though
|
834 |
-
data.prepare_data()
|
835 |
-
# data.setup()
|
836 |
-
print("#### Data #####")
|
837 |
-
try:
|
838 |
-
for k in data.datasets:
|
839 |
-
print(
|
840 |
-
f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
|
841 |
-
)
|
842 |
-
except:
|
843 |
-
print("datasets not yet initialized.")
|
844 |
-
|
845 |
-
# configure learning rate
|
846 |
-
if "batch_size" in config.data.params:
|
847 |
-
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
848 |
-
else:
|
849 |
-
bs, base_lr = (
|
850 |
-
config.data.params.train.loader.batch_size,
|
851 |
-
config.model.base_learning_rate,
|
852 |
-
)
|
853 |
-
if not cpu:
|
854 |
-
ngpu = len(lightning_config.trainer.devices.strip(",").split(","))
|
855 |
-
else:
|
856 |
-
ngpu = 1
|
857 |
-
if "accumulate_grad_batches" in lightning_config.trainer:
|
858 |
-
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
859 |
-
else:
|
860 |
-
accumulate_grad_batches = 1
|
861 |
-
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
862 |
-
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
863 |
-
if opt.scale_lr:
|
864 |
-
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
865 |
-
print(
|
866 |
-
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
867 |
-
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
|
868 |
-
)
|
869 |
-
)
|
870 |
-
else:
|
871 |
-
model.learning_rate = base_lr
|
872 |
-
print("++++ NOT USING LR SCALING ++++")
|
873 |
-
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
874 |
-
|
875 |
-
# allow checkpointing via USR1
|
876 |
-
def melk(*args, **kwargs):
|
877 |
-
# run all checkpoint hooks
|
878 |
-
if trainer.global_rank == 0:
|
879 |
-
print("Summoning checkpoint.")
|
880 |
-
if melk_ckpt_name is None:
|
881 |
-
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
882 |
-
else:
|
883 |
-
ckpt_path = os.path.join(ckptdir, melk_ckpt_name)
|
884 |
-
trainer.save_checkpoint(ckpt_path)
|
885 |
-
|
886 |
-
def divein(*args, **kwargs):
|
887 |
-
if trainer.global_rank == 0:
|
888 |
-
import pudb
|
889 |
-
|
890 |
-
pudb.set_trace()
|
891 |
-
|
892 |
-
import signal
|
893 |
-
|
894 |
-
signal.signal(signal.SIGUSR1, melk)
|
895 |
-
signal.signal(signal.SIGUSR2, divein)
|
896 |
-
|
897 |
-
# run
|
898 |
-
if opt.train:
|
899 |
-
try:
|
900 |
-
trainer.fit(model, data, ckpt_path=ckpt_resume_path)
|
901 |
-
except Exception:
|
902 |
-
if not opt.debug:
|
903 |
-
melk()
|
904 |
-
raise
|
905 |
-
if not opt.no_test and not trainer.interrupted:
|
906 |
-
trainer.test(model, data)
|
907 |
-
except RuntimeError as err:
|
908 |
-
if MULTINODE_HACKS:
|
909 |
-
import datetime
|
910 |
-
import os
|
911 |
-
import socket
|
912 |
-
|
913 |
-
import requests
|
914 |
-
|
915 |
-
device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
|
916 |
-
hostname = socket.gethostname()
|
917 |
-
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
918 |
-
resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id")
|
919 |
-
print(
|
920 |
-
f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}",
|
921 |
-
flush=True,
|
922 |
-
)
|
923 |
-
raise err
|
924 |
-
except Exception:
|
925 |
-
if opt.debug and trainer.global_rank == 0:
|
926 |
-
try:
|
927 |
-
import pudb as debugger
|
928 |
-
except ImportError:
|
929 |
-
import pdb as debugger
|
930 |
-
debugger.post_mortem()
|
931 |
-
raise
|
932 |
-
finally:
|
933 |
-
# move newly created debug project to debug_runs
|
934 |
-
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
935 |
-
dst, name = os.path.split(logdir)
|
936 |
-
dst = os.path.join(dst, "debug_runs", name)
|
937 |
-
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
938 |
-
os.rename(logdir, dst)
|
939 |
-
|
940 |
-
if opt.wandb:
|
941 |
-
wandb.finish()
|
942 |
-
# if trainer.global_rank == 0:
|
943 |
-
# print(trainer.profiler.summary())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,31 +0,0 @@
|
|
1 |
-
STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
|
2 |
-
Dated: November 21, 2023
|
3 |
-
|
4 |
-
“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
|
5 |
-
|
6 |
-
"Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
|
7 |
-
"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
|
8 |
-
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
|
9 |
-
|
10 |
-
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
11 |
-
|
12 |
-
"Stability AI" or "we" means Stability AI Ltd.
|
13 |
-
|
14 |
-
"Software" means, collectively, Stability AI’s proprietary StableCode made available under this Agreement.
|
15 |
-
|
16 |
-
“Software Products” means Software and Documentation.
|
17 |
-
|
18 |
-
By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
License Rights and Redistribution.
|
23 |
-
Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use.
|
24 |
-
b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
|
25 |
-
2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
|
26 |
-
3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
27 |
-
3. Intellectual Property.
|
28 |
-
a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products.
|
29 |
-
Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works.
|
30 |
-
If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement.
|
31 |
-
4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,75 +0,0 @@
|
|
1 |
-
SDXL 0.9 RESEARCH LICENSE AGREEMENT
|
2 |
-
Copyright (c) Stability AI Ltd.
|
3 |
-
This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd. (“Stability AI” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Stability AI under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software (“Documentation”).
|
4 |
-
By clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.
|
5 |
-
1. LICENSE GRANT
|
6 |
-
|
7 |
-
a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Stability AI’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License.
|
8 |
-
|
9 |
-
b. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above.
|
10 |
-
|
11 |
-
c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights not expressly granted by this License.
|
12 |
-
|
13 |
-
|
14 |
-
2. RESTRICTIONS
|
15 |
-
|
16 |
-
You will not, and will not permit, assist or cause any third party to:
|
17 |
-
|
18 |
-
a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
|
19 |
-
|
20 |
-
b. alter or remove copyright and other proprietary notices which appear on or in the Software Products;
|
21 |
-
|
22 |
-
c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Stability AI in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Stability AI; or
|
23 |
-
|
24 |
-
d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License.
|
25 |
-
|
26 |
-
e. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
|
27 |
-
|
28 |
-
|
29 |
-
3. ATTRIBUTION
|
30 |
-
|
31 |
-
Together with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “SDXL 0.9 is licensed under the SDXL Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.”
|
32 |
-
|
33 |
-
|
34 |
-
4. DISCLAIMERS
|
35 |
-
|
36 |
-
THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
|
37 |
-
|
38 |
-
|
39 |
-
5. LIMITATION OF LIABILITY
|
40 |
-
|
41 |
-
TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
|
42 |
-
|
43 |
-
|
44 |
-
6. INDEMNIFICATION
|
45 |
-
|
46 |
-
You will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Stability AI Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims. You will also grant the Stability AI Parties sole control of the defense or settlement, at Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Stability AI or the other Stability AI Parties.
|
47 |
-
|
48 |
-
|
49 |
-
7. TERMINATION; SURVIVAL
|
50 |
-
|
51 |
-
a. This License will automatically terminate upon any breach by you of the terms of this License.
|
52 |
-
|
53 |
-
b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
|
54 |
-
|
55 |
-
c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).
|
56 |
-
|
57 |
-
|
58 |
-
8. THIRD PARTY MATERIALS
|
59 |
-
|
60 |
-
The Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Stability AI does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
|
61 |
-
|
62 |
-
|
63 |
-
9. TRADEMARKS
|
64 |
-
|
65 |
-
Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Stability AI without the prior written permission of Stability AI, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.
|
66 |
-
|
67 |
-
|
68 |
-
10. APPLICABLE LAW; DISPUTE RESOLUTION
|
69 |
-
|
70 |
-
This License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts.
|
71 |
-
|
72 |
-
|
73 |
-
11. MISCELLANEOUS
|
74 |
-
|
75 |
-
If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Stability AI regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Stability AI regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Stability AI.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,175 +0,0 @@
|
|
1 |
-
Copyright (c) 2023 Stability AI CreativeML Open RAIL++-M License dated July 26, 2023
|
2 |
-
|
3 |
-
Section I: PREAMBLE Multimodal generative models are being widely adopted and used, and
|
4 |
-
have the potential to transform the way artists, among other individuals, conceive and
|
5 |
-
benefit from AI or ML technologies as a tool for content creation. Notwithstanding the
|
6 |
-
current and potential benefits that these artifacts can bring to society at large, there
|
7 |
-
are also concerns about potential misuses of them, either due to their technical
|
8 |
-
limitations or ethical considerations. In short, this license strives for both the open
|
9 |
-
and responsible downstream use of the accompanying model. When it comes to the open
|
10 |
-
character, we took inspiration from open source permissive licenses regarding the grant
|
11 |
-
of IP rights. Referring to the downstream responsible use, we added use-based
|
12 |
-
restrictions not permitting the use of the model in very specific scenarios, in order
|
13 |
-
for the licensor to be able to enforce the license in case potential misuses of the
|
14 |
-
Model may occur. At the same time, we strive to promote open and responsible research on
|
15 |
-
generative models for art and content generation. Even though downstream derivative
|
16 |
-
versions of the model could be released under different licensing terms, the latter will
|
17 |
-
always have to include - at minimum - the same use-based restrictions as the ones in the
|
18 |
-
original license (this license). We believe in the intersection between open and
|
19 |
-
responsible AI development; thus, this agreement aims to strike a balance between both
|
20 |
-
in order to enable responsible open-science in the field of AI. This CreativeML Open
|
21 |
-
RAIL++-M License governs the use of the model (and its derivatives) and is informed by
|
22 |
-
the model card associated with the model. NOW THEREFORE, You and Licensor agree as
|
23 |
-
follows: Definitions "License" means the terms and conditions for use, reproduction, and
|
24 |
-
Distribution as defined in this document. "Data" means a collection of information
|
25 |
-
and/or content extracted from the dataset used with the Model, including to train,
|
26 |
-
pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
27 |
-
"Output" means the results of operating a Model as embodied in informational content
|
28 |
-
resulting therefrom. "Model" means any accompanying machine-learning based assemblies
|
29 |
-
(including checkpoints), consisting of learnt weights, parameters (including optimizer
|
30 |
-
states), corresponding to the model architecture as embodied in the Complementary
|
31 |
-
Material, that have been trained or tuned, in whole or in part on the Data, using the
|
32 |
-
Complementary Material. "Derivatives of the Model" means all modifications to the Model,
|
33 |
-
works based on the Model, or any other model which is created or initialized by transfer
|
34 |
-
of patterns of the weights, parameters, activations or output of the Model, to the other
|
35 |
-
model, in order to cause the other model to perform similarly to the Model, including -
|
36 |
-
but not limited to - distillation methods entailing the use of intermediate data
|
37 |
-
representations or methods based on the generation of synthetic data by the Model for
|
38 |
-
training the other model. "Complementary Material" means the accompanying source code
|
39 |
-
and scripts used to define, run, load, benchmark or evaluate the Model, and used to
|
40 |
-
prepare data for training or evaluation, if any. This includes any accompanying
|
41 |
-
documentation, tutorials, examples, etc, if any. "Distribution" means any transmission,
|
42 |
-
reproduction, publication or other sharing of the Model or Derivatives of the Model to a
|
43 |
-
third party, including providing the Model as a hosted service made available by
|
44 |
-
electronic or other remote means - e.g. API-based or web access. "Licensor" means the
|
45 |
-
copyright owner or entity authorized by the copyright owner that is granting the
|
46 |
-
License, including the persons or entities that may have rights in the Model and/or
|
47 |
-
distributing the Model. "You" (or "Your") means an individual or Legal Entity exercising
|
48 |
-
permissions granted by this License and/or making use of the Model for whichever purpose
|
49 |
-
and in any field of use, including usage of the Model in an end-use application - e.g.
|
50 |
-
chatbot, translator, image generator. "Third Parties" means individuals or legal
|
51 |
-
entities that are not under common control with Licensor or You. "Contribution" means
|
52 |
-
any work of authorship, including the original version of the Model and any
|
53 |
-
modifications or additions to that Model or Derivatives of the Model thereof, that is
|
54 |
-
intentionally submitted to Licensor for inclusion in the Model by the copyright owner or
|
55 |
-
by an individual or Legal Entity authorized to submit on behalf of the copyright owner.
|
56 |
-
For the purposes of this definition, "submitted" means any form of electronic, verbal,
|
57 |
-
or written communication sent to the Licensor or its representatives, including but not
|
58 |
-
limited to communication on electronic mailing lists, source code control systems, and
|
59 |
-
issue tracking systems that are managed by, or on behalf of, the Licensor for the
|
60 |
-
purpose of discussing and improving the Model, but excluding communication that is
|
61 |
-
conspicuously marked or otherwise designated in writing by the copyright owner as "Not a
|
62 |
-
Contribution." "Contributor" means Licensor and any individual or Legal Entity on behalf
|
63 |
-
of whom a Contribution has been received by Licensor and subsequently incorporated
|
64 |
-
within the Model.
|
65 |
-
|
66 |
-
Section II: INTELLECTUAL PROPERTY RIGHTS Both copyright and patent grants apply to the
|
67 |
-
Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of
|
68 |
-
the Model are subject to additional terms as described in
|
69 |
-
|
70 |
-
Section III. Grant of Copyright License. Subject to the terms and conditions of this
|
71 |
-
License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
|
72 |
-
no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly
|
73 |
-
display, publicly perform, sublicense, and distribute the Complementary Material, the
|
74 |
-
Model, and Derivatives of the Model. Grant of Patent License. Subject to the terms and
|
75 |
-
conditions of this License and where and as applicable, each Contributor hereby grants
|
76 |
-
to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
-
(except as stated in this paragraph) patent license to make, have made, use, offer to
|
78 |
-
sell, sell, import, and otherwise transfer the Model and the Complementary Material,
|
79 |
-
where such license applies only to those patent claims licensable by such Contributor
|
80 |
-
that are necessarily infringed by their Contribution(s) alone or by combination of their
|
81 |
-
Contribution(s) with the Model to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a cross-claim or counterclaim
|
83 |
-
in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution
|
84 |
-
incorporated within the Model and/or Complementary Material constitutes direct or
|
85 |
-
contributory patent infringement, then any patent licenses granted to You under this
|
86 |
-
License for the Model and/or Work shall terminate as of the date such litigation is
|
87 |
-
asserted or filed. Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
88 |
-
Distribution and Redistribution. You may host for Third Party remote access purposes
|
89 |
-
(e.g. software-as-a-service), reproduce and distribute copies of the Model or
|
90 |
-
Derivatives of the Model thereof in any medium, with or without modifications, provided
|
91 |
-
that You meet the following conditions: Use-based restrictions as referenced in
|
92 |
-
paragraph 5 MUST be included as an enforceable provision by You in any type of legal
|
93 |
-
agreement (e.g. a license) governing the use and/or distribution of the Model or
|
94 |
-
Derivatives of the Model, and You shall give notice to subsequent users You Distribute
|
95 |
-
to, that the Model or Derivatives of the Model are subject to paragraph 5. This
|
96 |
-
provision does not apply to the use of Complementary Material. You must give any Third
|
97 |
-
Party recipients of the Model or Derivatives of the Model a copy of this License; You
|
98 |
-
must cause any modified files to carry prominent notices stating that You changed the
|
99 |
-
files; You must retain all copyright, patent, trademark, and attribution notices
|
100 |
-
excluding those notices that do not pertain to any part of the Model, Derivatives of the
|
101 |
-
Model. You may add Your own copyright statement to Your modifications and may provide
|
102 |
-
additional or different license terms and conditions - respecting paragraph 4.a. - for
|
103 |
-
use, reproduction, or Distribution of Your modifications, or for any such Derivatives of
|
104 |
-
the Model as a whole, provided Your use, reproduction, and Distribution of the Model
|
105 |
-
otherwise complies with the conditions stated in this License. Use-based restrictions.
|
106 |
-
The restrictions set forth in Attachment A are considered Use-based restrictions.
|
107 |
-
Therefore You cannot use the Model and the Derivatives of the Model for the specified
|
108 |
-
restricted uses. You may use the Model subject to this License, including only for
|
109 |
-
lawful purposes and in accordance with the License. Use may include creating any content
|
110 |
-
with, finetuning, updating, running, training, evaluating and/or reparametrizing the
|
111 |
-
Model. You shall require all of Your users who use the Model or a Derivative of the
|
112 |
-
Model to comply with the terms of this paragraph (paragraph 5). The Output You Generate.
|
113 |
-
Except as set forth herein, Licensor claims no rights in the Output You generate using
|
114 |
-
the Model. You are accountable for the Output you generate and its subsequent uses. No
|
115 |
-
use of the output can contravene any provision as stated in the License.
|
116 |
-
|
117 |
-
Section IV: OTHER PROVISIONS Updates and Runtime Restrictions. To the maximum extent
|
118 |
-
permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage
|
119 |
-
of the Model in violation of this License. Trademarks and related. Nothing in this
|
120 |
-
License permits You to make use of Licensors’ trademarks, trade names, logos or to
|
121 |
-
otherwise suggest endorsement or misrepresent the relationship between the parties; and
|
122 |
-
any rights not expressly granted herein are reserved by the Licensors. Disclaimer of
|
123 |
-
Warranty. Unless required by applicable law or agreed to in writing, Licensor provides
|
124 |
-
the Model and the Complementary Material (and each Contributor provides its
|
125 |
-
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
126 |
-
express or implied, including, without limitation, any warranties or conditions of
|
127 |
-
TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
|
128 |
-
solely responsible for determining the appropriateness of using or redistributing the
|
129 |
-
Model, Derivatives of the Model, and the Complementary Material and assume any risks
|
130 |
-
associated with Your exercise of permissions under this License. Limitation of
|
131 |
-
Liability. In no event and under no legal theory, whether in tort (including
|
132 |
-
negligence), contract, or otherwise, unless required by applicable law (such as
|
133 |
-
deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be
|
134 |
-
liable to You for damages, including any direct, indirect, special, incidental, or
|
135 |
-
consequential damages of any character arising as a result of this License or out of the
|
136 |
-
use or inability to use the Model and the Complementary Material (including but not
|
137 |
-
limited to damages for loss of goodwill, work stoppage, computer failure or malfunction,
|
138 |
-
or any and all other commercial damages or losses), even if such Contributor has been
|
139 |
-
advised of the possibility of such damages. Accepting Warranty or Additional Liability.
|
140 |
-
While redistributing the Model, Derivatives of the Model and the Complementary Material
|
141 |
-
thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty,
|
142 |
-
indemnity, or other liability obligations and/or rights consistent with this License.
|
143 |
-
However, in accepting such obligations, You may act only on Your own behalf and on Your
|
144 |
-
sole responsibility, not on behalf of any other Contributor, and only if You agree to
|
145 |
-
indemnify, defend, and hold each Contributor harmless for any liability incurred by, or
|
146 |
-
claims asserted against, such Contributor by reason of your accepting any such warranty
|
147 |
-
or additional liability. If any provision of this License is held to be invalid, illegal
|
148 |
-
or unenforceable, the remaining provisions shall be unaffected thereby and remain valid
|
149 |
-
as if such provision had not been set forth herein.
|
150 |
-
|
151 |
-
END OF TERMS AND CONDITIONS
|
152 |
-
|
153 |
-
Attachment A Use Restrictions
|
154 |
-
You agree not to use the Model or Derivatives of the Model:
|
155 |
-
In any way that violates any applicable national, federal, state, local or
|
156 |
-
international law or regulation; For the purpose of exploiting, harming or attempting to
|
157 |
-
exploit or harm minors in any way; To generate or disseminate verifiably false
|
158 |
-
information and/or content with the purpose of harming others; To generate or
|
159 |
-
disseminate personal identifiable information that can be used to harm an individual; To
|
160 |
-
defame, disparage or otherwise harass others; For fully automated decision making that
|
161 |
-
adversely impacts an individual’s legal rights or otherwise creates or modifies a
|
162 |
-
binding, enforceable obligation; For any use intended to or which has the effect of
|
163 |
-
discriminating against or harming individuals or groups based on online or offline
|
164 |
-
social behavior or known or predicted personal or personality characteristics; To
|
165 |
-
exploit any of the vulnerabilities of a specific group of persons based on their age,
|
166 |
-
social, physical or mental characteristics, in order to materially distort the behavior
|
167 |
-
of a person pertaining to that group in a manner that causes or is likely to cause that
|
168 |
-
person or another person physical or psychological harm; For any use intended to or
|
169 |
-
which has the effect of discriminating against individuals or groups based on legally
|
170 |
-
protected characteristics or categories; To provide medical advice and medical results
|
171 |
-
interpretation; To generate or disseminate information for the purpose to be used for
|
172 |
-
administration of justice, law enforcement, immigration or asylum processes, such as
|
173 |
-
predicting an individual will commit fraud/crime commitment (e.g. by text profiling,
|
174 |
-
drawing causal relationships between assertions made in documents, indiscriminate and
|
175 |
-
arbitrarily-targeted use).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,48 +0,0 @@
|
|
1 |
-
[build-system]
|
2 |
-
requires = ["hatchling"]
|
3 |
-
build-backend = "hatchling.build"
|
4 |
-
|
5 |
-
[project]
|
6 |
-
name = "sgm"
|
7 |
-
dynamic = ["version"]
|
8 |
-
description = "Stability Generative Models"
|
9 |
-
readme = "README.md"
|
10 |
-
license-files = { paths = ["LICENSE-CODE"] }
|
11 |
-
requires-python = ">=3.8"
|
12 |
-
|
13 |
-
[project.urls]
|
14 |
-
Homepage = "https://github.com/Stability-AI/generative-models"
|
15 |
-
|
16 |
-
[tool.hatch.version]
|
17 |
-
path = "sgm/__init__.py"
|
18 |
-
|
19 |
-
[tool.hatch.build]
|
20 |
-
# This needs to be explicitly set so the configuration files
|
21 |
-
# grafted into the `sgm` directory get included in the wheel's
|
22 |
-
# RECORD file.
|
23 |
-
include = [
|
24 |
-
"sgm",
|
25 |
-
]
|
26 |
-
# The force-include configurations below make Hatch copy
|
27 |
-
# the configs/ directory (containing the various YAML files required
|
28 |
-
# to generatively model) into the source distribution and the wheel.
|
29 |
-
|
30 |
-
[tool.hatch.build.targets.sdist.force-include]
|
31 |
-
"./configs" = "sgm/configs"
|
32 |
-
|
33 |
-
[tool.hatch.build.targets.wheel.force-include]
|
34 |
-
"./configs" = "sgm/configs"
|
35 |
-
|
36 |
-
[tool.hatch.envs.ci]
|
37 |
-
skip-install = false
|
38 |
-
|
39 |
-
dependencies = [
|
40 |
-
"pytest"
|
41 |
-
]
|
42 |
-
|
43 |
-
[tool.hatch.envs.ci.scripts]
|
44 |
-
test-inference = [
|
45 |
-
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
|
46 |
-
"pip install -r requirements/pt2.txt",
|
47 |
-
"pytest -v tests/inference/test_inference.py {args}",
|
48 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,3 +0,0 @@
|
|
1 |
-
[pytest]
|
2 |
-
markers =
|
3 |
-
inference: mark as inference test (deselect with '-m "not inference"')
|
|
|
|
|
|
|
|
@@ -1,42 +1,7 @@
|
|
1 |
https://gradio-builds.s3.amazonaws.com/756e3431d65172df986a7e335dce8136206a293a/gradio-4.7.1-py3-none-any.whl
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
fire>=0.5.0
|
8 |
-
fsspec>=2023.6.0
|
9 |
-
invisible-watermark>=0.2.0
|
10 |
-
kornia==0.6.9
|
11 |
-
matplotlib>=3.7.2
|
12 |
-
natsort>=8.4.0
|
13 |
-
ninja>=1.11.1
|
14 |
-
numpy>=1.24.4
|
15 |
-
omegaconf>=2.3.0
|
16 |
-
open-clip-torch>=2.20.0
|
17 |
-
opencv-python==4.6.0.66
|
18 |
-
pandas>=2.0.3
|
19 |
-
pillow>=9.5.0
|
20 |
-
pudb>=2022.1.3
|
21 |
-
pytorch-lightning==2.0.1
|
22 |
-
pyyaml>=6.0.1
|
23 |
-
scipy>=1.10.1
|
24 |
-
streamlit>=0.73.1
|
25 |
-
tensorboardx==2.6
|
26 |
-
timm>=0.9.2
|
27 |
-
tokenizers==0.12.1
|
28 |
-
torch>=2.0.1
|
29 |
-
torchaudio>=2.0.2
|
30 |
-
torchdata==0.6.1
|
31 |
-
torchmetrics>=1.0.1
|
32 |
-
torchvision>=0.15.2
|
33 |
-
tqdm>=4.65.0
|
34 |
-
transformers==4.19.1
|
35 |
-
triton==2.0.0
|
36 |
-
urllib3<1.27,>=1.25.4
|
37 |
-
wandb>=0.15.6
|
38 |
-
webdataset>=0.2.33
|
39 |
-
wheel>=0.41.0
|
40 |
-
xformers>=0.0.20
|
41 |
-
fire
|
42 |
uuid
|
|
|
1 |
https://gradio-builds.s3.amazonaws.com/756e3431d65172df986a7e335dce8136206a293a/gradio-4.7.1-py3-none-any.whl
|
2 |
+
git+https://github.com/huggingface/diffusers.git@refs/pull/5895/head
|
3 |
+
transformers
|
4 |
+
accelerate
|
5 |
+
safetensors
|
6 |
+
opencv-python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
uuid
|
@@ -1,40 +0,0 @@
|
|
1 |
-
black==23.7.0
|
2 |
-
chardet>=5.1.0
|
3 |
-
clip @ git+https://github.com/openai/CLIP.git
|
4 |
-
einops>=0.6.1
|
5 |
-
fairscale>=0.4.13
|
6 |
-
fire>=0.5.0
|
7 |
-
fsspec>=2023.6.0
|
8 |
-
invisible-watermark>=0.2.0
|
9 |
-
kornia==0.6.9
|
10 |
-
matplotlib>=3.7.2
|
11 |
-
natsort>=8.4.0
|
12 |
-
numpy>=1.24.4
|
13 |
-
omegaconf>=2.3.0
|
14 |
-
onnx<=1.12.0
|
15 |
-
open-clip-torch>=2.20.0
|
16 |
-
opencv-python==4.6.0.66
|
17 |
-
pandas>=2.0.3
|
18 |
-
pillow>=9.5.0
|
19 |
-
pudb>=2022.1.3
|
20 |
-
pytorch-lightning==1.8.5
|
21 |
-
pyyaml>=6.0.1
|
22 |
-
scipy>=1.10.1
|
23 |
-
streamlit>=1.25.0
|
24 |
-
tensorboardx==2.5.1
|
25 |
-
timm>=0.9.2
|
26 |
-
tokenizers==0.12.1
|
27 |
-
--extra-index-url https://download.pytorch.org/whl/cu117
|
28 |
-
torch==1.13.1+cu117
|
29 |
-
torchaudio==0.13.1
|
30 |
-
torchdata==0.5.1
|
31 |
-
torchmetrics>=1.0.1
|
32 |
-
torchvision==0.14.1+cu117
|
33 |
-
tqdm>=4.65.0
|
34 |
-
transformers==4.19.1
|
35 |
-
triton==2.0.0.post1
|
36 |
-
urllib3<1.27,>=1.25.4
|
37 |
-
wandb>=0.15.6
|
38 |
-
webdataset>=0.2.33
|
39 |
-
wheel>=0.41.0
|
40 |
-
xformers==0.0.16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,39 +0,0 @@
|
|
1 |
-
black==23.7.0
|
2 |
-
chardet==5.1.0
|
3 |
-
clip @ git+https://github.com/openai/CLIP.git
|
4 |
-
einops>=0.6.1
|
5 |
-
fairscale>=0.4.13
|
6 |
-
fire>=0.5.0
|
7 |
-
fsspec>=2023.6.0
|
8 |
-
invisible-watermark>=0.2.0
|
9 |
-
kornia==0.6.9
|
10 |
-
matplotlib>=3.7.2
|
11 |
-
natsort>=8.4.0
|
12 |
-
ninja>=1.11.1
|
13 |
-
numpy>=1.24.4
|
14 |
-
omegaconf>=2.3.0
|
15 |
-
open-clip-torch>=2.20.0
|
16 |
-
opencv-python==4.6.0.66
|
17 |
-
pandas>=2.0.3
|
18 |
-
pillow>=9.5.0
|
19 |
-
pudb>=2022.1.3
|
20 |
-
pytorch-lightning==2.0.1
|
21 |
-
pyyaml>=6.0.1
|
22 |
-
scipy>=1.10.1
|
23 |
-
streamlit>=0.73.1
|
24 |
-
tensorboardx==2.6
|
25 |
-
timm>=0.9.2
|
26 |
-
tokenizers==0.12.1
|
27 |
-
torch>=2.0.1
|
28 |
-
torchaudio>=2.0.2
|
29 |
-
torchdata==0.6.1
|
30 |
-
torchmetrics>=1.0.1
|
31 |
-
torchvision>=0.15.2
|
32 |
-
tqdm>=4.65.0
|
33 |
-
transformers==4.19.1
|
34 |
-
triton==2.0.0
|
35 |
-
urllib3<1.27,>=1.25.4
|
36 |
-
wandb>=0.15.6
|
37 |
-
webdataset>=0.2.33
|
38 |
-
wheel>=0.41.0
|
39 |
-
xformers>=0.0.20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Binary file (6.15 kB)
|
|
File without changes
|
File without changes
|
@@ -1,156 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
import cv2
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
try:
|
7 |
-
from imwatermark import WatermarkDecoder
|
8 |
-
except ImportError as e:
|
9 |
-
try:
|
10 |
-
# Assume some of the other dependencies such as torch are not fulfilled
|
11 |
-
# import file without loading unnecessary libraries.
|
12 |
-
import importlib.util
|
13 |
-
import sys
|
14 |
-
|
15 |
-
spec = importlib.util.find_spec("imwatermark.maxDct")
|
16 |
-
assert spec is not None
|
17 |
-
maxDct = importlib.util.module_from_spec(spec)
|
18 |
-
sys.modules["maxDct"] = maxDct
|
19 |
-
spec.loader.exec_module(maxDct)
|
20 |
-
|
21 |
-
class WatermarkDecoder(object):
|
22 |
-
"""A minimal version of
|
23 |
-
https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py
|
24 |
-
to only reconstruct bits using dwtDct"""
|
25 |
-
|
26 |
-
def __init__(self, wm_type="bytes", length=0):
|
27 |
-
assert wm_type == "bits", "Only bits defined in minimal import"
|
28 |
-
self._wmType = wm_type
|
29 |
-
self._wmLen = length
|
30 |
-
|
31 |
-
def reconstruct(self, bits):
|
32 |
-
if len(bits) != self._wmLen:
|
33 |
-
raise RuntimeError("bits are not matched with watermark length")
|
34 |
-
|
35 |
-
return bits
|
36 |
-
|
37 |
-
def decode(self, cv2Image, method="dwtDct", **configs):
|
38 |
-
(r, c, channels) = cv2Image.shape
|
39 |
-
if r * c < 256 * 256:
|
40 |
-
raise RuntimeError("image too small, should be larger than 256x256")
|
41 |
-
|
42 |
-
bits = []
|
43 |
-
assert method == "dwtDct"
|
44 |
-
embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
|
45 |
-
bits = embed.decode(cv2Image)
|
46 |
-
return self.reconstruct(bits)
|
47 |
-
|
48 |
-
except:
|
49 |
-
raise e
|
50 |
-
|
51 |
-
|
52 |
-
# A fixed 48-bit message that was choosen at random
|
53 |
-
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
54 |
-
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
55 |
-
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
56 |
-
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
57 |
-
MATCH_VALUES = [
|
58 |
-
[27, "No watermark detected"],
|
59 |
-
[33, "Partial watermark match. Cannot determine with certainty."],
|
60 |
-
[
|
61 |
-
35,
|
62 |
-
(
|
63 |
-
"Likely watermarked. In our test 0.02% of real images were "
|
64 |
-
'falsely detected as "Likely watermarked"'
|
65 |
-
),
|
66 |
-
],
|
67 |
-
[
|
68 |
-
49,
|
69 |
-
(
|
70 |
-
"Very likely watermarked. In our test no real images were "
|
71 |
-
'falsely detected as "Very likely watermarked"'
|
72 |
-
),
|
73 |
-
],
|
74 |
-
]
|
75 |
-
|
76 |
-
|
77 |
-
class GetWatermarkMatch:
|
78 |
-
def __init__(self, watermark):
|
79 |
-
self.watermark = watermark
|
80 |
-
self.num_bits = len(self.watermark)
|
81 |
-
self.decoder = WatermarkDecoder("bits", self.num_bits)
|
82 |
-
|
83 |
-
def __call__(self, x: np.ndarray) -> np.ndarray:
|
84 |
-
"""
|
85 |
-
Detects the number of matching bits the predefined watermark with one
|
86 |
-
or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
|
87 |
-
|
88 |
-
Args:
|
89 |
-
x: ([B], h w, c) in range [0, 255]
|
90 |
-
|
91 |
-
Returns:
|
92 |
-
number of matched bits ([B],)
|
93 |
-
"""
|
94 |
-
squeeze = len(x.shape) == 3
|
95 |
-
if squeeze:
|
96 |
-
x = x[None, ...]
|
97 |
-
|
98 |
-
bs = x.shape[0]
|
99 |
-
detected = np.empty((bs, self.num_bits), dtype=bool)
|
100 |
-
for k in range(bs):
|
101 |
-
detected[k] = self.decoder.decode(x[k], "dwtDct")
|
102 |
-
result = np.sum(detected == self.watermark, axis=-1)
|
103 |
-
if squeeze:
|
104 |
-
return result[0]
|
105 |
-
else:
|
106 |
-
return result
|
107 |
-
|
108 |
-
|
109 |
-
get_watermark_match = GetWatermarkMatch(WATERMARK_BITS)
|
110 |
-
|
111 |
-
|
112 |
-
if __name__ == "__main__":
|
113 |
-
parser = argparse.ArgumentParser()
|
114 |
-
parser.add_argument(
|
115 |
-
"filename",
|
116 |
-
nargs="+",
|
117 |
-
type=str,
|
118 |
-
help="Image files to check for watermarks",
|
119 |
-
)
|
120 |
-
opts = parser.parse_args()
|
121 |
-
|
122 |
-
print(
|
123 |
-
"""
|
124 |
-
This script tries to detect watermarked images. Please be aware of
|
125 |
-
the following:
|
126 |
-
- As the watermark is supposed to be invisible, there is the risk that
|
127 |
-
watermarked images may not be detected.
|
128 |
-
- To maximize the chance of detection make sure that the image has the same
|
129 |
-
dimensions as when the watermark was applied (most likely 1024x1024
|
130 |
-
or 512x512).
|
131 |
-
- Specific image manipulation may drastically decrease the chance that
|
132 |
-
watermarks can be detected.
|
133 |
-
- There is also the chance that an image has the characteristics of the
|
134 |
-
watermark by chance.
|
135 |
-
- The watermark script is public, anybody may watermark any images, and
|
136 |
-
could therefore claim it to be generated.
|
137 |
-
- All numbers below are based on a test using 10,000 images without any
|
138 |
-
modifications after applying the watermark.
|
139 |
-
"""
|
140 |
-
)
|
141 |
-
|
142 |
-
for fn in opts.filename:
|
143 |
-
image = cv2.imread(fn)
|
144 |
-
if image is None:
|
145 |
-
print(f"Couldn't read {fn}. Skipping")
|
146 |
-
continue
|
147 |
-
|
148 |
-
num_bits = get_watermark_match(image)
|
149 |
-
k = 0
|
150 |
-
while num_bits > MATCH_VALUES[k][0]:
|
151 |
-
k += 1
|
152 |
-
print(
|
153 |
-
f"{fn}: {MATCH_VALUES[k][1]}",
|
154 |
-
f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n",
|
155 |
-
sep="\n\t",
|
156 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|