Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from typing import Tuple, Union | |
import torch | |
from einops import rearrange | |
from torch import nn | |
def make_triple(value: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]: | |
value = (value,) * 3 if isinstance(value, int) else value | |
assert len(value) == 3 | |
return value | |
class AudioPack(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
patch_size: Union[int, Tuple[int, int, int]], | |
dim: int, | |
layernorm=False, | |
): | |
super().__init__() | |
t, h, w = make_triple(patch_size) | |
self.patch_size = t, h, w | |
self.proj = nn.Linear(in_channels * t * h * w, dim) | |
if layernorm: | |
self.norm_out = nn.LayerNorm(dim) | |
else: | |
self.norm_out = None | |
def forward( | |
self, | |
vid: torch.Tensor, | |
) -> torch.Tensor: | |
t, h, w = self.patch_size | |
vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) | |
vid = self.proj(vid) | |
if self.norm_out is not None: | |
vid = self.norm_out(vid) | |
return vid |