Spaces:
Sleeping
Sleeping
from typing import Any, Dict, List, Optional, Tuple, Union | |
from dataclasses import dataclass | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffusers.configuration_utils import register_to_config | |
from diffusers.utils import BaseOutput | |
from models.controlnet_sdv import ControlNetSDVModel, zero_module | |
from models.softsplat import softsplat | |
import models.cmp.models as cmp_models | |
import models.cmp.utils as cmp_utils | |
import yaml | |
import os | |
import torchvision.transforms as transforms | |
class ArgObj(object): | |
def __init__(self): | |
pass | |
class CMP_demo(nn.Module): | |
def __init__(self, configfn, load_iter): | |
super().__init__() | |
args = ArgObj() | |
with open(configfn) as f: | |
config = yaml.full_load(f) | |
for k, v in config.items(): | |
setattr(args, k, v) | |
setattr(args, 'load_iter', load_iter) | |
setattr(args, 'exp_path', os.path.dirname(configfn)) | |
self.model = cmp_models.__dict__[args.model['arch']](args.model, dist_model=False) | |
self.model.load_state("{}/checkpoints".format(args.exp_path), args.load_iter, False) | |
self.model.switch_to('eval') | |
self.data_mean = args.data['data_mean'] | |
self.data_div = args.data['data_div'] | |
self.img_transform = transforms.Compose([ | |
transforms.Normalize(self.data_mean, self.data_div)]) | |
self.args = args | |
self.fuser = cmp_utils.Fuser(args.model['module']['nbins'], args.model['module']['fmax']) | |
torch.cuda.synchronize() | |
def run(self, image, sparse, mask): | |
dtype = image.dtype | |
image = image * 2 - 1 | |
self.model.set_input(image.float(), torch.cat([sparse, mask], dim=1).float(), None) | |
cmp_output = self.model.model(self.model.image_input, self.model.sparse_input) | |
flow = self.fuser.convert_flow(cmp_output) | |
if flow.shape[2] != self.model.image_input.shape[2]: | |
flow = nn.functional.interpolate( | |
flow, size=self.model.image_input.shape[2:4], | |
mode="bilinear", align_corners=True) | |
return flow.to(dtype) # [b, 2, h, w] | |
class FlowControlNetConditioningEmbeddingSVD(nn.Module): | |
def __init__( | |
self, | |
conditioning_embedding_channels: int, | |
conditioning_channels: int = 3, | |
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), | |
): | |
super().__init__() | |
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) | |
self.blocks = nn.ModuleList([]) | |
for i in range(len(block_out_channels) - 1): | |
channel_in = block_out_channels[i] | |
channel_out = block_out_channels[i + 1] | |
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) | |
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) | |
self.conv_out = zero_module( | |
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) | |
) | |
def forward(self, conditioning): | |
embedding = self.conv_in(conditioning) | |
embedding = F.silu(embedding) | |
for block in self.blocks: | |
embedding = block(embedding) | |
embedding = F.silu(embedding) | |
embedding = self.conv_out(embedding) | |
return embedding | |
class FlowControlNetFirstFrameEncoderLayer(nn.Module): | |
def __init__( | |
self, | |
c_in, | |
c_out, | |
is_downsample=False | |
): | |
super().__init__() | |
self.conv_in = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2 if is_downsample else 1) | |
def forward(self, feature): | |
''' | |
feature: [b, c, h, w] | |
''' | |
embedding = self.conv_in(feature) | |
embedding = F.silu(embedding) | |
return embedding | |
class FlowControlNetFirstFrameEncoder(nn.Module): | |
def __init__( | |
self, | |
c_in=320, | |
channels=[320, 640, 1280], | |
downsamples=[True, True, True], | |
use_zeroconv=True | |
): | |
super().__init__() | |
self.encoders = nn.ModuleList([]) | |
self.zeroconvs = nn.ModuleList([]) | |
for channel, downsample in zip(channels, downsamples): | |
self.encoders.append(FlowControlNetFirstFrameEncoderLayer(c_in, channel, is_downsample=downsample)) | |
self.zeroconvs.append(zero_module(nn.Conv2d(channel, channel, kernel_size=1)) if use_zeroconv else nn.Identity()) | |
c_in = channel | |
def forward(self, first_frame): | |
feature = first_frame | |
deep_features = [] | |
for encoder, zeroconv in zip(self.encoders, self.zeroconvs): | |
feature = encoder(feature) | |
# print(feature.shape) | |
deep_features.append(zeroconv(feature)) | |
return deep_features | |
class FlowControlNetOutput(BaseOutput): | |
""" | |
The output of [`FlowControlNetOutput`]. | |
Args: | |
down_block_res_samples (`tuple[torch.Tensor]`): | |
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should | |
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be | |
used to condition the original UNet's downsampling activations. | |
mid_down_block_re_sample (`torch.Tensor`): | |
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape | |
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. | |
Output can be used to condition the original UNet's middle block activation. | |
""" | |
down_block_res_samples: Tuple[torch.Tensor] | |
mid_block_res_sample: torch.Tensor | |
controlnet_flow: torch.Tensor | |
cmp_output: torch.Tensor | |
class FlowControlNet(ControlNetSDVModel): | |
_supports_gradient_checkpointing = True | |
def __init__( | |
self, | |
sample_size: Optional[int] = None, | |
in_channels: int = 8, | |
out_channels: int = 4, | |
down_block_types: Tuple[str] = ( | |
"CrossAttnDownBlockSpatioTemporal", | |
"CrossAttnDownBlockSpatioTemporal", | |
"CrossAttnDownBlockSpatioTemporal", | |
"DownBlockSpatioTemporal", | |
), | |
up_block_types: Tuple[str] = ( | |
"UpBlockSpatioTemporal", | |
"CrossAttnUpBlockSpatioTemporal", | |
"CrossAttnUpBlockSpatioTemporal", | |
"CrossAttnUpBlockSpatioTemporal", | |
), | |
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
addition_time_embed_dim: int = 256, | |
projection_class_embeddings_input_dim: int = 768, | |
layers_per_block: Union[int, Tuple[int]] = 2, | |
cross_attention_dim: Union[int, Tuple[int]] = 1024, | |
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, | |
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), | |
num_frames: int = 25, | |
conditioning_channels: int = 3, | |
conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256), | |
): | |
super().__init__() | |
self.flow_encoder = FlowControlNetFirstFrameEncoder() | |
self.controlnet_cond_embedding = FlowControlNetConditioningEmbeddingSVD( | |
conditioning_embedding_channels=block_out_channels[0], | |
block_out_channels=conditioning_embedding_out_channels, | |
conditioning_channels=conditioning_channels, | |
) | |
def get_warped_frames(self, first_frame, flows): | |
''' | |
video_frame: [b, c, w, h] | |
flows: [b, t-1, c, w, h] | |
''' | |
dtype = first_frame.dtype | |
warped_frames = [] | |
for i in range(flows.shape[1]): | |
warped_frame = softsplat(tenIn=first_frame.float(), tenFlow=flows[:, i].float(), tenMetric=None, strMode='avg').to(dtype) # [b, c, w, h] | |
warped_frames.append(warped_frame.unsqueeze(1)) # [b, 1, c, w, h] | |
warped_frames = torch.cat(warped_frames, dim=1) # [b, t-1, c, w, h] | |
return warped_frames | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
encoder_hidden_states: torch.Tensor, | |
added_time_ids: torch.Tensor, | |
controlnet_cond: torch.FloatTensor = None, # [b, 3, h, w] | |
controlnet_flow: torch.FloatTensor = None, # [b, 13, 2, h, w] | |
image_only_indicator: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
guess_mode: bool = False, | |
conditioning_scale: float = 1.0, | |
) -> Union[FlowControlNetOutput, Tuple]: | |
# 1. time | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
batch_size, num_frames = sample.shape[:2] | |
timesteps = timesteps.expand(batch_size) | |
t_emb = self.time_proj(timesteps) | |
# `Timesteps` does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=sample.dtype) | |
emb = self.time_embedding(t_emb) | |
time_embeds = self.add_time_proj(added_time_ids.flatten()) | |
time_embeds = time_embeds.reshape((batch_size, -1)) | |
time_embeds = time_embeds.to(emb.dtype) | |
aug_emb = self.add_embedding(time_embeds) | |
emb = emb + aug_emb | |
# Flatten the batch and frames dimensions | |
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] | |
sample = sample.flatten(0, 1) | |
# Repeat the embeddings num_video_frames times | |
# emb: [batch, channels] -> [batch * frames, channels] | |
emb = emb.repeat_interleave(num_frames, dim=0) | |
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] | |
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) | |
# 2. pre-process | |
sample = self.conv_in(sample) # [b*l, 320, h//8, w//8] | |
# controlnet cond | |
if controlnet_cond != None: | |
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) # [b, 320, h//8, w//8] | |
controlnet_cond_features = [controlnet_cond] + self.flow_encoder(controlnet_cond) # [4] | |
scales = [8, 16, 32, 64] | |
scale_flows = {} | |
fb, fl, fc, fh, fw = controlnet_flow.shape | |
# print(controlnet_flow.shape) | |
for scale in scales: | |
scaled_flow = F.interpolate(controlnet_flow.reshape(-1, fc, fh, fw), scale_factor=1/scale) | |
scaled_flow = scaled_flow.reshape(fb, fl, fc, fh // scale, fw // scale) / scale | |
scale_flows[scale] = scaled_flow | |
warped_cond_features = [] | |
for cond_feature in controlnet_cond_features: | |
cb, cc, ch, cw = cond_feature.shape | |
# print(cond_feature.shape) | |
warped_cond_feature = self.get_warped_frames(cond_feature, scale_flows[fh // ch]) | |
warped_cond_feature = torch.cat([cond_feature.unsqueeze(1), warped_cond_feature], dim=1) # [b, c, h, w] | |
wb, wl, wc, wh, ww = warped_cond_feature.shape | |
# print(warped_cond_feature.shape) | |
warped_cond_features.append(warped_cond_feature.reshape(wb * wl, wc, wh, ww)) | |
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) | |
count = 0 | |
length = len(warped_cond_features) | |
# add the warped feature in the first scale | |
sample = sample + warped_cond_features[count] | |
count += 1 | |
down_block_res_samples = (sample,) | |
for downsample_block in self.down_blocks: | |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
else: | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
image_only_indicator=image_only_indicator, | |
) | |
sample = sample + warped_cond_features[min(count, length - 1)] | |
count += 1 | |
down_block_res_samples += res_samples | |
# add the warped feature in the last scale | |
sample = sample + warped_cond_features[-1] | |
# 4. mid | |
sample = self.mid_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
controlnet_down_block_res_samples = () | |
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): | |
down_block_res_sample = controlnet_block(down_block_res_sample) | |
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) | |
down_block_res_samples = controlnet_down_block_res_samples | |
mid_block_res_sample = self.controlnet_mid_block(sample) | |
# 6. scaling | |
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] | |
mid_block_res_sample = mid_block_res_sample * conditioning_scale | |
if not return_dict: | |
return (down_block_res_samples, mid_block_res_sample, controlnet_flow, None) | |
return FlowControlNetOutput( | |
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, controlnet_flow=controlnet_flow, cmp_output=None | |
) | |