Upload matryoshka.py
Browse files
unet/nesting_level_0/matryoshka.py
CHANGED
@@ -20,6 +20,7 @@
|
|
20 |
|
21 |
|
22 |
import inspect
|
|
|
23 |
import math
|
24 |
from dataclasses import dataclass
|
25 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
@@ -3782,8 +3783,6 @@ class MatryoshkaPipeline(
|
|
3782 |
else:
|
3783 |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3784 |
|
3785 |
-
unet = unet.to(self.device)
|
3786 |
-
|
3787 |
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
3788 |
deprecation_message = (
|
3789 |
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
@@ -3832,6 +3831,9 @@ class MatryoshkaPipeline(
|
|
3832 |
new_config["sample_size"] = 64
|
3833 |
unet._internal_dict = FrozenDict(new_config)
|
3834 |
|
|
|
|
|
|
|
3835 |
self.register_modules(
|
3836 |
text_encoder=text_encoder,
|
3837 |
tokenizer=tokenizer,
|
@@ -3840,10 +3842,32 @@ class MatryoshkaPipeline(
|
|
3840 |
feature_extractor=feature_extractor,
|
3841 |
image_encoder=image_encoder,
|
3842 |
)
|
3843 |
-
|
3844 |
-
scheduler.scales = unet.nest_ratio + [1]
|
3845 |
self.image_processor = VaeImageProcessor(do_resize=False)
|
3846 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3847 |
def encode_prompt(
|
3848 |
self,
|
3849 |
prompt,
|
@@ -4625,9 +4649,12 @@ class MatryoshkaPipeline(
|
|
4625 |
image = latents
|
4626 |
|
4627 |
if self.scheduler.scales is not None:
|
4628 |
-
|
4629 |
-
image[i]
|
4630 |
-
|
|
|
|
|
|
|
4631 |
else:
|
4632 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
4633 |
|
|
|
20 |
|
21 |
|
22 |
import inspect
|
23 |
+
import gc
|
24 |
import math
|
25 |
from dataclasses import dataclass
|
26 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
3783 |
else:
|
3784 |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3785 |
|
|
|
|
|
3786 |
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
3787 |
deprecation_message = (
|
3788 |
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
|
|
3831 |
new_config["sample_size"] = 64
|
3832 |
unet._internal_dict = FrozenDict(new_config)
|
3833 |
|
3834 |
+
if hasattr(unet, "nest_ratio"):
|
3835 |
+
scheduler.scales = unet.nest_ratio + [1]
|
3836 |
+
|
3837 |
self.register_modules(
|
3838 |
text_encoder=text_encoder,
|
3839 |
tokenizer=tokenizer,
|
|
|
3842 |
feature_extractor=feature_extractor,
|
3843 |
image_encoder=image_encoder,
|
3844 |
)
|
3845 |
+
self.register_to_config(nesting_level=nesting_level)
|
|
|
3846 |
self.image_processor = VaeImageProcessor(do_resize=False)
|
3847 |
|
3848 |
+
def change_nesting_level(self, nesting_level: int):
|
3849 |
+
if nesting_level == 0:
|
3850 |
+
if hasattr(self.unet, "nest_ratio"):
|
3851 |
+
self.scheduler.scales = None
|
3852 |
+
self.unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3853 |
+
subfolder="unet/nesting_level_0").to(self.device)
|
3854 |
+
self.config.nesting_level = 0
|
3855 |
+
elif nesting_level == 1:
|
3856 |
+
self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3857 |
+
subfolder="unet/nesting_level_1").to(self.device)
|
3858 |
+
self.config.nesting_level = 1
|
3859 |
+
self.scheduler.scales = self.unet.nest_ratio + [1]
|
3860 |
+
elif nesting_level == 2:
|
3861 |
+
self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3862 |
+
subfolder="unet/nesting_level_2").to(self.device)
|
3863 |
+
self.config.nesting_level = 2
|
3864 |
+
self.scheduler.scales = self.unet.nest_ratio + [1]
|
3865 |
+
else:
|
3866 |
+
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3867 |
+
|
3868 |
+
gc.collect()
|
3869 |
+
torch.cuda.empty_cache()
|
3870 |
+
|
3871 |
def encode_prompt(
|
3872 |
self,
|
3873 |
prompt,
|
|
|
4649 |
image = latents
|
4650 |
|
4651 |
if self.scheduler.scales is not None:
|
4652 |
+
scales = [
|
4653 |
+
image[i].size(-1) / image[-1].size(-1)
|
4654 |
+
for i in range(len(image))
|
4655 |
+
]
|
4656 |
+
for i, (img, scale) in enumerate(zip(image, scales)):
|
4657 |
+
image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0]
|
4658 |
else:
|
4659 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
4660 |
|