| | import torch |
| | from torch import nn |
| | import einops |
| | from typing import Tuple |
| | import random |
| | import numpy as np |
| | from tqdm import tqdm |
| | from .modules import DuoFrameDownEncoder,Upsampler,MapConv,MotionDownEncoder |
| | from .loss import l1,l2 |
| | from .transformer import (MotionTransformer, |
| | AMDDiffusionTransformerModel, |
| | MotionEncoderLearnTokenTransformer, |
| | AMDReconstructTransformerModel, |
| | AMDDiffusionTransformerModelDualStream, |
| | AMDDiffusionTransformerModelImgSpatial, |
| | AMDDiffusionTransformerModelImgSpatialDoubleRef, |
| | AMDReconstructTransformerModelSpatial) |
| | from .rectified_flow import RectifiedFlow |
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.models.modeling_utils import ModelMixin |
| | from diffusers.models.resnet import ResnetBlock2D |
| | import einops |
| | import torch.nn.functional as F |
| |
|
| | from diffusers.utils import export_to_gif |
| |
|
| | class AMDModel(ModelMixin, ConfigMixin): |
| | _supports_gradient_checkpointing = True |
| | |
| | @register_to_config |
| | def __init__(self, |
| | image_inchannel :int = 4, |
| | image_height :int = 32, |
| | image_width :int = 32, |
| | video_frames :int = 16, |
| | scheduler_num_step :int = 1000, |
| | |
| | |
| | motion_token_num:int = 12, |
| | motion_token_channel: int = 128, |
| | enc_num_layers:int = 8, |
| | enc_nhead:int = 8, |
| | enc_ndim:int = 64, |
| | enc_dropout:float = 0.0, |
| | motion_need_norm_out:bool = False, |
| | |
| | |
| | need_motion_transformer :bool = False, |
| | motion_transformer_attn_head_dim:int = 64, |
| | motion_transformer_attn_num_heads:int = 16, |
| | motion_transformer_num_layers:int = 4, |
| | |
| | |
| | diffusion_model_type : str = 'default', |
| | diffusion_attn_head_dim : int = 64, |
| | diffusion_attn_num_heads : int = 16, |
| | diffusion_out_channels : int = 4, |
| | diffusion_num_layers : int = 16, |
| | image_patch_size : int = 2, |
| | motion_patch_size : int = 1, |
| | motion_drop_ratio: float = 0.0, |
| | refimg_drop: bool = False, |
| | |
| | |
| | extract_motion_with_motion_transformer = False, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | self.num_step = scheduler_num_step |
| | self.scheduler = RectifiedFlow(num_steps=scheduler_num_step) |
| | self.need_motion_transformer = need_motion_transformer |
| | self.extract_motion_with_motion_transformer = extract_motion_with_motion_transformer |
| | self.diffusion_model_type = diffusion_model_type |
| | self.target_frame = video_frames |
| | self.refimg_drop = refimg_drop |
| |
|
| | |
| | self.motion_encoder = MotionEncoderLearnTokenTransformer(img_height = image_height, |
| | img_width=image_width, |
| | img_inchannel=image_inchannel, |
| | img_patch_size = image_patch_size, |
| | motion_token_num = motion_token_num, |
| | motion_channel = motion_token_channel, |
| | need_norm_out = motion_need_norm_out, |
| | |
| | num_attention_heads=enc_nhead, |
| | attention_head_dim=enc_ndim, |
| | num_layers=enc_num_layers, |
| | dropout=enc_dropout, |
| | attention_bias= True,) |
| |
|
| | |
| | if need_motion_transformer: |
| | self.motion_transformer = MotionTransformer(motion_token_num=motion_token_num, |
| | motion_token_channel=motion_token_channel, |
| | attention_head_dim=motion_transformer_attn_head_dim, |
| | num_attention_heads=motion_transformer_attn_num_heads, |
| | num_layers=motion_transformer_num_layers,) |
| | |
| | |
| |
|
| | if diffusion_model_type == 'default': |
| | dit_image_inchannel = image_inchannel * 2 |
| | self.diffusion_transformer = AMDDiffusionTransformerModel(num_attention_heads= diffusion_attn_num_heads, |
| | attention_head_dim= diffusion_attn_head_dim, |
| | out_channels = diffusion_out_channels, |
| | num_layers= diffusion_num_layers, |
| | |
| | image_width= image_width, |
| | image_height= image_height, |
| | image_patch_size= image_patch_size, |
| | image_in_channels = dit_image_inchannel, |
| | |
| | motion_token_num = motion_token_num, |
| | motion_in_channels = motion_token_channel,) |
| | elif diffusion_model_type == 'dual': |
| | dit_image_inchannel = image_inchannel * 2 |
| | self.diffusion_transformer = AMDDiffusionTransformerModelDualStream(num_attention_heads= diffusion_attn_num_heads, |
| | attention_head_dim= diffusion_attn_head_dim, |
| | out_channels = diffusion_out_channels, |
| | num_layers= diffusion_num_layers, |
| | |
| | image_width= image_width, |
| | image_height= image_height, |
| | image_patch_size= image_patch_size, |
| | image_in_channels = dit_image_inchannel, |
| | |
| | motion_token_num = motion_token_num, |
| | motion_in_channels = motion_token_channel, |
| | motion_target_num_frame = video_frames) |
| | elif diffusion_model_type == 'spatial': |
| | dit_image_inchannel = image_inchannel * 2 |
| | self.diffusion_transformer = AMDDiffusionTransformerModelImgSpatial(num_attention_heads= diffusion_attn_num_heads, |
| | attention_head_dim= diffusion_attn_head_dim, |
| | out_channels = diffusion_out_channels, |
| | num_layers= diffusion_num_layers, |
| | |
| | image_width= image_width, |
| | image_height= image_height, |
| | image_patch_size= image_patch_size, |
| | image_in_channels = dit_image_inchannel, |
| | |
| | motion_token_num = motion_token_num, |
| | motion_in_channels = motion_token_channel, |
| | motion_target_num_frame = video_frames) |
| | elif diffusion_model_type == 'doubleref': |
| | dit_image_inchannel = image_inchannel |
| | self.diffusion_transformer = AMDDiffusionTransformerModelImgSpatialDoubleRef(num_attention_heads= diffusion_attn_num_heads, |
| | attention_head_dim= diffusion_attn_head_dim, |
| | out_channels = diffusion_out_channels, |
| | num_layers= diffusion_num_layers, |
| | |
| | image_width= image_width, |
| | image_height= image_height, |
| | image_patch_size= image_patch_size, |
| | image_in_channels = dit_image_inchannel, |
| | |
| | motion_token_num = motion_token_num, |
| | motion_in_channels = motion_token_channel, |
| | motion_target_num_frame = video_frames) |
| | else: |
| | raise IndexError |
| |
|
| | def forward(self, |
| | video:torch.tensor, |
| | ref_img:torch.Tensor , |
| | randomref_img:torch.Tensor = None, |
| | time_step:torch.tensor = None, |
| | return_meta_info=False, |
| | mask_ratio=None, |
| | **kwargs,): |
| | """ |
| | Args: |
| | video: (N,T,C,H,W) |
| | ref_img: (N,T,C,H,W) |
| | randomref_img : (N,T,C,H,W) |
| | """ |
| |
|
| | device = video.device |
| | n,t,c,h,w = video.shape |
| |
|
| | assert video.shape == ref_img.shape ,f'video.shape:{video.shape}should be equal to ref_img.shape:{ref_img.shape}' |
| | if self.diffusion_model_type == 'doubleref' : |
| | assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given" |
| | |
| | |
| | if mask_ratio is not None: |
| | mask_ratio = torch.rand(1).item() * mask_ratio |
| | |
| | if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
| | if randomref_img.dim()==4: |
| | randomref_img = randomref_img.unsqueeze(1).repeat(1,t,1,1,1) |
| | refimg_and_video = torch.cat([randomref_img,video],dim=1) |
| | else: |
| | refimg_and_video = torch.cat([ref_img,video],dim=1) |
| | motion = self.motion_encoder(refimg_and_video,mask_ratio) |
| |
|
| | source_motion = motion[:,:t].flatten(0,1) |
| | target_motion = motion[:,t:].flatten(0,1) |
| |
|
| | assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
| |
|
| | |
| | if self.need_motion_transformer: |
| | target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
| | target_motion = self.motion_transformer(target_motion) |
| | target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
| |
|
| | |
| | |
| | zi = ref_img.flatten(0,1) |
| | zj = video.flatten(0,1) |
| | if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
| | randomref_img = randomref_img.flatten(0,1) |
| |
|
| | if time_step is None: |
| | time_step = self.prepare_timestep(batch_size= zj.shape[0],device= device) |
| | if self.diffusion_model_type != 'default': |
| | time_step = self.prepare_timestep(batch_size= n,device= device) |
| | time_step = time_step.repeat_interleave(t) |
| | zt,vel = self.scheduler.get_train_tuple(z1=zj,time_step=time_step) |
| | |
| | |
| | if self.refimg_drop: |
| | zi = torch.zeros_like(zi).to(video.device) |
| | image_hidden_states = torch.cat((zi,zt),dim=1) |
| | |
| |
|
| | pre = self.diffusion_transformer(motion_source_hidden_states = source_motion, |
| | motion_target_hidden_states = target_motion, |
| | image_hidden_states = image_hidden_states, |
| | randomref_image_hidden_states = randomref_img, |
| | timestep = time_step,) |
| |
|
| | |
| | diff_loss = l2(pre,vel) |
| |
|
| | rec_zj = self.scheduler.get_target_with_zt_vel(zt,pre,time_step) |
| | rec_loss = l2(rec_zj,zj) |
| |
|
| | loss = diff_loss |
| |
|
| | loss_dict = {'loss':loss,'diff_loss':diff_loss,'rec_loss':rec_loss} |
| | |
| | if return_meta_info: |
| | return {'motion' : motion, |
| | 'zi' : zi, |
| | 'zj' : zj, |
| | 'zt' : zt, |
| | 'gt' : vel, |
| | 'pre': pre, |
| | 'time_step': time_step, |
| | } |
| | else: |
| | return pre,vel,loss_dict |
| | def get_noise_latent_pair(self, |
| | video:torch.Tensor, |
| | ref_img:torch.Tensor , |
| | randomref_img:torch.Tensor, |
| | sample_step:int = 50, |
| | ): |
| | pass |
| |
|
| | @torch.no_grad() |
| | def sample(self,video:torch.Tensor, |
| | ref_img:torch.Tensor , |
| | randomref_img:torch.Tensor = None, |
| | sample_step:int = 50, |
| | mask_ratio = None, |
| | start_step:int = None, |
| | return_meta_info=False, |
| | **kwargs,): |
| |
|
| | device = video.device |
| | n,t,c,h,w = video.shape |
| |
|
| | if start_step is None: |
| | start_step = self.scheduler.num_step |
| | assert start_step <= self.scheduler.num_step , 'start_step cant be larger than scheduler.num_step' |
| |
|
| | if self.diffusion_model_type == 'doubleref' : |
| | assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given" |
| |
|
| | if ref_img.dim()==4: |
| | ref_img = ref_img.unsqueeze(1).repeat(1,t,1,1,1) |
| |
|
| | |
| | if mask_ratio is not None: |
| | print(f'* Sampling with Mask_Ratio = {mask_ratio}') |
| | mask_ratio = mask_ratio |
| | |
| | if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
| | if randomref_img.dim()==4: |
| | randomref_img = randomref_img.unsqueeze(1).repeat(1,t,1,1,1) |
| | refimg_and_video = torch.cat([randomref_img,video],dim=1) |
| | else: |
| | refimg_and_video = torch.cat([ref_img,video],dim=1) |
| |
|
| | motion = self.motion_encoder(refimg_and_video,mask_ratio) |
| |
|
| | source_motion = motion[:,:t].flatten(0,1) |
| | target_motion = motion[:,t:].flatten(0,1) |
| |
|
| | assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
| |
|
| | |
| | if self.need_motion_transformer: |
| | target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
| | target_motion = self.motion_transformer(target_motion) |
| | target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
| | |
| | |
| | time_step = torch.ones((source_motion.shape[0],)).to(device) |
| | time_step = time_step * start_step |
| |
|
| | zi = ref_img.flatten(0,1) |
| | zj = video.flatten(0,1) |
| | if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
| | randomref_img = randomref_img.flatten(0,1) |
| | zt,vel = self.scheduler.get_train_tuple(z1=zj,time_step=time_step) |
| | noise = zj - vel |
| | |
| | pre_cache = [] |
| | sample_cache = [] |
| | |
| | |
| | step_seq = np.linspace(0, start_step, num=sample_step+1, endpoint=True,dtype=int) |
| | step_seq = list(reversed(step_seq[1:])) |
| |
|
| | |
| | dt = 1./sample_step |
| |
|
| | if self.refimg_drop: |
| | zi = torch.zeros_like(zi).to(video.device) |
| |
|
| | for i in tqdm(step_seq): |
| | |
| | time_step = torch.ones((zt.shape[0],)).to(zt.device) |
| | time_step = time_step * i |
| | |
| | |
| | zt = zt.to(video.dtype) |
| | image_hidden_states = torch.cat((zi,zt),dim=1) |
| | |
| | |
| | pre = self.diffusion_transformer(motion_source_hidden_states = source_motion, |
| | motion_target_hidden_states = target_motion, |
| | image_hidden_states = image_hidden_states, |
| | randomref_image_hidden_states = randomref_img, |
| | timestep = time_step,) |
| | zt = zt + pre * dt |
| | pre_cache.append(pre) |
| | sample_cache.append(zt) |
| | |
| | zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n) |
| | zt = einops.rearrange(zt,'(n t) c h w -> n t c h w',n=n) |
| | zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n) |
| | |
| | if return_meta_info: |
| | return {'zi' : zi, |
| | 'zj' : zj, |
| | 'sample' : zt, |
| | 'pre_cache' : pre_cache, |
| | 'sample_cache' : sample_cache, |
| | 'step_seq' : step_seq, |
| | 'motion' : target_motion, |
| | "noise" : noise |
| | } |
| | else: |
| | return zi,zt,zj |
| |
|
| | @torch.no_grad() |
| | def sample_with_refimg_motion(self, |
| | ref_img:torch.Tensor, |
| | motion=torch.Tensor, |
| | randomref_img:torch.Tensor = None, |
| | sample_step:int = 10, |
| | mask_ratio = None, |
| | return_meta_info=False, |
| | **kwargs,): |
| | """ |
| | Args: |
| | ref_img : (N,C,H,W) |
| | randomref_img : (N,C,H,W) |
| | motion : (N,F,L,D) |
| | Return: |
| | video : (N,T,C,H,W) |
| | """ |
| | device = motion.device |
| | n,t,l,d = motion.shape |
| |
|
| | start_step = self.scheduler.num_step |
| | |
| | |
| | refimg = ref_img.unsqueeze(1) |
| | if self.diffusion_model_type == 'doubleref' : |
| | assert randomref_img is not None, "when diffusion_model_type == doubleref, randomref_img should be given" |
| | |
| | if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
| | print('* Warnning * diffusion_model_type:doubleref') |
| | if randomref_img.dim()==4: |
| | randomref_img = randomref_img.unsqueeze(1) |
| | source_motion = self.motion_encoder(randomref_img,mask_ratio) |
| | else: |
| | source_motion = self.motion_encoder(refimg,mask_ratio) |
| |
|
| | source_motion = source_motion.repeat(1,t,1,1).flatten(0,1) |
| | target_motion = motion.flatten(0,1) |
| |
|
| | assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
| |
|
| | |
| | if self.need_motion_transformer and not self.extract_motion_with_motion_transformer: |
| | target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
| | target_motion = self.motion_transformer(target_motion) |
| | target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
| | |
| | |
| | time_step = torch.ones((source_motion.shape[0],)).to(device) |
| | time_step = time_step * start_step |
| |
|
| | zi = refimg.repeat(1,t,1,1,1).flatten(0,1) |
| | zj = zi |
| | if self.diffusion_model_type == 'doubleref' and randomref_img is not None: |
| | randomref_img = randomref_img.repeat(1,t,1,1,1) |
| | randomref_img = randomref_img.flatten(0,1) |
| | |
| | zt = torch.randn_like(zj) |
| |
|
| | |
| | pre_cache = [] |
| | sample_cache = [] |
| | |
| | |
| | step_seq = np.linspace(0, start_step, num=sample_step+1, endpoint=True,dtype=int) |
| | step_seq = list(reversed(step_seq[1:])) |
| |
|
| | |
| | dt = 1./sample_step |
| |
|
| | if self.refimg_drop: |
| | zi = torch.zeros_like(zi).to(ref_img.device) |
| |
|
| | for i in tqdm(step_seq): |
| | |
| | time_step = torch.ones((zt.shape[0],)).to(zt.device) |
| | time_step = time_step * i |
| | |
| | |
| | zt = zt.to(ref_img.dtype) |
| | image_hidden_states = torch.cat((zi,zt),dim=1) |
| | |
| | |
| | pre = self.diffusion_transformer(motion_source_hidden_states = source_motion, |
| | motion_target_hidden_states = target_motion, |
| | image_hidden_states = image_hidden_states, |
| | randomref_image_hidden_states = randomref_img, |
| | timestep = time_step,) |
| |
|
| | zt = zt + pre * dt |
| | |
| | |
| | zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n,t=t) |
| | zt = einops.rearrange(zt,'(n t) c h w -> n t c h w',n=n,t=t) |
| | |
| | if return_meta_info: |
| | return {'zi' : zi, |
| | 'zj' : zj, |
| | 'sample' : zt, |
| | 'pre_cache' : pre_cache, |
| | 'sample_cache' : sample_cache, |
| | 'step_seq' : step_seq, |
| | 'motion' : target_motion, |
| | } |
| | else: |
| | return zi,zt,zj |
| |
|
| | |
| |
|
| | def extract_motion(self,video:torch.tensor,mask_ratio=None): |
| | |
| | n,t,c,h,w = video.shape |
| | |
| | motion = self.motion_encoder(video,mask_ratio) |
| |
|
| | if self.need_motion_transformer and self.extract_motion_with_motion_transformer: |
| | motion = self.motion_transformer(motion) |
| |
|
| | return motion |
| | |
| | def prepare_timestep(self,batch_size:int,device,time_step = None): |
| | if time_step is not None: |
| | return time_step.to(device) |
| | else: |
| | return torch.randint(0,self.num_step+1,(batch_size,)).to(device) |
| | |
| | def prepare_encoder_input(self,video:torch.tensor): |
| | assert len(video.shape) == 5 , f'only support video data : 5D tensor , but got {video.shape}' |
| | |
| | |
| | pre = video[:,:-1,:,:,:] |
| | post= video[:,1:,:,:,:] |
| | duo_frame_mix = torch.cat([pre,post],dim=2) |
| | duo_frame_mix = einops.rearrange(duo_frame_mix,'b t c h w -> (b t) c h w') |
| | |
| | return duo_frame_mix |
| |
|
| |
|
| | def unpatchify(self, x ,patch_size): |
| | """ |
| | x: (N, S, patch_size**2 *C) |
| | imgs: (N, C, H, W) |
| | """ |
| | p = patch_size |
| | h = w = int(x.shape[1]**.5) |
| | |
| | c = x.shape[2] // (p**2) |
| | assert h * w == x.shape[1] |
| | |
| | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) |
| | x = torch.einsum('nhwpqc->nchpwq', x) |
| | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) |
| | return imgs |
| |
|
| | def reset_infer_num_frame(self, num:int): |
| | old_num = self.diffusion_transformer.target_frame |
| | self.diffusion_transformer.target_frame = num |
| | print(f'* Reset infer frame from {old_num} to {self.diffusion_transformer.target_frame} *') |
| |
|
| |
|
| | class AMDModel_Rec(ModelMixin, ConfigMixin): |
| | _supports_gradient_checkpointing = True |
| | |
| | @register_to_config |
| | def __init__(self, |
| | image_inchannel :int = 4, |
| | image_height :int = 32, |
| | image_width :int = 32, |
| | video_frames :int = 16, |
| | scheduler_num_step :int = 1000, |
| | |
| | |
| | motion_token_num:int = 12, |
| | motion_token_channel: int = 128, |
| | enc_num_layers:int = 8, |
| | enc_nhead:int = 8, |
| | enc_ndim:int = 64, |
| | enc_dropout:float = 0.0, |
| | motion_need_norm_out:bool = True, |
| | |
| | |
| | need_motion_transformer :bool = False, |
| | motion_transformer_attn_head_dim:int = 64, |
| | motion_transformer_attn_num_heads:int = 16, |
| | motion_transformer_num_layers:int = 4, |
| | |
| | |
| | diffusion_model_type : str = 'default', |
| | diffusion_attn_head_dim : int = 64, |
| | diffusion_attn_num_heads : int = 16, |
| | diffusion_out_channels : int = 4, |
| | diffusion_num_layers : int = 16, |
| | image_patch_size : int = 2, |
| | motion_patch_size : int = 1, |
| | motion_drop_ratio: float = 0.0, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | self.num_step = scheduler_num_step |
| | self.scheduler = RectifiedFlow(num_steps=scheduler_num_step) |
| | self.need_motion_transformer = need_motion_transformer |
| |
|
| | |
| | INIT_CONST = 0.02 |
| | self.zt_token = nn.Parameter(torch.randn(1, image_inchannel, image_height,image_width) * INIT_CONST) |
| |
|
| | |
| | self.motion_encoder = MotionEncoderLearnTokenTransformer(img_height = image_height, |
| | img_width=image_width, |
| | img_inchannel=image_inchannel, |
| | img_patch_size = image_patch_size, |
| | motion_token_num = motion_token_num, |
| | motion_channel = motion_token_channel, |
| | need_norm_out = motion_need_norm_out, |
| | |
| | num_attention_heads=enc_nhead, |
| | attention_head_dim=enc_ndim, |
| | num_layers=enc_num_layers, |
| | dropout=enc_dropout, |
| | attention_bias= True,) |
| |
|
| | |
| | if need_motion_transformer: |
| | self.motion_transformer = MotionTransformer(motion_token_num=motion_token_num, |
| | motion_token_channel=motion_token_channel, |
| | attention_head_dim=motion_transformer_attn_head_dim, |
| | num_attention_heads=motion_transformer_attn_num_heads, |
| | num_layers=motion_transformer_num_layers,) |
| | |
| | |
| | if diffusion_model_type == 'default': |
| | dit_image_inchannel = image_inchannel * 2 |
| | self.transformer = AMDReconstructTransformerModel(num_attention_heads= diffusion_attn_num_heads, |
| | attention_head_dim= diffusion_attn_head_dim, |
| | out_channels = diffusion_out_channels, |
| | num_layers= diffusion_num_layers, |
| | |
| | image_width= image_width, |
| | image_height= image_height, |
| | image_patch_size= image_patch_size, |
| | image_in_channels = dit_image_inchannel, |
| | |
| | motion_token_num = motion_token_num, |
| | motion_in_channels = motion_token_channel,) |
| | elif diffusion_model_type == 'spatial': |
| | dit_image_inchannel = image_inchannel * 2 |
| | self.transformer = AMDReconstructTransformerModelSpatial(num_attention_heads= diffusion_attn_num_heads, |
| | attention_head_dim= diffusion_attn_head_dim, |
| | out_channels = diffusion_out_channels, |
| | num_layers= diffusion_num_layers, |
| | |
| | image_width= image_width, |
| | image_height= image_height, |
| | image_patch_size= image_patch_size, |
| | image_in_channels = dit_image_inchannel, |
| | |
| | motion_token_num = motion_token_num, |
| | motion_in_channels = motion_token_channel, |
| | motion_target_num_frame = video_frames) |
| |
|
| | def forward(self, |
| | video:torch.tensor, |
| | ref_img:torch.Tensor , |
| | time_step:torch.tensor = None, |
| | return_meta_info=False, |
| | **kwargs,): |
| | """ |
| | Args: |
| | video: (N,T,C,H,W) |
| | ref_img: (N,T,C,H,W) |
| | """ |
| |
|
| | device = video.device |
| | n,t,c,h,w = video.shape |
| |
|
| | assert video.shape == ref_img.shape ,f'video.shape:{video.shape}should be equal to ref_img.shape:{ref_img.shape}' |
| | |
| | |
| | refimg_and_video = torch.cat([ref_img,video],dim=1) |
| | motion = self.motion_encoder(refimg_and_video) |
| |
|
| | source_motion = motion[:,:t].flatten(0,1) |
| | target_motion = motion[:,t:].flatten(0,1) |
| |
|
| | assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
| |
|
| | |
| | if self.need_motion_transformer: |
| | target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
| | target_motion = self.motion_transformer(target_motion) |
| | target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
| |
|
| | |
| | |
| | zi = ref_img.flatten(0,1) |
| | zj = video.flatten(0,1) |
| | zt = self.zt_token.repeat(zj.shape[0],1,1,1) |
| | |
| | |
| | image_hidden_states = torch.cat((zi,zt),dim=1) |
| | pre = self.transformer(motion_source_hidden_states = source_motion, |
| | motion_target_hidden_states = target_motion, |
| | image_hidden_states = image_hidden_states,) |
| |
|
| | |
| | rec_loss = l2(pre,zj) |
| |
|
| | loss = rec_loss |
| |
|
| | loss_dict = {'loss':loss,'rec_loss':rec_loss} |
| | |
| | if return_meta_info: |
| | return {'motion' : motion, |
| | 'zi' : zi, |
| | 'zj' : zj, |
| | 'zt' : zt, |
| | 'pre': pre, |
| | 'time_step': time_step, |
| | } |
| | else: |
| | return pre,zj,loss_dict |
| |
|
| | @torch.no_grad() |
| | def sample(self, |
| | video:torch.tensor, |
| | ref_img:torch.Tensor , |
| | sample_step:int = 50, |
| | start_step:int = None, |
| | return_meta_info=False, |
| | **kwargs,): |
| |
|
| | device = video.device |
| | n,t,c,h,w = video.shape |
| |
|
| | if start_step is None: |
| | start_step = self.scheduler.num_step |
| | assert start_step <= self.scheduler.num_step , 'start_step cant be larger than scheduler.num_step' |
| |
|
| | |
| | refimg_and_video = torch.cat([ref_img,video],dim=1) |
| | motion = self.motion_encoder(refimg_and_video) |
| |
|
| | source_motion = motion[:,:t].flatten(0,1) |
| | target_motion = motion[:,t:].flatten(0,1) |
| |
|
| | assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
| |
|
| | |
| | if self.need_motion_transformer: |
| | target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
| | target_motion = self.motion_transformer(target_motion) |
| | target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
| | |
| | zi = ref_img.flatten(0,1) |
| | zj = video.flatten(0,1) |
| | zt = self.zt_token.repeat(zj.shape[0],1,1,1) |
| | |
| |
|
| | |
| | zt = zt.to(video.dtype) |
| | image_hidden_states = torch.cat((zi,zt),dim=1) |
| | |
| | |
| | pre = self.transformer(motion_source_hidden_states = source_motion, |
| | motion_target_hidden_states = target_motion, |
| | image_hidden_states = image_hidden_states,) |
| | |
| | zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n) |
| | zt = einops.rearrange(pre,'(n t) c h w -> n t c h w',n=n) |
| | zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n) |
| | |
| | if return_meta_info: |
| | return {'zi' : zi, |
| | 'zj' : zj, |
| | } |
| | else: |
| | return zi,zt,zj |
| |
|
| | def sample_with_refimg_motion(self, |
| | ref_img:torch.Tensor, |
| | motion=torch.Tensor, |
| | sample_step:int = 10, |
| | return_meta_info=False, |
| | **kwargs,): |
| | """ |
| | Args: |
| | ref_img : (N,C,H,W) |
| | motion : (N,F,L,D) |
| | Return: |
| | video : (N,T,C,H,W) |
| | """ |
| | device = motion.device |
| | n,t,l,d = motion.shape |
| |
|
| | start_step = self.scheduler.num_step |
| | |
| | |
| | refimg = ref_img.unsqueeze(1) |
| | source_motion = self.motion_encoder(refimg) |
| |
|
| | source_motion = source_motion.repeat(1,t,1,1).flatten(0,1) |
| | target_motion = motion.flatten(0,1) |
| |
|
| | assert source_motion.shape == target_motion.shape , f'source_motion.shape {source_motion.shape} != target_motion.shape {target_motion.shape}' |
| |
|
| | |
| | if self.need_motion_transformer: |
| | target_motion = einops.rearrange(target_motion,'(n f) l d -> n f l d',n=n) |
| | target_motion = self.motion_transformer(target_motion) |
| | target_motion = einops.rearrange(target_motion,'n f l d -> (n f) l d',n=n) |
| | |
| | |
| | time_step = torch.ones((source_motion.shape[0],)).to(device) |
| | time_step = time_step * start_step |
| |
|
| | zi = refimg.repeat(1,t,1,1,1).flatten(0,1) |
| | zj = zi |
| | zt = self.zt_token.repeat(zj.shape[0],1,1,1) |
| | |
| | |
| | zt = zt.to(zj.dtype) |
| | image_hidden_states = torch.cat((zi,zt),dim=1) |
| | |
| | |
| | pre = self.transformer(motion_source_hidden_states = source_motion, |
| | motion_target_hidden_states = target_motion, |
| | image_hidden_states = image_hidden_states,) |
| | |
| | zi = einops.rearrange(zi,'(n t) c h w -> n t c h w',n=n) |
| | zt = einops.rearrange(pre,'(n t) c h w -> n t c h w',n=n) |
| | zj = einops.rearrange(zj,'(n t) c h w -> n t c h w',n=n) |
| | |
| | if return_meta_info: |
| | return {'zi' : zi, |
| | 'zj' : zj, |
| | } |
| | else: |
| | return zi,zt,zj |
| | |
| | def extract_motion(self,video:torch.tensor): |
| | |
| |
|
| | |
| | motion = self.motion_encoder(video) |
| |
|
| | if self.need_motion_transformer: |
| | motion = self.motion_transformer(motion) |
| |
|
| | |
| | return motion |
| | |
| |
|
| | def AMD_S(**kwargs) -> AMDModel: |
| | return AMDModel( |
| | |
| | enc_num_layers = 8, |
| | enc_nhead = 8, |
| | enc_ndim = 64, |
| | |
| | diffusion_attn_head_dim = 64, |
| | diffusion_attn_num_heads = 16, |
| | diffusion_out_channels = 4, |
| | diffusion_num_layers = 12, |
| | **kwargs) |
| |
|
| | def AMD_L(**kwargs) -> AMDModel: |
| | return AMDModel( |
| | |
| | enc_num_layers = 8, |
| | enc_nhead = 16, |
| | enc_ndim = 64, |
| | |
| | diffusion_attn_head_dim = 96, |
| | diffusion_attn_num_heads = 16, |
| | diffusion_out_channels = 4, |
| | diffusion_num_layers = 16, |
| | **kwargs) |
| |
|
| | def AMD_S_Rec(**kwargs) -> AMDModel: |
| | return AMDModel_Rec( |
| | |
| | enc_num_layers = 8, |
| | enc_nhead = 8, |
| | enc_ndim = 64, |
| | |
| | diffusion_attn_head_dim = 64, |
| | diffusion_attn_num_heads = 16, |
| | diffusion_out_channels = 4, |
| | diffusion_num_layers = 12, |
| | **kwargs) |
| |
|
| | def AMD_S_RecSplit(**kwargs) -> AMDModel: |
| | return AMDModel_Rec( |
| | |
| | enc_num_layers = 8, |
| | enc_nhead = 8, |
| | enc_ndim = 64, |
| | |
| | diffusion_attn_head_dim = 64, |
| | diffusion_attn_num_heads = 16, |
| | diffusion_out_channels = 4, |
| | diffusion_num_layers = 12, |
| | is_split = True, |
| | **kwargs) |
| |
|
| |
|
| | AMD_models = { |
| | "AMD_S": AMD_S, |
| | "AMD_L": AMD_L, |
| | "AMD_S_Rec": AMD_S_Rec, |
| | "AMD_S_RecSplit" : AMD_S_RecSplit, |
| | } |