Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,568 Bytes
7c1a14b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from typing import Union, Tuple
import torch
from diffusers import UNetSpatioTemporalConditionModel
from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
class DiffusersUNetSpatioTemporalConditionModelDepthCrafter(
UNetSpatioTemporalConditionModel
):
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
added_time_ids: torch.Tensor,
return_dict: bool = True,
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
batch_size, num_frames = sample.shape[:2]
timesteps = timesteps.expand(batch_size)
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.conv_in.weight.dtype)
emb = self.time_embedding(t_emb) # [batch_size * num_frames, channels]
time_embeds = self.add_time_proj(added_time_ids.flatten())
time_embeds = time_embeds.reshape((batch_size, -1))
time_embeds = time_embeds.to(emb.dtype)
aug_emb = self.add_embedding(time_embeds)
emb = emb + aug_emb
# Flatten the batch and frames dimensions
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
# encoder_hidden_states: [batch, frames, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.flatten(0, 1).unsqueeze(1)
# 2. pre-process
sample = sample.to(dtype=self.conv_in.weight.dtype)
assert sample.dtype == self.conv_in.weight.dtype, (
f"sample.dtype: {sample.dtype}, "
f"self.conv_in.weight.dtype: {self.conv_in.weight.dtype}"
)
sample = self.conv_in(sample)
image_only_indicator = torch.zeros(
batch_size, num_frames, dtype=sample.dtype, device=sample.device
)
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
image_only_indicator=image_only_indicator,
)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[
: -len(upsample_block.resnets)
]
if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
sample = upsample_block(
hidden_states=sample,
res_hidden_states_tuple=res_samples,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
)
else:
sample = upsample_block(
hidden_states=sample,
res_hidden_states_tuple=res_samples,
temb=emb,
image_only_indicator=image_only_indicator,
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# 7. Reshape back to original shape
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
if not return_dict:
return (sample,)
return UNetSpatioTemporalConditionOutput(sample=sample)
|