Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class CogVideoXUpsample3D(nn.Module): | |
| r""" | |
| A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. | |
| Args: | |
| in_channels (`int`): | |
| Number of channels in the input image. | |
| out_channels (`int`): | |
| Number of channels produced by the convolution. | |
| kernel_size (`int`, defaults to `3`): | |
| Size of the convolving kernel. | |
| stride (`int`, defaults to `1`): | |
| Stride of the convolution. | |
| padding (`int`, defaults to `1`): | |
| Padding added to all four sides of the input. | |
| compress_time (`bool`, defaults to `False`): | |
| Whether or not to compress the time dimension. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int = 3, | |
| stride: int = 1, | |
| padding: int = 1, | |
| compress_time: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) | |
| self.compress_time = compress_time | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| if self.compress_time: | |
| if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: | |
| # split first frame | |
| x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] | |
| x_first = F.interpolate(x_first, scale_factor=2.0) | |
| x_rest = F.interpolate(x_rest, scale_factor=2.0) | |
| x_first = x_first[:, :, None, :, :] | |
| inputs = torch.cat([x_first, x_rest], dim=2) | |
| elif inputs.shape[2] > 1: | |
| inputs = F.interpolate(inputs, scale_factor=2.0) | |
| else: | |
| inputs = inputs.squeeze(2) | |
| inputs = F.interpolate(inputs, scale_factor=2.0) | |
| inputs = inputs[:, :, None, :, :] | |
| else: | |
| # only interpolate 2D | |
| b, c, t, h, w = inputs.shape | |
| inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) | |
| inputs = F.interpolate(inputs, scale_factor=2.0) | |
| inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) | |
| b, c, t, h, w = inputs.shape | |
| inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) | |
| inputs = self.conv(inputs) | |
| inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) | |
| return inputs | |