MOFA-Video_Traj / models /svdxt_featureflow_forward_controlnet_s2d_fixcmp_norefine.py
myniu
init
826d651
raw
history blame
No virus
14.6 kB
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import register_to_config
from diffusers.utils import BaseOutput
from models.controlnet_sdv import ControlNetSDVModel, zero_module
from models.softsplat import softsplat
import models.cmp.models as cmp_models
import models.cmp.utils as cmp_utils
import yaml
import os
import torchvision.transforms as transforms
class ArgObj(object):
def __init__(self):
pass
class CMP_demo(nn.Module):
def __init__(self, configfn, load_iter):
super().__init__()
args = ArgObj()
with open(configfn) as f:
config = yaml.full_load(f)
for k, v in config.items():
setattr(args, k, v)
setattr(args, 'load_iter', load_iter)
setattr(args, 'exp_path', os.path.dirname(configfn))
self.model = cmp_models.__dict__[args.model['arch']](args.model, dist_model=False)
self.model.load_state("{}/checkpoints".format(args.exp_path), args.load_iter, False)
self.model.switch_to('eval')
self.data_mean = args.data['data_mean']
self.data_div = args.data['data_div']
self.img_transform = transforms.Compose([
transforms.Normalize(self.data_mean, self.data_div)])
self.args = args
self.fuser = cmp_utils.Fuser(args.model['module']['nbins'], args.model['module']['fmax'])
torch.cuda.synchronize()
def run(self, image, sparse, mask):
dtype = image.dtype
image = image * 2 - 1
self.model.set_input(image.float(), torch.cat([sparse, mask], dim=1).float(), None)
cmp_output = self.model.model(self.model.image_input, self.model.sparse_input)
flow = self.fuser.convert_flow(cmp_output)
if flow.shape[2] != self.model.image_input.shape[2]:
flow = nn.functional.interpolate(
flow, size=self.model.image_input.shape[2:4],
mode="bilinear", align_corners=True)
return flow.to(dtype) # [b, 2, h, w]
class FlowControlNetConditioningEmbeddingSVD(nn.Module):
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = zero_module(
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class FlowControlNetFirstFrameEncoderLayer(nn.Module):
def __init__(
self,
c_in,
c_out,
is_downsample=False
):
super().__init__()
self.conv_in = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2 if is_downsample else 1)
def forward(self, feature):
'''
feature: [b, c, h, w]
'''
embedding = self.conv_in(feature)
embedding = F.silu(embedding)
return embedding
class FlowControlNetFirstFrameEncoder(nn.Module):
def __init__(
self,
c_in=320,
channels=[320, 640, 1280],
downsamples=[True, True, True],
use_zeroconv=True
):
super().__init__()
self.encoders = nn.ModuleList([])
self.zeroconvs = nn.ModuleList([])
for channel, downsample in zip(channels, downsamples):
self.encoders.append(FlowControlNetFirstFrameEncoderLayer(c_in, channel, is_downsample=downsample))
self.zeroconvs.append(zero_module(nn.Conv2d(channel, channel, kernel_size=1)) if use_zeroconv else nn.Identity())
c_in = channel
def forward(self, first_frame):
feature = first_frame
deep_features = []
for encoder, zeroconv in zip(self.encoders, self.zeroconvs):
feature = encoder(feature)
# print(feature.shape)
deep_features.append(zeroconv(feature))
return deep_features
@dataclass
class FlowControlNetOutput(BaseOutput):
"""
The output of [`FlowControlNetOutput`].
Args:
down_block_res_samples (`tuple[torch.Tensor]`):
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
used to condition the original UNet's downsampling activations.
mid_down_block_re_sample (`torch.Tensor`):
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
Output can be used to condition the original UNet's middle block activation.
"""
down_block_res_samples: Tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
controlnet_flow: torch.Tensor
cmp_output: torch.Tensor
class FlowControlNet(ControlNetSDVModel):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 8,
out_channels: int = 4,
down_block_types: Tuple[str] = (
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
up_block_types: Tuple[str] = (
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
addition_time_embed_dim: int = 256,
projection_class_embeddings_input_dim: int = 768,
layers_per_block: Union[int, Tuple[int]] = 2,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
num_frames: int = 25,
conditioning_channels: int = 3,
conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256),
):
super().__init__()
self.flow_encoder = FlowControlNetFirstFrameEncoder()
self.controlnet_cond_embedding = FlowControlNetConditioningEmbeddingSVD(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
def get_warped_frames(self, first_frame, flows):
'''
video_frame: [b, c, w, h]
flows: [b, t-1, c, w, h]
'''
dtype = first_frame.dtype
warped_frames = []
for i in range(flows.shape[1]):
warped_frame = softsplat(tenIn=first_frame.float(), tenFlow=flows[:, i].float(), tenMetric=None, strMode='avg').to(dtype) # [b, c, w, h]
warped_frames.append(warped_frame.unsqueeze(1)) # [b, 1, c, w, h]
warped_frames = torch.cat(warped_frames, dim=1) # [b, t-1, c, w, h]
return warped_frames
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
added_time_ids: torch.Tensor,
controlnet_cond: torch.FloatTensor = None, # [b, 3, h, w]
controlnet_flow: torch.FloatTensor = None, # [b, 13, 2, h, w]
image_only_indicator: Optional[torch.Tensor] = None,
return_dict: bool = True,
guess_mode: bool = False,
conditioning_scale: float = 1.0,
) -> Union[FlowControlNetOutput, Tuple]:
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
batch_size, num_frames = sample.shape[:2]
timesteps = timesteps.expand(batch_size)
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(added_time_ids.flatten())
time_embeds = time_embeds.reshape((batch_size, -1))
time_embeds = time_embeds.to(emb.dtype)
aug_emb = self.add_embedding(time_embeds)
emb = emb + aug_emb
# Flatten the batch and frames dimensions
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
# 2. pre-process
sample = self.conv_in(sample) # [b*l, 320, h//8, w//8]
# controlnet cond
if controlnet_cond != None:
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) # [b, 320, h//8, w//8]
controlnet_cond_features = [controlnet_cond] + self.flow_encoder(controlnet_cond) # [4]
scales = [8, 16, 32, 64]
scale_flows = {}
fb, fl, fc, fh, fw = controlnet_flow.shape
# print(controlnet_flow.shape)
for scale in scales:
scaled_flow = F.interpolate(controlnet_flow.reshape(-1, fc, fh, fw), scale_factor=1/scale)
scaled_flow = scaled_flow.reshape(fb, fl, fc, fh // scale, fw // scale) / scale
scale_flows[scale] = scaled_flow
warped_cond_features = []
for cond_feature in controlnet_cond_features:
cb, cc, ch, cw = cond_feature.shape
# print(cond_feature.shape)
warped_cond_feature = self.get_warped_frames(cond_feature, scale_flows[fh // ch])
warped_cond_feature = torch.cat([cond_feature.unsqueeze(1), warped_cond_feature], dim=1) # [b, c, h, w]
wb, wl, wc, wh, ww = warped_cond_feature.shape
# print(warped_cond_feature.shape)
warped_cond_features.append(warped_cond_feature.reshape(wb * wl, wc, wh, ww))
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
count = 0
length = len(warped_cond_features)
# add the warped feature in the first scale
sample = sample + warped_cond_features[count]
count += 1
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
image_only_indicator=image_only_indicator,
)
sample = sample + warped_cond_features[min(count, length - 1)]
count += 1
down_block_res_samples += res_samples
# add the warped feature in the last scale
sample = sample + warped_cond_features[-1]
# 4. mid
sample = self.mid_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if not return_dict:
return (down_block_res_samples, mid_block_res_sample, controlnet_flow, None)
return FlowControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, controlnet_flow=controlnet_flow, cmp_output=None
)