Alpha-Romeo commited on
Commit
7d28380
1 Parent(s): b0afe49

add cond stage to trainable parameters

Browse files
ControlNet/ControlNet.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
ControlNet/cldm/cldm.py CHANGED
@@ -2,6 +2,8 @@ import einops
2
  import torch
3
  import torch as th
4
  import torch.nn as nn
 
 
5
  from torchvision.transforms import Resize
6
 
7
  from ldm.modules.diffusionmodules.util import (
@@ -305,12 +307,15 @@ class ControlNet(nn.Module):
305
 
306
  class ControlInpaintLDM(LatentDiffusion):
307
 
308
- def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
309
  super().__init__(*args, **kwargs)
310
  self.control_model = instantiate_from_config(control_stage_config)
311
  self.control_key = control_key
312
  self.only_mid_control = only_mid_control
313
  self.control_scales = [1.0] * 13
 
 
 
314
 
315
  @torch.no_grad()
316
  def get_input(self, batch, k, bs=None, *args, **kwargs):
@@ -380,6 +385,7 @@ class ControlInpaintLDM(LatentDiffusion):
380
 
381
  if self.cond_stage_trainable:
382
  c = self.get_learned_conditioning(c)
 
383
 
384
  if sample:
385
  # get denoise row
@@ -412,15 +418,38 @@ class ControlInpaintLDM(LatentDiffusion):
412
  shape = (self.channels, h // 8, w // 8)
413
  samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
414
  return samples, intermediates
415
-
416
  def configure_optimizers(self):
417
  lr = self.learning_rate
418
  params = list(self.control_model.parameters())
 
 
 
 
 
419
  if not self.sd_locked:
420
  params += list(self.model.diffusion_model.output_blocks.parameters())
421
  params += list(self.model.diffusion_model.out.parameters())
422
- opt = torch.optim.AdamW(params, lr=lr)
 
 
423
  return opt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  def low_vram_shift(self, is_diffusing):
426
  if is_diffusing:
 
2
  import torch
3
  import torch as th
4
  import torch.nn as nn
5
+ import random
6
+ import bitsandbytes as bnb
7
  from torchvision.transforms import Resize
8
 
9
  from ldm.modules.diffusionmodules.util import (
 
307
 
308
  class ControlInpaintLDM(LatentDiffusion):
309
 
310
+ def __init__(self, control_stage_config, control_key, u_cond_percent, only_mid_control, *args, **kwargs):
311
  super().__init__(*args, **kwargs)
312
  self.control_model = instantiate_from_config(control_stage_config)
313
  self.control_key = control_key
314
  self.only_mid_control = only_mid_control
315
  self.control_scales = [1.0] * 13
316
+ self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
317
+ self.proj_out=nn.Linear(1024, 768)
318
+ self.u_cond_percent=u_cond_percent
319
 
320
  @torch.no_grad()
321
  def get_input(self, batch, k, bs=None, *args, **kwargs):
 
385
 
386
  if self.cond_stage_trainable:
387
  c = self.get_learned_conditioning(c)
388
+ c = self.proj_out(c)
389
 
390
  if sample:
391
  # get denoise row
 
418
  shape = (self.channels, h // 8, w // 8)
419
  samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
420
  return samples, intermediates
421
+
422
  def configure_optimizers(self):
423
  lr = self.learning_rate
424
  params = list(self.control_model.parameters())
425
+ if self.cond_stage_trainable:
426
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
427
+ params = params + list(self.cond_stage_model.final_ln.parameters())+list(self.cond_stage_model.mapper.parameters())+list(self.proj_out.parameters())
428
+ self.params = params
429
+ self.params_with_white=params + list(self.learnable_vector)
430
  if not self.sd_locked:
431
  params += list(self.model.diffusion_model.output_blocks.parameters())
432
  params += list(self.model.diffusion_model.out.parameters())
433
+ #opt = torch.optim.AdamW(params, lr=lr)
434
+ opt = bnb.optim.Adam8bit(params,lr=lr)
435
+ self.opt=opt
436
  return opt
437
+
438
+ def forward(self, x, c, *args, **kwargs):
439
+ self.opt.params=self.params
440
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
441
+ if self.model.conditioning_key is not None:
442
+ assert c is not None
443
+ if self.cond_stage_trainable:
444
+ c['c_crossattn'][0] = self.get_learned_conditioning(c['c_crossattn'][0])
445
+ c['c_crossattn'][0] = self.proj_out(c['c_crossattn'][0])
446
+ u_cond_prop=random.uniform(0, 1)
447
+ if u_cond_prop<self.u_cond_percent:
448
+ self.opt.params=self.params_with_white
449
+ c['c_crossattn'][0] = self.learnable_vector.repeat(x.shape[0],1,1)
450
+ return self.p_losses(x, c, t, *args, **kwargs)
451
+ return self.p_losses(x, c, t, *args, **kwargs)
452
+
453
 
454
  def low_vram_shift(self, is_diffusing):
455
  if is_diffusing:
ControlNet/environment.yaml CHANGED
@@ -1,12 +1,13 @@
1
  name: control
2
  channels:
3
  - pytorch
 
4
  - defaults
5
  dependencies:
6
  - python=3.8.5
7
  - pip=20.3
8
  - cudatoolkit=11.3
9
- - pytorch=1.12.1
10
  - torchvision=0.13.1
11
  - numpy=1.23.1
12
  - pip:
@@ -36,4 +37,5 @@ dependencies:
36
  - ipdb==0.13.11
37
  - ipython==8.11.0
38
  - ipykernel==6.21.2
 
39
 
 
1
  name: control
2
  channels:
3
  - pytorch
4
+ - anaconda
5
  - defaults
6
  dependencies:
7
  - python=3.8.5
8
  - pip=20.3
9
  - cudatoolkit=11.3
10
+ - pytorch=1.13.1
11
  - torchvision=0.13.1
12
  - numpy=1.23.1
13
  - pip:
 
37
  - ipdb==0.13.11
38
  - ipython==8.11.0
39
  - ipykernel==6.21.2
40
+ - bitsandbytes==0.37.1
41
 
ControlNet/ldm/models/diffusion/ddpm.py CHANGED
@@ -552,8 +552,6 @@ class LatentDiffusion(DDPM):
552
  reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
553
  ignore_keys = kwargs.pop("ignore_keys", [])
554
  super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
555
- self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
556
- self.u_cond_percent=u_cond_percent
557
  self.concat_mode = concat_mode
558
  self.cond_stage_trainable = cond_stage_trainable
559
  self.cond_stage_key = cond_stage_key
 
552
  reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
553
  ignore_keys = kwargs.pop("ignore_keys", [])
554
  super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
 
 
555
  self.concat_mode = concat_mode
556
  self.cond_stage_trainable = cond_stage_trainable
557
  self.cond_stage_key = cond_stage_key
ControlNet/ldm/modules/encoders/modules.py CHANGED
@@ -137,7 +137,6 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
137
  super().__init__()
138
  self.transformer = CLIPVisionModel.from_pretrained(version)
139
  self.final_ln = LayerNorm(1024)
140
- self.proj_out=nn.Linear(1024, 768)
141
  self.mapper = Transformer(
142
  1,
143
  1024,
@@ -162,7 +161,6 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
162
  z = z.unsqueeze(1)
163
  z = self.mapper(z)
164
  z = self.final_ln(z)
165
- z = self.proj_out(z)
166
  return z
167
 
168
  def encode(self, image):
 
137
  super().__init__()
138
  self.transformer = CLIPVisionModel.from_pretrained(version)
139
  self.final_ln = LayerNorm(1024)
 
140
  self.mapper = Transformer(
141
  1,
142
  1024,
 
161
  z = z.unsqueeze(1)
162
  z = self.mapper(z)
163
  z = self.final_ln(z)
 
164
  return z
165
 
166
  def encode(self, image):