deanna-emery's picture
updates
5672777
raw
history blame
4.54 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definitions for MoViNet structures.
Reference: "MoViNets: Mobile Video Networks for Efficient Video Recognition"
https://arxiv.org/pdf/2103.11511.pdf
MoViNets are efficient video classification networks that are part of a model
family, ranging from the smallest model, MoViNet-A0, to the largest model,
MoViNet-A6. Each model has various width, depth, input resolution, and input
frame-rate associated with them. See the main paper for more details.
"""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.configs import backbones_3d
from official.vision.configs import common
from official.vision.configs import video_classification
@dataclasses.dataclass
class Movinet(hyperparams.Config):
"""Backbone config for Base MoViNet."""
model_id: str = 'a0'
causal: bool = False
use_positional_encoding: bool = False
# Choose from ['3d', '2plus1d', '3d_2plus1d']
# 3d: default 3D convolution
# 2plus1d: (2+1)D convolution with Conv2D (2D reshaping)
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
conv_type: str = '3d'
# Choose from ['3d', '2d', '2plus3d']
# 3d: default 3D global average pooling.
# 2d: 2D global average pooling.
# 2plus3d: concatenation of 2D and 3D global average pooling.
se_type: str = '3d'
activation: str = 'swish'
gating_activation: str = 'sigmoid'
stochastic_depth_drop_rate: float = 0.2
use_external_states: bool = False
average_pooling_type: str = '3d'
output_states: bool = True
@dataclasses.dataclass
class MovinetA0(Movinet):
"""Backbone config for MoViNet-A0.
Represents the smallest base MoViNet searched by NAS.
Reference: https://arxiv.org/pdf/2103.11511.pdf
"""
model_id: str = 'a0'
@dataclasses.dataclass
class MovinetA1(Movinet):
"""Backbone config for MoViNet-A1."""
model_id: str = 'a1'
@dataclasses.dataclass
class MovinetA2(Movinet):
"""Backbone config for MoViNet-A2."""
model_id: str = 'a2'
@dataclasses.dataclass
class MovinetA3(Movinet):
"""Backbone config for MoViNet-A3."""
model_id: str = 'a3'
@dataclasses.dataclass
class MovinetA4(Movinet):
"""Backbone config for MoViNet-A4."""
model_id: str = 'a4'
@dataclasses.dataclass
class MovinetA5(Movinet):
"""Backbone config for MoViNet-A5.
Represents the largest base MoViNet searched by NAS.
"""
model_id: str = 'a5'
@dataclasses.dataclass
class MovinetT0(Movinet):
"""Backbone config for MoViNet-T0.
MoViNet-T0 is a smaller version of MoViNet-A0 for even faster processing.
"""
model_id: str = 't0'
@dataclasses.dataclass
class Backbone3D(backbones_3d.Backbone3D):
"""Configuration for backbones.
Attributes:
type: 'str', type of backbone be used, on the of fields below.
movinet: movinet backbone config.
"""
type: str = 'movinet'
movinet: Movinet = dataclasses.field(default_factory=Movinet)
@dataclasses.dataclass
class MovinetModel(video_classification.VideoClassificationModel):
"""The MoViNet model config."""
model_type: str = 'movinet'
backbone: Backbone3D = dataclasses.field(default_factory=Backbone3D)
norm_activation: common.NormActivation = dataclasses.field(
default_factory=lambda: common.NormActivation( # pylint: disable=g-long-lambda
activation=None, # legacy flag, not used.
norm_momentum=0.99,
norm_epsilon=1e-3,
use_sync_bn=True,
)
)
activation: str = 'swish'
output_states: bool = False
@exp_factory.register_config_factory('movinet_kinetics600')
def movinet_kinetics600() -> cfg.ExperimentConfig:
"""Video classification on Videonet with MoViNet backbone."""
exp = video_classification.video_classification_kinetics600()
exp.task.train_data.dtype = 'bfloat16'
exp.task.validation_data.dtype = 'bfloat16'
model = MovinetModel()
exp.task.model = model
return exp