Spaces:
Running
on
Zero
Running
on
Zero
from typing import Callable, Dict, List, Optional, Union | |
import numpy as np | |
import torch | |
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import ( | |
_resize_with_antialiasing, | |
StableVideoDiffusionPipelineOutput, | |
StableVideoDiffusionPipeline, | |
retrieve_timesteps, | |
) | |
from diffusers.utils import logging | |
from diffusers.utils.torch_utils import randn_tensor | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class DepthCrafterPipeline(StableVideoDiffusionPipeline): | |
def encode_video( | |
self, | |
video: torch.Tensor, | |
chunk_size: int = 14, | |
) -> torch.Tensor: | |
""" | |
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames | |
:param chunk_size: the chunk size to encode video | |
:return: image_embeddings in shape of [b, 1024] | |
""" | |
video_224 = _resize_with_antialiasing(video.float(), (224, 224)) | |
video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1] | |
embeddings = [] | |
for i in range(0, video_224.shape[0], chunk_size): | |
tmp = self.feature_extractor( | |
images=video_224[i : i + chunk_size], | |
do_normalize=True, | |
do_center_crop=False, | |
do_resize=False, | |
do_rescale=False, | |
return_tensors="pt", | |
).pixel_values.to(video.device, dtype=video.dtype) | |
embeddings.append(self.image_encoder(tmp).image_embeds) # [b, 1024] | |
embeddings = torch.cat(embeddings, dim=0) # [t, 1024] | |
return embeddings | |
def encode_vae_video( | |
self, | |
video: torch.Tensor, | |
chunk_size: int = 14, | |
): | |
""" | |
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames | |
:param chunk_size: the chunk size to encode video | |
:return: vae latents in shape of [b, c, h, w] | |
""" | |
video_latents = [] | |
for i in range(0, video.shape[0], chunk_size): | |
video_latents.append( | |
self.vae.encode(video[i : i + chunk_size]).latent_dist.mode() | |
) | |
video_latents = torch.cat(video_latents, dim=0) | |
return video_latents | |
def check_inputs(video, height, width): | |
""" | |
:param video: | |
:param height: | |
:param width: | |
:return: | |
""" | |
if not isinstance(video, torch.Tensor) and not isinstance(video, np.ndarray): | |
raise ValueError( | |
f"Expected `video` to be a `torch.Tensor` or `VideoReader`, but got a {type(video)}" | |
) | |
if height % 8 != 0 or width % 8 != 0: | |
raise ValueError( | |
f"`height` and `width` have to be divisible by 8 but are {height} and {width}." | |
) | |
def __call__( | |
self, | |
video: Union[np.ndarray, torch.Tensor], | |
height: int = 576, | |
width: int = 1024, | |
num_inference_steps: int = 25, | |
guidance_scale: float = 1.0, | |
window_size: Optional[int] = 110, | |
noise_aug_strength: float = 0.02, | |
decode_chunk_size: Optional[int] = None, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "pil", | |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
return_dict: bool = True, | |
overlap: int = 25, | |
track_time: bool = False, | |
): | |
""" | |
:param video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1] | |
:param height: | |
:param width: | |
:param num_inference_steps: | |
:param guidance_scale: | |
:param window_size: sliding window processing size | |
:param fps: | |
:param motion_bucket_id: | |
:param noise_aug_strength: | |
:param decode_chunk_size: | |
:param generator: | |
:param latents: | |
:param output_type: | |
:param callback_on_step_end: | |
:param callback_on_step_end_tensor_inputs: | |
:param return_dict: | |
:return: | |
""" | |
# 0. Default height and width to unet | |
height = height or self.unet.config.sample_size * self.vae_scale_factor | |
width = width or self.unet.config.sample_size * self.vae_scale_factor | |
num_frames = video.shape[0] | |
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8 | |
if num_frames <= window_size: | |
window_size = num_frames | |
overlap = 0 | |
stride = window_size - overlap | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs(video, height, width) | |
# 2. Define call parameters | |
batch_size = 1 | |
device = self._execution_device | |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
# corresponds to doing no classifier free guidance. | |
self._guidance_scale = guidance_scale | |
# 3. Encode input video | |
if isinstance(video, np.ndarray): | |
video = torch.from_numpy(video.transpose(0, 3, 1, 2)) | |
else: | |
assert isinstance(video, torch.Tensor) | |
video = video.to(device=device, dtype=self.dtype) | |
video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w] | |
if track_time: | |
start_event = torch.cuda.Event(enable_timing=True) | |
encode_event = torch.cuda.Event(enable_timing=True) | |
denoise_event = torch.cuda.Event(enable_timing=True) | |
decode_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
video_embeddings = self.encode_video( | |
video, chunk_size=decode_chunk_size | |
).unsqueeze( | |
0 | |
) # [1, t, 1024] | |
torch.cuda.empty_cache() | |
# 4. Encode input image using VAE | |
noise = randn_tensor( | |
video.shape, generator=generator, device=device, dtype=video.dtype | |
) | |
video = video + noise_aug_strength * noise # in [t, c, h, w] | |
# pdb.set_trace() | |
needs_upcasting = ( | |
self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
) | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float32) | |
video_latents = self.encode_vae_video( | |
video.to(self.vae.dtype), | |
chunk_size=decode_chunk_size, | |
).unsqueeze( | |
0 | |
) # [1, t, c, h, w] | |
if track_time: | |
encode_event.record() | |
torch.cuda.synchronize() | |
elapsed_time_ms = start_event.elapsed_time(encode_event) | |
print(f"Elapsed time for encoding video: {elapsed_time_ms} ms") | |
torch.cuda.empty_cache() | |
# cast back to fp16 if needed | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float16) | |
# 5. Get Added Time IDs | |
added_time_ids = self._get_add_time_ids( | |
7, | |
127, | |
noise_aug_strength, | |
video_embeddings.dtype, | |
batch_size, | |
1, | |
False, | |
) # [1 or 2, 3] | |
added_time_ids = added_time_ids.to(device) | |
# 6. Prepare timesteps | |
timesteps, num_inference_steps = retrieve_timesteps( | |
self.scheduler, num_inference_steps, device, None, None | |
) | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
self._num_timesteps = len(timesteps) | |
# 7. Prepare latent variables | |
num_channels_latents = self.unet.config.in_channels | |
latents_init = self.prepare_latents( | |
batch_size, | |
window_size, | |
num_channels_latents, | |
height, | |
width, | |
video_embeddings.dtype, | |
device, | |
generator, | |
latents, | |
) # [1, t, c, h, w] | |
latents_all = None | |
idx_start = 0 | |
if overlap > 0: | |
weights = torch.linspace(0, 1, overlap, device=device) | |
weights = weights.view(1, overlap, 1, 1, 1) | |
else: | |
weights = None | |
torch.cuda.empty_cache() | |
# inference strategy for long videos | |
# two main strategies: 1. noise init from previous frame, 2. segments stitching | |
while idx_start < num_frames - overlap: | |
idx_end = min(idx_start + window_size, num_frames) | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
# 9. Denoising loop | |
latents = latents_init[:, : idx_end - idx_start].clone() | |
latents_init = torch.cat( | |
[latents_init[:, -overlap:], latents_init[:, :stride]], dim=1 | |
) | |
video_latents_current = video_latents[:, idx_start:idx_end] | |
video_embeddings_current = video_embeddings[:, idx_start:idx_end] | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
if latents_all is not None and i == 0: | |
latents[:, :overlap] = ( | |
latents_all[:, -overlap:] | |
+ latents[:, :overlap] | |
/ self.scheduler.init_noise_sigma | |
* self.scheduler.sigmas[i] | |
) | |
latent_model_input = latents # [1, t, c, h, w] | |
latent_model_input = self.scheduler.scale_model_input( | |
latent_model_input, t | |
) # [1, t, c, h, w] | |
latent_model_input = torch.cat( | |
[latent_model_input, video_latents_current], dim=2 | |
) | |
noise_pred = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=video_embeddings_current, | |
added_time_ids=added_time_ids, | |
return_dict=False, | |
)[0] | |
# perform guidance | |
if self.do_classifier_free_guidance: | |
latent_model_input = latents | |
latent_model_input = self.scheduler.scale_model_input( | |
latent_model_input, t | |
) | |
latent_model_input = torch.cat( | |
[latent_model_input, torch.zeros_like(latent_model_input)], | |
dim=2, | |
) | |
noise_pred_uncond = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=torch.zeros_like( | |
video_embeddings_current | |
), | |
added_time_ids=added_time_ids, | |
return_dict=False, | |
)[0] | |
noise_pred = noise_pred_uncond + self.guidance_scale * ( | |
noise_pred - noise_pred_uncond | |
) | |
latents = self.scheduler.step(noise_pred, t, latents).prev_sample | |
if callback_on_step_end is not None: | |
callback_kwargs = {} | |
for k in callback_on_step_end_tensor_inputs: | |
callback_kwargs[k] = locals()[k] | |
callback_outputs = callback_on_step_end( | |
self, i, t, callback_kwargs | |
) | |
latents = callback_outputs.pop("latents", latents) | |
if i == len(timesteps) - 1 or ( | |
(i + 1) > num_warmup_steps | |
and (i + 1) % self.scheduler.order == 0 | |
): | |
progress_bar.update() | |
if latents_all is None: | |
latents_all = latents.clone() | |
else: | |
assert weights is not None | |
# latents_all[:, -overlap:] = ( | |
# latents[:, :overlap] + latents_all[:, -overlap:] | |
# ) / 2.0 | |
latents_all[:, -overlap:] = latents[ | |
:, :overlap | |
] * weights + latents_all[:, -overlap:] * (1 - weights) | |
latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1) | |
idx_start += stride | |
if track_time: | |
denoise_event.record() | |
torch.cuda.synchronize() | |
elapsed_time_ms = encode_event.elapsed_time(denoise_event) | |
print(f"Elapsed time for denoising video: {elapsed_time_ms} ms") | |
if not output_type == "latent": | |
# cast back to fp16 if needed | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float16) | |
frames = self.decode_latents(latents_all, num_frames, decode_chunk_size) | |
if track_time: | |
decode_event.record() | |
torch.cuda.synchronize() | |
elapsed_time_ms = denoise_event.elapsed_time(decode_event) | |
print(f"Elapsed time for decoding video: {elapsed_time_ms} ms") | |
frames = self.video_processor.postprocess_video( | |
video=frames, output_type=output_type | |
) | |
else: | |
frames = latents_all | |
self.maybe_free_model_hooks() | |
if not return_dict: | |
return frames | |
return StableVideoDiffusionPipelineOutput(frames=frames) | |