import logging import torch from einops import rearrange, repeat from lvdm.models.utils_diffusion import timestep_embedding try: import xformers import xformers.ops XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False mainlogger = logging.getLogger('mainlogger') def TemporalTransformer_forward(self, x, context=None, is_imgbatch=False): b, c, t, h, w = x.shape x_in = x x = self.norm(x) x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous() if not self.use_linear: x = self.proj_in(x) x = rearrange(x, 'bhw c t -> bhw t c').contiguous() if self.use_linear: x = self.proj_in(x) temp_mask = None if self.causal_attention: temp_mask = torch.tril(torch.ones([1, t, t])) if is_imgbatch: temp_mask = torch.eye(t).unsqueeze(0) if temp_mask is not None: mask = temp_mask.to(x.device) mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w) else: mask = None if self.only_self_att: ## note: if no context is given, cross-attention defaults to self-attention for i, block in enumerate(self.transformer_blocks): x = block(x, context=context, mask=mask) x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() else: x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous() for i, block in enumerate(self.transformer_blocks): # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) for j in range(b): unit_context = context[j][0:1] context_j = repeat(unit_context, 't l con -> (t r) l con', r=(h * w)).contiguous() ## note: causal mask will not applied in cross-attention case x[j] = block(x[j], context=context_j) if self.use_linear: x = self.proj_out(x) x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous() if not self.use_linear: x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous() x = self.proj_out(x) x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous() if self.use_image_dataset: x = 0.0 * x + x_in else: x = x + x_in return x def selfattn_forward_unet(self, x, timesteps, context=None, y=None, features_adapter=None, is_imgbatch=False, T=None, **kwargs): b,_,t,_,_ = x.shape t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) if self.micro_condition and y is not None: micro_emb = timestep_embedding(y, self.model_channels, repeat_only=False) emb = emb + self.micro_embed(micro_emb) # pose_emb = pose_emb.reshape(-1, pose_emb.shape[-1]) ## repeat t times for context [(b t) 77 768] & time embedding if not is_imgbatch: context = context.repeat_interleave(repeats=t, dim=0) if 'pose_emb' in kwargs: pose_emb = kwargs.pop('pose_emb') context = { 'context': context, 'pose_emb': pose_emb } emb = emb.repeat_interleave(repeats=t, dim=0) ## always in shape (b t) c h w, except for temporal layer x = rearrange(x, 'b c t h w -> (b t) c h w') if features_adapter is not None: features_adapter = [rearrange(feature, 'b c t h w -> (b t) c h w') for feature in features_adapter] h = x.type(self.dtype) adapter_idx = 0 hs = [] for id, module in enumerate(self.input_blocks): h = module(h, emb, context=context, batch_size=b,is_imgbatch=is_imgbatch) if id ==0 and self.addition_attention: h = self.init_attn(h, emb, context=context, batch_size=b,is_imgbatch=is_imgbatch) ## plug-in adapter features if ((id+1)%3 == 0) and features_adapter is not None: # if adapter_idx == 0 or adapter_idx == 1 or adapter_idx == 2: h = h + features_adapter[adapter_idx] adapter_idx += 1 hs.append(h) if features_adapter is not None: assert len(features_adapter)==adapter_idx, 'Wrong features_adapter' h = self.middle_block(h, emb, context=context, batch_size=b, is_imgbatch=is_imgbatch) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb, context=context, batch_size=b, is_imgbatch=is_imgbatch) h = h.type(x.dtype) y = self.out(h) # reshape back to (b c t h w) y = rearrange(y, '(b t) c h w -> b c t h w', b=b) return y def spatial_forward_BasicTransformerBlock(self, x, context=None, mask=None): if isinstance(context, dict): context = context['context'] x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x x = self.attn2(self.norm2(x), context=context, mask=mask) + x x = self.ff(self.norm3(x)) + x return x def temporal_selfattn_forward_BasicTransformerBlock(self, x, context=None, mask=None): if isinstance(context, dict) and 'pose_emb' in context: pose_emb = context['pose_emb'] # {channel_num: [B, video_length, pose_dim, pose_embedding_dim]} context = None else: pose_emb = None context = None x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x # Add camera pose if pose_emb is not None: B, t, _, _ = pose_emb.shape # [B, video_length, pose_dim, pose_embedding_dim] hw = x.shape[0] // B pose_emb = pose_emb.reshape(B, t, -1) pose_emb = pose_emb.repeat_interleave(repeats=hw, dim=0) x = self.cc_projection(torch.cat([x, pose_emb], dim=-1)) x = self.attn2(self.norm2(x), context=context, mask=mask) + x x = self.ff(self.norm3(x)) + x return x