Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn.functional as F | |
| def split_with_overlap(video_BCTHW, num_video_frames, overlap=2, tobf16=True): | |
| """ | |
| Splits the video tensor into chunks of num_video_frames with a specified overlap. | |
| Args: | |
| - video_BCTHW (torch.Tensor): Input tensor with shape [Batch, Channels, Time, Height, Width]. | |
| - num_video_frames (int): Number of frames per chunk. | |
| - overlap (int): Number of overlapping frames between chunks. | |
| Returns: | |
| - List of torch.Tensors: List of video chunks with overlap. | |
| """ | |
| # Get the dimensions of the input tensor | |
| B, C, T, H, W = video_BCTHW.shape | |
| # Ensure overlap is less than num_video_frames | |
| assert overlap < num_video_frames, "Overlap should be less than num_video_frames." | |
| # List to store the chunks | |
| chunks = [] | |
| # Step size for the sliding window | |
| step = num_video_frames - overlap | |
| # Loop through the time dimension (T) with the sliding window | |
| for start in range(0, T - overlap, step): | |
| end = start + num_video_frames | |
| # Handle the case when the last chunk might go out of bounds | |
| if end > T: | |
| # Get the last available frame | |
| num_padding_frames = end - T | |
| chunk = F.pad(video_BCTHW[:, :, start:T, :, :], (0, 0, 0, 0, 0, num_padding_frames), mode="reflect") | |
| else: | |
| # Regular case: no padding needed | |
| chunk = video_BCTHW[:, :, start:end, :, :] | |
| if tobf16: | |
| chunks.append(chunk.to(torch.bfloat16)) | |
| else: | |
| chunks.append(chunk) | |
| return chunks | |
| def linear_blend_video_list(videos, D): | |
| """ | |
| Linearly blends a list of videos along the time dimension with overlap length D. | |
| Parameters: | |
| - videos: list of video tensors, each of shape [b, c, t, h, w] | |
| - D: int, overlap length | |
| Returns: | |
| - output_video: blended video tensor of shape [b, c, L, h, w] | |
| """ | |
| assert len(videos) >= 2, "At least two videos are required." | |
| b, c, t, h, w = videos[0].shape | |
| N = len(videos) | |
| # Ensure all videos have the same shape | |
| for video in videos: | |
| assert video.shape == (b, c, t, h, w), "All videos must have the same shape." | |
| # Calculate total output length | |
| L = N * t - D * (N - 1) | |
| output_video = torch.zeros((b, c, L, h, w), device=videos[0].device) | |
| output_index = 0 # Current index in the output video | |
| for i in range(N): | |
| if i == 0: | |
| # Copy frames from the first video up to t - D | |
| output_video[:, :, output_index : output_index + t - D, :, :] = videos[i][:, :, : t - D, :, :] | |
| output_index += t - D | |
| else: | |
| # Blend overlapping frames between videos[i-1] and videos[i] | |
| blend_weights = torch.linspace(0, 1, steps=D, device=videos[0].device) | |
| for j in range(D): | |
| w1 = 1 - blend_weights[j] | |
| w2 = blend_weights[j] | |
| frame_from_prev = videos[i - 1][:, :, t - D + j, :, :] | |
| frame_from_curr = videos[i][:, :, j, :, :] | |
| output_frame = w1 * frame_from_prev + w2 * frame_from_curr | |
| output_video[:, :, output_index, :, :] = output_frame | |
| output_index += 1 | |
| if i < N - 1: | |
| # Copy non-overlapping frames from current video up to t - D | |
| frames_to_copy = t - 2 * D | |
| if frames_to_copy > 0: | |
| output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][ | |
| :, :, D : t - D, :, : | |
| ] | |
| output_index += frames_to_copy | |
| else: | |
| # For the last video, copy frames from D to t | |
| frames_to_copy = t - D | |
| output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][:, :, D:, :, :] | |
| output_index += frames_to_copy | |
| return output_video | |