Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Definitions for MoViNet structures. | |
Reference: "MoViNets: Mobile Video Networks for Efficient Video Recognition" | |
https://arxiv.org/pdf/2103.11511.pdf | |
MoViNets are efficient video classification networks that are part of a model | |
family, ranging from the smallest model, MoViNet-A0, to the largest model, | |
MoViNet-A6. Each model has various width, depth, input resolution, and input | |
frame-rate associated with them. See the main paper for more details. | |
""" | |
import dataclasses | |
from official.core import config_definitions as cfg | |
from official.core import exp_factory | |
from official.modeling import hyperparams | |
from official.vision.configs import backbones_3d | |
from official.vision.configs import common | |
from official.vision.configs import video_classification | |
class Movinet(hyperparams.Config): | |
"""Backbone config for Base MoViNet.""" | |
model_id: str = 'a0' | |
causal: bool = False | |
use_positional_encoding: bool = False | |
# Choose from ['3d', '2plus1d', '3d_2plus1d'] | |
# 3d: default 3D convolution | |
# 2plus1d: (2+1)D convolution with Conv2D (2D reshaping) | |
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping) | |
conv_type: str = '3d' | |
# Choose from ['3d', '2d', '2plus3d'] | |
# 3d: default 3D global average pooling. | |
# 2d: 2D global average pooling. | |
# 2plus3d: concatenation of 2D and 3D global average pooling. | |
se_type: str = '3d' | |
activation: str = 'swish' | |
gating_activation: str = 'sigmoid' | |
stochastic_depth_drop_rate: float = 0.2 | |
use_external_states: bool = False | |
average_pooling_type: str = '3d' | |
output_states: bool = True | |
class MovinetA0(Movinet): | |
"""Backbone config for MoViNet-A0. | |
Represents the smallest base MoViNet searched by NAS. | |
Reference: https://arxiv.org/pdf/2103.11511.pdf | |
""" | |
model_id: str = 'a0' | |
class MovinetA1(Movinet): | |
"""Backbone config for MoViNet-A1.""" | |
model_id: str = 'a1' | |
class MovinetA2(Movinet): | |
"""Backbone config for MoViNet-A2.""" | |
model_id: str = 'a2' | |
class MovinetA3(Movinet): | |
"""Backbone config for MoViNet-A3.""" | |
model_id: str = 'a3' | |
class MovinetA4(Movinet): | |
"""Backbone config for MoViNet-A4.""" | |
model_id: str = 'a4' | |
class MovinetA5(Movinet): | |
"""Backbone config for MoViNet-A5. | |
Represents the largest base MoViNet searched by NAS. | |
""" | |
model_id: str = 'a5' | |
class MovinetT0(Movinet): | |
"""Backbone config for MoViNet-T0. | |
MoViNet-T0 is a smaller version of MoViNet-A0 for even faster processing. | |
""" | |
model_id: str = 't0' | |
class Backbone3D(backbones_3d.Backbone3D): | |
"""Configuration for backbones. | |
Attributes: | |
type: 'str', type of backbone be used, on the of fields below. | |
movinet: movinet backbone config. | |
""" | |
type: str = 'movinet' | |
movinet: Movinet = dataclasses.field(default_factory=Movinet) | |
class MovinetModel(video_classification.VideoClassificationModel): | |
"""The MoViNet model config.""" | |
model_type: str = 'movinet' | |
backbone: Backbone3D = dataclasses.field(default_factory=Backbone3D) | |
norm_activation: common.NormActivation = dataclasses.field( | |
default_factory=lambda: common.NormActivation( # pylint: disable=g-long-lambda | |
activation=None, # legacy flag, not used. | |
norm_momentum=0.99, | |
norm_epsilon=1e-3, | |
use_sync_bn=True, | |
) | |
) | |
activation: str = 'swish' | |
output_states: bool = False | |
def movinet_kinetics600() -> cfg.ExperimentConfig: | |
"""Video classification on Videonet with MoViNet backbone.""" | |
exp = video_classification.video_classification_kinetics600() | |
exp.task.train_data.dtype = 'bfloat16' | |
exp.task.validation_data.dtype = 'bfloat16' | |
model = MovinetModel() | |
exp.task.model = model | |
return exp | |