|
from typing import Any, Dict, Optional |
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel |
|
from diffusers.schedulers import KarrasDiffusionSchedulers |
|
|
|
import numpy |
|
import torch |
|
import torch.nn as nn |
|
import torch.utils.checkpoint |
|
import torch.distributed |
|
import transformers |
|
from collections import OrderedDict |
|
from PIL import Image |
|
from torchvision import transforms |
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer |
|
|
|
import diffusers |
|
from diffusers import ( |
|
AutoencoderKL, |
|
DDPMScheduler, |
|
DiffusionPipeline, |
|
EulerAncestralDiscreteScheduler, |
|
UNet2DConditionModel, |
|
ImagePipelineOutput |
|
) |
|
from diffusers.image_processor import VaeImageProcessor |
|
from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0 |
|
from diffusers.utils.import_utils import is_xformers_available |
|
|
|
|
|
def to_rgb_image(maybe_rgba: Image.Image): |
|
if maybe_rgba.mode == 'RGB': |
|
return maybe_rgba |
|
elif maybe_rgba.mode == 'RGBA': |
|
rgba = maybe_rgba |
|
img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) |
|
img = Image.fromarray(img, 'RGB') |
|
img.paste(rgba, mask=rgba.getchannel('A')) |
|
return img |
|
else: |
|
raise ValueError("Unsupported image type.", maybe_rgba.mode) |
|
|
|
|
|
class ReferenceOnlyAttnProc(torch.nn.Module): |
|
def __init__( |
|
self, |
|
chained_proc, |
|
enabled=False, |
|
name=None |
|
) -> None: |
|
super().__init__() |
|
self.enabled = enabled |
|
self.chained_proc = chained_proc |
|
self.name = name |
|
|
|
def __call__( |
|
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, |
|
mode="w", ref_dict: dict = None, is_cfg_guidance = False |
|
) -> Any: |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
if self.enabled and is_cfg_guidance: |
|
res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask) |
|
hidden_states = hidden_states[1:] |
|
encoder_hidden_states = encoder_hidden_states[1:] |
|
if self.enabled: |
|
if mode == 'w': |
|
ref_dict[self.name] = encoder_hidden_states |
|
elif mode == 'r': |
|
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) |
|
elif mode == 'm': |
|
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) |
|
else: |
|
assert False, mode |
|
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) |
|
if self.enabled and is_cfg_guidance: |
|
res = torch.cat([res0, res]) |
|
return res |
|
|
|
|
|
class RefOnlyNoisedUNet(torch.nn.Module): |
|
def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None: |
|
super().__init__() |
|
self.unet = unet |
|
self.train_sched = train_sched |
|
self.val_sched = val_sched |
|
|
|
unet_lora_attn_procs = dict() |
|
for name, _ in unet.attn_processors.items(): |
|
if torch.__version__ >= '2.0': |
|
default_attn_proc = AttnProcessor2_0() |
|
elif is_xformers_available(): |
|
default_attn_proc = XFormersAttnProcessor() |
|
else: |
|
default_attn_proc = AttnProcessor() |
|
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( |
|
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name |
|
) |
|
unet.set_attn_processor(unet_lora_attn_procs) |
|
|
|
def __getattr__(self, name: str): |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.unet, name) |
|
|
|
def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): |
|
if is_cfg_guidance: |
|
encoder_hidden_states = encoder_hidden_states[1:] |
|
class_labels = class_labels[1:] |
|
self.unet( |
|
noisy_cond_lat, timestep, |
|
encoder_hidden_states=encoder_hidden_states, |
|
class_labels=class_labels, |
|
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), |
|
**kwargs |
|
) |
|
|
|
def forward( |
|
self, sample, timestep, encoder_hidden_states, class_labels=None, |
|
*args, cross_attention_kwargs, |
|
down_block_res_samples=None, mid_block_res_sample=None, |
|
**kwargs |
|
): |
|
cond_lat = cross_attention_kwargs['cond_lat'] |
|
is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) |
|
noise = torch.randn_like(cond_lat) |
|
if self.training: |
|
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) |
|
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) |
|
else: |
|
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) |
|
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) |
|
ref_dict = {} |
|
self.forward_cond( |
|
noisy_cond_lat, timestep, |
|
encoder_hidden_states, class_labels, |
|
ref_dict, is_cfg_guidance, **kwargs |
|
) |
|
weight_dtype = self.unet.dtype |
|
return self.unet( |
|
sample, timestep, |
|
encoder_hidden_states, *args, |
|
class_labels=class_labels, |
|
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), |
|
down_block_additional_residuals=[ |
|
sample.to(dtype=weight_dtype) for sample in down_block_res_samples |
|
] if down_block_res_samples is not None else None, |
|
mid_block_additional_residual=( |
|
mid_block_res_sample.to(dtype=weight_dtype) |
|
if mid_block_res_sample is not None else None |
|
), |
|
**kwargs |
|
) |
|
|
|
|
|
def scale_latents(latents): |
|
latents = (latents - 0.22) * 0.75 |
|
return latents |
|
|
|
|
|
def unscale_latents(latents): |
|
latents = latents / 0.75 + 0.22 |
|
return latents |
|
|
|
|
|
def scale_image(image): |
|
image = image * 0.5 / 0.8 |
|
return image |
|
|
|
|
|
def unscale_image(image): |
|
image = image / 0.5 * 0.8 |
|
return image |
|
|
|
|
|
class DepthControlUNet(torch.nn.Module): |
|
def __init__(self, unet: RefOnlyNoisedUNet) -> None: |
|
super().__init__() |
|
self.unet = unet |
|
self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet) |
|
DefaultAttnProc = AttnProcessor2_0 |
|
if is_xformers_available(): |
|
DefaultAttnProc = XFormersAttnProcessor |
|
self.controlnet.set_attn_processor(DefaultAttnProc()) |
|
|
|
def __getattr__(self, name: str): |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.unet, name) |
|
|
|
def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs): |
|
cross_attention_kwargs = dict(cross_attention_kwargs) |
|
control_depth = cross_attention_kwargs.pop('control_depth') |
|
down_block_res_samples, mid_block_res_sample = self.controlnet( |
|
sample, |
|
timestep, |
|
encoder_hidden_states=encoder_hidden_states, |
|
controlnet_cond=control_depth, |
|
return_dict=False, |
|
) |
|
return self.unet( |
|
sample, |
|
timestep, |
|
encoder_hidden_states=encoder_hidden_states, |
|
down_block_res_samples=down_block_res_samples, |
|
mid_block_res_sample=mid_block_res_sample, |
|
cross_attention_kwargs=cross_attention_kwargs |
|
) |
|
|
|
|
|
class ModuleListDict(torch.nn.Module): |
|
def __init__(self, procs: dict) -> None: |
|
super().__init__() |
|
self.keys = sorted(procs.keys()) |
|
self.values = torch.nn.ModuleList(procs[k] for k in self.keys) |
|
|
|
def __getitem__(self, key): |
|
return self.values[self.keys.index(key)] |
|
|
|
|
|
class SuperNet(torch.nn.Module): |
|
def __init__(self, state_dict: Dict[str, torch.Tensor]): |
|
super().__init__() |
|
state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys())) |
|
self.layers = torch.nn.ModuleList(state_dict.values()) |
|
self.mapping = dict(enumerate(state_dict.keys())) |
|
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} |
|
|
|
|
|
self.split_keys = [".processor", ".self_attn"] |
|
|
|
|
|
|
|
def map_to(module, state_dict, *args, **kwargs): |
|
new_state_dict = {} |
|
for key, value in state_dict.items(): |
|
num = int(key.split(".")[1]) |
|
new_key = key.replace(f"layers.{num}", module.mapping[num]) |
|
new_state_dict[new_key] = value |
|
|
|
return new_state_dict |
|
|
|
def remap_key(key, state_dict): |
|
for k in self.split_keys: |
|
if k in key: |
|
return key.split(k)[0] + k |
|
return key.split('.')[0] |
|
|
|
def map_from(module, state_dict, *args, **kwargs): |
|
all_keys = list(state_dict.keys()) |
|
for key in all_keys: |
|
replace_key = remap_key(key, state_dict) |
|
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") |
|
state_dict[new_key] = state_dict[key] |
|
del state_dict[key] |
|
|
|
self._register_state_dict_hook(map_to) |
|
self._register_load_state_dict_pre_hook(map_from, with_module=True) |
|
|
|
|
|
class Zero123PlusPipeline(diffusers.StableDiffusionPipeline): |
|
tokenizer: transformers.CLIPTokenizer |
|
text_encoder: transformers.CLIPTextModel |
|
vision_encoder: transformers.CLIPVisionModelWithProjection |
|
|
|
feature_extractor_clip: transformers.CLIPImageProcessor |
|
unet: UNet2DConditionModel |
|
scheduler: diffusers.schedulers.KarrasDiffusionSchedulers |
|
|
|
vae: AutoencoderKL |
|
ramping: nn.Linear |
|
|
|
feature_extractor_vae: transformers.CLIPImageProcessor |
|
|
|
depth_transforms_multi = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]) |
|
]) |
|
|
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
scheduler: KarrasDiffusionSchedulers, |
|
vision_encoder: transformers.CLIPVisionModelWithProjection, |
|
feature_extractor_clip: CLIPImageProcessor, |
|
feature_extractor_vae: CLIPImageProcessor, |
|
ramping_coefficients: Optional[list] = None, |
|
safety_checker=None, |
|
): |
|
DiffusionPipeline.__init__(self) |
|
|
|
self.register_modules( |
|
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, |
|
unet=unet, scheduler=scheduler, safety_checker=None, |
|
vision_encoder=vision_encoder, |
|
feature_extractor_clip=feature_extractor_clip, |
|
feature_extractor_vae=feature_extractor_vae |
|
) |
|
self.register_to_config(ramping_coefficients=ramping_coefficients) |
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
|
|
|
def prepare(self): |
|
train_sched = DDPMScheduler.from_config(self.scheduler.config) |
|
if isinstance(self.unet, UNet2DConditionModel): |
|
self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval() |
|
|
|
def add_controlnet(self): |
|
self.unet = DepthControlUNet(self.unet) |
|
return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)])) |
|
|
|
def encode_condition_image(self, image: torch.Tensor): |
|
image = self.vae.encode(image).latent_dist.sample() |
|
return image |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
image: Image.Image = None, |
|
prompt = "", |
|
*args, |
|
num_images_per_prompt: Optional[int] = 1, |
|
guidance_scale=4.0, |
|
depth_image: Image.Image = None, |
|
output_type: Optional[str] = "pil", |
|
width=640, |
|
height=960, |
|
num_inference_steps=28, |
|
return_dict=True, |
|
**kwargs |
|
): |
|
self.prepare() |
|
if image is None: |
|
raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.") |
|
assert not isinstance(image, torch.Tensor) |
|
image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values |
|
image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values |
|
if depth_image is not None and hasattr(self.unet, "controlnet"): |
|
depth_image = self.depth_transforms_multi(depth_image).to( |
|
device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype |
|
) |
|
image = image_1.to(device=self.vae.device, dtype=self.vae.dtype) |
|
image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype) |
|
cond_lat = self.encode_condition_image(image) |
|
if guidance_scale > 1: |
|
negative_lat = self.encode_condition_image(torch.zeros_like(image)) |
|
cond_lat = torch.cat([negative_lat, cond_lat]) |
|
encoded = self.vision_encoder(image_2, output_hidden_states=False) |
|
global_embeds = encoded.image_embeds |
|
global_embeds = global_embeds.unsqueeze(-2) |
|
|
|
encoder_hidden_states = self._encode_prompt( |
|
prompt, |
|
self.device, |
|
num_images_per_prompt, |
|
False |
|
) |
|
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) |
|
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp |
|
cak = dict(cond_lat=cond_lat) |
|
if hasattr(self.unet, "controlnet"): |
|
cak['control_depth'] = depth_image |
|
latents: torch.Tensor = super().__call__( |
|
None, |
|
*args, |
|
cross_attention_kwargs=cak, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=num_images_per_prompt, |
|
prompt_embeds=encoder_hidden_states, |
|
num_inference_steps=num_inference_steps, |
|
output_type='latent', |
|
width=width, |
|
height=height, |
|
**kwargs |
|
).images |
|
latents = unscale_latents(latents) |
|
if not output_type == "latent": |
|
image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]) |
|
else: |
|
image = latents |
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type) |
|
if not return_dict: |
|
return (image,) |
|
|
|
return ImagePipelineOutput(images=image) |
|
|