Spaces:
Runtime error
Runtime error
File size: 4,538 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definitions for MoViNet structures.
Reference: "MoViNets: Mobile Video Networks for Efficient Video Recognition"
https://arxiv.org/pdf/2103.11511.pdf
MoViNets are efficient video classification networks that are part of a model
family, ranging from the smallest model, MoViNet-A0, to the largest model,
MoViNet-A6. Each model has various width, depth, input resolution, and input
frame-rate associated with them. See the main paper for more details.
"""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.configs import backbones_3d
from official.vision.configs import common
from official.vision.configs import video_classification
@dataclasses.dataclass
class Movinet(hyperparams.Config):
"""Backbone config for Base MoViNet."""
model_id: str = 'a0'
causal: bool = False
use_positional_encoding: bool = False
# Choose from ['3d', '2plus1d', '3d_2plus1d']
# 3d: default 3D convolution
# 2plus1d: (2+1)D convolution with Conv2D (2D reshaping)
# 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
conv_type: str = '3d'
# Choose from ['3d', '2d', '2plus3d']
# 3d: default 3D global average pooling.
# 2d: 2D global average pooling.
# 2plus3d: concatenation of 2D and 3D global average pooling.
se_type: str = '3d'
activation: str = 'swish'
gating_activation: str = 'sigmoid'
stochastic_depth_drop_rate: float = 0.2
use_external_states: bool = False
average_pooling_type: str = '3d'
output_states: bool = True
@dataclasses.dataclass
class MovinetA0(Movinet):
"""Backbone config for MoViNet-A0.
Represents the smallest base MoViNet searched by NAS.
Reference: https://arxiv.org/pdf/2103.11511.pdf
"""
model_id: str = 'a0'
@dataclasses.dataclass
class MovinetA1(Movinet):
"""Backbone config for MoViNet-A1."""
model_id: str = 'a1'
@dataclasses.dataclass
class MovinetA2(Movinet):
"""Backbone config for MoViNet-A2."""
model_id: str = 'a2'
@dataclasses.dataclass
class MovinetA3(Movinet):
"""Backbone config for MoViNet-A3."""
model_id: str = 'a3'
@dataclasses.dataclass
class MovinetA4(Movinet):
"""Backbone config for MoViNet-A4."""
model_id: str = 'a4'
@dataclasses.dataclass
class MovinetA5(Movinet):
"""Backbone config for MoViNet-A5.
Represents the largest base MoViNet searched by NAS.
"""
model_id: str = 'a5'
@dataclasses.dataclass
class MovinetT0(Movinet):
"""Backbone config for MoViNet-T0.
MoViNet-T0 is a smaller version of MoViNet-A0 for even faster processing.
"""
model_id: str = 't0'
@dataclasses.dataclass
class Backbone3D(backbones_3d.Backbone3D):
"""Configuration for backbones.
Attributes:
type: 'str', type of backbone be used, on the of fields below.
movinet: movinet backbone config.
"""
type: str = 'movinet'
movinet: Movinet = dataclasses.field(default_factory=Movinet)
@dataclasses.dataclass
class MovinetModel(video_classification.VideoClassificationModel):
"""The MoViNet model config."""
model_type: str = 'movinet'
backbone: Backbone3D = dataclasses.field(default_factory=Backbone3D)
norm_activation: common.NormActivation = dataclasses.field(
default_factory=lambda: common.NormActivation( # pylint: disable=g-long-lambda
activation=None, # legacy flag, not used.
norm_momentum=0.99,
norm_epsilon=1e-3,
use_sync_bn=True,
)
)
activation: str = 'swish'
output_states: bool = False
@exp_factory.register_config_factory('movinet_kinetics600')
def movinet_kinetics600() -> cfg.ExperimentConfig:
"""Video classification on Videonet with MoViNet backbone."""
exp = video_classification.video_classification_kinetics600()
exp.task.train_data.dtype = 'bfloat16'
exp.task.validation_data.dtype = 'bfloat16'
model = MovinetModel()
exp.task.model = model
return exp
|