File size: 4,538 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Definitions for MoViNet structures.

Reference: "MoViNets: Mobile Video Networks for Efficient Video Recognition"
https://arxiv.org/pdf/2103.11511.pdf

MoViNets are efficient video classification networks that are part of a model
family, ranging from the smallest model, MoViNet-A0, to the largest model,
MoViNet-A6. Each model has various width, depth, input resolution, and input
frame-rate associated with them. See the main paper for more details.
"""

import dataclasses

from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.configs import backbones_3d
from official.vision.configs import common
from official.vision.configs import video_classification


@dataclasses.dataclass
class Movinet(hyperparams.Config):
  """Backbone config for Base MoViNet."""
  model_id: str = 'a0'
  causal: bool = False
  use_positional_encoding: bool = False
  # Choose from ['3d', '2plus1d', '3d_2plus1d']
  # 3d: default 3D convolution
  # 2plus1d: (2+1)D convolution with Conv2D (2D reshaping)
  # 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping)
  conv_type: str = '3d'
  # Choose from ['3d', '2d', '2plus3d']
  # 3d: default 3D global average pooling.
  # 2d: 2D global average pooling.
  # 2plus3d: concatenation of 2D and 3D global average pooling.
  se_type: str = '3d'
  activation: str = 'swish'
  gating_activation: str = 'sigmoid'
  stochastic_depth_drop_rate: float = 0.2
  use_external_states: bool = False
  average_pooling_type: str = '3d'
  output_states: bool = True


@dataclasses.dataclass
class MovinetA0(Movinet):
  """Backbone config for MoViNet-A0.

  Represents the smallest base MoViNet searched by NAS.

  Reference: https://arxiv.org/pdf/2103.11511.pdf
  """
  model_id: str = 'a0'


@dataclasses.dataclass
class MovinetA1(Movinet):
  """Backbone config for MoViNet-A1."""
  model_id: str = 'a1'


@dataclasses.dataclass
class MovinetA2(Movinet):
  """Backbone config for MoViNet-A2."""
  model_id: str = 'a2'


@dataclasses.dataclass
class MovinetA3(Movinet):
  """Backbone config for MoViNet-A3."""
  model_id: str = 'a3'


@dataclasses.dataclass
class MovinetA4(Movinet):
  """Backbone config for MoViNet-A4."""
  model_id: str = 'a4'


@dataclasses.dataclass
class MovinetA5(Movinet):
  """Backbone config for MoViNet-A5.

  Represents the largest base MoViNet searched by NAS.
  """
  model_id: str = 'a5'


@dataclasses.dataclass
class MovinetT0(Movinet):
  """Backbone config for MoViNet-T0.

  MoViNet-T0 is a smaller version of MoViNet-A0 for even faster processing.
  """
  model_id: str = 't0'


@dataclasses.dataclass
class Backbone3D(backbones_3d.Backbone3D):
  """Configuration for backbones.

  Attributes:
    type: 'str', type of backbone be used, on the of fields below.
    movinet: movinet backbone config.
  """
  type: str = 'movinet'
  movinet: Movinet = dataclasses.field(default_factory=Movinet)


@dataclasses.dataclass
class MovinetModel(video_classification.VideoClassificationModel):
  """The MoViNet model config."""
  model_type: str = 'movinet'
  backbone: Backbone3D = dataclasses.field(default_factory=Backbone3D)
  norm_activation: common.NormActivation = dataclasses.field(
      default_factory=lambda: common.NormActivation(  # pylint: disable=g-long-lambda
          activation=None,  # legacy flag, not used.
          norm_momentum=0.99,
          norm_epsilon=1e-3,
          use_sync_bn=True,
      )
  )
  activation: str = 'swish'
  output_states: bool = False


@exp_factory.register_config_factory('movinet_kinetics600')
def movinet_kinetics600() -> cfg.ExperimentConfig:
  """Video classification on Videonet with MoViNet backbone."""
  exp = video_classification.video_classification_kinetics600()
  exp.task.train_data.dtype = 'bfloat16'
  exp.task.validation_data.dtype = 'bfloat16'

  model = MovinetModel()
  exp.task.model = model

  return exp