|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import flax.linen as nn | 
					
						
						|  | import jax | 
					
						
						|  | import jax.numpy as jnp | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FlaxUpsample2D(nn.Module): | 
					
						
						|  | out_channels: int | 
					
						
						|  | dtype: jnp.dtype = jnp.float32 | 
					
						
						|  |  | 
					
						
						|  | def setup(self): | 
					
						
						|  | self.conv = nn.Conv( | 
					
						
						|  | self.out_channels, | 
					
						
						|  | kernel_size=(3, 3), | 
					
						
						|  | strides=(1, 1), | 
					
						
						|  | padding=((1, 1), (1, 1)), | 
					
						
						|  | dtype=self.dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, hidden_states): | 
					
						
						|  | batch, height, width, channels = hidden_states.shape | 
					
						
						|  | hidden_states = jax.image.resize( | 
					
						
						|  | hidden_states, | 
					
						
						|  | shape=(batch, height * 2, width * 2, channels), | 
					
						
						|  | method="nearest", | 
					
						
						|  | ) | 
					
						
						|  | hidden_states = self.conv(hidden_states) | 
					
						
						|  | return hidden_states | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FlaxDownsample2D(nn.Module): | 
					
						
						|  | out_channels: int | 
					
						
						|  | dtype: jnp.dtype = jnp.float32 | 
					
						
						|  |  | 
					
						
						|  | def setup(self): | 
					
						
						|  | self.conv = nn.Conv( | 
					
						
						|  | self.out_channels, | 
					
						
						|  | kernel_size=(3, 3), | 
					
						
						|  | strides=(2, 2), | 
					
						
						|  | padding=((1, 1), (1, 1)), | 
					
						
						|  | dtype=self.dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, hidden_states): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | hidden_states = self.conv(hidden_states) | 
					
						
						|  | return hidden_states | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FlaxResnetBlock2D(nn.Module): | 
					
						
						|  | in_channels: int | 
					
						
						|  | out_channels: int = None | 
					
						
						|  | dropout_prob: float = 0.0 | 
					
						
						|  | use_nin_shortcut: bool = None | 
					
						
						|  | dtype: jnp.dtype = jnp.float32 | 
					
						
						|  |  | 
					
						
						|  | def setup(self): | 
					
						
						|  | out_channels = self.in_channels if self.out_channels is None else self.out_channels | 
					
						
						|  |  | 
					
						
						|  | self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) | 
					
						
						|  | self.conv1 = nn.Conv( | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size=(3, 3), | 
					
						
						|  | strides=(1, 1), | 
					
						
						|  | padding=((1, 1), (1, 1)), | 
					
						
						|  | dtype=self.dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) | 
					
						
						|  |  | 
					
						
						|  | self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) | 
					
						
						|  | self.dropout = nn.Dropout(self.dropout_prob) | 
					
						
						|  | self.conv2 = nn.Conv( | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size=(3, 3), | 
					
						
						|  | strides=(1, 1), | 
					
						
						|  | padding=((1, 1), (1, 1)), | 
					
						
						|  | dtype=self.dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut | 
					
						
						|  |  | 
					
						
						|  | self.conv_shortcut = None | 
					
						
						|  | if use_nin_shortcut: | 
					
						
						|  | self.conv_shortcut = nn.Conv( | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size=(1, 1), | 
					
						
						|  | strides=(1, 1), | 
					
						
						|  | padding="VALID", | 
					
						
						|  | dtype=self.dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, hidden_states, temb, deterministic=True): | 
					
						
						|  | residual = hidden_states | 
					
						
						|  | hidden_states = self.norm1(hidden_states) | 
					
						
						|  | hidden_states = nn.swish(hidden_states) | 
					
						
						|  | hidden_states = self.conv1(hidden_states) | 
					
						
						|  |  | 
					
						
						|  | temb = self.time_emb_proj(nn.swish(temb)) | 
					
						
						|  | temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) | 
					
						
						|  | hidden_states = hidden_states + temb | 
					
						
						|  |  | 
					
						
						|  | hidden_states = self.norm2(hidden_states) | 
					
						
						|  | hidden_states = nn.swish(hidden_states) | 
					
						
						|  | hidden_states = self.dropout(hidden_states, deterministic) | 
					
						
						|  | hidden_states = self.conv2(hidden_states) | 
					
						
						|  |  | 
					
						
						|  | if self.conv_shortcut is not None: | 
					
						
						|  | residual = self.conv_shortcut(residual) | 
					
						
						|  |  | 
					
						
						|  | return hidden_states + residual | 
					
						
						|  |  |