Spaces:
Running
Running
# coding=utf-8 | |
# Copyright 2023 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Implementation of ResNet V1 in Flax. | |
"Deep Residual Learning for Image Recognition" | |
He et al., 2015, [https://arxiv.org/abs/1512.03385] | |
""" | |
import functools | |
from typing import Any, Tuple, Type, List, Optional, Callable, Sequence | |
import flax.linen as nn | |
import jax.numpy as jnp | |
Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1), use_bias=False) | |
Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3), use_bias=False) | |
class ResNetBlock(nn.Module): | |
"""ResNet block without bottleneck used in ResNet-18 and ResNet-34.""" | |
filters: int | |
norm: Any | |
kernel_dilation: Tuple[int, int] = (1, 1) | |
strides: Tuple[int, int] = (1, 1) | |
def __call__(self, x): | |
residual = x | |
x = Conv3x3( | |
self.filters, | |
strides=self.strides, | |
kernel_dilation=self.kernel_dilation, | |
name="conv1")(x) | |
x = self.norm(name="bn1")(x) | |
x = nn.relu(x) | |
x = Conv3x3(self.filters, name="conv2")(x) | |
# Initializing the scale to 0 has been common practice since "Fixup | |
# Initialization: Residual Learning Without Normalization" Tengyu et al, | |
# 2019, [https://openreview.net/forum?id=H1gsz30cKX]. | |
x = self.norm(scale_init=nn.initializers.zeros, name="bn2")(x) | |
if residual.shape != x.shape: | |
residual = Conv1x1( | |
self.filters, strides=self.strides, name="proj_conv")( | |
residual) | |
residual = self.norm(name="proj_bn")(residual) | |
x = nn.relu(residual + x) | |
return x | |
class BottleneckResNetBlock(ResNetBlock): | |
"""Bottleneck ResNet block used in ResNet-50 and larger.""" | |
def __call__(self, x): | |
residual = x | |
x = Conv1x1(self.filters, name="conv1")(x) | |
x = self.norm(name="bn1")(x) | |
x = nn.relu(x) | |
x = Conv3x3( | |
self.filters, | |
strides=self.strides, | |
kernel_dilation=self.kernel_dilation, | |
name="conv2")(x) | |
x = self.norm(name="bn2")(x) | |
x = nn.relu(x) | |
x = Conv1x1(4 * self.filters, name="conv3")(x) | |
# Initializing the scale to 0 has been common practice since "Fixup | |
# Initialization: Residual Learning Without Normalization" Tengyu et al, | |
# 2019, [https://openreview.net/forum?id=H1gsz30cKX]. | |
x = self.norm(name="bn3")(x) | |
if residual.shape != x.shape: | |
residual = Conv1x1( | |
4 * self.filters, strides=self.strides, name="proj_conv")( | |
residual) | |
residual = self.norm(name="proj_bn")(residual) | |
x = nn.relu(residual + x) | |
return x | |
class ResNetStage(nn.Module): | |
"""ResNet stage consistent of multiple ResNet blocks.""" | |
stage_size: int | |
filters: int | |
block_cls: Type[ResNetBlock] | |
norm: Any | |
first_block_strides: Tuple[int, int] | |
def __call__(self, x): | |
for i in range(self.stage_size): | |
x = self.block_cls( | |
filters=self.filters, | |
norm=self.norm, | |
strides=self.first_block_strides if i == 0 else (1, 1), | |
name=f"block{i + 1}")( | |
x) | |
return x | |
class ResNet(nn.Module): | |
"""Construct ResNet V1 with `num_classes` outputs. | |
Attributes: | |
num_classes: Number of nodes in the final layer. | |
block_cls: Class for the blocks. ResNet-50 and larger use | |
`BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and | |
ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions). | |
stage_sizes: List with the number of ResNet blocks in each stage. Number of | |
stages can be varied. | |
norm_type: Which type of normalization layer to apply. Options are: | |
"batch": BatchNorm, "group": GroupNorm, "layer": LayerNorm. Defaults to | |
BatchNorm. | |
width_factor: Factor applied to the number of filters. The 64 * width_factor | |
is the number of filters in the first stage, every consecutive stage | |
doubles the number of filters. | |
small_inputs: Bool, if True, ignore strides and skip max pooling in the root | |
block and use smaller filter size. | |
stage_strides: Stride per stage. This overrides all other arguments. | |
include_top: Whether to include the fully-connected layer at the top | |
of the network. | |
axis_name: Axis name over which to aggregate batchnorm statistics. | |
""" | |
num_classes: int | |
block_cls: Type[ResNetBlock] | |
stage_sizes: List[int] | |
norm_type: str = "batch" | |
width_factor: int = 1 | |
small_inputs: bool = False | |
stage_strides: Optional[List[Tuple[int, int]]] = None | |
include_top: bool = False | |
axis_name: Optional[str] = None | |
output_initializer: Callable[[Any, Sequence[int], Any], Any] = ( | |
nn.initializers.zeros) | |
def __call__(self, x, *, train): | |
"""Apply the ResNet to the inputs `x`. | |
Args: | |
x: Inputs. | |
train: Whether to use BatchNorm in training or inference mode. | |
Returns: | |
The output head with `num_classes` entries. | |
""" | |
width = 64 * self.width_factor | |
if self.norm_type == "batch": | |
norm = functools.partial( | |
nn.BatchNorm, use_running_average=not train, momentum=0.9, | |
axis_name=self.axis_name) | |
elif self.norm_type == "layer": | |
norm = nn.LayerNorm | |
elif self.norm_type == "group": | |
norm = nn.GroupNorm | |
else: | |
raise ValueError(f"Invalid norm_type: {self.norm_type}") | |
# Root block. | |
x = nn.Conv( | |
features=width, | |
kernel_size=(7, 7) if not self.small_inputs else (3, 3), | |
strides=(2, 2) if not self.small_inputs else (1, 1), | |
use_bias=False, | |
name="init_conv")( | |
x) | |
x = norm(name="init_bn")(x) | |
if not self.small_inputs: | |
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") | |
# Stages. | |
for i, stage_size in enumerate(self.stage_sizes): | |
if i == 0: | |
first_block_strides = ( | |
1, 1) if self.stage_strides is None else self.stage_strides[i] | |
else: | |
first_block_strides = ( | |
2, 2) if self.stage_strides is None else self.stage_strides[i] | |
x = ResNetStage( | |
stage_size, | |
filters=width * 2**i, | |
block_cls=self.block_cls, | |
norm=norm, | |
first_block_strides=first_block_strides, | |
name=f"stage{i + 1}")(x) | |
# Head. | |
if self.include_top: | |
x = jnp.mean(x, axis=(1, 2)) | |
x = nn.Dense( | |
self.num_classes, kernel_init=self.output_initializer, name="head")(x) | |
return x | |
ResNetWithBasicBlk = functools.partial(ResNet, block_cls=ResNetBlock) | |
ResNetWithBottleneckBlk = functools.partial(ResNet, | |
block_cls=BottleneckResNetBlock) | |
ResNet18 = functools.partial(ResNetWithBasicBlk, stage_sizes=[2, 2, 2, 2]) | |
ResNet34 = functools.partial(ResNetWithBasicBlk, stage_sizes=[3, 4, 6, 3]) | |
ResNet50 = functools.partial(ResNetWithBottleneckBlk, stage_sizes=[3, 4, 6, 3]) | |
ResNet101 = functools.partial(ResNetWithBottleneckBlk, | |
stage_sizes=[3, 4, 23, 3]) | |
ResNet152 = functools.partial(ResNetWithBottleneckBlk, | |
stage_sizes=[3, 8, 36, 3]) | |
ResNet200 = functools.partial(ResNetWithBottleneckBlk, | |
stage_sizes=[3, 24, 36, 3]) | |