# coding=utf-8 # Copyright 2023 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Convolutional module library.""" import functools from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union from flax import linen as nn import jax Shape = Tuple[int] DType = Any Array = Any # jnp.ndarray ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet ProcessorState = ArrayTree PRNGKey = Array NestedDict = Dict[str, Any] class SimpleCNN(nn.Module): """Simple CNN encoder with multiple Conv+ReLU layers.""" features: Sequence[int] kernel_size: Sequence[Tuple[int, int]] strides: Sequence[Tuple[int, int]] transpose: bool = False use_batch_norm: bool = False axis_name: Optional[str] = None # Over which axis to aggregate batch stats. padding: Union[str, Iterable[Tuple[int, int]]] = "SAME" resize_output: Optional[Iterable[int]] = None @nn.compact def __call__(self, inputs, train = False): num_layers = len(self.features) assert len(self.kernel_size) == num_layers, ( "len(kernel_size) and len(features) must match.") assert len(self.strides) == num_layers, ( "len(strides) and len(features) must match.") assert num_layers >= 1, "Need to have at least one layer." if self.transpose: conv_module = nn.ConvTranspose else: conv_module = nn.Conv x = conv_module( name="conv_simple_0", features=self.features[0], kernel_size=self.kernel_size[0], strides=self.strides[0], use_bias=False if self.use_batch_norm else True, padding=self.padding)(inputs) for i in range(1, num_layers): if self.use_batch_norm: x = nn.BatchNorm( momentum=0.9, use_running_average=not train, axis_name=self.axis_name, name=f"bn_simple_{i-1}")(x) x = nn.relu(x) x = conv_module( name=f"conv_simple_{i}", features=self.features[i], kernel_size=self.kernel_size[i], strides=self.strides[i], use_bias=False if ( self.use_batch_norm and i < (num_layers-1)) else True, padding=self.padding)(x) if self.resize_output: x = jax.image.resize( x, list(x.shape[:-3]) + list(self.resize_output) + [x.shape[-1]], method=jax.image.ResizeMethod.LINEAR) return x class CNN(nn.Module): """Flexible CNN model with Conv/Normalization/Pooling layers.""" features: Sequence[int] kernel_size: Sequence[Tuple[int, int]] strides: Sequence[Tuple[int, int]] max_pool_strides: Sequence[Tuple[int, int]] layer_transpose: Sequence[bool] activation_fn: Callable[[Array], Array] = nn.relu norm_type: Optional[str] = None axis_name: Optional[str] = None # Over which axis to aggregate batch stats. output_size: Optional[int] = None @nn.compact def __call__(self, inputs, train = False): num_layers = len(self.features) assert num_layers >= 1, "Need to have at least one layer." assert len(self.kernel_size) == num_layers, ( "len(kernel_size) and len(features) must match.") assert len(self.strides) == num_layers, ( "len(strides) and len(features) must match.") assert len(self.max_pool_strides) == num_layers, ( "len(max_pool_strides) and len(features) must match.") assert len(self.layer_transpose) == num_layers, ( "len(layer_transpose) and len(features) must match.") if self.norm_type: assert self.norm_type in {"batch", "group", "instance", "layer"}, ( f"{self.norm_type} is unrecognizaed normalization") # Whether transpose conv or regular conv. conv_module = {False: nn.Conv, True: nn.ConvTranspose} if self.norm_type == "batch": norm_module = functools.partial( nn.BatchNorm, momentum=0.9, use_running_average=not train, axis_name=self.axis_name) elif self.norm_type == "group": norm_module = functools.partial( nn.GroupNorm, num_groups=32) elif self.norm_type == "layer": norm_module = nn.LayerNorm x = inputs for i in range(num_layers): x = conv_module[self.layer_transpose[i]]( name=f"conv_{i}", features=self.features[i], kernel_size=self.kernel_size[i], strides=self.strides[i], use_bias=False if self.norm_type else True)(x) # Normalization layer. if self.norm_type: if self.norm_type == "instance": x = nn.GroupNorm( num_groups=self.features[i], name=f"{self.norm_type}_norm_{i}")(x) else: norm_module(name=f"{self.norm_type}_norm_{i}")(x) # Activation layer. x = self.activation_fn(x) # Max pooling layer. x = x if self.max_pool_strides[i] == (1, 1) else nn.max_pool( x, self.max_pool_strides[i], strides=self.max_pool_strides[i], padding="SAME") # Final dense layer. if self.output_size: x = nn.Dense(self.output_size, name="output_layer", use_bias=True)(x) return x