from typing import Optional, Union, Sequence import jax import jax.numpy as jnp import flax.linen as nn import einops class ConvPseudo3D(nn.Module): features: int kernel_size: Sequence[int] strides: Union[None, int, Sequence[int]] = 1 padding: nn.linear.PaddingLike = 'SAME' dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.spatial_conv = nn.Conv( features = self.features, kernel_size = self.kernel_size, strides = self.strides, padding = self.padding, dtype = self.dtype ) self.temporal_conv = nn.Conv( features = self.features, kernel_size = (3,), padding = 'SAME', dtype = self.dtype, bias_init = nn.initializers.zeros_init() # TODO dirac delta (identity) initialization impl # kernel_init = torch.nn.init.dirac_ <-> jax/lax ) def __call__(self, x: jax.Array, convolve_across_time: bool = True) -> jax.Array: is_video = x.ndim == 5 convolve_across_time = convolve_across_time and is_video if is_video: b, f, h, w, c = x.shape x = einops.rearrange(x, 'b f h w c -> (b f) h w c') x = self.spatial_conv(x) if is_video: x = einops.rearrange(x, '(b f) h w c -> b f h w c', b = b) b, f, h, w, c = x.shape if not convolve_across_time: return x if is_video: x = einops.rearrange(x, 'b f h w c -> (b h w) f c') x = self.temporal_conv(x) x = einops.rearrange(x, '(b h w) f c -> b f h w c', h = h, w = w) return x class UpsamplePseudo3D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.conv = ConvPseudo3D( features = self.out_channels, kernel_size = (3, 3), strides = (1, 1), padding = ((1, 1), (1, 1)), dtype = self.dtype ) def __call__(self, hidden_states: jax.Array) -> jax.Array: is_video = hidden_states.ndim == 5 if is_video: b, *_ = hidden_states.shape hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c') batch, h, w, c = hidden_states.shape hidden_states = jax.image.resize( image = hidden_states, shape = (batch, h * 2, w * 2, c), method = 'nearest' ) if is_video: hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b) hidden_states = self.conv(hidden_states) return hidden_states class DownsamplePseudo3D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.conv = ConvPseudo3D( features = self.out_channels, kernel_size = (3, 3), strides = (2, 2), padding = ((1, 1), (1, 1)), dtype = self.dtype ) def __call__(self, hidden_states: jax.Array) -> jax.Array: hidden_states = self.conv(hidden_states) return hidden_states class ResnetBlockPseudo3D(nn.Module): in_channels: int out_channels: Optional[int] = None use_nin_shortcut: Optional[bool] = None dtype: jnp.dtype = jnp.float32 def setup(self) -> None: out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm( num_groups = 32, epsilon = 1e-5 ) self.conv1 = ConvPseudo3D( features = out_channels, kernel_size = (3, 3), strides = (1, 1), padding = ((1, 1), (1, 1)), dtype = self.dtype ) self.time_emb_proj = nn.Dense( out_channels, dtype = self.dtype ) self.norm2 = nn.GroupNorm( num_groups = 32, epsilon = 1e-5 ) self.conv2 = ConvPseudo3D( features = out_channels, kernel_size = (3, 3), strides = (1, 1), padding = ((1, 1), (1, 1)), dtype = self.dtype ) use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut self.conv_shortcut = None if use_nin_shortcut: self.conv_shortcut = ConvPseudo3D( features = self.out_channels, kernel_size = (1, 1), strides = (1, 1), padding = 'VALID', dtype = self.dtype ) def __call__(self, hidden_states: jax.Array, temb: jax.Array ) -> jax.Array: is_video = hidden_states.ndim == 5 residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.silu(hidden_states) hidden_states = self.conv1(hidden_states) temb = nn.silu(temb) temb = self.time_emb_proj(temb) temb = jnp.expand_dims(temb, 1) temb = jnp.expand_dims(temb, 1) if is_video: b, f, *_ = hidden_states.shape hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c') hidden_states = hidden_states + temb.repeat(f, 0) hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b) else: hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) hidden_states = nn.silu(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) hidden_states = hidden_states + residual return hidden_states