MotionCtrl / motionctrl /lvdm_modified_modules.py
wzhouxiff's picture
init
f1df74a
raw
history blame
No virus
6.05 kB
import logging
import torch
from einops import rearrange, repeat
from lvdm.models.utils_diffusion import timestep_embedding
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
mainlogger = logging.getLogger('mainlogger')
def TemporalTransformer_forward(self, x, context=None, is_imgbatch=False):
b, c, t, h, w = x.shape
x_in = x
x = self.norm(x)
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
if self.use_linear:
x = self.proj_in(x)
temp_mask = None
if self.causal_attention:
temp_mask = torch.tril(torch.ones([1, t, t]))
if is_imgbatch:
temp_mask = torch.eye(t).unsqueeze(0)
if temp_mask is not None:
mask = temp_mask.to(x.device)
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w)
else:
mask = None
if self.only_self_att:
## note: if no context is given, cross-attention defaults to self-attention
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context, mask=mask)
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
else:
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous()
for i, block in enumerate(self.transformer_blocks):
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
unit_context = context[j][0:1]
context_j = repeat(unit_context, 't l con -> (t r) l con', r=(h * w)).contiguous()
## note: causal mask will not applied in cross-attention case
x[j] = block(x[j], context=context_j)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
if not self.use_linear:
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
x = self.proj_out(x)
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous()
if self.use_image_dataset:
x = 0.0 * x + x_in
else:
x = x + x_in
return x
def selfattn_forward_unet(self, x, timesteps, context=None, y=None, features_adapter=None, is_imgbatch=False, T=None, **kwargs):
b,_,t,_,_ = x.shape
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.micro_condition and y is not None:
micro_emb = timestep_embedding(y, self.model_channels, repeat_only=False)
emb = emb + self.micro_embed(micro_emb)
# pose_emb = pose_emb.reshape(-1, pose_emb.shape[-1])
## repeat t times for context [(b t) 77 768] & time embedding
if not is_imgbatch:
context = context.repeat_interleave(repeats=t, dim=0)
if 'pose_emb' in kwargs:
pose_emb = kwargs.pop('pose_emb')
context = { 'context': context, 'pose_emb': pose_emb }
emb = emb.repeat_interleave(repeats=t, dim=0)
## always in shape (b t) c h w, except for temporal layer
x = rearrange(x, 'b c t h w -> (b t) c h w')
if features_adapter is not None:
features_adapter = [rearrange(feature, 'b c t h w -> (b t) c h w') for feature in features_adapter]
h = x.type(self.dtype)
adapter_idx = 0
hs = []
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b,is_imgbatch=is_imgbatch)
if id ==0 and self.addition_attention:
h = self.init_attn(h, emb, context=context, batch_size=b,is_imgbatch=is_imgbatch)
## plug-in adapter features
if ((id+1)%3 == 0) and features_adapter is not None:
# if adapter_idx == 0 or adapter_idx == 1 or adapter_idx == 2:
h = h + features_adapter[adapter_idx]
adapter_idx += 1
hs.append(h)
if features_adapter is not None:
assert len(features_adapter)==adapter_idx, 'Wrong features_adapter'
h = self.middle_block(h, emb, context=context, batch_size=b, is_imgbatch=is_imgbatch)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context=context, batch_size=b, is_imgbatch=is_imgbatch)
h = h.type(x.dtype)
y = self.out(h)
# reshape back to (b c t h w)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
return y
def spatial_forward_BasicTransformerBlock(self, x, context=None, mask=None):
if isinstance(context, dict):
context = context['context']
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x
def temporal_selfattn_forward_BasicTransformerBlock(self, x, context=None, mask=None):
if isinstance(context, dict) and 'pose_emb' in context:
pose_emb = context['pose_emb'] # {channel_num: [B, video_length, pose_dim, pose_embedding_dim]}
context = None
else:
pose_emb = None
context = None
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x
# Add camera pose
if pose_emb is not None:
B, t, _, _ = pose_emb.shape # [B, video_length, pose_dim, pose_embedding_dim]
hw = x.shape[0] // B
pose_emb = pose_emb.reshape(B, t, -1)
pose_emb = pose_emb.repeat_interleave(repeats=hw, dim=0)
x = self.cc_projection(torch.cat([x, pose_emb], dim=-1))
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x