yyk19 commited on
Commit
4da8d8c
·
1 Parent(s): 1f7ae51

archive the files.

Browse files
app.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  import time
9
  from PIL import Image
10
  from cldm.hack import disable_verbosity, enable_sliced_attention
11
- from pytorch_lightning import seed_everything
12
 
13
  def process_multi_wrapper(rendered_txt_0, rendered_txt_1, rendered_txt_2, rendered_txt_3,
14
  shared_prompt,
@@ -87,13 +87,13 @@ def load_ckpt(model_ckpt = "LAION-Glyph-10M-Epoch-5"):
87
  # if model_ckpt == "LAION-Glyph-10M-Epoch-5":
88
  # model = load_model_ckpt(model, "laion10M_epoch_5_model_wo_ema.ckpt")
89
  if model_ckpt == "LAION-Glyph-10M-Epoch-6":
90
- model = load_model_ckpt(model, "laion10M_epoch_6_model_wo_ema.ckpt")
91
  elif model_ckpt == "TextCaps-5K-Epoch-10":
92
- model = load_model_ckpt(model, "textcaps5K_epoch_10_model_wo_ema.ckpt")
93
  elif model_ckpt == "TextCaps-5K-Epoch-20":
94
- model = load_model_ckpt(model, "textcaps5K_epoch_20_model_wo_ema.ckpt")
95
  elif model_ckpt == "TextCaps-5K-Epoch-40":
96
- model = load_model_ckpt(model, "textcaps5K_epoch_40_model_wo_ema.ckpt")
97
 
98
  render_tool = Render_Text(model, save_memory = SAVE_MEMORY)
99
  output_str = f"already change the model checkpoint to {model_ckpt}"
@@ -107,20 +107,11 @@ def load_ckpt(model_ckpt = "LAION-Glyph-10M-Epoch-5"):
107
  return output_str, None, allow_run_generation
108
 
109
  SAVE_MEMORY = False
110
- shared_seed = 0
111
- if shared_seed == -1:
112
- shared_seed = random.randint(0, 65535)
113
- seed_everything(shared_seed)
114
-
115
  disable_verbosity()
116
  if SAVE_MEMORY:
117
  enable_sliced_attention()
118
  cfg = OmegaConf.load("config.yaml")
119
- model = load_model_from_config(cfg, "laion10M_epoch_6_model_wo_ema.ckpt", verbose=True)
120
- # model = load_model_from_config(cfg, "model_wo_ema.ckpt", verbose=True)
121
- # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
122
- # model = load_model_from_config(cfg, "model.ckpt", verbose=True)
123
- # ddim_sampler = DDIMSampler(model)
124
  render_tool = Render_Text(model, save_memory = SAVE_MEMORY)
125
 
126
 
 
8
  import time
9
  from PIL import Image
10
  from cldm.hack import disable_verbosity, enable_sliced_attention
11
+ # from pytorch_lightning import seed_everything
12
 
13
  def process_multi_wrapper(rendered_txt_0, rendered_txt_1, rendered_txt_2, rendered_txt_3,
14
  shared_prompt,
 
87
  # if model_ckpt == "LAION-Glyph-10M-Epoch-5":
88
  # model = load_model_ckpt(model, "laion10M_epoch_5_model_wo_ema.ckpt")
89
  if model_ckpt == "LAION-Glyph-10M-Epoch-6":
90
+ model = load_model_ckpt(model, "checkpoints/laion10M_epoch_6_model_wo_ema.ckpt")
91
  elif model_ckpt == "TextCaps-5K-Epoch-10":
92
+ model = load_model_ckpt(model, "checkpoints/textcaps5K_epoch_10_model_wo_ema.ckpt")
93
  elif model_ckpt == "TextCaps-5K-Epoch-20":
94
+ model = load_model_ckpt(model, "checkpoints/textcaps5K_epoch_20_model_wo_ema.ckpt")
95
  elif model_ckpt == "TextCaps-5K-Epoch-40":
96
+ model = load_model_ckpt(model, "checkpoints/textcaps5K_epoch_40_model_wo_ema.ckpt")
97
 
98
  render_tool = Render_Text(model, save_memory = SAVE_MEMORY)
99
  output_str = f"already change the model checkpoint to {model_ckpt}"
 
107
  return output_str, None, allow_run_generation
108
 
109
  SAVE_MEMORY = False
 
 
 
 
 
110
  disable_verbosity()
111
  if SAVE_MEMORY:
112
  enable_sliced_attention()
113
  cfg = OmegaConf.load("config.yaml")
114
+ model = load_model_from_config(cfg, "checkpoints/laion10M_epoch_6_model_wo_ema.ckpt", verbose=True)
 
 
 
 
115
  render_tool = Render_Text(model, save_memory = SAVE_MEMORY)
116
 
117
 
laion10M_epoch_6_model_wo_ema.ckpt → checkpoints/laion10M_epoch_6_model_wo_ema.ckpt RENAMED
File without changes
textcaps5K_epoch_10_model_wo_ema.ckpt → checkpoints/textcaps5K_epoch_10_model_wo_ema.ckpt RENAMED
File without changes
textcaps5K_epoch_20_model_wo_ema.ckpt → checkpoints/textcaps5K_epoch_20_model_wo_ema.ckpt RENAMED
File without changes
textcaps5K_epoch_40_model_wo_ema.ckpt → checkpoints/textcaps5K_epoch_40_model_wo_ema.ckpt RENAMED
File without changes
cldm/ddim_hacked.py CHANGED
@@ -79,15 +79,7 @@ class DDIMSampler(object):
79
  ):
80
  if conditioning is not None:
81
  if isinstance(conditioning, dict):
82
- # ctmp = conditioning[list(conditioning.keys())[0]]
83
- # while isinstance(ctmp, list): ctmp = ctmp[0]
84
- # cbs = ctmp.shape[0]
85
- # if cbs != batch_size:
86
- # print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
87
- # for ctmp in conditioning.values():
88
  for key, ctmp in conditioning.items():
89
- if key == "c_glyph":
90
- continue
91
  if ctmp is None:
92
  continue
93
  else:
 
79
  ):
80
  if conditioning is not None:
81
  if isinstance(conditioning, dict):
 
 
 
 
 
 
82
  for key, ctmp in conditioning.items():
 
 
83
  if ctmp is None:
84
  continue
85
  else:
config_ema.yaml DELETED
@@ -1,88 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-6 #1.0e-5 #1.0e-4
3
- target: cldm.cldm.ControlLDM
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "jpg"
11
- cond_stage_key: "txt"
12
- control_key: "hint"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false
16
- conditioning_key: crossattn
17
- monitor: #val/loss_simple_ema
18
- scale_factor: 0.18215
19
- only_mid_control: False
20
- sd_locked: True
21
- use_ema: True #TODO: specify
22
-
23
- control_stage_config:
24
- target: cldm.cldm.ControlNet
25
- params:
26
- use_checkpoint: True
27
- image_size: 32 # unused
28
- in_channels: 4
29
- hint_channels: 3
30
- model_channels: 320
31
- attention_resolutions: [ 4, 2, 1 ]
32
- num_res_blocks: 2
33
- channel_mult: [ 1, 2, 4, 4 ]
34
- num_head_channels: 64 # need to fix for flash-attn
35
- use_spatial_transformer: True
36
- use_linear_in_transformer: True
37
- transformer_depth: 1
38
- context_dim: 1024
39
- legacy: False
40
-
41
- unet_config:
42
- target: cldm.cldm.ControlledUnetModel
43
- params:
44
- use_checkpoint: True
45
- image_size: 32 # unused
46
- in_channels: 4
47
- out_channels: 4
48
- model_channels: 320
49
- attention_resolutions: [ 4, 2, 1 ]
50
- num_res_blocks: 2
51
- channel_mult: [ 1, 2, 4, 4 ]
52
- num_head_channels: 64 # need to fix for flash-attn
53
- use_spatial_transformer: True
54
- use_linear_in_transformer: True
55
- transformer_depth: 1
56
- context_dim: 1024
57
- legacy: False
58
-
59
- first_stage_config:
60
- target: ldm.models.autoencoder.AutoencoderKL
61
- params:
62
- embed_dim: 4
63
- monitor: val/rec_loss
64
- ddconfig:
65
- #attn_type: "vanilla-xformers"
66
- double_z: true
67
- z_channels: 4
68
- resolution: 256
69
- in_channels: 3
70
- out_ch: 3
71
- ch: 128
72
- ch_mult:
73
- - 1
74
- - 2
75
- - 4
76
- - 4
77
- num_res_blocks: 2
78
- attn_resolutions: []
79
- dropout: 0.0
80
- lossconfig:
81
- target: torch.nn.Identity
82
-
83
- cond_stage_config:
84
- target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
85
- params:
86
- freeze: True
87
- layer: "penultimate"
88
- # device: "cpu" #TODO: specify
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config_ema_unlock.yaml DELETED
@@ -1,88 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-6 #1.0e-5 #1.0e-4
3
- target: cldm.cldm.ControlLDM
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "jpg"
11
- cond_stage_key: "txt"
12
- control_key: "hint"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false
16
- conditioning_key: crossattn
17
- monitor: #val/loss_simple_ema
18
- scale_factor: 0.18215
19
- only_mid_control: False
20
- sd_locked: False #True
21
- use_ema: True #TODO: specify
22
-
23
- control_stage_config:
24
- target: cldm.cldm.ControlNet
25
- params:
26
- use_checkpoint: True
27
- image_size: 32 # unused
28
- in_channels: 4
29
- hint_channels: 3
30
- model_channels: 320
31
- attention_resolutions: [ 4, 2, 1 ]
32
- num_res_blocks: 2
33
- channel_mult: [ 1, 2, 4, 4 ]
34
- num_head_channels: 64 # need to fix for flash-attn
35
- use_spatial_transformer: True
36
- use_linear_in_transformer: True
37
- transformer_depth: 1
38
- context_dim: 1024
39
- legacy: False
40
-
41
- unet_config:
42
- target: cldm.cldm.ControlledUnetModel
43
- params:
44
- use_checkpoint: True
45
- image_size: 32 # unused
46
- in_channels: 4
47
- out_channels: 4
48
- model_channels: 320
49
- attention_resolutions: [ 4, 2, 1 ]
50
- num_res_blocks: 2
51
- channel_mult: [ 1, 2, 4, 4 ]
52
- num_head_channels: 64 # need to fix for flash-attn
53
- use_spatial_transformer: True
54
- use_linear_in_transformer: True
55
- transformer_depth: 1
56
- context_dim: 1024
57
- legacy: False
58
-
59
- first_stage_config:
60
- target: ldm.models.autoencoder.AutoencoderKL
61
- params:
62
- embed_dim: 4
63
- monitor: val/rec_loss
64
- ddconfig:
65
- #attn_type: "vanilla-xformers"
66
- double_z: true
67
- z_channels: 4
68
- resolution: 256
69
- in_channels: 3
70
- out_ch: 3
71
- ch: 128
72
- ch_mult:
73
- - 1
74
- - 2
75
- - 4
76
- - 4
77
- num_res_blocks: 2
78
- attn_resolutions: []
79
- dropout: 0.0
80
- lossconfig:
81
- target: torch.nn.Identity
82
-
83
- cond_stage_config:
84
- target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
85
- params:
86
- freeze: True
87
- layer: "penultimate"
88
- # device: "cpu" #TODO: specify
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/ldm_autoencoder.py DELETED
@@ -1,443 +0,0 @@
1
- import torch
2
- import pytorch_lightning as pl
3
- import torch.nn.functional as F
4
- from contextlib import contextmanager
5
-
6
- from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
-
8
- from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
- from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
-
11
- from ldm.util import instantiate_from_config
12
-
13
-
14
- class VQModel(pl.LightningModule):
15
- def __init__(self,
16
- ddconfig,
17
- lossconfig,
18
- n_embed,
19
- embed_dim,
20
- ckpt_path=None,
21
- ignore_keys=[],
22
- image_key="image",
23
- colorize_nlabels=None,
24
- monitor=None,
25
- batch_resize_range=None,
26
- scheduler_config=None,
27
- lr_g_factor=1.0,
28
- remap=None,
29
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
- use_ema=False
31
- ):
32
- super().__init__()
33
- self.embed_dim = embed_dim
34
- self.n_embed = n_embed
35
- self.image_key = image_key
36
- self.encoder = Encoder(**ddconfig)
37
- self.decoder = Decoder(**ddconfig)
38
- self.loss = instantiate_from_config(lossconfig)
39
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
- remap=remap,
41
- sane_index_shape=sane_index_shape)
42
- self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
- if colorize_nlabels is not None:
45
- assert type(colorize_nlabels)==int
46
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
- if monitor is not None:
48
- self.monitor = monitor
49
- self.batch_resize_range = batch_resize_range
50
- if self.batch_resize_range is not None:
51
- print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
-
53
- self.use_ema = use_ema
54
- if self.use_ema:
55
- self.model_ema = LitEma(self)
56
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
-
58
- if ckpt_path is not None:
59
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
- self.scheduler_config = scheduler_config
61
- self.lr_g_factor = lr_g_factor
62
-
63
- @contextmanager
64
- def ema_scope(self, context=None):
65
- if self.use_ema:
66
- self.model_ema.store(self.parameters())
67
- self.model_ema.copy_to(self)
68
- if context is not None:
69
- print(f"{context}: Switched to EMA weights")
70
- try:
71
- yield None
72
- finally:
73
- if self.use_ema:
74
- self.model_ema.restore(self.parameters())
75
- if context is not None:
76
- print(f"{context}: Restored training weights")
77
-
78
- def init_from_ckpt(self, path, ignore_keys=list()):
79
- sd = torch.load(path, map_location="cpu")["state_dict"]
80
- keys = list(sd.keys())
81
- for k in keys:
82
- for ik in ignore_keys:
83
- if k.startswith(ik):
84
- print("Deleting key {} from state_dict.".format(k))
85
- del sd[k]
86
- missing, unexpected = self.load_state_dict(sd, strict=False)
87
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
- if len(missing) > 0:
89
- print(f"Missing Keys: {missing}")
90
- print(f"Unexpected Keys: {unexpected}")
91
-
92
- def on_train_batch_end(self, *args, **kwargs):
93
- if self.use_ema:
94
- self.model_ema(self)
95
-
96
- def encode(self, x):
97
- h = self.encoder(x)
98
- h = self.quant_conv(h)
99
- quant, emb_loss, info = self.quantize(h)
100
- return quant, emb_loss, info
101
-
102
- def encode_to_prequant(self, x):
103
- h = self.encoder(x)
104
- h = self.quant_conv(h)
105
- return h
106
-
107
- def decode(self, quant):
108
- quant = self.post_quant_conv(quant)
109
- dec = self.decoder(quant)
110
- return dec
111
-
112
- def decode_code(self, code_b):
113
- quant_b = self.quantize.embed_code(code_b)
114
- dec = self.decode(quant_b)
115
- return dec
116
-
117
- def forward(self, input, return_pred_indices=False):
118
- quant, diff, (_,_,ind) = self.encode(input)
119
- dec = self.decode(quant)
120
- if return_pred_indices:
121
- return dec, diff, ind
122
- return dec, diff
123
-
124
- def get_input(self, batch, k):
125
- x = batch[k]
126
- if len(x.shape) == 3:
127
- x = x[..., None]
128
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
- if self.batch_resize_range is not None:
130
- lower_size = self.batch_resize_range[0]
131
- upper_size = self.batch_resize_range[1]
132
- if self.global_step <= 4:
133
- # do the first few batches with max size to avoid later oom
134
- new_resize = upper_size
135
- else:
136
- new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
- if new_resize != x.shape[2]:
138
- x = F.interpolate(x, size=new_resize, mode="bicubic")
139
- x = x.detach()
140
- return x
141
-
142
- def training_step(self, batch, batch_idx, optimizer_idx):
143
- # https://github.com/pytorch/pytorch/issues/37142
144
- # try not to fool the heuristics
145
- x = self.get_input(batch, self.image_key)
146
- xrec, qloss, ind = self(x, return_pred_indices=True)
147
-
148
- if optimizer_idx == 0:
149
- # autoencode
150
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
- last_layer=self.get_last_layer(), split="train",
152
- predicted_indices=ind)
153
-
154
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
- return aeloss
156
-
157
- if optimizer_idx == 1:
158
- # discriminator
159
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
- last_layer=self.get_last_layer(), split="train")
161
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
- return discloss
163
-
164
- def validation_step(self, batch, batch_idx):
165
- log_dict = self._validation_step(batch, batch_idx)
166
- with self.ema_scope():
167
- log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
- return log_dict
169
-
170
- def _validation_step(self, batch, batch_idx, suffix=""):
171
- x = self.get_input(batch, self.image_key)
172
- xrec, qloss, ind = self(x, return_pred_indices=True)
173
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
- self.global_step,
175
- last_layer=self.get_last_layer(),
176
- split="val"+suffix,
177
- predicted_indices=ind
178
- )
179
-
180
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
- self.global_step,
182
- last_layer=self.get_last_layer(),
183
- split="val"+suffix,
184
- predicted_indices=ind
185
- )
186
- rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
- self.log(f"val{suffix}/rec_loss", rec_loss,
188
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
- self.log(f"val{suffix}/aeloss", aeloss,
190
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
- if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
- del log_dict_ae[f"val{suffix}/rec_loss"]
193
- self.log_dict(log_dict_ae)
194
- self.log_dict(log_dict_disc)
195
- return self.log_dict
196
-
197
- def configure_optimizers(self):
198
- lr_d = self.learning_rate
199
- lr_g = self.lr_g_factor*self.learning_rate
200
- print("lr_d", lr_d)
201
- print("lr_g", lr_g)
202
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
- list(self.decoder.parameters())+
204
- list(self.quantize.parameters())+
205
- list(self.quant_conv.parameters())+
206
- list(self.post_quant_conv.parameters()),
207
- lr=lr_g, betas=(0.5, 0.9))
208
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
- lr=lr_d, betas=(0.5, 0.9))
210
-
211
- if self.scheduler_config is not None:
212
- scheduler = instantiate_from_config(self.scheduler_config)
213
-
214
- print("Setting up LambdaLR scheduler...")
215
- scheduler = [
216
- {
217
- 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
- 'interval': 'step',
219
- 'frequency': 1
220
- },
221
- {
222
- 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
- 'interval': 'step',
224
- 'frequency': 1
225
- },
226
- ]
227
- return [opt_ae, opt_disc], scheduler
228
- return [opt_ae, opt_disc], []
229
-
230
- def get_last_layer(self):
231
- return self.decoder.conv_out.weight
232
-
233
- def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
- log = dict()
235
- x = self.get_input(batch, self.image_key)
236
- x = x.to(self.device)
237
- if only_inputs:
238
- log["inputs"] = x
239
- return log
240
- xrec, _ = self(x)
241
- if x.shape[1] > 3:
242
- # colorize with random projection
243
- assert xrec.shape[1] > 3
244
- x = self.to_rgb(x)
245
- xrec = self.to_rgb(xrec)
246
- log["inputs"] = x
247
- log["reconstructions"] = xrec
248
- if plot_ema:
249
- with self.ema_scope():
250
- xrec_ema, _ = self(x)
251
- if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
- log["reconstructions_ema"] = xrec_ema
253
- return log
254
-
255
- def to_rgb(self, x):
256
- assert self.image_key == "segmentation"
257
- if not hasattr(self, "colorize"):
258
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
- x = F.conv2d(x, weight=self.colorize)
260
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
- return x
262
-
263
-
264
- class VQModelInterface(VQModel):
265
- def __init__(self, embed_dim, *args, **kwargs):
266
- super().__init__(embed_dim=embed_dim, *args, **kwargs)
267
- self.embed_dim = embed_dim
268
-
269
- def encode(self, x):
270
- h = self.encoder(x)
271
- h = self.quant_conv(h)
272
- return h
273
-
274
- def decode(self, h, force_not_quantize=False):
275
- # also go through quantization layer
276
- if not force_not_quantize:
277
- quant, emb_loss, info = self.quantize(h)
278
- else:
279
- quant = h
280
- quant = self.post_quant_conv(quant)
281
- dec = self.decoder(quant)
282
- return dec
283
-
284
-
285
- class AutoencoderKL(pl.LightningModule):
286
- def __init__(self,
287
- ddconfig,
288
- lossconfig,
289
- embed_dim,
290
- ckpt_path=None,
291
- ignore_keys=[],
292
- image_key="image",
293
- colorize_nlabels=None,
294
- monitor=None,
295
- ):
296
- super().__init__()
297
- self.image_key = image_key
298
- self.encoder = Encoder(**ddconfig)
299
- self.decoder = Decoder(**ddconfig)
300
- self.loss = instantiate_from_config(lossconfig)
301
- assert ddconfig["double_z"]
302
- self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304
- self.embed_dim = embed_dim
305
- if colorize_nlabels is not None:
306
- assert type(colorize_nlabels)==int
307
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308
- if monitor is not None:
309
- self.monitor = monitor
310
- if ckpt_path is not None:
311
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312
-
313
- def init_from_ckpt(self, path, ignore_keys=list()):
314
- sd = torch.load(path, map_location="cpu")["state_dict"]
315
- keys = list(sd.keys())
316
- for k in keys:
317
- for ik in ignore_keys:
318
- if k.startswith(ik):
319
- print("Deleting key {} from state_dict.".format(k))
320
- del sd[k]
321
- self.load_state_dict(sd, strict=False)
322
- print(f"Restored from {path}")
323
-
324
- def encode(self, x):
325
- h = self.encoder(x)
326
- moments = self.quant_conv(h)
327
- posterior = DiagonalGaussianDistribution(moments)
328
- return posterior
329
-
330
- def decode(self, z):
331
- z = self.post_quant_conv(z)
332
- dec = self.decoder(z)
333
- return dec
334
-
335
- def forward(self, input, sample_posterior=True):
336
- posterior = self.encode(input)
337
- if sample_posterior:
338
- z = posterior.sample()
339
- else:
340
- z = posterior.mode()
341
- dec = self.decode(z)
342
- return dec, posterior
343
-
344
- def get_input(self, batch, k):
345
- x = batch[k]
346
- if len(x.shape) == 3:
347
- x = x[..., None]
348
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
349
- return x
350
-
351
- def training_step(self, batch, batch_idx, optimizer_idx):
352
- inputs = self.get_input(batch, self.image_key)
353
- reconstructions, posterior = self(inputs)
354
-
355
- if optimizer_idx == 0:
356
- # train encoder+decoder+logvar
357
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
358
- last_layer=self.get_last_layer(), split="train")
359
- self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
360
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
361
- return aeloss
362
-
363
- if optimizer_idx == 1:
364
- # train the discriminator
365
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
366
- last_layer=self.get_last_layer(), split="train")
367
-
368
- self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
369
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
370
- return discloss
371
-
372
- def validation_step(self, batch, batch_idx):
373
- inputs = self.get_input(batch, self.image_key)
374
- reconstructions, posterior = self(inputs)
375
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376
- last_layer=self.get_last_layer(), split="val")
377
-
378
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379
- last_layer=self.get_last_layer(), split="val")
380
-
381
- self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382
- self.log_dict(log_dict_ae)
383
- self.log_dict(log_dict_disc)
384
- return self.log_dict
385
-
386
- def configure_optimizers(self):
387
- lr = self.learning_rate
388
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389
- list(self.decoder.parameters())+
390
- list(self.quant_conv.parameters())+
391
- list(self.post_quant_conv.parameters()),
392
- lr=lr, betas=(0.5, 0.9))
393
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394
- lr=lr, betas=(0.5, 0.9))
395
- return [opt_ae, opt_disc], []
396
-
397
- def get_last_layer(self):
398
- return self.decoder.conv_out.weight
399
-
400
- @torch.no_grad()
401
- def log_images(self, batch, only_inputs=False, **kwargs):
402
- log = dict()
403
- x = self.get_input(batch, self.image_key)
404
- x = x.to(self.device)
405
- if not only_inputs:
406
- xrec, posterior = self(x)
407
- if x.shape[1] > 3:
408
- # colorize with random projection
409
- assert xrec.shape[1] > 3
410
- x = self.to_rgb(x)
411
- xrec = self.to_rgb(xrec)
412
- log["samples"] = self.decode(torch.randn_like(posterior.sample()))
413
- log["reconstructions"] = xrec
414
- log["inputs"] = x
415
- return log
416
-
417
- def to_rgb(self, x):
418
- assert self.image_key == "segmentation"
419
- if not hasattr(self, "colorize"):
420
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
421
- x = F.conv2d(x, weight=self.colorize)
422
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
423
- return x
424
-
425
-
426
- class IdentityFirstStage(torch.nn.Module):
427
- def __init__(self, *args, vq_interface=False, **kwargs):
428
- self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429
- super().__init__()
430
-
431
- def encode(self, x, *args, **kwargs):
432
- return x
433
-
434
- def decode(self, x, *args, **kwargs):
435
- return x
436
-
437
- def quantize(self, x, *args, **kwargs):
438
- if self.vq_interface:
439
- return x, None, [None, None, None]
440
- return x
441
-
442
- def forward(self, x, *args, **kwargs):
443
- return x