multimodalart HF staff commited on
Commit
a3f8f46
1 Parent(s): e3d310b

Upload 147 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. CODEOWNERS +1 -0
  3. LICENSE-CODE +21 -0
  4. README.md +292 -11
  5. assets/000.jpg +0 -0
  6. assets/001_with_eval.png +3 -0
  7. assets/test_image.png +0 -0
  8. assets/tile.gif +3 -0
  9. data/DejaVuSans.ttf +0 -0
  10. main.py +943 -0
  11. model_licenses/LICENSE-SDV +31 -0
  12. model_licenses/LICENSE-SDXL0.9 +75 -0
  13. model_licenses/LICENSE-SDXL1.0 +175 -0
  14. outputs/000000.mp4 +0 -0
  15. outputs/000001.mp4 +0 -0
  16. outputs/000002.mp4 +0 -0
  17. outputs/000003.mp4 +0 -0
  18. outputs/000004.mp4 +3 -0
  19. outputs/000005.mp4 +0 -0
  20. outputs/simple_video_sample/svd_xt/000000.mp4 +0 -0
  21. pyproject.toml +48 -0
  22. pytest.ini +3 -0
  23. requirements/pt13.txt +40 -0
  24. scripts/__pycache__/__init__.cpython-310.pyc +0 -0
  25. scripts/util/__pycache__/__init__.cpython-310.pyc +0 -0
  26. scripts/util/detection/__pycache__/__init__.cpython-310.pyc +0 -0
  27. scripts/util/detection/__pycache__/nsfw_and_watermark_dectection.cpython-310.pyc +0 -0
  28. sgm/__pycache__/__init__.cpython-310.pyc +0 -0
  29. sgm/__pycache__/util.cpython-310.pyc +0 -0
  30. sgm/inference/__pycache__/helpers.cpython-310.pyc +0 -0
  31. sgm/inference/api.py +8 -9
  32. sgm/models/__pycache__/__init__.cpython-310.pyc +0 -0
  33. sgm/models/__pycache__/autoencoder.cpython-310.pyc +0 -0
  34. sgm/models/__pycache__/diffusion.cpython-310.pyc +0 -0
  35. sgm/models/autoencoder.py +2 -6
  36. sgm/models/diffusion.py +2 -7
  37. sgm/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  38. sgm/modules/__pycache__/attention.cpython-310.pyc +0 -0
  39. sgm/modules/__pycache__/ema.cpython-310.pyc +0 -0
  40. sgm/modules/__pycache__/video_attention.cpython-310.pyc +0 -0
  41. sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc +0 -0
  42. sgm/modules/autoencoding/__pycache__/temporal_ae.cpython-310.pyc +0 -0
  43. sgm/modules/autoencoding/regularizers/__init__.py +2 -1
  44. sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-310.pyc +0 -0
  45. sgm/modules/autoencoding/regularizers/__pycache__/base.cpython-310.pyc +0 -0
  46. sgm/modules/autoencoding/temporal_ae.py +4 -6
  47. sgm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc +0 -0
  48. sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-310.pyc +0 -0
  49. sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-310.pyc +0 -0
  50. sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/001_with_eval.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/tile.gif filter=lfs diff=lfs merge=lfs -text
38
+ outputs/000004.mp4 filter=lfs diff=lfs merge=lfs -text
CODEOWNERS ADDED
@@ -0,0 +1 @@
 
 
1
+ .github @Stability-AI/infrastructure
LICENSE-CODE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,11 +1,292 @@
1
- ---
2
- title: Stable Video Diffusion
3
- emoji: 📺
4
- colorFrom: purple
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.4.0
8
- app_file: app.py
9
- pinned: false
10
- license: other
11
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generative Models by Stability AI
2
+
3
+ ![sample1](assets/000.jpg)
4
+
5
+ ## News
6
+
7
+ **November 21, 2023**
8
+
9
+ - We are releasing Stable Video Diffusion, an image-to-video model, for research purposes:
10
+ - [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14
11
+ frames at resolution 576x1024 given a context frame of the same size.
12
+ We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`.
13
+ - [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned
14
+ for 25 frame generation.
15
+ - We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
16
+ - Alongside the model, we release a [technical report](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets).
17
+
18
+ ![tile](assets/tile.gif)
19
+
20
+ **July 26, 2023**
21
+
22
+ - We are releasing two new open models with a
23
+ permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file
24
+ hashes):
25
+ - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version
26
+ over `SDXL-base-0.9`.
27
+ - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version
28
+ over `SDXL-refiner-0.9`.
29
+
30
+ ![sample2](assets/001_with_eval.png)
31
+
32
+ **July 4, 2023**
33
+
34
+ - A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
35
+
36
+ **June 22, 2023**
37
+
38
+ - We are releasing two new diffusion models for research purposes:
39
+ - `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The
40
+ base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip)
41
+ and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses
42
+ the OpenCLIP model.
43
+ - `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is
44
+ not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
45
+
46
+ If you would like to access these models for your research, please apply using one of the following links:
47
+ [SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
48
+ and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
49
+ This means that you can apply for any of the two links - and if you are granted - you can access both.
50
+ Please log in to your Hugging Face Account with your organization email to request access.
51
+ **We plan to do a full release soon (July).**
52
+
53
+ ## The codebase
54
+
55
+ ### General Philosophy
56
+
57
+ Modularity is king. This repo implements a config-driven approach where we build and combine submodules by
58
+ calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
59
+
60
+ ### Changelog from the old `ldm` codebase
61
+
62
+ For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other
63
+ training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`,
64
+ now `DiffusionEngine`) has been cleaned up:
65
+
66
+ - No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial
67
+ conditionings, and all combinations thereof) in a single class: `GeneralConditioner`,
68
+ see `sgm/modules/encoders/modules.py`.
69
+ - We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
70
+ samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
71
+ - We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable
72
+ change is probably now the option to train continuous time models):
73
+ * Discrete times models (denoisers) are simply a special case of continuous time models (denoisers);
74
+ see `sgm/modules/diffusionmodules/denoiser.py`.
75
+ * The following features are now independent: weighting of the diffusion loss
76
+ function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the
77
+ network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during
78
+ training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
79
+ - Autoencoding models have also been cleaned up.
80
+
81
+ ## Installation:
82
+
83
+ <a name="installation"></a>
84
+
85
+ #### 1. Clone the repo
86
+
87
+ ```shell
88
+ git clone git@github.com:Stability-AI/generative-models.git
89
+ cd generative-models
90
+ ```
91
+
92
+ #### 2. Setting up the virtualenv
93
+
94
+ This is assuming you have navigated to the `generative-models` root after cloning it.
95
+
96
+ **NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts.
97
+
98
+ **PyTorch 2.0**
99
+
100
+ ```shell
101
+ # install required packages from pypi
102
+ python3 -m venv .pt2
103
+ source .pt2/bin/activate
104
+ pip3 install -r requirements/pt2.txt
105
+ ```
106
+
107
+ #### 3. Install `sgm`
108
+
109
+ ```shell
110
+ pip3 install .
111
+ ```
112
+
113
+ #### 4. Install `sdata` for training
114
+
115
+ ```shell
116
+ pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
117
+ ```
118
+
119
+ ## Packaging
120
+
121
+ This repository uses PEP 517 compliant packaging using [Hatch](https://hatch.pypa.io/latest/).
122
+
123
+ To build a distributable wheel, install `hatch` and run `hatch build`
124
+ (specifying `-t wheel` will skip building a sdist, which is not necessary).
125
+
126
+ ```
127
+ pip install hatch
128
+ hatch build -t wheel
129
+ ```
130
+
131
+ You will find the built package in `dist/`. You can install the wheel with `pip install dist/*.whl`.
132
+
133
+ Note that the package does **not** currently specify dependencies; you will need to install the required packages,
134
+ depending on your use case and PyTorch version, manually.
135
+
136
+ ## Inference
137
+
138
+ We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling
139
+ in `scripts/demo/sampling.py`.
140
+ We provide file hashes for the complete file as well as for only the saved tensors in the file (
141
+ see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
142
+ The following models are currently supported:
143
+
144
+ - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
145
+ ```
146
+ File Hash (sha256): 31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b
147
+ Tensordata Hash (sha256): 0xd7a9105a900fd52748f20725fe52fe52b507fd36bee4fc107b1550a26e6ee1d7
148
+ ```
149
+ - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)
150
+ ```
151
+ File Hash (sha256): 7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f
152
+ Tensordata Hash (sha256): 0x1a77d21bebc4b4de78c474a90cb74dc0d2217caf4061971dbfa75ad406b75d81
153
+ ```
154
+ - [SDXL-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
155
+ - [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
156
+ - [SD-2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
157
+ - [SD-2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
158
+
159
+ **Weights for SDXL**:
160
+
161
+ **SDXL-1.0:**
162
+ The weights of SDXL-1.0 are available (subject to
163
+ a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
164
+
165
+ - base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
166
+ - refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
167
+
168
+ **SDXL-0.9:**
169
+ The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
170
+ If you would like to access these models for your research, please apply using one of the following links:
171
+ [SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
172
+ and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
173
+ This means that you can apply for any of the two links - and if you are granted - you can access both.
174
+ Please log in to your Hugging Face Account with your organization email to request access.
175
+
176
+ After obtaining the weights, place them into `checkpoints/`.
177
+ Next, start the demo using
178
+
179
+ ```
180
+ streamlit run scripts/demo/sampling.py --server.port <your_port>
181
+ ```
182
+
183
+ ### Invisible Watermark Detection
184
+
185
+ Images generated with our code use the
186
+ [invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/)
187
+ library to embed an invisible watermark into the model output. We also provide
188
+ a script to easily detect that watermark. Please note that this watermark is
189
+ not the same as in previous Stable Diffusion 1.x/2.x versions.
190
+
191
+ To run the script you need to either have a working installation as above or
192
+ try an _experimental_ import using only a minimal amount of packages:
193
+
194
+ ```bash
195
+ python -m venv .detect
196
+ source .detect/bin/activate
197
+
198
+ pip install "numpy>=1.17" "PyWavelets>=1.1.1" "opencv-python>=4.1.0.25"
199
+ pip install --no-deps invisible-watermark
200
+ ```
201
+
202
+ To run the script you need to have a working installation as above. The script
203
+ is then useable in the following ways (don't forget to activate your
204
+ virtual environment beforehand, e.g. `source .pt1/bin/activate`):
205
+
206
+ ```bash
207
+ # test a single file
208
+ python scripts/demo/detect.py <your filename here>
209
+ # test multiple files at once
210
+ python scripts/demo/detect.py <filename 1> <filename 2> ... <filename n>
211
+ # test all files in a specific folder
212
+ python scripts/demo/detect.py <your folder name here>/*
213
+ ```
214
+
215
+ ## Training:
216
+
217
+ We are providing example training configs in `configs/example_training`. To launch a training, run
218
+
219
+ ```
220
+ python main.py --base configs/<config1.yaml> configs/<config2.yaml>
221
+ ```
222
+
223
+ where configs are merged from left to right (later configs overwrite the same values).
224
+ This can be used to combine model, training and data configs. However, all of them can also be
225
+ defined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST,
226
+ run
227
+
228
+ ```bash
229
+ python main.py --base configs/example_training/toy/mnist_cond.yaml
230
+ ```
231
+
232
+ **NOTE 1:** Using the non-toy-dataset
233
+ configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml`
234
+ and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the
235
+ used dataset (which is expected to stored in tar-file in
236
+ the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search
237
+ for comments containing `USER:` in the respective config.
238
+
239
+ **NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for
240
+ autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`,
241
+ only `pytorch1.13` is supported.
242
+
243
+ **NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires
244
+ retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing
245
+ the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done
246
+ for the provided text-to-image configs.
247
+
248
+ ### Building New Diffusion Models
249
+
250
+ #### Conditioner
251
+
252
+ The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
253
+ different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
254
+ All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
255
+ guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for
256
+ text-conditioning or `cls` for class-conditioning.
257
+ When computing conditionings, the embedder will get `batch[input_key]` as input.
258
+ We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
259
+ appropriately.
260
+ Note that the order of the embedders in the `conditioner_config` is important.
261
+
262
+ #### Network
263
+
264
+ The neural network is set through the `network_config`. This used to be called `unet_config`, which is not general
265
+ enough as we plan to experiment with transformer-based diffusion backbones.
266
+
267
+ #### Loss
268
+
269
+ The loss is configured through `loss_config`. For standard diffusion model training, you will have to
270
+ set `sigma_sampler_config`.
271
+
272
+ #### Sampler config
273
+
274
+ As discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical
275
+ solver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free
276
+ guidance.
277
+
278
+ ### Dataset Handling
279
+
280
+ For large scale training we recommend using the data pipelines from
281
+ our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement
282
+ and automatically included when following the steps from the [Installation section](#installation).
283
+ Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
284
+ data keys/values,
285
+ e.g.,
286
+
287
+ ```python
288
+ example = {"jpg": x, # this is a tensor -1...1 chw
289
+ "txt": "a beautiful image"}
290
+ ```
291
+
292
+ where we expect images in -1...1, channel-first format.
assets/000.jpg ADDED
assets/001_with_eval.png ADDED

Git LFS Details

  • SHA256: 026fa14e30098729064a00fb7fcec41bb57dcddb33b36b548d553f601bc53634
  • Pointer size: 132 Bytes
  • Size of remote file: 4.19 MB
assets/test_image.png ADDED
assets/tile.gif ADDED

Git LFS Details

  • SHA256: 2340a9809e36fa9634633c7cc5fd256737c620ba47151726c85173512dc5c8ff
  • Pointer size: 133 Bytes
  • Size of remote file: 18.6 MB
data/DejaVuSans.ttf ADDED
Binary file (757 kB). View file
 
main.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import glob
4
+ import inspect
5
+ import os
6
+ import sys
7
+ from inspect import Parameter
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ import pytorch_lightning as pl
12
+ import torch
13
+ import torchvision
14
+ import wandb
15
+ from matplotlib import pyplot as plt
16
+ from natsort import natsorted
17
+ from omegaconf import OmegaConf
18
+ from packaging import version
19
+ from PIL import Image
20
+ from pytorch_lightning import seed_everything
21
+ from pytorch_lightning.callbacks import Callback
22
+ from pytorch_lightning.loggers import WandbLogger
23
+ from pytorch_lightning.trainer import Trainer
24
+ from pytorch_lightning.utilities import rank_zero_only
25
+
26
+ from sgm.util import exists, instantiate_from_config, isheatmap
27
+
28
+ MULTINODE_HACKS = True
29
+
30
+
31
+ def default_trainer_args():
32
+ argspec = dict(inspect.signature(Trainer.__init__).parameters)
33
+ argspec.pop("self")
34
+ default_args = {
35
+ param: argspec[param].default
36
+ for param in argspec
37
+ if argspec[param] != Parameter.empty
38
+ }
39
+ return default_args
40
+
41
+
42
+ def get_parser(**parser_kwargs):
43
+ def str2bool(v):
44
+ if isinstance(v, bool):
45
+ return v
46
+ if v.lower() in ("yes", "true", "t", "y", "1"):
47
+ return True
48
+ elif v.lower() in ("no", "false", "f", "n", "0"):
49
+ return False
50
+ else:
51
+ raise argparse.ArgumentTypeError("Boolean value expected.")
52
+
53
+ parser = argparse.ArgumentParser(**parser_kwargs)
54
+ parser.add_argument(
55
+ "-n",
56
+ "--name",
57
+ type=str,
58
+ const=True,
59
+ default="",
60
+ nargs="?",
61
+ help="postfix for logdir",
62
+ )
63
+ parser.add_argument(
64
+ "--no_date",
65
+ type=str2bool,
66
+ nargs="?",
67
+ const=True,
68
+ default=False,
69
+ help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)",
70
+ )
71
+ parser.add_argument(
72
+ "-r",
73
+ "--resume",
74
+ type=str,
75
+ const=True,
76
+ default="",
77
+ nargs="?",
78
+ help="resume from logdir or checkpoint in logdir",
79
+ )
80
+ parser.add_argument(
81
+ "-b",
82
+ "--base",
83
+ nargs="*",
84
+ metavar="base_config.yaml",
85
+ help="paths to base configs. Loaded from left-to-right. "
86
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
87
+ default=list(),
88
+ )
89
+ parser.add_argument(
90
+ "-t",
91
+ "--train",
92
+ type=str2bool,
93
+ const=True,
94
+ default=True,
95
+ nargs="?",
96
+ help="train",
97
+ )
98
+ parser.add_argument(
99
+ "--no-test",
100
+ type=str2bool,
101
+ const=True,
102
+ default=False,
103
+ nargs="?",
104
+ help="disable test",
105
+ )
106
+ parser.add_argument(
107
+ "-p", "--project", help="name of new or path to existing project"
108
+ )
109
+ parser.add_argument(
110
+ "-d",
111
+ "--debug",
112
+ type=str2bool,
113
+ nargs="?",
114
+ const=True,
115
+ default=False,
116
+ help="enable post-mortem debugging",
117
+ )
118
+ parser.add_argument(
119
+ "-s",
120
+ "--seed",
121
+ type=int,
122
+ default=23,
123
+ help="seed for seed_everything",
124
+ )
125
+ parser.add_argument(
126
+ "-f",
127
+ "--postfix",
128
+ type=str,
129
+ default="",
130
+ help="post-postfix for default name",
131
+ )
132
+ parser.add_argument(
133
+ "--projectname",
134
+ type=str,
135
+ default="stablediffusion",
136
+ )
137
+ parser.add_argument(
138
+ "-l",
139
+ "--logdir",
140
+ type=str,
141
+ default="logs",
142
+ help="directory for logging dat shit",
143
+ )
144
+ parser.add_argument(
145
+ "--scale_lr",
146
+ type=str2bool,
147
+ nargs="?",
148
+ const=True,
149
+ default=False,
150
+ help="scale base-lr by ngpu * batch_size * n_accumulate",
151
+ )
152
+ parser.add_argument(
153
+ "--legacy_naming",
154
+ type=str2bool,
155
+ nargs="?",
156
+ const=True,
157
+ default=False,
158
+ help="name run based on config file name if true, else by whole path",
159
+ )
160
+ parser.add_argument(
161
+ "--enable_tf32",
162
+ type=str2bool,
163
+ nargs="?",
164
+ const=True,
165
+ default=False,
166
+ help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12",
167
+ )
168
+ parser.add_argument(
169
+ "--startup",
170
+ type=str,
171
+ default=None,
172
+ help="Startuptime from distributed script",
173
+ )
174
+ parser.add_argument(
175
+ "--wandb",
176
+ type=str2bool,
177
+ nargs="?",
178
+ const=True,
179
+ default=False, # TODO: later default to True
180
+ help="log to wandb",
181
+ )
182
+ parser.add_argument(
183
+ "--no_base_name",
184
+ type=str2bool,
185
+ nargs="?",
186
+ const=True,
187
+ default=False, # TODO: later default to True
188
+ help="log to wandb",
189
+ )
190
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
191
+ parser.add_argument(
192
+ "--resume_from_checkpoint",
193
+ type=str,
194
+ default=None,
195
+ help="single checkpoint file to resume from",
196
+ )
197
+ default_args = default_trainer_args()
198
+ for key in default_args:
199
+ parser.add_argument("--" + key, default=default_args[key])
200
+ return parser
201
+
202
+
203
+ def get_checkpoint_name(logdir):
204
+ ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt")
205
+ ckpt = natsorted(glob.glob(ckpt))
206
+ print('available "last" checkpoints:')
207
+ print(ckpt)
208
+ if len(ckpt) > 1:
209
+ print("got most recent checkpoint")
210
+ ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1]
211
+ print(f"Most recent ckpt is {ckpt}")
212
+ with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f:
213
+ f.write(ckpt + "\n")
214
+ try:
215
+ version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0])
216
+ except Exception as e:
217
+ print("version confusion but not bad")
218
+ print(e)
219
+ version = 1
220
+ # version = last_version + 1
221
+ else:
222
+ # in this case, we only have one "last.ckpt"
223
+ ckpt = ckpt[0]
224
+ version = 1
225
+ melk_ckpt_name = f"last-v{version}.ckpt"
226
+ print(f"Current melk ckpt name: {melk_ckpt_name}")
227
+ return ckpt, melk_ckpt_name
228
+
229
+
230
+ class SetupCallback(Callback):
231
+ def __init__(
232
+ self,
233
+ resume,
234
+ now,
235
+ logdir,
236
+ ckptdir,
237
+ cfgdir,
238
+ config,
239
+ lightning_config,
240
+ debug,
241
+ ckpt_name=None,
242
+ ):
243
+ super().__init__()
244
+ self.resume = resume
245
+ self.now = now
246
+ self.logdir = logdir
247
+ self.ckptdir = ckptdir
248
+ self.cfgdir = cfgdir
249
+ self.config = config
250
+ self.lightning_config = lightning_config
251
+ self.debug = debug
252
+ self.ckpt_name = ckpt_name
253
+
254
+ def on_exception(self, trainer: pl.Trainer, pl_module, exception):
255
+ if not self.debug and trainer.global_rank == 0:
256
+ print("Summoning checkpoint.")
257
+ if self.ckpt_name is None:
258
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
259
+ else:
260
+ ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
261
+ trainer.save_checkpoint(ckpt_path)
262
+
263
+ def on_fit_start(self, trainer, pl_module):
264
+ if trainer.global_rank == 0:
265
+ # Create logdirs and save configs
266
+ os.makedirs(self.logdir, exist_ok=True)
267
+ os.makedirs(self.ckptdir, exist_ok=True)
268
+ os.makedirs(self.cfgdir, exist_ok=True)
269
+
270
+ if "callbacks" in self.lightning_config:
271
+ if (
272
+ "metrics_over_trainsteps_checkpoint"
273
+ in self.lightning_config["callbacks"]
274
+ ):
275
+ os.makedirs(
276
+ os.path.join(self.ckptdir, "trainstep_checkpoints"),
277
+ exist_ok=True,
278
+ )
279
+ print("Project config")
280
+ print(OmegaConf.to_yaml(self.config))
281
+ if MULTINODE_HACKS:
282
+ import time
283
+
284
+ time.sleep(5)
285
+ OmegaConf.save(
286
+ self.config,
287
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
288
+ )
289
+
290
+ print("Lightning config")
291
+ print(OmegaConf.to_yaml(self.lightning_config))
292
+ OmegaConf.save(
293
+ OmegaConf.create({"lightning": self.lightning_config}),
294
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
295
+ )
296
+
297
+ else:
298
+ # ModelCheckpoint callback created log directory --- remove it
299
+ if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
300
+ dst, name = os.path.split(self.logdir)
301
+ dst = os.path.join(dst, "child_runs", name)
302
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
303
+ try:
304
+ os.rename(self.logdir, dst)
305
+ except FileNotFoundError:
306
+ pass
307
+
308
+
309
+ class ImageLogger(Callback):
310
+ def __init__(
311
+ self,
312
+ batch_frequency,
313
+ max_images,
314
+ clamp=True,
315
+ increase_log_steps=True,
316
+ rescale=True,
317
+ disabled=False,
318
+ log_on_batch_idx=False,
319
+ log_first_step=False,
320
+ log_images_kwargs=None,
321
+ log_before_first_step=False,
322
+ enable_autocast=True,
323
+ ):
324
+ super().__init__()
325
+ self.enable_autocast = enable_autocast
326
+ self.rescale = rescale
327
+ self.batch_freq = batch_frequency
328
+ self.max_images = max_images
329
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
330
+ if not increase_log_steps:
331
+ self.log_steps = [self.batch_freq]
332
+ self.clamp = clamp
333
+ self.disabled = disabled
334
+ self.log_on_batch_idx = log_on_batch_idx
335
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
336
+ self.log_first_step = log_first_step
337
+ self.log_before_first_step = log_before_first_step
338
+
339
+ @rank_zero_only
340
+ def log_local(
341
+ self,
342
+ save_dir,
343
+ split,
344
+ images,
345
+ global_step,
346
+ current_epoch,
347
+ batch_idx,
348
+ pl_module: Union[None, pl.LightningModule] = None,
349
+ ):
350
+ root = os.path.join(save_dir, "images", split)
351
+ for k in images:
352
+ if isheatmap(images[k]):
353
+ fig, ax = plt.subplots()
354
+ ax = ax.matshow(
355
+ images[k].cpu().numpy(), cmap="hot", interpolation="lanczos"
356
+ )
357
+ plt.colorbar(ax)
358
+ plt.axis("off")
359
+
360
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
361
+ k, global_step, current_epoch, batch_idx
362
+ )
363
+ os.makedirs(root, exist_ok=True)
364
+ path = os.path.join(root, filename)
365
+ plt.savefig(path)
366
+ plt.close()
367
+ # TODO: support wandb
368
+ else:
369
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
370
+ if self.rescale:
371
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
372
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
373
+ grid = grid.numpy()
374
+ grid = (grid * 255).astype(np.uint8)
375
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
376
+ k, global_step, current_epoch, batch_idx
377
+ )
378
+ path = os.path.join(root, filename)
379
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
380
+ img = Image.fromarray(grid)
381
+ img.save(path)
382
+ if exists(pl_module):
383
+ assert isinstance(
384
+ pl_module.logger, WandbLogger
385
+ ), "logger_log_image only supports WandbLogger currently"
386
+ pl_module.logger.log_image(
387
+ key=f"{split}/{k}",
388
+ images=[
389
+ img,
390
+ ],
391
+ step=pl_module.global_step,
392
+ )
393
+
394
+ @rank_zero_only
395
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
396
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
397
+ if (
398
+ self.check_frequency(check_idx)
399
+ and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
400
+ and callable(pl_module.log_images)
401
+ and
402
+ # batch_idx > 5 and
403
+ self.max_images > 0
404
+ ):
405
+ logger = type(pl_module.logger)
406
+ is_train = pl_module.training
407
+ if is_train:
408
+ pl_module.eval()
409
+
410
+ gpu_autocast_kwargs = {
411
+ "enabled": self.enable_autocast, # torch.is_autocast_enabled(),
412
+ "dtype": torch.get_autocast_gpu_dtype(),
413
+ "cache_enabled": torch.is_autocast_cache_enabled(),
414
+ }
415
+ with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
416
+ images = pl_module.log_images(
417
+ batch, split=split, **self.log_images_kwargs
418
+ )
419
+
420
+ for k in images:
421
+ N = min(images[k].shape[0], self.max_images)
422
+ if not isheatmap(images[k]):
423
+ images[k] = images[k][:N]
424
+ if isinstance(images[k], torch.Tensor):
425
+ images[k] = images[k].detach().float().cpu()
426
+ if self.clamp and not isheatmap(images[k]):
427
+ images[k] = torch.clamp(images[k], -1.0, 1.0)
428
+
429
+ self.log_local(
430
+ pl_module.logger.save_dir,
431
+ split,
432
+ images,
433
+ pl_module.global_step,
434
+ pl_module.current_epoch,
435
+ batch_idx,
436
+ pl_module=pl_module
437
+ if isinstance(pl_module.logger, WandbLogger)
438
+ else None,
439
+ )
440
+
441
+ if is_train:
442
+ pl_module.train()
443
+
444
+ def check_frequency(self, check_idx):
445
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
446
+ check_idx > 0 or self.log_first_step
447
+ ):
448
+ try:
449
+ self.log_steps.pop(0)
450
+ except IndexError as e:
451
+ print(e)
452
+ pass
453
+ return True
454
+ return False
455
+
456
+ @rank_zero_only
457
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
458
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
459
+ self.log_img(pl_module, batch, batch_idx, split="train")
460
+
461
+ @rank_zero_only
462
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
463
+ if self.log_before_first_step and pl_module.global_step == 0:
464
+ print(f"{self.__class__.__name__}: logging before training")
465
+ self.log_img(pl_module, batch, batch_idx, split="train")
466
+
467
+ @rank_zero_only
468
+ def on_validation_batch_end(
469
+ self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs
470
+ ):
471
+ if not self.disabled and pl_module.global_step > 0:
472
+ self.log_img(pl_module, batch, batch_idx, split="val")
473
+ if hasattr(pl_module, "calibrate_grad_norm"):
474
+ if (
475
+ pl_module.calibrate_grad_norm and batch_idx % 25 == 0
476
+ ) and batch_idx > 0:
477
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
478
+
479
+
480
+ @rank_zero_only
481
+ def init_wandb(save_dir, opt, config, group_name, name_str):
482
+ print(f"setting WANDB_DIR to {save_dir}")
483
+ os.makedirs(save_dir, exist_ok=True)
484
+
485
+ os.environ["WANDB_DIR"] = save_dir
486
+ if opt.debug:
487
+ wandb.init(project=opt.projectname, mode="offline", group=group_name)
488
+ else:
489
+ wandb.init(
490
+ project=opt.projectname,
491
+ config=config,
492
+ settings=wandb.Settings(code_dir="./sgm"),
493
+ group=group_name,
494
+ name=name_str,
495
+ )
496
+
497
+
498
+ if __name__ == "__main__":
499
+ # custom parser to specify config files, train, test and debug mode,
500
+ # postfix, resume.
501
+ # `--key value` arguments are interpreted as arguments to the trainer.
502
+ # `nested.key=value` arguments are interpreted as config parameters.
503
+ # configs are merged from left-to-right followed by command line parameters.
504
+
505
+ # model:
506
+ # base_learning_rate: float
507
+ # target: path to lightning module
508
+ # params:
509
+ # key: value
510
+ # data:
511
+ # target: main.DataModuleFromConfig
512
+ # params:
513
+ # batch_size: int
514
+ # wrap: bool
515
+ # train:
516
+ # target: path to train dataset
517
+ # params:
518
+ # key: value
519
+ # validation:
520
+ # target: path to validation dataset
521
+ # params:
522
+ # key: value
523
+ # test:
524
+ # target: path to test dataset
525
+ # params:
526
+ # key: value
527
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
528
+ # trainer:
529
+ # additional arguments to trainer
530
+ # logger:
531
+ # logger to instantiate
532
+ # modelcheckpoint:
533
+ # modelcheckpoint to instantiate
534
+ # callbacks:
535
+ # callback1:
536
+ # target: importpath
537
+ # params:
538
+ # key: value
539
+
540
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
541
+
542
+ # add cwd for convenience and to make classes in this file available when
543
+ # running as `python main.py`
544
+ # (in particular `main.DataModuleFromConfig`)
545
+ sys.path.append(os.getcwd())
546
+
547
+ parser = get_parser()
548
+
549
+ opt, unknown = parser.parse_known_args()
550
+
551
+ if opt.name and opt.resume:
552
+ raise ValueError(
553
+ "-n/--name and -r/--resume cannot be specified both."
554
+ "If you want to resume training in a new log folder, "
555
+ "use -n/--name in combination with --resume_from_checkpoint"
556
+ )
557
+ melk_ckpt_name = None
558
+ name = None
559
+ if opt.resume:
560
+ if not os.path.exists(opt.resume):
561
+ raise ValueError("Cannot find {}".format(opt.resume))
562
+ if os.path.isfile(opt.resume):
563
+ paths = opt.resume.split("/")
564
+ # idx = len(paths)-paths[::-1].index("logs")+1
565
+ # logdir = "/".join(paths[:idx])
566
+ logdir = "/".join(paths[:-2])
567
+ ckpt = opt.resume
568
+ _, melk_ckpt_name = get_checkpoint_name(logdir)
569
+ else:
570
+ assert os.path.isdir(opt.resume), opt.resume
571
+ logdir = opt.resume.rstrip("/")
572
+ ckpt, melk_ckpt_name = get_checkpoint_name(logdir)
573
+
574
+ print("#" * 100)
575
+ print(f'Resuming from checkpoint "{ckpt}"')
576
+ print("#" * 100)
577
+
578
+ opt.resume_from_checkpoint = ckpt
579
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
580
+ opt.base = base_configs + opt.base
581
+ _tmp = logdir.split("/")
582
+ nowname = _tmp[-1]
583
+ else:
584
+ if opt.name:
585
+ name = "_" + opt.name
586
+ elif opt.base:
587
+ if opt.no_base_name:
588
+ name = ""
589
+ else:
590
+ if opt.legacy_naming:
591
+ cfg_fname = os.path.split(opt.base[0])[-1]
592
+ cfg_name = os.path.splitext(cfg_fname)[0]
593
+ else:
594
+ assert "configs" in os.path.split(opt.base[0])[0], os.path.split(
595
+ opt.base[0]
596
+ )[0]
597
+ cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[
598
+ os.path.split(opt.base[0])[0].split(os.sep).index("configs")
599
+ + 1 :
600
+ ] # cut away the first one (we assert all configs are in "configs")
601
+ cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0]
602
+ cfg_name = "-".join(cfg_path) + f"-{cfg_name}"
603
+ name = "_" + cfg_name
604
+ else:
605
+ name = ""
606
+ if not opt.no_date:
607
+ nowname = now + name + opt.postfix
608
+ else:
609
+ nowname = name + opt.postfix
610
+ if nowname.startswith("_"):
611
+ nowname = nowname[1:]
612
+ logdir = os.path.join(opt.logdir, nowname)
613
+ print(f"LOGDIR: {logdir}")
614
+
615
+ ckptdir = os.path.join(logdir, "checkpoints")
616
+ cfgdir = os.path.join(logdir, "configs")
617
+ seed_everything(opt.seed, workers=True)
618
+
619
+ # move before model init, in case a torch.compile(...) is called somewhere
620
+ if opt.enable_tf32:
621
+ # pt_version = version.parse(torch.__version__)
622
+ torch.backends.cuda.matmul.allow_tf32 = True
623
+ torch.backends.cudnn.allow_tf32 = True
624
+ print(f"Enabling TF32 for PyTorch {torch.__version__}")
625
+ else:
626
+ print(f"Using default TF32 settings for PyTorch {torch.__version__}:")
627
+ print(
628
+ f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}"
629
+ )
630
+ print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}")
631
+
632
+ try:
633
+ # init and save configs
634
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
635
+ cli = OmegaConf.from_dotlist(unknown)
636
+ config = OmegaConf.merge(*configs, cli)
637
+ lightning_config = config.pop("lightning", OmegaConf.create())
638
+ # merge trainer cli with config
639
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
640
+
641
+ # default to gpu
642
+ trainer_config["accelerator"] = "gpu"
643
+ #
644
+ standard_args = default_trainer_args()
645
+ for k in standard_args:
646
+ if getattr(opt, k) != standard_args[k]:
647
+ trainer_config[k] = getattr(opt, k)
648
+
649
+ ckpt_resume_path = opt.resume_from_checkpoint
650
+
651
+ if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu":
652
+ del trainer_config["accelerator"]
653
+ cpu = True
654
+ else:
655
+ gpuinfo = trainer_config["devices"]
656
+ print(f"Running on GPUs {gpuinfo}")
657
+ cpu = False
658
+ trainer_opt = argparse.Namespace(**trainer_config)
659
+ lightning_config.trainer = trainer_config
660
+
661
+ # model
662
+ model = instantiate_from_config(config.model)
663
+
664
+ # trainer and callbacks
665
+ trainer_kwargs = dict()
666
+
667
+ # default logger configs
668
+ default_logger_cfgs = {
669
+ "wandb": {
670
+ "target": "pytorch_lightning.loggers.WandbLogger",
671
+ "params": {
672
+ "name": nowname,
673
+ # "save_dir": logdir,
674
+ "offline": opt.debug,
675
+ "id": nowname,
676
+ "project": opt.projectname,
677
+ "log_model": False,
678
+ # "dir": logdir,
679
+ },
680
+ },
681
+ "csv": {
682
+ "target": "pytorch_lightning.loggers.CSVLogger",
683
+ "params": {
684
+ "name": "testtube", # hack for sbord fanatics
685
+ "save_dir": logdir,
686
+ },
687
+ },
688
+ }
689
+ default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"]
690
+ if opt.wandb:
691
+ # TODO change once leaving "swiffer" config directory
692
+ try:
693
+ group_name = nowname.split(now)[-1].split("-")[1]
694
+ except:
695
+ group_name = nowname
696
+ default_logger_cfg["params"]["group"] = group_name
697
+ init_wandb(
698
+ os.path.join(os.getcwd(), logdir),
699
+ opt=opt,
700
+ group_name=group_name,
701
+ config=config,
702
+ name_str=nowname,
703
+ )
704
+ if "logger" in lightning_config:
705
+ logger_cfg = lightning_config.logger
706
+ else:
707
+ logger_cfg = OmegaConf.create()
708
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
709
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
710
+
711
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
712
+ # specify which metric is used to determine best models
713
+ default_modelckpt_cfg = {
714
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
715
+ "params": {
716
+ "dirpath": ckptdir,
717
+ "filename": "{epoch:06}",
718
+ "verbose": True,
719
+ "save_last": True,
720
+ },
721
+ }
722
+ if hasattr(model, "monitor"):
723
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
724
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
725
+ default_modelckpt_cfg["params"]["save_top_k"] = 3
726
+
727
+ if "modelcheckpoint" in lightning_config:
728
+ modelckpt_cfg = lightning_config.modelcheckpoint
729
+ else:
730
+ modelckpt_cfg = OmegaConf.create()
731
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
732
+ print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
733
+
734
+ # https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html
735
+ # default to ddp if not further specified
736
+ default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"}
737
+
738
+ if "strategy" in lightning_config:
739
+ strategy_cfg = lightning_config.strategy
740
+ else:
741
+ strategy_cfg = OmegaConf.create()
742
+ default_strategy_config["params"] = {
743
+ "find_unused_parameters": False,
744
+ # "static_graph": True,
745
+ # "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded
746
+ }
747
+ strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg)
748
+ print(
749
+ f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ "
750
+ )
751
+ trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
752
+
753
+ # add callback which sets up log directory
754
+ default_callbacks_cfg = {
755
+ "setup_callback": {
756
+ "target": "main.SetupCallback",
757
+ "params": {
758
+ "resume": opt.resume,
759
+ "now": now,
760
+ "logdir": logdir,
761
+ "ckptdir": ckptdir,
762
+ "cfgdir": cfgdir,
763
+ "config": config,
764
+ "lightning_config": lightning_config,
765
+ "debug": opt.debug,
766
+ "ckpt_name": melk_ckpt_name,
767
+ },
768
+ },
769
+ "image_logger": {
770
+ "target": "main.ImageLogger",
771
+ "params": {"batch_frequency": 1000, "max_images": 4, "clamp": True},
772
+ },
773
+ "learning_rate_logger": {
774
+ "target": "pytorch_lightning.callbacks.LearningRateMonitor",
775
+ "params": {
776
+ "logging_interval": "step",
777
+ # "log_momentum": True
778
+ },
779
+ },
780
+ }
781
+ if version.parse(pl.__version__) >= version.parse("1.4.0"):
782
+ default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
783
+
784
+ if "callbacks" in lightning_config:
785
+ callbacks_cfg = lightning_config.callbacks
786
+ else:
787
+ callbacks_cfg = OmegaConf.create()
788
+
789
+ if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
790
+ print(
791
+ "Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
792
+ )
793
+ default_metrics_over_trainsteps_ckpt_dict = {
794
+ "metrics_over_trainsteps_checkpoint": {
795
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
796
+ "params": {
797
+ "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
798
+ "filename": "{epoch:06}-{step:09}",
799
+ "verbose": True,
800
+ "save_top_k": -1,
801
+ "every_n_train_steps": 10000,
802
+ "save_weights_only": True,
803
+ },
804
+ }
805
+ }
806
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
807
+
808
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
809
+ if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None:
810
+ callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path
811
+ elif "ignore_keys_callback" in callbacks_cfg:
812
+ del callbacks_cfg["ignore_keys_callback"]
813
+
814
+ trainer_kwargs["callbacks"] = [
815
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
816
+ ]
817
+ if not "plugins" in trainer_kwargs:
818
+ trainer_kwargs["plugins"] = list()
819
+
820
+ # cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)
821
+ trainer_opt = vars(trainer_opt)
822
+ trainer_kwargs = {
823
+ key: val for key, val in trainer_kwargs.items() if key not in trainer_opt
824
+ }
825
+ trainer = Trainer(**trainer_opt, **trainer_kwargs)
826
+
827
+ trainer.logdir = logdir ###
828
+
829
+ # data
830
+ data = instantiate_from_config(config.data)
831
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
832
+ # calling these ourselves should not be necessary but it is.
833
+ # lightning still takes care of proper multiprocessing though
834
+ data.prepare_data()
835
+ # data.setup()
836
+ print("#### Data #####")
837
+ try:
838
+ for k in data.datasets:
839
+ print(
840
+ f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
841
+ )
842
+ except:
843
+ print("datasets not yet initialized.")
844
+
845
+ # configure learning rate
846
+ if "batch_size" in config.data.params:
847
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
848
+ else:
849
+ bs, base_lr = (
850
+ config.data.params.train.loader.batch_size,
851
+ config.model.base_learning_rate,
852
+ )
853
+ if not cpu:
854
+ ngpu = len(lightning_config.trainer.devices.strip(",").split(","))
855
+ else:
856
+ ngpu = 1
857
+ if "accumulate_grad_batches" in lightning_config.trainer:
858
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
859
+ else:
860
+ accumulate_grad_batches = 1
861
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
862
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
863
+ if opt.scale_lr:
864
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
865
+ print(
866
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
867
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
868
+ )
869
+ )
870
+ else:
871
+ model.learning_rate = base_lr
872
+ print("++++ NOT USING LR SCALING ++++")
873
+ print(f"Setting learning rate to {model.learning_rate:.2e}")
874
+
875
+ # allow checkpointing via USR1
876
+ def melk(*args, **kwargs):
877
+ # run all checkpoint hooks
878
+ if trainer.global_rank == 0:
879
+ print("Summoning checkpoint.")
880
+ if melk_ckpt_name is None:
881
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
882
+ else:
883
+ ckpt_path = os.path.join(ckptdir, melk_ckpt_name)
884
+ trainer.save_checkpoint(ckpt_path)
885
+
886
+ def divein(*args, **kwargs):
887
+ if trainer.global_rank == 0:
888
+ import pudb
889
+
890
+ pudb.set_trace()
891
+
892
+ import signal
893
+
894
+ signal.signal(signal.SIGUSR1, melk)
895
+ signal.signal(signal.SIGUSR2, divein)
896
+
897
+ # run
898
+ if opt.train:
899
+ try:
900
+ trainer.fit(model, data, ckpt_path=ckpt_resume_path)
901
+ except Exception:
902
+ if not opt.debug:
903
+ melk()
904
+ raise
905
+ if not opt.no_test and not trainer.interrupted:
906
+ trainer.test(model, data)
907
+ except RuntimeError as err:
908
+ if MULTINODE_HACKS:
909
+ import datetime
910
+ import os
911
+ import socket
912
+
913
+ import requests
914
+
915
+ device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
916
+ hostname = socket.gethostname()
917
+ ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
918
+ resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id")
919
+ print(
920
+ f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}",
921
+ flush=True,
922
+ )
923
+ raise err
924
+ except Exception:
925
+ if opt.debug and trainer.global_rank == 0:
926
+ try:
927
+ import pudb as debugger
928
+ except ImportError:
929
+ import pdb as debugger
930
+ debugger.post_mortem()
931
+ raise
932
+ finally:
933
+ # move newly created debug project to debug_runs
934
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
935
+ dst, name = os.path.split(logdir)
936
+ dst = os.path.join(dst, "debug_runs", name)
937
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
938
+ os.rename(logdir, dst)
939
+
940
+ if opt.wandb:
941
+ wandb.finish()
942
+ # if trainer.global_rank == 0:
943
+ # print(trainer.profiler.summary())
model_licenses/LICENSE-SDV ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
2
+ Dated: November 21, 2023
3
+
4
+ “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
5
+
6
+ "Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
7
+ "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
8
+ “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
9
+
10
+ "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
11
+
12
+ "Stability AI" or "we" means Stability AI Ltd.
13
+
14
+ "Software" means, collectively, Stability AI’s proprietary StableCode made available under this Agreement.
15
+
16
+ “Software Products” means Software and Documentation.
17
+
18
+ By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.
19
+
20
+
21
+
22
+ License Rights and Redistribution.
23
+ Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use.
24
+ b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
25
+ 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
26
+ 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
27
+ 3. Intellectual Property.
28
+ a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products.
29
+ Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works.
30
+ If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement.
31
+ 4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement.
model_licenses/LICENSE-SDXL0.9 ADDED
@@ -0,0 +1,75 @@