Spaces:
Paused
Paused
Fabrice-TIERCELIN
commited on
Commit
•
87c1b97
1
Parent(s):
f616433
Upload 2 files
Browse files- sgm/models/autoencoder.py +335 -0
- sgm/models/diffusion.py +320 -0
sgm/models/autoencoder.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from abc import abstractmethod
|
3 |
+
from contextlib import contextmanager
|
4 |
+
from typing import Any, Dict, Tuple, Union
|
5 |
+
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
from omegaconf import ListConfig
|
9 |
+
from packaging import version
|
10 |
+
from safetensors.torch import load_file as load_safetensors
|
11 |
+
|
12 |
+
from ..modules.diffusionmodules.model import Decoder, Encoder
|
13 |
+
from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
14 |
+
from ..modules.ema import LitEma
|
15 |
+
from ..util import default, get_obj_from_str, instantiate_from_config
|
16 |
+
|
17 |
+
|
18 |
+
class AbstractAutoencoder(pl.LightningModule):
|
19 |
+
"""
|
20 |
+
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
21 |
+
unCLIP models, etc. Hence, it is fairly general, and specific features
|
22 |
+
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
ema_decay: Union[None, float] = None,
|
28 |
+
monitor: Union[None, str] = None,
|
29 |
+
input_key: str = "jpg",
|
30 |
+
ckpt_path: Union[None, str] = None,
|
31 |
+
ignore_keys: Union[Tuple, list, ListConfig] = (),
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.input_key = input_key
|
35 |
+
self.use_ema = ema_decay is not None
|
36 |
+
if monitor is not None:
|
37 |
+
self.monitor = monitor
|
38 |
+
|
39 |
+
if self.use_ema:
|
40 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
41 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
42 |
+
|
43 |
+
if ckpt_path is not None:
|
44 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
45 |
+
|
46 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
47 |
+
self.automatic_optimization = False
|
48 |
+
|
49 |
+
def init_from_ckpt(
|
50 |
+
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
|
51 |
+
) -> None:
|
52 |
+
if path.endswith("ckpt"):
|
53 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
54 |
+
elif path.endswith("safetensors"):
|
55 |
+
sd = load_safetensors(path)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
keys = list(sd.keys())
|
60 |
+
for k in keys:
|
61 |
+
for ik in ignore_keys:
|
62 |
+
if re.match(ik, k):
|
63 |
+
print("Deleting key {} from state_dict.".format(k))
|
64 |
+
del sd[k]
|
65 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
66 |
+
print(
|
67 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
68 |
+
)
|
69 |
+
if len(missing) > 0:
|
70 |
+
print(f"Missing Keys: {missing}")
|
71 |
+
if len(unexpected) > 0:
|
72 |
+
print(f"Unexpected Keys: {unexpected}")
|
73 |
+
|
74 |
+
@abstractmethod
|
75 |
+
def get_input(self, batch) -> Any:
|
76 |
+
raise NotImplementedError()
|
77 |
+
|
78 |
+
def on_train_batch_end(self, *args, **kwargs):
|
79 |
+
# for EMA computation
|
80 |
+
if self.use_ema:
|
81 |
+
self.model_ema(self)
|
82 |
+
|
83 |
+
@contextmanager
|
84 |
+
def ema_scope(self, context=None):
|
85 |
+
if self.use_ema:
|
86 |
+
self.model_ema.store(self.parameters())
|
87 |
+
self.model_ema.copy_to(self)
|
88 |
+
if context is not None:
|
89 |
+
print(f"{context}: Switched to EMA weights")
|
90 |
+
try:
|
91 |
+
yield None
|
92 |
+
finally:
|
93 |
+
if self.use_ema:
|
94 |
+
self.model_ema.restore(self.parameters())
|
95 |
+
if context is not None:
|
96 |
+
print(f"{context}: Restored training weights")
|
97 |
+
|
98 |
+
@abstractmethod
|
99 |
+
def encode(self, *args, **kwargs) -> torch.Tensor:
|
100 |
+
raise NotImplementedError("encode()-method of abstract base class called")
|
101 |
+
|
102 |
+
@abstractmethod
|
103 |
+
def decode(self, *args, **kwargs) -> torch.Tensor:
|
104 |
+
raise NotImplementedError("decode()-method of abstract base class called")
|
105 |
+
|
106 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
107 |
+
print(f"loading >>> {cfg['target']} <<< optimizer from config")
|
108 |
+
return get_obj_from_str(cfg["target"])(
|
109 |
+
params, lr=lr, **cfg.get("params", dict())
|
110 |
+
)
|
111 |
+
|
112 |
+
def configure_optimizers(self) -> Any:
|
113 |
+
raise NotImplementedError()
|
114 |
+
|
115 |
+
|
116 |
+
class AutoencodingEngine(AbstractAutoencoder):
|
117 |
+
"""
|
118 |
+
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
119 |
+
(we also restore them explicitly as special cases for legacy reasons).
|
120 |
+
Regularizations such as KL or VQ are moved to the regularizer class.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
*args,
|
126 |
+
encoder_config: Dict,
|
127 |
+
decoder_config: Dict,
|
128 |
+
loss_config: Dict,
|
129 |
+
regularizer_config: Dict,
|
130 |
+
optimizer_config: Union[Dict, None] = None,
|
131 |
+
lr_g_factor: float = 1.0,
|
132 |
+
**kwargs,
|
133 |
+
):
|
134 |
+
super().__init__(*args, **kwargs)
|
135 |
+
# todo: add options to freeze encoder/decoder
|
136 |
+
self.encoder = instantiate_from_config(encoder_config)
|
137 |
+
self.decoder = instantiate_from_config(decoder_config)
|
138 |
+
self.loss = instantiate_from_config(loss_config)
|
139 |
+
self.regularization = instantiate_from_config(regularizer_config)
|
140 |
+
self.optimizer_config = default(
|
141 |
+
optimizer_config, {"target": "torch.optim.Adam"}
|
142 |
+
)
|
143 |
+
self.lr_g_factor = lr_g_factor
|
144 |
+
|
145 |
+
def get_input(self, batch: Dict) -> torch.Tensor:
|
146 |
+
# assuming unified data format, dataloader returns a dict.
|
147 |
+
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
|
148 |
+
return batch[self.input_key]
|
149 |
+
|
150 |
+
def get_autoencoder_params(self) -> list:
|
151 |
+
params = (
|
152 |
+
list(self.encoder.parameters())
|
153 |
+
+ list(self.decoder.parameters())
|
154 |
+
+ list(self.regularization.get_trainable_parameters())
|
155 |
+
+ list(self.loss.get_trainable_autoencoder_parameters())
|
156 |
+
)
|
157 |
+
return params
|
158 |
+
|
159 |
+
def get_discriminator_params(self) -> list:
|
160 |
+
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
161 |
+
return params
|
162 |
+
|
163 |
+
def get_last_layer(self):
|
164 |
+
return self.decoder.get_last_layer()
|
165 |
+
|
166 |
+
def encode(self, x: Any, return_reg_log: bool = False) -> Any:
|
167 |
+
z = self.encoder(x)
|
168 |
+
z, reg_log = self.regularization(z)
|
169 |
+
if return_reg_log:
|
170 |
+
return z, reg_log
|
171 |
+
return z
|
172 |
+
|
173 |
+
def decode(self, z: Any) -> torch.Tensor:
|
174 |
+
x = self.decoder(z)
|
175 |
+
return x
|
176 |
+
|
177 |
+
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
178 |
+
z, reg_log = self.encode(x, return_reg_log=True)
|
179 |
+
dec = self.decode(z)
|
180 |
+
return z, dec, reg_log
|
181 |
+
|
182 |
+
def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
|
183 |
+
x = self.get_input(batch)
|
184 |
+
z, xrec, regularization_log = self(x)
|
185 |
+
|
186 |
+
if optimizer_idx == 0:
|
187 |
+
# autoencode
|
188 |
+
aeloss, log_dict_ae = self.loss(
|
189 |
+
regularization_log,
|
190 |
+
x,
|
191 |
+
xrec,
|
192 |
+
optimizer_idx,
|
193 |
+
self.global_step,
|
194 |
+
last_layer=self.get_last_layer(),
|
195 |
+
split="train",
|
196 |
+
)
|
197 |
+
|
198 |
+
self.log_dict(
|
199 |
+
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
200 |
+
)
|
201 |
+
return aeloss
|
202 |
+
|
203 |
+
if optimizer_idx == 1:
|
204 |
+
# discriminator
|
205 |
+
discloss, log_dict_disc = self.loss(
|
206 |
+
regularization_log,
|
207 |
+
x,
|
208 |
+
xrec,
|
209 |
+
optimizer_idx,
|
210 |
+
self.global_step,
|
211 |
+
last_layer=self.get_last_layer(),
|
212 |
+
split="train",
|
213 |
+
)
|
214 |
+
self.log_dict(
|
215 |
+
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
216 |
+
)
|
217 |
+
return discloss
|
218 |
+
|
219 |
+
def validation_step(self, batch, batch_idx) -> Dict:
|
220 |
+
log_dict = self._validation_step(batch, batch_idx)
|
221 |
+
with self.ema_scope():
|
222 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
223 |
+
log_dict.update(log_dict_ema)
|
224 |
+
return log_dict
|
225 |
+
|
226 |
+
def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
|
227 |
+
x = self.get_input(batch)
|
228 |
+
|
229 |
+
z, xrec, regularization_log = self(x)
|
230 |
+
aeloss, log_dict_ae = self.loss(
|
231 |
+
regularization_log,
|
232 |
+
x,
|
233 |
+
xrec,
|
234 |
+
0,
|
235 |
+
self.global_step,
|
236 |
+
last_layer=self.get_last_layer(),
|
237 |
+
split="val" + postfix,
|
238 |
+
)
|
239 |
+
|
240 |
+
discloss, log_dict_disc = self.loss(
|
241 |
+
regularization_log,
|
242 |
+
x,
|
243 |
+
xrec,
|
244 |
+
1,
|
245 |
+
self.global_step,
|
246 |
+
last_layer=self.get_last_layer(),
|
247 |
+
split="val" + postfix,
|
248 |
+
)
|
249 |
+
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
250 |
+
log_dict_ae.update(log_dict_disc)
|
251 |
+
self.log_dict(log_dict_ae)
|
252 |
+
return log_dict_ae
|
253 |
+
|
254 |
+
def configure_optimizers(self) -> Any:
|
255 |
+
ae_params = self.get_autoencoder_params()
|
256 |
+
disc_params = self.get_discriminator_params()
|
257 |
+
|
258 |
+
opt_ae = self.instantiate_optimizer_from_config(
|
259 |
+
ae_params,
|
260 |
+
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
261 |
+
self.optimizer_config,
|
262 |
+
)
|
263 |
+
opt_disc = self.instantiate_optimizer_from_config(
|
264 |
+
disc_params, self.learning_rate, self.optimizer_config
|
265 |
+
)
|
266 |
+
|
267 |
+
return [opt_ae, opt_disc], []
|
268 |
+
|
269 |
+
@torch.no_grad()
|
270 |
+
def log_images(self, batch: Dict, **kwargs) -> Dict:
|
271 |
+
log = dict()
|
272 |
+
x = self.get_input(batch)
|
273 |
+
_, xrec, _ = self(x)
|
274 |
+
log["inputs"] = x
|
275 |
+
log["reconstructions"] = xrec
|
276 |
+
with self.ema_scope():
|
277 |
+
_, xrec_ema, _ = self(x)
|
278 |
+
log["reconstructions_ema"] = xrec_ema
|
279 |
+
return log
|
280 |
+
|
281 |
+
|
282 |
+
class AutoencoderKL(AutoencodingEngine):
|
283 |
+
def __init__(self, embed_dim: int, **kwargs):
|
284 |
+
ddconfig = kwargs.pop("ddconfig")
|
285 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
286 |
+
ignore_keys = kwargs.pop("ignore_keys", ())
|
287 |
+
super().__init__(
|
288 |
+
encoder_config={"target": "torch.nn.Identity"},
|
289 |
+
decoder_config={"target": "torch.nn.Identity"},
|
290 |
+
regularizer_config={"target": "torch.nn.Identity"},
|
291 |
+
loss_config=kwargs.pop("lossconfig"),
|
292 |
+
**kwargs,
|
293 |
+
)
|
294 |
+
assert ddconfig["double_z"]
|
295 |
+
self.encoder = Encoder(**ddconfig)
|
296 |
+
self.decoder = Decoder(**ddconfig)
|
297 |
+
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
298 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
299 |
+
self.embed_dim = embed_dim
|
300 |
+
|
301 |
+
if ckpt_path is not None:
|
302 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
303 |
+
|
304 |
+
def encode(self, x):
|
305 |
+
assert (
|
306 |
+
not self.training
|
307 |
+
), f"{self.__class__.__name__} only supports inference currently"
|
308 |
+
h = self.encoder(x)
|
309 |
+
moments = self.quant_conv(h)
|
310 |
+
posterior = DiagonalGaussianDistribution(moments)
|
311 |
+
return posterior
|
312 |
+
|
313 |
+
def decode(self, z, **decoder_kwargs):
|
314 |
+
z = self.post_quant_conv(z)
|
315 |
+
dec = self.decoder(z, **decoder_kwargs)
|
316 |
+
return dec
|
317 |
+
|
318 |
+
|
319 |
+
class AutoencoderKLInferenceWrapper(AutoencoderKL):
|
320 |
+
def encode(self, x):
|
321 |
+
return super().encode(x).sample()
|
322 |
+
|
323 |
+
|
324 |
+
class IdentityFirstStage(AbstractAutoencoder):
|
325 |
+
def __init__(self, *args, **kwargs):
|
326 |
+
super().__init__(*args, **kwargs)
|
327 |
+
|
328 |
+
def get_input(self, x: Any) -> Any:
|
329 |
+
return x
|
330 |
+
|
331 |
+
def encode(self, x: Any, *args, **kwargs) -> Any:
|
332 |
+
return x
|
333 |
+
|
334 |
+
def decode(self, x: Any, *args, **kwargs) -> Any:
|
335 |
+
return x
|
sgm/models/diffusion.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
from typing import Any, Dict, List, Tuple, Union
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
from omegaconf import ListConfig, OmegaConf
|
7 |
+
from safetensors.torch import load_file as load_safetensors
|
8 |
+
from torch.optim.lr_scheduler import LambdaLR
|
9 |
+
|
10 |
+
from ..modules import UNCONDITIONAL_CONFIG
|
11 |
+
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
12 |
+
from ..modules.ema import LitEma
|
13 |
+
from ..util import (
|
14 |
+
default,
|
15 |
+
disabled_train,
|
16 |
+
get_obj_from_str,
|
17 |
+
instantiate_from_config,
|
18 |
+
log_txt_as_img,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class DiffusionEngine(pl.LightningModule):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
network_config,
|
26 |
+
denoiser_config,
|
27 |
+
first_stage_config,
|
28 |
+
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
29 |
+
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
30 |
+
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
31 |
+
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
32 |
+
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
33 |
+
network_wrapper: Union[None, str] = None,
|
34 |
+
ckpt_path: Union[None, str] = None,
|
35 |
+
use_ema: bool = False,
|
36 |
+
ema_decay_rate: float = 0.9999,
|
37 |
+
scale_factor: float = 1.0,
|
38 |
+
disable_first_stage_autocast=False,
|
39 |
+
input_key: str = "jpg",
|
40 |
+
log_keys: Union[List, None] = None,
|
41 |
+
no_cond_log: bool = False,
|
42 |
+
compile_model: bool = False,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.log_keys = log_keys
|
46 |
+
self.input_key = input_key
|
47 |
+
self.optimizer_config = default(
|
48 |
+
optimizer_config, {"target": "torch.optim.AdamW"}
|
49 |
+
)
|
50 |
+
model = instantiate_from_config(network_config)
|
51 |
+
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
52 |
+
model, compile_model=compile_model
|
53 |
+
)
|
54 |
+
|
55 |
+
self.denoiser = instantiate_from_config(denoiser_config)
|
56 |
+
self.sampler = (
|
57 |
+
instantiate_from_config(sampler_config)
|
58 |
+
if sampler_config is not None
|
59 |
+
else None
|
60 |
+
)
|
61 |
+
self.conditioner = instantiate_from_config(
|
62 |
+
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
63 |
+
)
|
64 |
+
self.scheduler_config = scheduler_config
|
65 |
+
self._init_first_stage(first_stage_config)
|
66 |
+
|
67 |
+
self.loss_fn = (
|
68 |
+
instantiate_from_config(loss_fn_config)
|
69 |
+
if loss_fn_config is not None
|
70 |
+
else None
|
71 |
+
)
|
72 |
+
|
73 |
+
self.use_ema = use_ema
|
74 |
+
if self.use_ema:
|
75 |
+
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
76 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
77 |
+
|
78 |
+
self.scale_factor = scale_factor
|
79 |
+
self.disable_first_stage_autocast = disable_first_stage_autocast
|
80 |
+
self.no_cond_log = no_cond_log
|
81 |
+
|
82 |
+
if ckpt_path is not None:
|
83 |
+
self.init_from_ckpt(ckpt_path)
|
84 |
+
|
85 |
+
def init_from_ckpt(
|
86 |
+
self,
|
87 |
+
path: str,
|
88 |
+
) -> None:
|
89 |
+
if path.endswith("ckpt"):
|
90 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
91 |
+
elif path.endswith("safetensors"):
|
92 |
+
sd = load_safetensors(path)
|
93 |
+
else:
|
94 |
+
raise NotImplementedError
|
95 |
+
|
96 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
97 |
+
print(
|
98 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
99 |
+
)
|
100 |
+
if len(missing) > 0:
|
101 |
+
print(f"Missing Keys: {missing}")
|
102 |
+
if len(unexpected) > 0:
|
103 |
+
print(f"Unexpected Keys: {unexpected}")
|
104 |
+
|
105 |
+
def _init_first_stage(self, config):
|
106 |
+
model = instantiate_from_config(config).eval()
|
107 |
+
model.train = disabled_train
|
108 |
+
for param in model.parameters():
|
109 |
+
param.requires_grad = False
|
110 |
+
self.first_stage_model = model
|
111 |
+
|
112 |
+
def get_input(self, batch):
|
113 |
+
# assuming unified data format, dataloader returns a dict.
|
114 |
+
# image tensors should be scaled to -1 ... 1 and in bchw format
|
115 |
+
return batch[self.input_key]
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def decode_first_stage(self, z):
|
119 |
+
z = 1.0 / self.scale_factor * z
|
120 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
121 |
+
out = self.first_stage_model.decode(z)
|
122 |
+
return out
|
123 |
+
|
124 |
+
@torch.no_grad()
|
125 |
+
def encode_first_stage(self, x):
|
126 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
127 |
+
z = self.first_stage_model.encode(x)
|
128 |
+
z = self.scale_factor * z
|
129 |
+
return z
|
130 |
+
|
131 |
+
def forward(self, x, batch):
|
132 |
+
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
|
133 |
+
loss_mean = loss.mean()
|
134 |
+
loss_dict = {"loss": loss_mean}
|
135 |
+
return loss_mean, loss_dict
|
136 |
+
|
137 |
+
def shared_step(self, batch: Dict) -> Any:
|
138 |
+
x = self.get_input(batch)
|
139 |
+
x = self.encode_first_stage(x)
|
140 |
+
batch["global_step"] = self.global_step
|
141 |
+
loss, loss_dict = self(x, batch)
|
142 |
+
return loss, loss_dict
|
143 |
+
|
144 |
+
def training_step(self, batch, batch_idx):
|
145 |
+
loss, loss_dict = self.shared_step(batch)
|
146 |
+
|
147 |
+
self.log_dict(
|
148 |
+
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
149 |
+
)
|
150 |
+
|
151 |
+
self.log(
|
152 |
+
"global_step",
|
153 |
+
self.global_step,
|
154 |
+
prog_bar=True,
|
155 |
+
logger=True,
|
156 |
+
on_step=True,
|
157 |
+
on_epoch=False,
|
158 |
+
)
|
159 |
+
|
160 |
+
# if self.scheduler_config is not None:
|
161 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
162 |
+
self.log(
|
163 |
+
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
164 |
+
)
|
165 |
+
|
166 |
+
return loss
|
167 |
+
|
168 |
+
def on_train_start(self, *args, **kwargs):
|
169 |
+
if self.sampler is None or self.loss_fn is None:
|
170 |
+
raise ValueError("Sampler and loss function need to be set for training.")
|
171 |
+
|
172 |
+
def on_train_batch_end(self, *args, **kwargs):
|
173 |
+
if self.use_ema:
|
174 |
+
self.model_ema(self.model)
|
175 |
+
|
176 |
+
@contextmanager
|
177 |
+
def ema_scope(self, context=None):
|
178 |
+
if self.use_ema:
|
179 |
+
self.model_ema.store(self.model.parameters())
|
180 |
+
self.model_ema.copy_to(self.model)
|
181 |
+
if context is not None:
|
182 |
+
print(f"{context}: Switched to EMA weights")
|
183 |
+
try:
|
184 |
+
yield None
|
185 |
+
finally:
|
186 |
+
if self.use_ema:
|
187 |
+
self.model_ema.restore(self.model.parameters())
|
188 |
+
if context is not None:
|
189 |
+
print(f"{context}: Restored training weights")
|
190 |
+
|
191 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
192 |
+
return get_obj_from_str(cfg["target"])(
|
193 |
+
params, lr=lr, **cfg.get("params", dict())
|
194 |
+
)
|
195 |
+
|
196 |
+
def configure_optimizers(self):
|
197 |
+
lr = self.learning_rate
|
198 |
+
params = list(self.model.parameters())
|
199 |
+
for embedder in self.conditioner.embedders:
|
200 |
+
if embedder.is_trainable:
|
201 |
+
params = params + list(embedder.parameters())
|
202 |
+
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
203 |
+
if self.scheduler_config is not None:
|
204 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
205 |
+
print("Setting up LambdaLR scheduler...")
|
206 |
+
scheduler = [
|
207 |
+
{
|
208 |
+
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
209 |
+
"interval": "step",
|
210 |
+
"frequency": 1,
|
211 |
+
}
|
212 |
+
]
|
213 |
+
return [opt], scheduler
|
214 |
+
return opt
|
215 |
+
|
216 |
+
@torch.no_grad()
|
217 |
+
def sample(
|
218 |
+
self,
|
219 |
+
cond: Dict,
|
220 |
+
uc: Union[Dict, None] = None,
|
221 |
+
batch_size: int = 16,
|
222 |
+
shape: Union[None, Tuple, List] = None,
|
223 |
+
**kwargs,
|
224 |
+
):
|
225 |
+
randn = torch.randn(batch_size, *shape).to(self.device)
|
226 |
+
|
227 |
+
denoiser = lambda input, sigma, c: self.denoiser(
|
228 |
+
self.model, input, sigma, c, **kwargs
|
229 |
+
)
|
230 |
+
samples = self.sampler(denoiser, randn, cond, uc=uc)
|
231 |
+
return samples
|
232 |
+
|
233 |
+
@torch.no_grad()
|
234 |
+
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
235 |
+
"""
|
236 |
+
Defines heuristics to log different conditionings.
|
237 |
+
These can be lists of strings (text-to-image), tensors, ints, ...
|
238 |
+
"""
|
239 |
+
image_h, image_w = batch[self.input_key].shape[2:]
|
240 |
+
log = dict()
|
241 |
+
|
242 |
+
for embedder in self.conditioner.embedders:
|
243 |
+
if (
|
244 |
+
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
245 |
+
) and not self.no_cond_log:
|
246 |
+
x = batch[embedder.input_key][:n]
|
247 |
+
if isinstance(x, torch.Tensor):
|
248 |
+
if x.dim() == 1:
|
249 |
+
# class-conditional, convert integer to string
|
250 |
+
x = [str(x[i].item()) for i in range(x.shape[0])]
|
251 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
252 |
+
elif x.dim() == 2:
|
253 |
+
# size and crop cond and the like
|
254 |
+
x = [
|
255 |
+
"x".join([str(xx) for xx in x[i].tolist()])
|
256 |
+
for i in range(x.shape[0])
|
257 |
+
]
|
258 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
259 |
+
else:
|
260 |
+
raise NotImplementedError()
|
261 |
+
elif isinstance(x, (List, ListConfig)):
|
262 |
+
if isinstance(x[0], str):
|
263 |
+
# strings
|
264 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
265 |
+
else:
|
266 |
+
raise NotImplementedError()
|
267 |
+
else:
|
268 |
+
raise NotImplementedError()
|
269 |
+
log[embedder.input_key] = xc
|
270 |
+
return log
|
271 |
+
|
272 |
+
@torch.no_grad()
|
273 |
+
def log_images(
|
274 |
+
self,
|
275 |
+
batch: Dict,
|
276 |
+
N: int = 8,
|
277 |
+
sample: bool = True,
|
278 |
+
ucg_keys: List[str] = None,
|
279 |
+
**kwargs,
|
280 |
+
) -> Dict:
|
281 |
+
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
282 |
+
if ucg_keys:
|
283 |
+
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
284 |
+
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
285 |
+
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
286 |
+
)
|
287 |
+
else:
|
288 |
+
ucg_keys = conditioner_input_keys
|
289 |
+
log = dict()
|
290 |
+
|
291 |
+
x = self.get_input(batch)
|
292 |
+
|
293 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
294 |
+
batch,
|
295 |
+
force_uc_zero_embeddings=ucg_keys
|
296 |
+
if len(self.conditioner.embedders) > 0
|
297 |
+
else [],
|
298 |
+
)
|
299 |
+
|
300 |
+
sampling_kwargs = {}
|
301 |
+
|
302 |
+
N = min(x.shape[0], N)
|
303 |
+
x = x.to(self.device)[:N]
|
304 |
+
log["inputs"] = x
|
305 |
+
z = self.encode_first_stage(x)
|
306 |
+
log["reconstructions"] = self.decode_first_stage(z)
|
307 |
+
log.update(self.log_conditionings(batch, N))
|
308 |
+
|
309 |
+
for k in c:
|
310 |
+
if isinstance(c[k], torch.Tensor):
|
311 |
+
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
312 |
+
|
313 |
+
if sample:
|
314 |
+
with self.ema_scope("Plotting"):
|
315 |
+
samples = self.sample(
|
316 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
317 |
+
)
|
318 |
+
samples = self.decode_first_stage(samples)
|
319 |
+
log["samples"] = samples
|
320 |
+
return log
|