Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
from einops import rearrange | |
from typing import Any, Dict, Optional | |
from diffusers.utils.import_utils import is_xformers_available | |
from canonicalize.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor | |
class ReferenceOnlyAttnProc(torch.nn.Module): | |
def __init__( | |
self, | |
chained_proc, | |
enabled=False, | |
name=None | |
) -> None: | |
super().__init__() | |
self.enabled = enabled | |
self.chained_proc = chained_proc | |
self.name = name | |
def __call__( | |
self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, | |
mode="w", ref_dict: dict = None, is_cfg_guidance = False,num_views=4, | |
multiview_attention=True, | |
cross_domain_attention=False, | |
) -> Any: | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
if self.enabled: | |
if mode == 'w': | |
ref_dict[self.name] = encoder_hidden_states | |
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=1, | |
multiview_attention=False, | |
cross_domain_attention=False,) | |
elif mode == 'r': | |
encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views) | |
if self.name in ref_dict: | |
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) | |
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views, | |
multiview_attention=False, | |
cross_domain_attention=False,) | |
elif mode == 'm': | |
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) | |
elif mode == 'n': | |
encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views) | |
encoder_hidden_states = torch.cat([encoder_hidden_states], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) | |
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views, | |
multiview_attention=False, | |
cross_domain_attention=False,) | |
else: | |
assert False, mode | |
else: | |
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) | |
return res | |
class RefOnlyNoisedUNet(torch.nn.Module): | |
def __init__(self, unet, train_sched, val_sched) -> None: | |
super().__init__() | |
self.unet = unet | |
self.train_sched = train_sched | |
self.val_sched = val_sched | |
unet_lora_attn_procs = dict() | |
for name, _ in unet.attn_processors.items(): | |
if is_xformers_available(): | |
default_attn_proc = XFormersMVAttnProcessor() | |
else: | |
default_attn_proc = MVAttnProcessor() | |
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( | |
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name) | |
self.unet.set_attn_processor(unet_lora_attn_procs) | |
def __getattr__(self, name: str): | |
try: | |
return super().__getattr__(name) | |
except AttributeError: | |
return getattr(self.unet, name) | |
def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): | |
if is_cfg_guidance: | |
encoder_hidden_states = encoder_hidden_states[1:] | |
class_labels = class_labels[1:] | |
self.unet( | |
noisy_cond_lat, timestep, | |
encoder_hidden_states=encoder_hidden_states, | |
class_labels=class_labels, | |
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), | |
**kwargs | |
) | |
def forward( | |
self, sample, timestep, encoder_hidden_states, class_labels=None, | |
*args, cross_attention_kwargs, | |
down_block_res_samples=None, mid_block_res_sample=None, | |
**kwargs | |
): | |
cond_lat = cross_attention_kwargs['cond_lat'] | |
is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) | |
noise = torch.randn_like(cond_lat) | |
if self.training: | |
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) | |
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) | |
else: | |
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) | |
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) | |
ref_dict = {} | |
self.forward_cond( | |
noisy_cond_lat, timestep, | |
encoder_hidden_states, class_labels, | |
ref_dict, is_cfg_guidance, **kwargs | |
) | |
weight_dtype = self.unet.dtype | |
return self.unet( | |
sample, timestep, | |
encoder_hidden_states, *args, | |
class_labels=class_labels, | |
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), | |
down_block_additional_residuals=[ | |
sample.to(dtype=weight_dtype) for sample in down_block_res_samples | |
] if down_block_res_samples is not None else None, | |
mid_block_additional_residual=( | |
mid_block_res_sample.to(dtype=weight_dtype) | |
if mid_block_res_sample is not None else None | |
), | |
**kwargs | |
) |