m10813108 commited on
Commit
d9f3603
·
1 Parent(s): 9c166e5

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitignore +14 -0
  2. CODEOWNERS +1 -0
  3. LICENSE-CODE +21 -0
  4. README.md +253 -3
  5. main.py +943 -0
  6. pyproject.toml +48 -0
  7. pytest.ini +3 -0
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extensions
2
+ *.egg-info
3
+ *.py[cod]
4
+
5
+ # envs
6
+ .pt13
7
+ .pt2
8
+
9
+ # directories
10
+ /checkpoints
11
+ /dist
12
+ /outputs
13
+ /build
14
+ /src
CODEOWNERS ADDED
@@ -0,0 +1 @@
 
 
1
+ .github @Stability-AI/infrastructure
LICENSE-CODE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,253 @@
1
- ---
2
- license: unknown
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generative Models by Stability AI
2
+
3
+ ![sample1](assets/000.jpg)
4
+
5
+ ## News
6
+
7
+ **July 26, 2023**
8
+ - We are releasing two new open models with a permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file hashes):
9
+ - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version over `SDXL-base-0.9`.
10
+ - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version over `SDXL-refiner-0.9`.
11
+
12
+ ![sample2](assets/001_with_eval.png)
13
+
14
+
15
+ **July 4, 2023**
16
+ - A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
17
+
18
+ **June 22, 2023**
19
+
20
+
21
+ - We are releasing two new diffusion models for research purposes:
22
+ - `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
23
+ - `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
24
+
25
+ If you would like to access these models for your research, please apply using one of the following links:
26
+ [SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
27
+ This means that you can apply for any of the two links - and if you are granted - you can access both.
28
+ Please log in to your Hugging Face Account with your organization email to request access.
29
+ **We plan to do a full release soon (July).**
30
+
31
+ ## The codebase
32
+
33
+ ### General Philosophy
34
+
35
+ Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
36
+
37
+ ### Changelog from the old `ldm` codebase
38
+
39
+ For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
40
+
41
+ - No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
42
+ - We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
43
+ samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
44
+ - We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models):
45
+ * Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`.
46
+ * The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
47
+ - Autoencoding models have also been cleaned up.
48
+
49
+ ## Installation:
50
+ <a name="installation"></a>
51
+
52
+ #### 1. Clone the repo
53
+
54
+ ```shell
55
+ git clone git@github.com:Stability-AI/generative-models.git
56
+ cd generative-models
57
+ ```
58
+
59
+ #### 2. Setting up the virtualenv
60
+
61
+ This is assuming you have navigated to the `generative-models` root after cloning it.
62
+
63
+ **NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.
64
+
65
+
66
+ **PyTorch 1.13**
67
+
68
+ ```shell
69
+ # install required packages from pypi
70
+ python3 -m venv .pt13
71
+ source .pt13/bin/activate
72
+ pip3 install -r requirements/pt13.txt
73
+ ```
74
+
75
+ **PyTorch 2.0**
76
+
77
+
78
+ ```shell
79
+ # install required packages from pypi
80
+ python3 -m venv .pt2
81
+ source .pt2/bin/activate
82
+ pip3 install -r requirements/pt2.txt
83
+ ```
84
+
85
+
86
+ #### 3. Install `sgm`
87
+
88
+ ```shell
89
+ pip3 install .
90
+ ```
91
+
92
+ #### 4. Install `sdata` for training
93
+
94
+ ```shell
95
+ pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
96
+ ```
97
+
98
+ ## Packaging
99
+
100
+ This repository uses PEP 517 compliant packaging using [Hatch](https://hatch.pypa.io/latest/).
101
+
102
+ To build a distributable wheel, install `hatch` and run `hatch build`
103
+ (specifying `-t wheel` will skip building a sdist, which is not necessary).
104
+
105
+ ```
106
+ pip install hatch
107
+ hatch build -t wheel
108
+ ```
109
+
110
+ You will find the built package in `dist/`. You can install the wheel with `pip install dist/*.whl`.
111
+
112
+ Note that the package does **not** currently specify dependencies; you will need to install the required packages,
113
+ depending on your use case and PyTorch version, manually.
114
+
115
+ ## Inference
116
+
117
+ We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`.
118
+ We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
119
+ The following models are currently supported:
120
+
121
+ - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
122
+ ```
123
+ File Hash (sha256): 31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b
124
+ Tensordata Hash (sha256): 0xd7a9105a900fd52748f20725fe52fe52b507fd36bee4fc107b1550a26e6ee1d7
125
+ ```
126
+ - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)
127
+ ```
128
+ File Hash (sha256): 7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f
129
+ Tensordata Hash (sha256): 0x1a77d21bebc4b4de78c474a90cb74dc0d2217caf4061971dbfa75ad406b75d81
130
+ ```
131
+ - [SDXL-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
132
+ - [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
133
+ - [SD-2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
134
+ - [SD-2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
135
+
136
+ **Weights for SDXL**:
137
+
138
+ **SDXL-1.0:**
139
+ The weights of SDXL-1.0 are available (subject to a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
140
+ - base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
141
+ - refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
142
+
143
+
144
+ **SDXL-0.9:**
145
+ The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
146
+ If you would like to access these models for your research, please apply using one of the following links:
147
+ [SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
148
+ This means that you can apply for any of the two links - and if you are granted - you can access both.
149
+ Please log in to your Hugging Face Account with your organization email to request access.
150
+
151
+
152
+ After obtaining the weights, place them into `checkpoints/`.
153
+ Next, start the demo using
154
+
155
+ ```
156
+ streamlit run scripts/demo/sampling.py --server.port <your_port>
157
+ ```
158
+
159
+ ### Invisible Watermark Detection
160
+
161
+ Images generated with our code use the
162
+ [invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/)
163
+ library to embed an invisible watermark into the model output. We also provide
164
+ a script to easily detect that watermark. Please note that this watermark is
165
+ not the same as in previous Stable Diffusion 1.x/2.x versions.
166
+
167
+ To run the script you need to either have a working installation as above or
168
+ try an _experimental_ import using only a minimal amount of packages:
169
+ ```bash
170
+ python -m venv .detect
171
+ source .detect/bin/activate
172
+
173
+ pip install "numpy>=1.17" "PyWavelets>=1.1.1" "opencv-python>=4.1.0.25"
174
+ pip install --no-deps invisible-watermark
175
+ ```
176
+
177
+ To run the script you need to have a working installation as above. The script
178
+ is then useable in the following ways (don't forget to activate your
179
+ virtual environment beforehand, e.g. `source .pt1/bin/activate`):
180
+ ```bash
181
+ # test a single file
182
+ python scripts/demo/detect.py <your filename here>
183
+ # test multiple files at once
184
+ python scripts/demo/detect.py <filename 1> <filename 2> ... <filename n>
185
+ # test all files in a specific folder
186
+ python scripts/demo/detect.py <your folder name here>/*
187
+ ```
188
+
189
+ ## Training:
190
+
191
+ We are providing example training configs in `configs/example_training`. To launch a training, run
192
+
193
+ ```
194
+ python main.py --base configs/<config1.yaml> configs/<config2.yaml>
195
+ ```
196
+
197
+ where configs are merged from left to right (later configs overwrite the same values).
198
+ This can be used to combine model, training and data configs. However, all of them can also be
199
+ defined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST,
200
+ run
201
+
202
+ ```bash
203
+ python main.py --base configs/example_training/toy/mnist_cond.yaml
204
+ ```
205
+
206
+ **NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
207
+
208
+ **NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
209
+
210
+ **NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs.
211
+
212
+ ### Building New Diffusion Models
213
+
214
+ #### Conditioner
215
+
216
+ The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
217
+ different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
218
+ All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
219
+ guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning.
220
+ When computing conditionings, the embedder will get `batch[input_key]` as input.
221
+ We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
222
+ appropriately.
223
+ Note that the order of the embedders in the `conditioner_config` is important.
224
+
225
+ #### Network
226
+
227
+ The neural network is set through the `network_config`. This used to be called `unet_config`, which is not general
228
+ enough as we plan to experiment with transformer-based diffusion backbones.
229
+
230
+ #### Loss
231
+
232
+ The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`.
233
+
234
+ #### Sampler config
235
+
236
+ As discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical
237
+ solver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free
238
+ guidance.
239
+
240
+ ### Dataset Handling
241
+
242
+
243
+ For large scale training we recommend using the data pipelines from our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
244
+ Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
245
+ data keys/values,
246
+ e.g.,
247
+
248
+ ```python
249
+ example = {"jpg": x, # this is a tensor -1...1 chw
250
+ "txt": "a beautiful image"}
251
+ ```
252
+
253
+ where we expect images in -1...1, channel-first format.
main.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import glob
4
+ import inspect
5
+ import os
6
+ import sys
7
+ from inspect import Parameter
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ import pytorch_lightning as pl
12
+ import torch
13
+ import torchvision
14
+ import wandb
15
+ from matplotlib import pyplot as plt
16
+ from natsort import natsorted
17
+ from omegaconf import OmegaConf
18
+ from packaging import version
19
+ from PIL import Image
20
+ from pytorch_lightning import seed_everything
21
+ from pytorch_lightning.callbacks import Callback
22
+ from pytorch_lightning.loggers import WandbLogger
23
+ from pytorch_lightning.trainer import Trainer
24
+ from pytorch_lightning.utilities import rank_zero_only
25
+
26
+ from sgm.util import exists, instantiate_from_config, isheatmap
27
+
28
+ MULTINODE_HACKS = True
29
+
30
+
31
+ def default_trainer_args():
32
+ argspec = dict(inspect.signature(Trainer.__init__).parameters)
33
+ argspec.pop("self")
34
+ default_args = {
35
+ param: argspec[param].default
36
+ for param in argspec
37
+ if argspec[param] != Parameter.empty
38
+ }
39
+ return default_args
40
+
41
+
42
+ def get_parser(**parser_kwargs):
43
+ def str2bool(v):
44
+ if isinstance(v, bool):
45
+ return v
46
+ if v.lower() in ("yes", "true", "t", "y", "1"):
47
+ return True
48
+ elif v.lower() in ("no", "false", "f", "n", "0"):
49
+ return False
50
+ else:
51
+ raise argparse.ArgumentTypeError("Boolean value expected.")
52
+
53
+ parser = argparse.ArgumentParser(**parser_kwargs)
54
+ parser.add_argument(
55
+ "-n",
56
+ "--name",
57
+ type=str,
58
+ const=True,
59
+ default="",
60
+ nargs="?",
61
+ help="postfix for logdir",
62
+ )
63
+ parser.add_argument(
64
+ "--no_date",
65
+ type=str2bool,
66
+ nargs="?",
67
+ const=True,
68
+ default=False,
69
+ help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)",
70
+ )
71
+ parser.add_argument(
72
+ "-r",
73
+ "--resume",
74
+ type=str,
75
+ const=True,
76
+ default="",
77
+ nargs="?",
78
+ help="resume from logdir or checkpoint in logdir",
79
+ )
80
+ parser.add_argument(
81
+ "-b",
82
+ "--base",
83
+ nargs="*",
84
+ metavar="base_config.yaml",
85
+ help="paths to base configs. Loaded from left-to-right. "
86
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
87
+ default=list(),
88
+ )
89
+ parser.add_argument(
90
+ "-t",
91
+ "--train",
92
+ type=str2bool,
93
+ const=True,
94
+ default=True,
95
+ nargs="?",
96
+ help="train",
97
+ )
98
+ parser.add_argument(
99
+ "--no-test",
100
+ type=str2bool,
101
+ const=True,
102
+ default=False,
103
+ nargs="?",
104
+ help="disable test",
105
+ )
106
+ parser.add_argument(
107
+ "-p", "--project", help="name of new or path to existing project"
108
+ )
109
+ parser.add_argument(
110
+ "-d",
111
+ "--debug",
112
+ type=str2bool,
113
+ nargs="?",
114
+ const=True,
115
+ default=False,
116
+ help="enable post-mortem debugging",
117
+ )
118
+ parser.add_argument(
119
+ "-s",
120
+ "--seed",
121
+ type=int,
122
+ default=23,
123
+ help="seed for seed_everything",
124
+ )
125
+ parser.add_argument(
126
+ "-f",
127
+ "--postfix",
128
+ type=str,
129
+ default="",
130
+ help="post-postfix for default name",
131
+ )
132
+ parser.add_argument(
133
+ "--projectname",
134
+ type=str,
135
+ default="stablediffusion",
136
+ )
137
+ parser.add_argument(
138
+ "-l",
139
+ "--logdir",
140
+ type=str,
141
+ default="logs",
142
+ help="directory for logging dat shit",
143
+ )
144
+ parser.add_argument(
145
+ "--scale_lr",
146
+ type=str2bool,
147
+ nargs="?",
148
+ const=True,
149
+ default=False,
150
+ help="scale base-lr by ngpu * batch_size * n_accumulate",
151
+ )
152
+ parser.add_argument(
153
+ "--legacy_naming",
154
+ type=str2bool,
155
+ nargs="?",
156
+ const=True,
157
+ default=False,
158
+ help="name run based on config file name if true, else by whole path",
159
+ )
160
+ parser.add_argument(
161
+ "--enable_tf32",
162
+ type=str2bool,
163
+ nargs="?",
164
+ const=True,
165
+ default=False,
166
+ help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12",
167
+ )
168
+ parser.add_argument(
169
+ "--startup",
170
+ type=str,
171
+ default=None,
172
+ help="Startuptime from distributed script",
173
+ )
174
+ parser.add_argument(
175
+ "--wandb",
176
+ type=str2bool,
177
+ nargs="?",
178
+ const=True,
179
+ default=False, # TODO: later default to True
180
+ help="log to wandb",
181
+ )
182
+ parser.add_argument(
183
+ "--no_base_name",
184
+ type=str2bool,
185
+ nargs="?",
186
+ const=True,
187
+ default=False, # TODO: later default to True
188
+ help="log to wandb",
189
+ )
190
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
191
+ parser.add_argument(
192
+ "--resume_from_checkpoint",
193
+ type=str,
194
+ default=None,
195
+ help="single checkpoint file to resume from",
196
+ )
197
+ default_args = default_trainer_args()
198
+ for key in default_args:
199
+ parser.add_argument("--" + key, default=default_args[key])
200
+ return parser
201
+
202
+
203
+ def get_checkpoint_name(logdir):
204
+ ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt")
205
+ ckpt = natsorted(glob.glob(ckpt))
206
+ print('available "last" checkpoints:')
207
+ print(ckpt)
208
+ if len(ckpt) > 1:
209
+ print("got most recent checkpoint")
210
+ ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1]
211
+ print(f"Most recent ckpt is {ckpt}")
212
+ with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f:
213
+ f.write(ckpt + "\n")
214
+ try:
215
+ version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0])
216
+ except Exception as e:
217
+ print("version confusion but not bad")
218
+ print(e)
219
+ version = 1
220
+ # version = last_version + 1
221
+ else:
222
+ # in this case, we only have one "last.ckpt"
223
+ ckpt = ckpt[0]
224
+ version = 1
225
+ melk_ckpt_name = f"last-v{version}.ckpt"
226
+ print(f"Current melk ckpt name: {melk_ckpt_name}")
227
+ return ckpt, melk_ckpt_name
228
+
229
+
230
+ class SetupCallback(Callback):
231
+ def __init__(
232
+ self,
233
+ resume,
234
+ now,
235
+ logdir,
236
+ ckptdir,
237
+ cfgdir,
238
+ config,
239
+ lightning_config,
240
+ debug,
241
+ ckpt_name=None,
242
+ ):
243
+ super().__init__()
244
+ self.resume = resume
245
+ self.now = now
246
+ self.logdir = logdir
247
+ self.ckptdir = ckptdir
248
+ self.cfgdir = cfgdir
249
+ self.config = config
250
+ self.lightning_config = lightning_config
251
+ self.debug = debug
252
+ self.ckpt_name = ckpt_name
253
+
254
+ def on_exception(self, trainer: pl.Trainer, pl_module, exception):
255
+ if not self.debug and trainer.global_rank == 0:
256
+ print("Summoning checkpoint.")
257
+ if self.ckpt_name is None:
258
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
259
+ else:
260
+ ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
261
+ trainer.save_checkpoint(ckpt_path)
262
+
263
+ def on_fit_start(self, trainer, pl_module):
264
+ if trainer.global_rank == 0:
265
+ # Create logdirs and save configs
266
+ os.makedirs(self.logdir, exist_ok=True)
267
+ os.makedirs(self.ckptdir, exist_ok=True)
268
+ os.makedirs(self.cfgdir, exist_ok=True)
269
+
270
+ if "callbacks" in self.lightning_config:
271
+ if (
272
+ "metrics_over_trainsteps_checkpoint"
273
+ in self.lightning_config["callbacks"]
274
+ ):
275
+ os.makedirs(
276
+ os.path.join(self.ckptdir, "trainstep_checkpoints"),
277
+ exist_ok=True,
278
+ )
279
+ print("Project config")
280
+ print(OmegaConf.to_yaml(self.config))
281
+ if MULTINODE_HACKS:
282
+ import time
283
+
284
+ time.sleep(5)
285
+ OmegaConf.save(
286
+ self.config,
287
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
288
+ )
289
+
290
+ print("Lightning config")
291
+ print(OmegaConf.to_yaml(self.lightning_config))
292
+ OmegaConf.save(
293
+ OmegaConf.create({"lightning": self.lightning_config}),
294
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
295
+ )
296
+
297
+ else:
298
+ # ModelCheckpoint callback created log directory --- remove it
299
+ if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
300
+ dst, name = os.path.split(self.logdir)
301
+ dst = os.path.join(dst, "child_runs", name)
302
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
303
+ try:
304
+ os.rename(self.logdir, dst)
305
+ except FileNotFoundError:
306
+ pass
307
+
308
+
309
+ class ImageLogger(Callback):
310
+ def __init__(
311
+ self,
312
+ batch_frequency,
313
+ max_images,
314
+ clamp=True,
315
+ increase_log_steps=True,
316
+ rescale=True,
317
+ disabled=False,
318
+ log_on_batch_idx=False,
319
+ log_first_step=False,
320
+ log_images_kwargs=None,
321
+ log_before_first_step=False,
322
+ enable_autocast=True,
323
+ ):
324
+ super().__init__()
325
+ self.enable_autocast = enable_autocast
326
+ self.rescale = rescale
327
+ self.batch_freq = batch_frequency
328
+ self.max_images = max_images
329
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
330
+ if not increase_log_steps:
331
+ self.log_steps = [self.batch_freq]
332
+ self.clamp = clamp
333
+ self.disabled = disabled
334
+ self.log_on_batch_idx = log_on_batch_idx
335
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
336
+ self.log_first_step = log_first_step
337
+ self.log_before_first_step = log_before_first_step
338
+
339
+ @rank_zero_only
340
+ def log_local(
341
+ self,
342
+ save_dir,
343
+ split,
344
+ images,
345
+ global_step,
346
+ current_epoch,
347
+ batch_idx,
348
+ pl_module: Union[None, pl.LightningModule] = None,
349
+ ):
350
+ root = os.path.join(save_dir, "images", split)
351
+ for k in images:
352
+ if isheatmap(images[k]):
353
+ fig, ax = plt.subplots()
354
+ ax = ax.matshow(
355
+ images[k].cpu().numpy(), cmap="hot", interpolation="lanczos"
356
+ )
357
+ plt.colorbar(ax)
358
+ plt.axis("off")
359
+
360
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
361
+ k, global_step, current_epoch, batch_idx
362
+ )
363
+ os.makedirs(root, exist_ok=True)
364
+ path = os.path.join(root, filename)
365
+ plt.savefig(path)
366
+ plt.close()
367
+ # TODO: support wandb
368
+ else:
369
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
370
+ if self.rescale:
371
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
372
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
373
+ grid = grid.numpy()
374
+ grid = (grid * 255).astype(np.uint8)
375
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
376
+ k, global_step, current_epoch, batch_idx
377
+ )
378
+ path = os.path.join(root, filename)
379
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
380
+ img = Image.fromarray(grid)
381
+ img.save(path)
382
+ if exists(pl_module):
383
+ assert isinstance(
384
+ pl_module.logger, WandbLogger
385
+ ), "logger_log_image only supports WandbLogger currently"
386
+ pl_module.logger.log_image(
387
+ key=f"{split}/{k}",
388
+ images=[
389
+ img,
390
+ ],
391
+ step=pl_module.global_step,
392
+ )
393
+
394
+ @rank_zero_only
395
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
396
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
397
+ if (
398
+ self.check_frequency(check_idx)
399
+ and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
400
+ and callable(pl_module.log_images)
401
+ and
402
+ # batch_idx > 5 and
403
+ self.max_images > 0
404
+ ):
405
+ logger = type(pl_module.logger)
406
+ is_train = pl_module.training
407
+ if is_train:
408
+ pl_module.eval()
409
+
410
+ gpu_autocast_kwargs = {
411
+ "enabled": self.enable_autocast, # torch.is_autocast_enabled(),
412
+ "dtype": torch.get_autocast_gpu_dtype(),
413
+ "cache_enabled": torch.is_autocast_cache_enabled(),
414
+ }
415
+ with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
416
+ images = pl_module.log_images(
417
+ batch, split=split, **self.log_images_kwargs
418
+ )
419
+
420
+ for k in images:
421
+ N = min(images[k].shape[0], self.max_images)
422
+ if not isheatmap(images[k]):
423
+ images[k] = images[k][:N]
424
+ if isinstance(images[k], torch.Tensor):
425
+ images[k] = images[k].detach().float().cpu()
426
+ if self.clamp and not isheatmap(images[k]):
427
+ images[k] = torch.clamp(images[k], -1.0, 1.0)
428
+
429
+ self.log_local(
430
+ pl_module.logger.save_dir,
431
+ split,
432
+ images,
433
+ pl_module.global_step,
434
+ pl_module.current_epoch,
435
+ batch_idx,
436
+ pl_module=pl_module
437
+ if isinstance(pl_module.logger, WandbLogger)
438
+ else None,
439
+ )
440
+
441
+ if is_train:
442
+ pl_module.train()
443
+
444
+ def check_frequency(self, check_idx):
445
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
446
+ check_idx > 0 or self.log_first_step
447
+ ):
448
+ try:
449
+ self.log_steps.pop(0)
450
+ except IndexError as e:
451
+ print(e)
452
+ pass
453
+ return True
454
+ return False
455
+
456
+ @rank_zero_only
457
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
458
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
459
+ self.log_img(pl_module, batch, batch_idx, split="train")
460
+
461
+ @rank_zero_only
462
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
463
+ if self.log_before_first_step and pl_module.global_step == 0:
464
+ print(f"{self.__class__.__name__}: logging before training")
465
+ self.log_img(pl_module, batch, batch_idx, split="train")
466
+
467
+ @rank_zero_only
468
+ def on_validation_batch_end(
469
+ self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs
470
+ ):
471
+ if not self.disabled and pl_module.global_step > 0:
472
+ self.log_img(pl_module, batch, batch_idx, split="val")
473
+ if hasattr(pl_module, "calibrate_grad_norm"):
474
+ if (
475
+ pl_module.calibrate_grad_norm and batch_idx % 25 == 0
476
+ ) and batch_idx > 0:
477
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
478
+
479
+
480
+ @rank_zero_only
481
+ def init_wandb(save_dir, opt, config, group_name, name_str):
482
+ print(f"setting WANDB_DIR to {save_dir}")
483
+ os.makedirs(save_dir, exist_ok=True)
484
+
485
+ os.environ["WANDB_DIR"] = save_dir
486
+ if opt.debug:
487
+ wandb.init(project=opt.projectname, mode="offline", group=group_name)
488
+ else:
489
+ wandb.init(
490
+ project=opt.projectname,
491
+ config=config,
492
+ settings=wandb.Settings(code_dir="./sgm"),
493
+ group=group_name,
494
+ name=name_str,
495
+ )
496
+
497
+
498
+ if __name__ == "__main__":
499
+ # custom parser to specify config files, train, test and debug mode,
500
+ # postfix, resume.
501
+ # `--key value` arguments are interpreted as arguments to the trainer.
502
+ # `nested.key=value` arguments are interpreted as config parameters.
503
+ # configs are merged from left-to-right followed by command line parameters.
504
+
505
+ # model:
506
+ # base_learning_rate: float
507
+ # target: path to lightning module
508
+ # params:
509
+ # key: value
510
+ # data:
511
+ # target: main.DataModuleFromConfig
512
+ # params:
513
+ # batch_size: int
514
+ # wrap: bool
515
+ # train:
516
+ # target: path to train dataset
517
+ # params:
518
+ # key: value
519
+ # validation:
520
+ # target: path to validation dataset
521
+ # params:
522
+ # key: value
523
+ # test:
524
+ # target: path to test dataset
525
+ # params:
526
+ # key: value
527
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
528
+ # trainer:
529
+ # additional arguments to trainer
530
+ # logger:
531
+ # logger to instantiate
532
+ # modelcheckpoint:
533
+ # modelcheckpoint to instantiate
534
+ # callbacks:
535
+ # callback1:
536
+ # target: importpath
537
+ # params:
538
+ # key: value
539
+
540
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
541
+
542
+ # add cwd for convenience and to make classes in this file available when
543
+ # running as `python main.py`
544
+ # (in particular `main.DataModuleFromConfig`)
545
+ sys.path.append(os.getcwd())
546
+
547
+ parser = get_parser()
548
+
549
+ opt, unknown = parser.parse_known_args()
550
+
551
+ if opt.name and opt.resume:
552
+ raise ValueError(
553
+ "-n/--name and -r/--resume cannot be specified both."
554
+ "If you want to resume training in a new log folder, "
555
+ "use -n/--name in combination with --resume_from_checkpoint"
556
+ )
557
+ melk_ckpt_name = None
558
+ name = None
559
+ if opt.resume:
560
+ if not os.path.exists(opt.resume):
561
+ raise ValueError("Cannot find {}".format(opt.resume))
562
+ if os.path.isfile(opt.resume):
563
+ paths = opt.resume.split("/")
564
+ # idx = len(paths)-paths[::-1].index("logs")+1
565
+ # logdir = "/".join(paths[:idx])
566
+ logdir = "/".join(paths[:-2])
567
+ ckpt = opt.resume
568
+ _, melk_ckpt_name = get_checkpoint_name(logdir)
569
+ else:
570
+ assert os.path.isdir(opt.resume), opt.resume
571
+ logdir = opt.resume.rstrip("/")
572
+ ckpt, melk_ckpt_name = get_checkpoint_name(logdir)
573
+
574
+ print("#" * 100)
575
+ print(f'Resuming from checkpoint "{ckpt}"')
576
+ print("#" * 100)
577
+
578
+ opt.resume_from_checkpoint = ckpt
579
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
580
+ opt.base = base_configs + opt.base
581
+ _tmp = logdir.split("/")
582
+ nowname = _tmp[-1]
583
+ else:
584
+ if opt.name:
585
+ name = "_" + opt.name
586
+ elif opt.base:
587
+ if opt.no_base_name:
588
+ name = ""
589
+ else:
590
+ if opt.legacy_naming:
591
+ cfg_fname = os.path.split(opt.base[0])[-1]
592
+ cfg_name = os.path.splitext(cfg_fname)[0]
593
+ else:
594
+ assert "configs" in os.path.split(opt.base[0])[0], os.path.split(
595
+ opt.base[0]
596
+ )[0]
597
+ cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[
598
+ os.path.split(opt.base[0])[0].split(os.sep).index("configs")
599
+ + 1 :
600
+ ] # cut away the first one (we assert all configs are in "configs")
601
+ cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0]
602
+ cfg_name = "-".join(cfg_path) + f"-{cfg_name}"
603
+ name = "_" + cfg_name
604
+ else:
605
+ name = ""
606
+ if not opt.no_date:
607
+ nowname = now + name + opt.postfix
608
+ else:
609
+ nowname = name + opt.postfix
610
+ if nowname.startswith("_"):
611
+ nowname = nowname[1:]
612
+ logdir = os.path.join(opt.logdir, nowname)
613
+ print(f"LOGDIR: {logdir}")
614
+
615
+ ckptdir = os.path.join(logdir, "checkpoints")
616
+ cfgdir = os.path.join(logdir, "configs")
617
+ seed_everything(opt.seed, workers=True)
618
+
619
+ # move before model init, in case a torch.compile(...) is called somewhere
620
+ if opt.enable_tf32:
621
+ # pt_version = version.parse(torch.__version__)
622
+ torch.backends.cuda.matmul.allow_tf32 = True
623
+ torch.backends.cudnn.allow_tf32 = True
624
+ print(f"Enabling TF32 for PyTorch {torch.__version__}")
625
+ else:
626
+ print(f"Using default TF32 settings for PyTorch {torch.__version__}:")
627
+ print(
628
+ f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}"
629
+ )
630
+ print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}")
631
+
632
+ try:
633
+ # init and save configs
634
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
635
+ cli = OmegaConf.from_dotlist(unknown)
636
+ config = OmegaConf.merge(*configs, cli)
637
+ lightning_config = config.pop("lightning", OmegaConf.create())
638
+ # merge trainer cli with config
639
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
640
+
641
+ # default to gpu
642
+ trainer_config["accelerator"] = "gpu"
643
+ #
644
+ standard_args = default_trainer_args()
645
+ for k in standard_args:
646
+ if getattr(opt, k) != standard_args[k]:
647
+ trainer_config[k] = getattr(opt, k)
648
+
649
+ ckpt_resume_path = opt.resume_from_checkpoint
650
+
651
+ if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu":
652
+ del trainer_config["accelerator"]
653
+ cpu = True
654
+ else:
655
+ gpuinfo = trainer_config["devices"]
656
+ print(f"Running on GPUs {gpuinfo}")
657
+ cpu = False
658
+ trainer_opt = argparse.Namespace(**trainer_config)
659
+ lightning_config.trainer = trainer_config
660
+
661
+ # model
662
+ model = instantiate_from_config(config.model)
663
+
664
+ # trainer and callbacks
665
+ trainer_kwargs = dict()
666
+
667
+ # default logger configs
668
+ default_logger_cfgs = {
669
+ "wandb": {
670
+ "target": "pytorch_lightning.loggers.WandbLogger",
671
+ "params": {
672
+ "name": nowname,
673
+ # "save_dir": logdir,
674
+ "offline": opt.debug,
675
+ "id": nowname,
676
+ "project": opt.projectname,
677
+ "log_model": False,
678
+ # "dir": logdir,
679
+ },
680
+ },
681
+ "csv": {
682
+ "target": "pytorch_lightning.loggers.CSVLogger",
683
+ "params": {
684
+ "name": "testtube", # hack for sbord fanatics
685
+ "save_dir": logdir,
686
+ },
687
+ },
688
+ }
689
+ default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"]
690
+ if opt.wandb:
691
+ # TODO change once leaving "swiffer" config directory
692
+ try:
693
+ group_name = nowname.split(now)[-1].split("-")[1]
694
+ except:
695
+ group_name = nowname
696
+ default_logger_cfg["params"]["group"] = group_name
697
+ init_wandb(
698
+ os.path.join(os.getcwd(), logdir),
699
+ opt=opt,
700
+ group_name=group_name,
701
+ config=config,
702
+ name_str=nowname,
703
+ )
704
+ if "logger" in lightning_config:
705
+ logger_cfg = lightning_config.logger
706
+ else:
707
+ logger_cfg = OmegaConf.create()
708
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
709
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
710
+
711
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
712
+ # specify which metric is used to determine best models
713
+ default_modelckpt_cfg = {
714
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
715
+ "params": {
716
+ "dirpath": ckptdir,
717
+ "filename": "{epoch:06}",
718
+ "verbose": True,
719
+ "save_last": True,
720
+ },
721
+ }
722
+ if hasattr(model, "monitor"):
723
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
724
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
725
+ default_modelckpt_cfg["params"]["save_top_k"] = 3
726
+
727
+ if "modelcheckpoint" in lightning_config:
728
+ modelckpt_cfg = lightning_config.modelcheckpoint
729
+ else:
730
+ modelckpt_cfg = OmegaConf.create()
731
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
732
+ print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
733
+
734
+ # https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html
735
+ # default to ddp if not further specified
736
+ default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"}
737
+
738
+ if "strategy" in lightning_config:
739
+ strategy_cfg = lightning_config.strategy
740
+ else:
741
+ strategy_cfg = OmegaConf.create()
742
+ default_strategy_config["params"] = {
743
+ "find_unused_parameters": False,
744
+ # "static_graph": True,
745
+ # "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded
746
+ }
747
+ strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg)
748
+ print(
749
+ f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ "
750
+ )
751
+ trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
752
+
753
+ # add callback which sets up log directory
754
+ default_callbacks_cfg = {
755
+ "setup_callback": {
756
+ "target": "main.SetupCallback",
757
+ "params": {
758
+ "resume": opt.resume,
759
+ "now": now,
760
+ "logdir": logdir,
761
+ "ckptdir": ckptdir,
762
+ "cfgdir": cfgdir,
763
+ "config": config,
764
+ "lightning_config": lightning_config,
765
+ "debug": opt.debug,
766
+ "ckpt_name": melk_ckpt_name,
767
+ },
768
+ },
769
+ "image_logger": {
770
+ "target": "main.ImageLogger",
771
+ "params": {"batch_frequency": 1000, "max_images": 4, "clamp": True},
772
+ },
773
+ "learning_rate_logger": {
774
+ "target": "pytorch_lightning.callbacks.LearningRateMonitor",
775
+ "params": {
776
+ "logging_interval": "step",
777
+ # "log_momentum": True
778
+ },
779
+ },
780
+ }
781
+ if version.parse(pl.__version__) >= version.parse("1.4.0"):
782
+ default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
783
+
784
+ if "callbacks" in lightning_config:
785
+ callbacks_cfg = lightning_config.callbacks
786
+ else:
787
+ callbacks_cfg = OmegaConf.create()
788
+
789
+ if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
790
+ print(
791
+ "Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
792
+ )
793
+ default_metrics_over_trainsteps_ckpt_dict = {
794
+ "metrics_over_trainsteps_checkpoint": {
795
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
796
+ "params": {
797
+ "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
798
+ "filename": "{epoch:06}-{step:09}",
799
+ "verbose": True,
800
+ "save_top_k": -1,
801
+ "every_n_train_steps": 10000,
802
+ "save_weights_only": True,
803
+ },
804
+ }
805
+ }
806
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
807
+
808
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
809
+ if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None:
810
+ callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path
811
+ elif "ignore_keys_callback" in callbacks_cfg:
812
+ del callbacks_cfg["ignore_keys_callback"]
813
+
814
+ trainer_kwargs["callbacks"] = [
815
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
816
+ ]
817
+ if not "plugins" in trainer_kwargs:
818
+ trainer_kwargs["plugins"] = list()
819
+
820
+ # cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)
821
+ trainer_opt = vars(trainer_opt)
822
+ trainer_kwargs = {
823
+ key: val for key, val in trainer_kwargs.items() if key not in trainer_opt
824
+ }
825
+ trainer = Trainer(**trainer_opt, **trainer_kwargs)
826
+
827
+ trainer.logdir = logdir ###
828
+
829
+ # data
830
+ data = instantiate_from_config(config.data)
831
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
832
+ # calling these ourselves should not be necessary but it is.
833
+ # lightning still takes care of proper multiprocessing though
834
+ data.prepare_data()
835
+ # data.setup()
836
+ print("#### Data #####")
837
+ try:
838
+ for k in data.datasets:
839
+ print(
840
+ f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
841
+ )
842
+ except:
843
+ print("datasets not yet initialized.")
844
+
845
+ # configure learning rate
846
+ if "batch_size" in config.data.params:
847
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
848
+ else:
849
+ bs, base_lr = (
850
+ config.data.params.train.loader.batch_size,
851
+ config.model.base_learning_rate,
852
+ )
853
+ if not cpu:
854
+ ngpu = len(lightning_config.trainer.devices.strip(",").split(","))
855
+ else:
856
+ ngpu = 1
857
+ if "accumulate_grad_batches" in lightning_config.trainer:
858
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
859
+ else:
860
+ accumulate_grad_batches = 1
861
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
862
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
863
+ if opt.scale_lr:
864
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
865
+ print(
866
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
867
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
868
+ )
869
+ )
870
+ else:
871
+ model.learning_rate = base_lr
872
+ print("++++ NOT USING LR SCALING ++++")
873
+ print(f"Setting learning rate to {model.learning_rate:.2e}")
874
+
875
+ # allow checkpointing via USR1
876
+ def melk(*args, **kwargs):
877
+ # run all checkpoint hooks
878
+ if trainer.global_rank == 0:
879
+ print("Summoning checkpoint.")
880
+ if melk_ckpt_name is None:
881
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
882
+ else:
883
+ ckpt_path = os.path.join(ckptdir, melk_ckpt_name)
884
+ trainer.save_checkpoint(ckpt_path)
885
+
886
+ def divein(*args, **kwargs):
887
+ if trainer.global_rank == 0:
888
+ import pudb
889
+
890
+ pudb.set_trace()
891
+
892
+ import signal
893
+
894
+ signal.signal(signal.SIGUSR1, melk)
895
+ signal.signal(signal.SIGUSR2, divein)
896
+
897
+ # run
898
+ if opt.train:
899
+ try:
900
+ trainer.fit(model, data, ckpt_path=ckpt_resume_path)
901
+ except Exception:
902
+ if not opt.debug:
903
+ melk()
904
+ raise
905
+ if not opt.no_test and not trainer.interrupted:
906
+ trainer.test(model, data)
907
+ except RuntimeError as err:
908
+ if MULTINODE_HACKS:
909
+ import datetime
910
+ import os
911
+ import socket
912
+
913
+ import requests
914
+
915
+ device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
916
+ hostname = socket.gethostname()
917
+ ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
918
+ resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id")
919
+ print(
920
+ f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}",
921
+ flush=True,
922
+ )
923
+ raise err
924
+ except Exception:
925
+ if opt.debug and trainer.global_rank == 0:
926
+ try:
927
+ import pudb as debugger
928
+ except ImportError:
929
+ import pdb as debugger
930
+ debugger.post_mortem()
931
+ raise
932
+ finally:
933
+ # move newly created debug project to debug_runs
934
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
935
+ dst, name = os.path.split(logdir)
936
+ dst = os.path.join(dst, "debug_runs", name)
937
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
938
+ os.rename(logdir, dst)
939
+
940
+ if opt.wandb:
941
+ wandb.finish()
942
+ # if trainer.global_rank == 0:
943
+ # print(trainer.profiler.summary())
pyproject.toml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "sgm"
7
+ dynamic = ["version"]
8
+ description = "Stability Generative Models"
9
+ readme = "README.md"
10
+ license-files = { paths = ["LICENSE-CODE"] }
11
+ requires-python = ">=3.8"
12
+
13
+ [project.urls]
14
+ Homepage = "https://github.com/Stability-AI/generative-models"
15
+
16
+ [tool.hatch.version]
17
+ path = "sgm/__init__.py"
18
+
19
+ [tool.hatch.build]
20
+ # This needs to be explicitly set so the configuration files
21
+ # grafted into the `sgm` directory get included in the wheel's
22
+ # RECORD file.
23
+ include = [
24
+ "sgm",
25
+ ]
26
+ # The force-include configurations below make Hatch copy
27
+ # the configs/ directory (containing the various YAML files required
28
+ # to generatively model) into the source distribution and the wheel.
29
+
30
+ [tool.hatch.build.targets.sdist.force-include]
31
+ "./configs" = "sgm/configs"
32
+
33
+ [tool.hatch.build.targets.wheel.force-include]
34
+ "./configs" = "sgm/configs"
35
+
36
+ [tool.hatch.envs.ci]
37
+ skip-install = false
38
+
39
+ dependencies = [
40
+ "pytest"
41
+ ]
42
+
43
+ [tool.hatch.envs.ci.scripts]
44
+ test-inference = [
45
+ "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
46
+ "pip install -r requirements/pt2.txt",
47
+ "pytest -v tests/inference/test_inference.py {args}",
48
+ ]
pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ markers =
3
+ inference: mark as inference test (deselect with '-m "not inference"')