Spaces:
Sleeping
Sleeping
from typing import Tuple, Union | |
import torch | |
import torch.nn as nn | |
from torch.utils.checkpoint import checkpoint | |
import torch.nn.functional as F | |
from collections import deque | |
from einops import rearrange | |
from timm.models.layers import trunc_normal_ | |
from IPython import embed | |
from torch import Tensor | |
from utils import ( | |
is_context_parallel_initialized, | |
get_context_parallel_group, | |
get_context_parallel_world_size, | |
get_context_parallel_rank, | |
get_context_parallel_group_rank, | |
) | |
from .context_parallel_ops import ( | |
conv_scatter_to_context_parallel_region, | |
conv_gather_from_context_parallel_region, | |
cp_pass_from_previous_rank, | |
) | |
def divisible_by(num, den): | |
return (num % den) == 0 | |
def cast_tuple(t, length = 1): | |
return t if isinstance(t, tuple) else ((t,) * length) | |
def is_odd(n): | |
return not divisible_by(n, 2) | |
class CausalGroupNorm(nn.GroupNorm): | |
def forward(self, x: Tensor) -> Tensor: | |
t = x.shape[2] | |
x = rearrange(x, 'b c t h w -> (b t) c h w') | |
x = super().forward(x) | |
x = rearrange(x, '(b t) c h w -> b c t h w', t=t) | |
return x | |
class CausalConv3d(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size: Union[int, Tuple[int, int, int]], | |
stride: Union[int, Tuple[int, int, int]] = 1, | |
pad_mode: str ='constant', | |
**kwargs | |
): | |
super().__init__() | |
if isinstance(kernel_size, int): | |
kernel_size = cast_tuple(kernel_size, 3) | |
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size | |
self.time_kernel_size = time_kernel_size | |
assert is_odd(height_kernel_size) and is_odd(width_kernel_size) | |
dilation = kwargs.pop('dilation', 1) | |
self.pad_mode = pad_mode | |
if isinstance(stride, int): | |
stride = (stride, 1, 1) | |
time_pad = dilation * (time_kernel_size - 1) | |
height_pad = height_kernel_size // 2 | |
width_pad = width_kernel_size // 2 | |
self.temporal_stride = stride[0] | |
self.time_pad = time_pad | |
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) | |
self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0) | |
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs) | |
self.cache_front_feat = deque() | |
def _clear_context_parallel_cache(self): | |
del self.cache_front_feat | |
self.cache_front_feat = deque() | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): | |
trunc_normal_(m.weight, std=.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def context_parallel_forward(self, x): | |
x = cp_pass_from_previous_rank(x, dim=2, kernel_size=self.time_kernel_size) | |
x = F.pad(x, self.time_uncausal_padding, mode='constant') | |
cp_rank = get_context_parallel_rank() | |
if cp_rank != 0: | |
if self.temporal_stride == 2 and self.time_kernel_size == 3: | |
x = x[:,:,1:] | |
x = self.conv(x) | |
return x | |
def forward(self, x, is_init_image=True, temporal_chunk=False): | |
# temporal_chunk: whether to use the temporal chunk | |
if is_context_parallel_initialized(): | |
return self.context_parallel_forward(x) | |
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant' | |
if not temporal_chunk: | |
x = F.pad(x, self.time_causal_padding, mode=pad_mode) | |
else: | |
assert not self.training, "The feature cache should not be used in training" | |
if is_init_image: | |
# Encode the first chunk | |
x = F.pad(x, self.time_causal_padding, mode=pad_mode) | |
self._clear_context_parallel_cache() | |
self.cache_front_feat.append(x[:, :, -2:].clone().detach()) | |
else: | |
x = F.pad(x, self.time_uncausal_padding, mode=pad_mode) | |
video_front_context = self.cache_front_feat.pop() | |
self._clear_context_parallel_cache() | |
if self.temporal_stride == 1 and self.time_kernel_size == 3: | |
x = torch.cat([video_front_context, x], dim=2) | |
elif self.temporal_stride == 2 and self.time_kernel_size == 3: | |
x = torch.cat([video_front_context[:,:,-1:], x], dim=2) | |
self.cache_front_feat.append(x[:, :, -2:].clone().detach()) | |
x = self.conv(x) | |
return x |