YulianSa's picture
Add application file
ef198e0
import torch
from einops import rearrange
from typing import Any, Dict, Optional
from diffusers.utils.import_utils import is_xformers_available
from canonicalize.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
class ReferenceOnlyAttnProc(torch.nn.Module):
def __init__(
self,
chained_proc,
enabled=False,
name=None
) -> None:
super().__init__()
self.enabled = enabled
self.chained_proc = chained_proc
self.name = name
def __call__(
self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None,
mode="w", ref_dict: dict = None, is_cfg_guidance = False,num_views=4,
multiview_attention=True,
cross_domain_attention=False,
) -> Any:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if self.enabled:
if mode == 'w':
ref_dict[self.name] = encoder_hidden_states
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=1,
multiview_attention=False,
cross_domain_attention=False,)
elif mode == 'r':
encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
if self.name in ref_dict:
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
multiview_attention=False,
cross_domain_attention=False,)
elif mode == 'm':
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
elif mode == 'n':
encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
encoder_hidden_states = torch.cat([encoder_hidden_states], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
multiview_attention=False,
cross_domain_attention=False,)
else:
assert False, mode
else:
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
return res
class RefOnlyNoisedUNet(torch.nn.Module):
def __init__(self, unet, train_sched, val_sched) -> None:
super().__init__()
self.unet = unet
self.train_sched = train_sched
self.val_sched = val_sched
unet_lora_attn_procs = dict()
for name, _ in unet.attn_processors.items():
if is_xformers_available():
default_attn_proc = XFormersMVAttnProcessor()
else:
default_attn_proc = MVAttnProcessor()
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name)
self.unet.set_attn_processor(unet_lora_attn_procs)
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.unet, name)
def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
if is_cfg_guidance:
encoder_hidden_states = encoder_hidden_states[1:]
class_labels = class_labels[1:]
self.unet(
noisy_cond_lat, timestep,
encoder_hidden_states=encoder_hidden_states,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
**kwargs
)
def forward(
self, sample, timestep, encoder_hidden_states, class_labels=None,
*args, cross_attention_kwargs,
down_block_res_samples=None, mid_block_res_sample=None,
**kwargs
):
cond_lat = cross_attention_kwargs['cond_lat']
is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
noise = torch.randn_like(cond_lat)
if self.training:
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
else:
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
ref_dict = {}
self.forward_cond(
noisy_cond_lat, timestep,
encoder_hidden_states, class_labels,
ref_dict, is_cfg_guidance, **kwargs
)
weight_dtype = self.unet.dtype
return self.unet(
sample, timestep,
encoder_hidden_states, *args,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
down_block_additional_residuals=[
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
] if down_block_res_samples is not None else None,
mid_block_additional_residual=(
mid_block_res_sample.to(dtype=weight_dtype)
if mid_block_res_sample is not None else None
),
**kwargs
)