Spaces:
Running
Running
File size: 5,623 Bytes
a560c26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convolutional module library."""
import functools
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
from flax import linen as nn
import jax
Shape = Tuple[int]
DType = Any
Array = Any # jnp.ndarray
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
ProcessorState = ArrayTree
PRNGKey = Array
NestedDict = Dict[str, Any]
class SimpleCNN(nn.Module):
"""Simple CNN encoder with multiple Conv+ReLU layers."""
features: Sequence[int]
kernel_size: Sequence[Tuple[int, int]]
strides: Sequence[Tuple[int, int]]
transpose: bool = False
use_batch_norm: bool = False
axis_name: Optional[str] = None # Over which axis to aggregate batch stats.
padding: Union[str, Iterable[Tuple[int, int]]] = "SAME"
resize_output: Optional[Iterable[int]] = None
@nn.compact
def __call__(self, inputs, train = False):
num_layers = len(self.features)
assert len(self.kernel_size) == num_layers, (
"len(kernel_size) and len(features) must match.")
assert len(self.strides) == num_layers, (
"len(strides) and len(features) must match.")
assert num_layers >= 1, "Need to have at least one layer."
if self.transpose:
conv_module = nn.ConvTranspose
else:
conv_module = nn.Conv
x = conv_module(
name="conv_simple_0",
features=self.features[0],
kernel_size=self.kernel_size[0],
strides=self.strides[0],
use_bias=False if self.use_batch_norm else True,
padding=self.padding)(inputs)
for i in range(1, num_layers):
if self.use_batch_norm:
x = nn.BatchNorm(
momentum=0.9, use_running_average=not train,
axis_name=self.axis_name, name=f"bn_simple_{i-1}")(x)
x = nn.relu(x)
x = conv_module(
name=f"conv_simple_{i}",
features=self.features[i],
kernel_size=self.kernel_size[i],
strides=self.strides[i],
use_bias=False if (
self.use_batch_norm and i < (num_layers-1)) else True,
padding=self.padding)(x)
if self.resize_output:
x = jax.image.resize(
x, list(x.shape[:-3]) + list(self.resize_output) + [x.shape[-1]],
method=jax.image.ResizeMethod.LINEAR)
return x
class CNN(nn.Module):
"""Flexible CNN model with Conv/Normalization/Pooling layers."""
features: Sequence[int]
kernel_size: Sequence[Tuple[int, int]]
strides: Sequence[Tuple[int, int]]
max_pool_strides: Sequence[Tuple[int, int]]
layer_transpose: Sequence[bool]
activation_fn: Callable[[Array], Array] = nn.relu
norm_type: Optional[str] = None
axis_name: Optional[str] = None # Over which axis to aggregate batch stats.
output_size: Optional[int] = None
@nn.compact
def __call__(self, inputs, train = False):
num_layers = len(self.features)
assert num_layers >= 1, "Need to have at least one layer."
assert len(self.kernel_size) == num_layers, (
"len(kernel_size) and len(features) must match.")
assert len(self.strides) == num_layers, (
"len(strides) and len(features) must match.")
assert len(self.max_pool_strides) == num_layers, (
"len(max_pool_strides) and len(features) must match.")
assert len(self.layer_transpose) == num_layers, (
"len(layer_transpose) and len(features) must match.")
if self.norm_type:
assert self.norm_type in {"batch", "group", "instance", "layer"}, (
f"{self.norm_type} is unrecognizaed normalization")
# Whether transpose conv or regular conv.
conv_module = {False: nn.Conv, True: nn.ConvTranspose}
if self.norm_type == "batch":
norm_module = functools.partial(
nn.BatchNorm, momentum=0.9, use_running_average=not train,
axis_name=self.axis_name)
elif self.norm_type == "group":
norm_module = functools.partial(
nn.GroupNorm, num_groups=32)
elif self.norm_type == "layer":
norm_module = nn.LayerNorm
x = inputs
for i in range(num_layers):
x = conv_module[self.layer_transpose[i]](
name=f"conv_{i}",
features=self.features[i],
kernel_size=self.kernel_size[i],
strides=self.strides[i],
use_bias=False if self.norm_type else True)(x)
# Normalization layer.
if self.norm_type:
if self.norm_type == "instance":
x = nn.GroupNorm(
num_groups=self.features[i],
name=f"{self.norm_type}_norm_{i}")(x)
else:
norm_module(name=f"{self.norm_type}_norm_{i}")(x)
# Activation layer.
x = self.activation_fn(x)
# Max pooling layer.
x = x if self.max_pool_strides[i] == (1, 1) else nn.max_pool(
x, self.max_pool_strides[i], strides=self.max_pool_strides[i],
padding="SAME")
# Final dense layer.
if self.output_size:
x = nn.Dense(self.output_size, name="output_layer", use_bias=True)(x)
return x
|