|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
class FlaxUpsample2D(nn.Module): |
|
out_channels: int |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.conv = nn.Conv( |
|
self.out_channels, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, hidden_states): |
|
batch, height, width, channels = hidden_states.shape |
|
hidden_states = jax.image.resize( |
|
hidden_states, |
|
shape=(batch, height * 2, width * 2, channels), |
|
method="nearest", |
|
) |
|
hidden_states = self.conv(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class FlaxDownsample2D(nn.Module): |
|
out_channels: int |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.conv = nn.Conv( |
|
self.out_channels, |
|
kernel_size=(3, 3), |
|
strides=(2, 2), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, hidden_states): |
|
|
|
|
|
hidden_states = self.conv(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class FlaxResnetBlock2D(nn.Module): |
|
in_channels: int |
|
out_channels: int = None |
|
dropout_prob: float = 0.0 |
|
use_nin_shortcut: bool = None |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
out_channels = self.in_channels if self.out_channels is None else self.out_channels |
|
|
|
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) |
|
self.conv1 = nn.Conv( |
|
out_channels, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) |
|
|
|
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) |
|
self.dropout = nn.Dropout(self.dropout_prob) |
|
self.conv2 = nn.Conv( |
|
out_channels, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut |
|
|
|
self.conv_shortcut = None |
|
if use_nin_shortcut: |
|
self.conv_shortcut = nn.Conv( |
|
out_channels, |
|
kernel_size=(1, 1), |
|
strides=(1, 1), |
|
padding="VALID", |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, hidden_states, temb, deterministic=True): |
|
residual = hidden_states |
|
hidden_states = self.norm1(hidden_states) |
|
hidden_states = nn.swish(hidden_states) |
|
hidden_states = self.conv1(hidden_states) |
|
|
|
temb = self.time_emb_proj(nn.swish(temb)) |
|
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) |
|
hidden_states = hidden_states + temb |
|
|
|
hidden_states = self.norm2(hidden_states) |
|
hidden_states = nn.swish(hidden_states) |
|
hidden_states = self.dropout(hidden_states, deterministic) |
|
hidden_states = self.conv2(hidden_states) |
|
|
|
if self.conv_shortcut is not None: |
|
residual = self.conv_shortcut(residual) |
|
|
|
return hidden_states + residual |
|
|