tolgacangoz commited on
Commit
ca01bef
·
verified ·
1 Parent(s): 04b9c24

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. unet/nesting_level_0/matryoshka.py +34 -7
unet/nesting_level_0/matryoshka.py CHANGED
@@ -20,6 +20,7 @@
20
 
21
 
22
  import inspect
 
23
  import math
24
  from dataclasses import dataclass
25
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -3782,8 +3783,6 @@ class MatryoshkaPipeline(
3782
  else:
3783
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3784
 
3785
- unet = unet.to(self.device)
3786
-
3787
  if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
3788
  deprecation_message = (
3789
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
@@ -3832,6 +3831,9 @@ class MatryoshkaPipeline(
3832
  new_config["sample_size"] = 64
3833
  unet._internal_dict = FrozenDict(new_config)
3834
 
 
 
 
3835
  self.register_modules(
3836
  text_encoder=text_encoder,
3837
  tokenizer=tokenizer,
@@ -3840,10 +3842,32 @@ class MatryoshkaPipeline(
3840
  feature_extractor=feature_extractor,
3841
  image_encoder=image_encoder,
3842
  )
3843
- if hasattr(unet, "nest_ratio"):
3844
- scheduler.scales = unet.nest_ratio + [1]
3845
  self.image_processor = VaeImageProcessor(do_resize=False)
3846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3847
  def encode_prompt(
3848
  self,
3849
  prompt,
@@ -4625,9 +4649,12 @@ class MatryoshkaPipeline(
4625
  image = latents
4626
 
4627
  if self.scheduler.scales is not None:
4628
- for i in range(len(image)):
4629
- image[i] = image[i] * self.scheduler.scales[i]
4630
- image[i] = self.image_processor.postprocess(image[i], output_type=output_type)
 
 
 
4631
  else:
4632
  image = self.image_processor.postprocess(image, output_type=output_type)
4633
 
 
20
 
21
 
22
  import inspect
23
+ import gc
24
  import math
25
  from dataclasses import dataclass
26
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
3783
  else:
3784
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3785
 
 
 
3786
  if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
3787
  deprecation_message = (
3788
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
 
3831
  new_config["sample_size"] = 64
3832
  unet._internal_dict = FrozenDict(new_config)
3833
 
3834
+ if hasattr(unet, "nest_ratio"):
3835
+ scheduler.scales = unet.nest_ratio + [1]
3836
+
3837
  self.register_modules(
3838
  text_encoder=text_encoder,
3839
  tokenizer=tokenizer,
 
3842
  feature_extractor=feature_extractor,
3843
  image_encoder=image_encoder,
3844
  )
3845
+ self.register_to_config(nesting_level=nesting_level)
 
3846
  self.image_processor = VaeImageProcessor(do_resize=False)
3847
 
3848
+ def change_nesting_level(self, nesting_level: int):
3849
+ if nesting_level == 0:
3850
+ if hasattr(self.unet, "nest_ratio"):
3851
+ self.scheduler.scales = None
3852
+ self.unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3853
+ subfolder="unet/nesting_level_0").to(self.device)
3854
+ self.config.nesting_level = 0
3855
+ elif nesting_level == 1:
3856
+ self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3857
+ subfolder="unet/nesting_level_1").to(self.device)
3858
+ self.config.nesting_level = 1
3859
+ self.scheduler.scales = self.unet.nest_ratio + [1]
3860
+ elif nesting_level == 2:
3861
+ self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3862
+ subfolder="unet/nesting_level_2").to(self.device)
3863
+ self.config.nesting_level = 2
3864
+ self.scheduler.scales = self.unet.nest_ratio + [1]
3865
+ else:
3866
+ raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3867
+
3868
+ gc.collect()
3869
+ torch.cuda.empty_cache()
3870
+
3871
  def encode_prompt(
3872
  self,
3873
  prompt,
 
4649
  image = latents
4650
 
4651
  if self.scheduler.scales is not None:
4652
+ scales = [
4653
+ image[i].size(-1) / image[-1].size(-1)
4654
+ for i in range(len(image))
4655
+ ]
4656
+ for i, (img, scale) in enumerate(zip(image, scales)):
4657
+ image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0]
4658
  else:
4659
  image = self.image_processor.postprocess(image, output_type=output_type)
4660