anugunj commited on
Commit
21d5ffb
1 Parent(s): e7055a5
Files changed (3) hide show
  1. __init__.py +74 -0
  2. configuration_omnivore.py +128 -0
  3. modelling.py +1130 -0
__init__.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # There's no way to ignore "F401 '...' imported but unused" warnings in this
3
+ # module, but to preserve other warnings. So, don't check this module at all.
4
+
5
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ from typing import TYPE_CHECKING
19
+
20
+ # rely on isort to merge the imports
21
+ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
22
+
23
+
24
+ _import_structure = {"configuration_omnivore": ["OMNIVORE_PRETRAINED_CONFIG_ARCHIVE_MAP", "OmnivoreConfig"]}
25
+
26
+ try:
27
+ if not is_vision_available():
28
+ raise OptionalDependencyNotAvailable()
29
+ except OptionalDependencyNotAvailable:
30
+ pass
31
+ else:
32
+ _import_structure["feature_extraction_omnivore"] = ["OmnivoreFeatureExtractor"]
33
+
34
+ try:
35
+ if not is_torch_available():
36
+ raise OptionalDependencyNotAvailable()
37
+ except OptionalDependencyNotAvailable:
38
+ pass
39
+ else:
40
+ _import_structure["modeling_omnivore"] = [
41
+ "OMNIVORE_PRETRAINED_MODEL_ARCHIVE_LIST",
42
+ "OmnivoreForJointClassification",
43
+ "OmnivoreModel",
44
+ "OmnivorePreTrainedModel",
45
+ ]
46
+
47
+ if TYPE_CHECKING:
48
+ from .configuration_omnivore import OMNIVORE_PRETRAINED_CONFIG_ARCHIVE_MAP, OmnivoreConfig
49
+
50
+ try:
51
+ if not is_vision_available():
52
+ raise OptionalDependencyNotAvailable()
53
+ except OptionalDependencyNotAvailable:
54
+ pass
55
+ else:
56
+ from .feature_extraction_omnivore import OmnivoreFeatureExtractor
57
+
58
+ try:
59
+ if not is_torch_available():
60
+ raise OptionalDependencyNotAvailable()
61
+ except OptionalDependencyNotAvailable:
62
+ pass
63
+ else:
64
+ from .modeling_omnivore import (
65
+ OMNIVORE_PRETRAINED_MODEL_ARCHIVE_LIST,
66
+ OmnivoreForClassification,
67
+ OmnivoreModel,
68
+ OmnivorePreTrainedModel,
69
+ )
70
+
71
+ else:
72
+ import sys
73
+
74
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
configuration_omnivore.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Omnivore model configuration"""
16
+
17
+ from torch import nn
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ OMNIVORE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "anugunj/omnivore": "https://huggingface.co/anugunj/omnivore/resolve/main/config.json",
27
+ }
28
+
29
+
30
+ class OmnivoreConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`OmnivoreModel`]. It is used to instantiate an
33
+ Omnivore model according to the specified arguments, defining the model architecture. Instantiating a configuration
34
+ with the defaults will yield a similar configuration to that of the Omnivore
35
+ [anugunj/omnivore](https://huggingface.co/anugunj/omnivore) architecture.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+ Args:
41
+ input_channels (`int`, *optional*, defaults to 3):
42
+ The number of input channels.
43
+ patch_size (`int` | `List[int]`, *optional*, defaults to [4, 4, 4]):
44
+ Patch size to use in the patch embedding layer.
45
+ embed_dim (`int`, *optional*, defaults to 96):
46
+ Number of linear projection output channels.
47
+ depths (`List[int]`, *optional*, defaults to [2, 2, 6, 2],):
48
+ Depth (number of layers) for each stage.
49
+ num_heads (`List[int]`, *optional*, defaults to [3, 6, 12, 24]):
50
+ Number of attention head of each stage.
51
+ window_size (`int`, *optional*, defaults to 7)
52
+ Size of the window used by swin transformer in the model,
53
+ mlp_ratios (`float`, *optional*, defaults to 4.0):
54
+ Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
55
+ encoder blocks.
56
+ attention_dropout_rate (`float`, *optional*, defaults to 0.0):
57
+ The dropout ratio for the attention probabilities.
58
+ dropout_rate (`float`, *optional*, defaults to 0.0):
59
+ The dropout ratio for the patch embeddings probabilities and projections in attention.
60
+ drop_path_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.1]`):
61
+ The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.
62
+ qkv_bias (`bool`, *optional*, defaults to True):
63
+ The bias bool for query, key and value in attentions
64
+ qk_scale (`bool`, *optional*, defaults to None):
65
+ Override default qk scale of head_dim ** -0.5 if set.
66
+ norm_layer (`nn.Module`, *optional*, defaults to nn.LayerNorm):
67
+ Normalization layer for the model
68
+ patch_norm (`bool`, *optional*, defaults to False):
69
+ If True, add normalization after patch embedding.
70
+ frozen_stages (`int`, *optional*, defaults to -1):
71
+ Stages to be frozen (stop grad and set eval mode) -1 means not freezing any parameters.
72
+ initializer_range (`float`, *optional*, defaults to 0.02):
73
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
74
+
75
+ Example:
76
+ ```python
77
+ >>> from transformers import OmnivoreModel, OmnivoreConfig
78
+
79
+ >>> # Initializing a Omnivore omnivore-tiny-224 style configuration
80
+ >>> configuration = OmnivoreConfig()
81
+ >>> # Initializing a model from the omnivore-tiny-224 style configuration
82
+ >>> model = OmnivoreModel(configuration)
83
+ >>> # Accessing the model configuration
84
+ >>> configuration = model.config
85
+ ```"""
86
+ model_type = "omnivore"
87
+
88
+ def __init__(
89
+ self,
90
+ input_channels=3,
91
+ patch_size=[2, 4, 4],
92
+ embed_dim=96,
93
+ depths=[2, 2, 18, 2],
94
+ num_heads=[3, 6, 12, 24],
95
+ window_size=(8, 7, 7),
96
+ mlp_ratio=4.0,
97
+ qkv_bias=True,
98
+ qk_scale=None,
99
+ dropout_rate=0.0,
100
+ attention_dropout_rate=0.0,
101
+ drop_path_rate=0.3,
102
+ patch_norm=True,
103
+ frozen_stages=-1,
104
+ depth_mode="summed_rgb_d_tokens",
105
+ initializer_range=0.02,
106
+ **kwargs
107
+ ):
108
+ super().__init__(**kwargs)
109
+ self.input_channels = input_channels
110
+ self.patch_size = patch_size
111
+ self.embed_dim = embed_dim
112
+ self.depths = depths
113
+ self.num_heads = num_heads
114
+ self.window_size = window_size
115
+ self.mlp_ratio = mlp_ratio
116
+ self.qkv_bias = qkv_bias
117
+ self.qk_scale = qk_scale
118
+ self.dropout_rate = dropout_rate
119
+ self.attention_dropout_rate = attention_dropout_rate
120
+ self.drop_path_rate = drop_path_rate
121
+ self.patch_norm = patch_norm
122
+ self.frozen_stages = frozen_stages
123
+ self.initializer_range = initializer_range
124
+ self.head_dim_in = embed_dim * 8
125
+ self.depth_mode = depth_mode
126
+ self.num_image_labels = 1000
127
+ self.num_video_labels = 400
128
+ self.num_rgbd_labels = 19
modelling.py ADDED
@@ -0,0 +1,1130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Omnivore model."""
16
+
17
+ import math
18
+ import warnings
19
+ from dataclasses import dataclass
20
+ from functools import lru_cache, reduce
21
+ from operator import mul
22
+ from typing import Optional, Tuple
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ import torch.utils.checkpoint as checkpoint
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+ from torch.nn import functional as F
31
+
32
+ from transformers.utils.generic import ModelOutput
33
+
34
+ from ...activations import ACT2FN
35
+ from ...modeling_outputs import BaseModelOutputWithNoAttention, BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention
36
+ from ...modeling_utils import PreTrainedModel
37
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
38
+ from .configuration_omnivore import OmnivoreConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ # General docstring
44
+ _CONFIG_FOR_DOC = "OmnivoreConfig"
45
+ _FEAT_EXTRACTOR_FOR_DOC = "OmniverseFeatureExtractor"
46
+
47
+ # Base docstring
48
+ _CHECKPOINT_FOR_DOC = "anugunj/omnivore"
49
+ _EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7]
50
+
51
+ # Image classification docstring
52
+ _IMAGE_CLASS_CHECKPOINT = "anugunj/omnivore"
53
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
54
+
55
+ OMNIVORE_PRETRAINED_MODEL_ARCHIVE_LIST = [
56
+ "anugunj/omnivore",
57
+ # See all Omnivore models at https://huggingface.co/models?filter=omnivore
58
+ ]
59
+
60
+
61
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
62
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
63
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
64
+ def norm_cdf(x):
65
+ # Computes standard normal cumulative distribution function
66
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
67
+
68
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
69
+ warnings.warn(
70
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
71
+ "The distribution of values may be incorrect.",
72
+ stacklevel=2,
73
+ )
74
+
75
+ with torch.no_grad():
76
+ # Values are generated by using a truncated uniform distribution and
77
+ # then using the inverse CDF for the normal distribution.
78
+ # Get upper and lower cdf values
79
+ l = norm_cdf((a - mean) / std)
80
+ u = norm_cdf((b - mean) / std)
81
+
82
+ # Uniformly fill tensor with values from [l, u], then translate to
83
+ # [2l-1, 2u-1].
84
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
85
+
86
+ # Use inverse cdf transform for normal distribution to get truncated
87
+ # standard normal
88
+ tensor.erfinv_()
89
+
90
+ # Transform to proper mean, std
91
+ tensor.mul_(std * math.sqrt(2.0))
92
+ tensor.add_(mean)
93
+
94
+ # Clamp to ensure it's in the proper range
95
+ tensor.clamp_(min=a, max=b)
96
+ return tensor
97
+
98
+
99
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
100
+ r"""Fills the input Tensor with values drawn from a truncated
101
+ Args:
102
+ normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean},
103
+ \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for
104
+ generating the random values works best when :math:`a \leq \text{mean} \leq b`.
105
+ tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation
106
+ of the normal distribution a: the minimum cutoff value b: the maximum cutoff value
107
+ Examples:
108
+ >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w)
109
+ """
110
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
111
+
112
+
113
+ # Stochastic depth implementation
114
+ # Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
115
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
116
+ """
117
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
118
+ DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
119
+ Connect' is a different form of dropout in a separate paper... See discussion:
120
+ https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
121
+ argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
122
+ """
123
+ if drop_prob == 0.0 or not training:
124
+ return x
125
+ keep_prob = 1 - drop_prob
126
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
127
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
128
+ random_tensor.floor_() # binarize
129
+ output = x.div(keep_prob) * random_tensor
130
+ return output
131
+
132
+
133
+ class OmnivoreDropPath(nn.Module):
134
+ def __init__(self, drop_prob=None):
135
+ super().__init__()
136
+ self.drop_prob = drop_prob
137
+
138
+ def forward(self, x: torch.Tensor):
139
+ return drop_path(x, self.drop_prob, self.training)
140
+
141
+
142
+ class OmnivoreLayerNorm(nn.Module):
143
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
144
+ super().__init__()
145
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
146
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
147
+ self.eps = eps
148
+ self.data_format = data_format
149
+ if self.data_format not in ["channels_last", "channels_first"]:
150
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
151
+ self.normalized_shape = (normalized_shape,)
152
+
153
+ def forward(self, x: torch.Tensor):
154
+ if self.data_format == "channels_last":
155
+ x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
156
+ elif self.data_format == "channels_first":
157
+ u = x.mean(1, keepdim=True)
158
+ s = (x - u).pow(2).mean(1, keepdim=True)
159
+ x = (x - u) / torch.sqrt(s + self.eps)
160
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
161
+ return x
162
+
163
+
164
+ class OmnivoreIm2Video(nn.Module):
165
+ """Convert Image into a trivial video"""
166
+
167
+ def forward(self, pixel_values):
168
+ if pixel_values.ndim == 4:
169
+ return pixel_values.unsqueeze(2)
170
+ elif pixel_values.ndim == 5:
171
+ return pixel_values
172
+ else:
173
+ raise ValueError(f"Dimension incorrect {pixel_values.shape}")
174
+
175
+
176
+ class OmnivoreMLP(nn.Module):
177
+ def __init__(self, in_features, hidden_features=None, out_features=None, dropout_rate=0.0, act_layer=nn.GELU):
178
+ super().__init__()
179
+ out_features = out_features or in_features
180
+ hidden_features = hidden_features or in_features
181
+ self.linear1 = nn.Linear(in_features, hidden_features)
182
+ self.activation = act_layer()
183
+ self.linear2 = nn.Linear(hidden_features, out_features)
184
+ self.drop_out = nn.Dropout(dropout_rate)
185
+
186
+ def forward(self, hidden_state):
187
+ hidden_state = self.linear1(hidden_state)
188
+ hidden_state = self.activation(hidden_state)
189
+ hidden_state = self.drop_out(hidden_state)
190
+ hidden_state = self.linear2(hidden_state)
191
+ hidden_state = self.drop_out(hidden_state)
192
+ return hidden_state
193
+
194
+
195
+ def window_partition(input_feature, window_size):
196
+ batch_size, D, height, width, channels = input_feature.shape
197
+ input_feature = input_feature.view(
198
+ batch_size,
199
+ D // window_size[0],
200
+ window_size[0],
201
+ height // window_size[1],
202
+ window_size[1],
203
+ width // window_size[2],
204
+ window_size[2],
205
+ channels,
206
+ )
207
+ windows = input_feature.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), channels)
208
+ return windows
209
+
210
+
211
+ def window_partition_image(input_feature, window_size):
212
+ batch_size, height, width, channels = input_feature.shape
213
+ input_feature = input_feature.view(
214
+ batch_size, height // window_size[1], window_size[1], width // window_size[2], window_size[2], channels
215
+ )
216
+ windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[1], window_size[2], channels)
217
+ return windows
218
+
219
+
220
+ def window_reverse(windows, windows_size, batch_size, D, height, width):
221
+ input_feature = windows.view(
222
+ batch_size,
223
+ D // windows_size[0],
224
+ height // windows_size[1],
225
+ width // windows_size[2],
226
+ windows_size[0],
227
+ windows_size[1],
228
+ windows_size[2],
229
+ -1,
230
+ )
231
+ input_feature = input_feature.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(batch_size, D, height, width, -1)
232
+ return input_feature
233
+
234
+
235
+ def get_window_size(input_size, window_size, shift_size=None):
236
+ use_window_size = list(window_size)
237
+ if shift_size is not None:
238
+ use_shift_size = list(shift_size)
239
+ for i in range(len(input_size)):
240
+ if input_size[i] <= window_size[i]:
241
+ use_window_size[i] = input_size[i]
242
+ if shift_size is not None:
243
+ use_shift_size[i] = 0
244
+
245
+ if shift_size is None:
246
+ return tuple(use_window_size)
247
+ else:
248
+ return tuple(use_window_size), tuple(use_shift_size)
249
+
250
+
251
+ class OmnivoreWindowAttention3D(nn.Module):
252
+ def __init__(
253
+ self,
254
+ dim,
255
+ window_size,
256
+ num_heads,
257
+ qkv_bias=False,
258
+ qk_scale=None,
259
+ attention_dropout_rate=0.0,
260
+ projection_dropout_rate=0.0,
261
+ ):
262
+
263
+ super().__init__()
264
+ self.dim = dim
265
+ self.window_size = window_size
266
+ self.num_heads = num_heads
267
+ head_dim = dim // num_heads
268
+ self.scale = qk_scale or head_dim**-0.5
269
+
270
+ # define a parameter table of relative position bias
271
+ self.relative_position_bias_table = nn.Parameter(
272
+ torch.zeros(
273
+ (2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1),
274
+ num_heads,
275
+ )
276
+ )
277
+
278
+ # get pair-wise relative position index for each token inside the window
279
+ coords_d = torch.arange(self.window_size[0])
280
+ coords_h = torch.arange(self.window_size[1])
281
+ coords_w = torch.arange(self.window_size[2])
282
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
283
+ coords_flatten = torch.flatten(coords, 1)
284
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
285
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
286
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
287
+ relative_coords[:, :, 1] += self.window_size[1] - 1
288
+ relative_coords[:, :, 2] += self.window_size[2] - 1
289
+
290
+ relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
291
+ relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
292
+ relative_position_index = relative_coords.sum(-1)
293
+ self.register_buffer("relative_position_index", relative_position_index)
294
+
295
+ self.queries_keys_values = nn.Linear(dim, dim * 3, bias=qkv_bias)
296
+ self.attention_dropout = nn.Dropout(attention_dropout_rate)
297
+ self.projection = nn.Linear(dim, dim)
298
+ self.projection_dropout = nn.Dropout(projection_dropout_rate)
299
+
300
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
301
+ self.softmax = nn.Softmax(dim=-1)
302
+
303
+ def forward(self, hidden_state, attention_mask=None):
304
+ batch_size, seq_len, channels = hidden_state.shape
305
+ queries_keys_values = (
306
+ self.queries_keys_values(hidden_state)
307
+ .reshape(batch_size, seq_len, 3, self.num_heads, channels // self.num_heads)
308
+ .permute(2, 0, 3, 1, 4)
309
+ )
310
+ queries, keys, values = queries_keys_values[0], queries_keys_values[1], queries_keys_values[2]
311
+
312
+ queries = queries * self.scale
313
+ attention = queries @ keys.transpose(-2, -1)
314
+
315
+ relative_position_bias = self.relative_position_bias_table[
316
+ self.relative_position_index[:seq_len, :seq_len].reshape(-1)
317
+ ].reshape(seq_len, seq_len, -1)
318
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
319
+ attention = attention + relative_position_bias.unsqueeze(0)
320
+
321
+ if attention_mask is not None:
322
+ nW = attention_mask.shape[0]
323
+ attention = attention.view(
324
+ batch_size // nW, nW, self.num_heads, seq_len, seq_len
325
+ ) + attention_mask.unsqueeze(1).unsqueeze(0)
326
+ attention = attention.view(-1, self.num_heads, seq_len, seq_len)
327
+ attention = self.softmax(attention)
328
+ else:
329
+ attention = self.softmax(attention)
330
+
331
+ attention = self.attention_dropout(attention)
332
+
333
+ hidden_state = (attention @ values).transpose(1, 2).reshape(batch_size, seq_len, channels)
334
+ hidden_state = self.projection(hidden_state)
335
+ hidden_state = self.projection_dropout(hidden_state)
336
+ return hidden_state
337
+
338
+
339
+ class OmnivoreSwinTransformer3DLayer(nn.Module):
340
+ def __init__(
341
+ self,
342
+ dim,
343
+ num_heads,
344
+ window_size=(2, 7, 7),
345
+ shift_size=(0, 0, 0),
346
+ mlp_ratio=4.0,
347
+ qkv_bias=True,
348
+ qk_scale=None,
349
+ dropout_rate=0.0,
350
+ attention_dropout_rate=0.0,
351
+ drop_path_rate=0.0,
352
+ act_layer=nn.GELU,
353
+ norm_layer=nn.LayerNorm,
354
+ ):
355
+ super().__init__()
356
+ self.dim = dim
357
+ self.num_heads = num_heads
358
+ self.window_size = window_size
359
+ self.shift_size = shift_size
360
+ self.mlp_ratio = mlp_ratio
361
+
362
+ assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size"
363
+ assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size"
364
+ assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size"
365
+
366
+ self.norm1 = norm_layer(dim)
367
+ self.attention = OmnivoreWindowAttention3D(
368
+ dim,
369
+ window_size=self.window_size,
370
+ num_heads=num_heads,
371
+ qkv_bias=qkv_bias,
372
+ qk_scale=qk_scale,
373
+ attention_dropout_rate=attention_dropout_rate,
374
+ projection_dropout_rate=dropout_rate,
375
+ )
376
+
377
+ self.drop_path = OmnivoreDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
378
+ self.norm2 = norm_layer(dim)
379
+ mlp_hidden_dim = int(dim * mlp_ratio)
380
+ self.mlp = OmnivoreMLP(
381
+ in_features=dim, hidden_features=mlp_hidden_dim, dropout_rate=dropout_rate, act_layer=act_layer
382
+ )
383
+
384
+ def forward_before(self, hidden_state, attention_mask):
385
+ batch_size, D, height, width, channels = hidden_state.shape
386
+ window_size, shift_size = get_window_size((D, height, width), self.window_size, self.shift_size)
387
+
388
+ hidden_state = self.norm1(hidden_state)
389
+ # pad feature maps to multiples of window size
390
+ pad_l = pad_t = pad_d0 = 0
391
+ pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]
392
+ pad_b = (window_size[1] - height % window_size[1]) % window_size[1]
393
+ pad_r = (window_size[2] - width % window_size[2]) % window_size[2]
394
+ hidden_state = F.pad(hidden_state, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
395
+ _, Dp, Hp, Wp, _ = hidden_state.shape
396
+ # cyclic shift
397
+ if any(i > 0 for i in shift_size):
398
+ shifted_hidden_state = torch.roll(
399
+ hidden_state, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)
400
+ )
401
+ attention_mask = attention_mask
402
+ else:
403
+ shifted_hidden_state = hidden_state
404
+ attention_mask = None
405
+ # partition windows
406
+ hidden_state_windows = window_partition(shifted_hidden_state, window_size)
407
+ # W-MSA/SW-MSA
408
+ attention_windows = self.attention(hidden_state_windows, attention_mask=attention_mask)
409
+ # merge windows
410
+ attention_windows = attention_windows.view(-1, *(window_size + (channels,)))
411
+ shifted_hidden_state = window_reverse(attention_windows, window_size, batch_size, Dp, Hp, Wp)
412
+ # reverse cyclic shift
413
+ if any(i > 0 for i in shift_size):
414
+ hidden_state = torch.roll(
415
+ shifted_hidden_state, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)
416
+ )
417
+ else:
418
+ hidden_state = shifted_hidden_state
419
+
420
+ if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
421
+ hidden_state = hidden_state[:, :D, :height, :width, :].contiguous()
422
+ return hidden_state
423
+
424
+ def forward_after(self, hidden_state):
425
+ hidden_state = self.norm2(hidden_state)
426
+ hidden_state = self.mlp(hidden_state)
427
+ hidden_state = self.drop_path(hidden_state)
428
+ return hidden_state
429
+
430
+ def forward(self, hidden_state, mask_matrix, use_checkpoint=False):
431
+ shortcut = hidden_state
432
+ if use_checkpoint:
433
+ hidden_state = checkpoint.checkpoint(self.forward_before, hidden_state, mask_matrix)
434
+ else:
435
+ hidden_state = self.forward_before(hidden_state, mask_matrix)
436
+ hidden_state = shortcut + self.drop_path(hidden_state)
437
+
438
+ if use_checkpoint:
439
+ hidden_state = hidden_state + checkpoint.checkpoint(self.forward_after, hidden_state)
440
+ else:
441
+ hidden_state = hidden_state + self.forward_after(hidden_state)
442
+
443
+ return hidden_state
444
+
445
+
446
+ class OmnivorePatchMerging(nn.Module):
447
+ """
448
+ Args:
449
+ Patch Merging Layer
450
+ dim (`int`): Number of input channels. norm_layer (`nn.Module`, *optional*): Normalization layer. Default:
451
+ `nn.LayerNorm`
452
+ """
453
+
454
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
455
+ super().__init__()
456
+ self.dim = dim
457
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
458
+ self.norm = norm_layer(4 * dim)
459
+
460
+ def forward(self, hidden_state, height=None, width=None):
461
+ if height is None:
462
+ batch_size, D, height, width, channels = hidden_state.shape
463
+
464
+ # padding
465
+ pad_input = (height % 2 == 1) or (width % 2 == 1)
466
+ if pad_input:
467
+ hidden_state = F.pad(hidden_state, (0, 0, 0, width % 2, 0, height % 2))
468
+
469
+ hidden_state0 = hidden_state[:, :, 0::2, 0::2, :]
470
+ hidden_state1 = hidden_state[:, :, 1::2, 0::2, :]
471
+ hidden_state2 = hidden_state[:, :, 0::2, 1::2, :]
472
+ hidden_state3 = hidden_state[:, :, 1::2, 1::2, :]
473
+ hidden_state = torch.cat([hidden_state0, hidden_state1, hidden_state2, hidden_state3], -1)
474
+
475
+ hidden_state = self.norm(hidden_state)
476
+ hidden_state = self.reduction(hidden_state)
477
+
478
+ return hidden_state
479
+
480
+
481
+ @lru_cache()
482
+ def compute_mask(D, height, width, window_size, shift_size, device):
483
+ img_mask = torch.zeros((1, D, height, width, 1), device=device) # 1 Dp Hp Wp 1
484
+ cnt = 0
485
+ for d in (
486
+ slice(-window_size[0]),
487
+ slice(-window_size[0], -shift_size[0]),
488
+ slice(-shift_size[0], None),
489
+ ):
490
+ for h in (
491
+ slice(-window_size[1]),
492
+ slice(-window_size[1], -shift_size[1]),
493
+ slice(-shift_size[1], None),
494
+ ):
495
+ for w in (
496
+ slice(-window_size[2]),
497
+ slice(-window_size[2], -shift_size[2]),
498
+ slice(-shift_size[2], None),
499
+ ):
500
+ img_mask[:, d, h, w, :] = cnt
501
+ cnt += 1
502
+ mask_windows = window_partition(img_mask, window_size)
503
+ mask_windows = mask_windows.squeeze(-1)
504
+ attention_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
505
+ attention_mask = attention_mask.masked_fill(attention_mask != 0, float(-100.0)).masked_fill(
506
+ attention_mask == 0, float(0.0)
507
+ )
508
+ return attention_mask
509
+
510
+
511
+ class OmnivoreSwinTransformerStage(nn.Module):
512
+ def __init__(
513
+ self,
514
+ dim,
515
+ depth,
516
+ num_heads,
517
+ window_size=(1, 7, 7),
518
+ mlp_ratio=4.0,
519
+ qkv_bias=False,
520
+ qk_scale=None,
521
+ dropout_rate=0.0,
522
+ attention_dropout_rate=0.0,
523
+ drop_path_rate=0.0,
524
+ norm_layer=nn.LayerNorm,
525
+ downsample=None,
526
+ ):
527
+ super().__init__()
528
+ self.window_size = window_size
529
+ self.shift_size = tuple(i // 2 for i in window_size)
530
+ self.depth = depth
531
+
532
+ # build layers
533
+ self.layers = nn.ModuleList(
534
+ [
535
+ OmnivoreSwinTransformer3DLayer(
536
+ dim=dim,
537
+ num_heads=num_heads,
538
+ window_size=window_size,
539
+ shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size,
540
+ mlp_ratio=mlp_ratio,
541
+ qkv_bias=qkv_bias,
542
+ qk_scale=qk_scale,
543
+ dropout_rate=dropout_rate,
544
+ attention_dropout_rate=attention_dropout_rate,
545
+ drop_path_rate=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate,
546
+ norm_layer=norm_layer,
547
+ )
548
+ for i in range(depth)
549
+ ]
550
+ )
551
+
552
+ self.downsample = downsample
553
+ if self.downsample is not None:
554
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
555
+
556
+ def forward(self, hidden_state, use_checkpoint=False, height=None, width=None, use_seg=False):
557
+ if use_seg:
558
+ return self.forward_seg(hidden_state, height, width)
559
+ batch_size, channels, D, height, width = hidden_state.shape
560
+ window_size, shift_size = get_window_size((D, height, width), self.window_size, self.shift_size)
561
+ hidden_state = hidden_state.permute(0, 2, 3, 4, 1)
562
+
563
+ Dp = int(np.ceil(D / window_size[0])) * window_size[0]
564
+ Hp = int(np.ceil(height / window_size[1])) * window_size[1]
565
+ Wp = int(np.ceil(width / window_size[2])) * window_size[2]
566
+
567
+ attention_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, hidden_state.device)
568
+
569
+ for layer in self.layers:
570
+ hidden_state = layer(hidden_state, attention_mask, use_checkpoint=use_checkpoint)
571
+ hidden_state = hidden_state.view(batch_size, D, height, width, -1)
572
+
573
+ if self.downsample is not None:
574
+ hidden_state = self.downsample(hidden_state)
575
+
576
+ hidden_state = hidden_state.permute(0, 4, 1, 2, 3)
577
+
578
+ return hidden_state
579
+
580
+ def forward_seg(self, hidden_state, height, width):
581
+
582
+ Hp = int(np.ceil(height / self.window_size[1])) * self.window_size[1]
583
+ Wp = int(np.ceil(width / self.window_size[2])) * self.window_size[2]
584
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=hidden_state.device) # 1 Hp Wp 1
585
+ h_slices = (
586
+ slice(0, -self.window_size[1]),
587
+ slice(-self.window_size[1], -self.shift_size[1]),
588
+ slice(-self.shift_size[1], None),
589
+ )
590
+ w_slices = (
591
+ slice(0, -self.window_size[2]),
592
+ slice(-self.window_size[2], -self.shift_size[2]),
593
+ slice(-self.shift_size[2], None),
594
+ )
595
+ cnt = 0
596
+ for h in h_slices:
597
+ for w in w_slices:
598
+ img_mask[:, h, w, :] = cnt
599
+ cnt += 1
600
+
601
+ mask_windows = window_partition_image(img_mask, self.window_size) # nW, window_size, window_size, 1
602
+ mask_windows = mask_windows.view(-1, self.window_size[1] * self.window_size[2])
603
+ attention_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
604
+ attention_mask = attention_mask.masked_fill(attention_mask != 0, float(-100.0)).masked_fill(
605
+ attention_mask == 0, float(0.0)
606
+ )
607
+
608
+ for layer in self.layers:
609
+ layer.height, layer.width = height, width
610
+ if hidden_state.ndim == 4:
611
+ batch_size, D, channels, seq_len = hidden_state.shape
612
+ assert seq_len == height * width, "input feature has wrong size"
613
+ hidden_state = hidden_state.reshape(batch_size, D, channels, height, width)
614
+ hidden_state = hidden_state.permute(0, 1, 3, 4, 2)
615
+ assert hidden_state.shape[2] == height
616
+ assert hidden_state.shape[3] == width
617
+ hidden_state = layer(hidden_state, attention_mask)
618
+ if self.downsample is not None:
619
+ x_down = self.downsample(hidden_state, height, width)
620
+ Wh, Ww = (height + 1) // 2, (width + 1) // 2
621
+ return hidden_state, height, width, x_down, Wh, Ww
622
+ else:
623
+ return hidden_state, height, width, hidden_state, height, width
624
+
625
+
626
+ class OmnivorePatchEmbeddings3D(nn.Module):
627
+ """Video to Patch Embedding"""
628
+
629
+ def __init__(
630
+ self,
631
+ patch_size=(2, 4, 4),
632
+ input_channels=3,
633
+ embed_dim=96,
634
+ norm_layer=None,
635
+ additional_variable_channels=None,
636
+ ):
637
+ super().__init__()
638
+ self.patch_size = patch_size
639
+
640
+ self.input_channels = input_channels
641
+ self.embed_dim = embed_dim
642
+ self.additional_variable_channels = additional_variable_channels
643
+
644
+ self.projection = nn.Conv3d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
645
+ if additional_variable_channels:
646
+ # we create var_proj separately from proj
647
+ # this makes it convenient to ignore var_proj on downstream tasks
648
+ # where we only use RGB
649
+ self.var_projection = [
650
+ nn.Conv3d(x, embed_dim, kernel_size=patch_size, stride=patch_size)
651
+ for x in additional_variable_channels
652
+ ]
653
+ self.var_projection = nn.ModuleList(self.var_projection)
654
+
655
+ if norm_layer is not None:
656
+ self.norm = norm_layer(embed_dim)
657
+ else:
658
+ self.norm = None
659
+
660
+ def run_variable_channel_forward(self, hidden_state):
661
+ sidx = 0
662
+ out = None
663
+ for idx in range(len(self.additional_variable_channels)):
664
+ eidx = sidx + self.additional_variable_channels[idx]
665
+ c_out = self.var_projection[idx](hidden_state[:, sidx:eidx, ...])
666
+ if idx == 0:
667
+ out = c_out
668
+ else:
669
+ out += c_out
670
+ sidx = eidx
671
+ return out
672
+
673
+ def forward(self, hidden_state):
674
+ _, _, D, height, width = hidden_state.size()
675
+ if width % self.patch_size[2] != 0:
676
+ hidden_state = F.pad(hidden_state, (0, self.patch_size[2] - width % self.patch_size[2]))
677
+ if height % self.patch_size[1] != 0:
678
+ hidden_state = F.pad(hidden_state, (0, 0, 0, self.patch_size[1] - height % self.patch_size[1]))
679
+ if D % self.patch_size[0] != 0:
680
+ hidden_state = F.pad(hidden_state, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
681
+
682
+ if self.additional_variable_channels:
683
+ hidden_state_rgb = hidden_state[:, :3, ...]
684
+ hidden_state_rem = hidden_state[:, 3:, ...]
685
+ hidden_state_rgb = self.projection(hidden_state_rgb)
686
+ if hidden_state.shape[1] > 3:
687
+ hidden_state_rem = self.run_variable_channel_forward(hidden_state_rem)
688
+ hidden_state = hidden_state_rgb + hidden_state_rem
689
+ else:
690
+ hidden_state = hidden_state_rgb
691
+ else:
692
+ hidden_state = self.projection(hidden_state) # B C D Wh Ww
693
+ if self.norm is not None:
694
+ D, Wh, Ww = hidden_state.size(2), hidden_state.size(3), hidden_state.size(4)
695
+ hidden_state = hidden_state.flatten(2).transpose(1, 2)
696
+ hidden_state = self.norm(hidden_state)
697
+ hidden_state = hidden_state.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
698
+
699
+ return hidden_state
700
+
701
+
702
+ class OmnivoreSwinTransformer3DModel(nn.Module):
703
+ def __init__(self, config):
704
+ super().__init__()
705
+ self.config = config
706
+ self.im2vid = OmnivoreIm2Video()
707
+ self.num_stages = len(self.config.depths)
708
+ self.patch_size = self.config.patch_size
709
+ self.input_channels = self.config.input_channels
710
+ self.embed_dim = self.config.embed_dim
711
+ self.depths = self.config.depths
712
+ self.num_heads = self.config.num_heads
713
+ self.window_size = self.config.window_size
714
+ self.mlp_ratio = self.config.mlp_ratio
715
+ self.qkv_bias = self.config.qkv_bias
716
+ self.qk_scale = self.config.qk_scale
717
+ self.dropout_rate = self.config.dropout_rate
718
+ self.attention_dropout_rate = self.config.attention_dropout_rate
719
+ self.drop_path_rate = self.config.drop_path_rate
720
+ self.norm_layer = nn.LayerNorm
721
+ self.patch_norm = self.config.patch_norm
722
+ self.frozen_stages = self.config.frozen_stages
723
+ self.depth_patch_embed_separate_params = True
724
+ self.depth_mode = self.config.depth_mode
725
+ depth_chans = None
726
+ assert self.input_channels == 3, "Only 3 channels supported"
727
+
728
+ # split image into non-overlapping patches
729
+ self.patch_embed = OmnivorePatchEmbeddings3D(
730
+ patch_size=self.patch_size,
731
+ input_channels=self.input_channels,
732
+ embed_dim=self.embed_dim,
733
+ norm_layer=self.norm_layer if self.patch_norm else None,
734
+ )
735
+
736
+ if self.depth_mode is not None:
737
+ msg = f"Using depth mode {self.depth_mode}"
738
+ logger.info(msg)
739
+ assert self.depth_mode in ["separate_d_tokens", "summed_rgb_d_tokens", "rgbd"]
740
+ if self.depth_mode in ["separate_d_tokens", "summed_rgb_d_tokens"]:
741
+ depth_chans = 1
742
+ assert self.depth_patch_embed_separate_params, "separate tokenization needs separate parameters"
743
+ if self.depth_mode == "separate_d_tokens":
744
+ raise NotImplementedError()
745
+ else:
746
+ assert self.depth_mode == "rgbd"
747
+ depth_chans = 4
748
+
749
+ self.depth_patch_embed_separate_params = self.depth_patch_embed_separate_params
750
+
751
+ if self.depth_patch_embed_separate_params:
752
+ self.depth_patch_embed = OmnivorePatchEmbeddings3D(
753
+ patch_size=self.patch_size,
754
+ input_channels=depth_chans,
755
+ embed_dim=self.embed_dim,
756
+ norm_layer=self.norm_layer if self.patch_norm else None,
757
+ )
758
+ else:
759
+ del self.patch_embed
760
+ assert depth_chans == 4
761
+ logger.info("Certain channels of patch projection may not be used in forward pass")
762
+ logger.info("Make sure config.DISTRIBUTED.FIND_UNUSED_PARAMETERS is set to True")
763
+ self.patch_embed = OmnivorePatchEmbeddings3D(
764
+ patch_size=self.patch_size,
765
+ input_channels=3,
766
+ embed_dim=self.embed_dim,
767
+ additional_variable_channels=[1],
768
+ norm_layer=self.norm_layer if self.patch_norm else None,
769
+ )
770
+
771
+ self.pos_drop = nn.Dropout(p=self.dropout_rate)
772
+
773
+ # stochastic depth
774
+ dpr = [
775
+ x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
776
+ ] # stochastic depth decay rule
777
+
778
+ # build stages
779
+ self.stages = nn.ModuleList()
780
+ for stage in range(self.num_stages):
781
+ stage_module = OmnivoreSwinTransformerStage(
782
+ dim=int(self.embed_dim * 2**stage),
783
+ depth=self.depths[stage],
784
+ num_heads=self.num_heads[stage],
785
+ window_size=self.window_size,
786
+ mlp_ratio=self.mlp_ratio,
787
+ qkv_bias=self.qkv_bias,
788
+ qk_scale=self.qk_scale,
789
+ dropout_rate=self.dropout_rate,
790
+ attention_dropout_rate=self.attention_dropout_rate,
791
+ drop_path_rate=dpr[sum(self.depths[:stage]) : sum(self.depths[: stage + 1])],
792
+ norm_layer=self.norm_layer,
793
+ downsample=OmnivorePatchMerging if stage < self.num_stages - 1 else None,
794
+ )
795
+ self.stages.append(stage_module)
796
+
797
+ self.num_features = int(self.embed_dim * 2 ** (self.num_stages - 1))
798
+ self.norm = self.norm_layer(self.num_features)
799
+ self._freeze_stages()
800
+
801
+ def _freeze_stages(self):
802
+ if self.frozen_stages >= 0:
803
+ self.patch_embed.eval()
804
+ for param in self.patch_embed.parameters():
805
+ param.requires_grad = False
806
+
807
+ if self.frozen_stages >= 1:
808
+ self.pos_drop.eval()
809
+ for i in range(0, self.frozen_stages):
810
+ m = self.layers[i]
811
+ m.eval()
812
+ for param in m.parameters():
813
+ param.requires_grad = False
814
+
815
+ def _apply_norm(self, x):
816
+ x = x.permute(0, 2, 3, 4, 1)
817
+ x = self.norm(x)
818
+ x = x.permute(0, 4, 1, 2, 3)
819
+ return x
820
+
821
+ def forward_intermediate_features(self, stage_outputs, out_feat_keys):
822
+ """
823
+ Inputs
824
+ - stage_outputs: list of features without self.norm() applied to them
825
+ - out_feat_keys: list of feature names (str)
826
+ specified as "stage<int>" for feature with norm or "interim<int>" for feature without norm
827
+ """
828
+ out_features = []
829
+ for key in out_feat_keys:
830
+ if key.startswith("stage"):
831
+ rep = "stage"
832
+ elif key.startswith("interim"):
833
+ rep = "interim"
834
+ else:
835
+ raise ValueError(f"Invalid key {key}")
836
+ idx = int(key.replace(rep, ""))
837
+ feat = stage_outputs[idx]
838
+ if rep == "stage":
839
+ feat = self._apply_norm(feat)
840
+ out_features.append(feat)
841
+ return out_features
842
+
843
+ def get_patch_embedding(self, hidden_state):
844
+ assert hidden_state.ndim == 5
845
+ has_depth = hidden_state.shape[1] == 4
846
+
847
+ if has_depth:
848
+ if self.depth_mode in ["summed_rgb_d_tokens"]:
849
+ hidden_state_rgb = hidden_state[:, :3, ...]
850
+ hidden_state_d = hidden_state[:, 3:, ...]
851
+ hidden_state_d = self.depth_patch_embed(hidden_state_d)
852
+ hidden_state_rgb = self.patch_embed(hidden_state_rgb)
853
+ # sum the two sets of tokens
854
+ hidden_state = hidden_state_rgb + hidden_state_d
855
+ elif self.depth_mode == "rgbd":
856
+ if self.depth_patch_embed_separate_params:
857
+ hidden_state = self.depth_patch_embed(hidden_state)
858
+ else:
859
+ hidden_state = self.patch_embed(hidden_state)
860
+ else:
861
+ logger.info("Depth mode %s not supported" % self.depth_mode)
862
+ raise NotImplementedError()
863
+ else:
864
+ hidden_state = self.patch_embed(hidden_state)
865
+ return hidden_state
866
+
867
+ def forward(
868
+ self, hidden_state, out_feat_keys=None, use_checkpoint=False, output_hidden_states=False, return_dict=True
869
+ ):
870
+ all_hidden_states = () if output_hidden_states else None
871
+ hidden_state = self.im2vid(hidden_state)
872
+ hidden_state = self.get_patch_embedding(hidden_state)
873
+ hidden_state = self.pos_drop(hidden_state)
874
+
875
+ stage_outputs = []
876
+
877
+ for stage in self.stages:
878
+ hidden_state = stage(hidden_state.contiguous(), use_checkpoint=use_checkpoint)
879
+ if output_hidden_states:
880
+ all_hidden_states = all_hidden_states + (hidden_state,)
881
+ stage_outputs.append(hidden_state)
882
+
883
+ if out_feat_keys is not None and len(out_feat_keys) > 0:
884
+ final_hidden_state = self.forward_intermediate_features(stage_outputs, out_feat_keys)
885
+ else:
886
+ hidden_state = self._apply_norm(hidden_state)
887
+ # Mean over the spatiotemporal dimensions
888
+ hidden_state = torch.mean(hidden_state, [-3, -2, -1])
889
+
890
+ final_hidden_state = hidden_state
891
+
892
+ if not return_dict:
893
+ return tuple(v for v in [final_hidden_state, all_hidden_states] if v is not None)
894
+ return BaseModelOutputWithNoAttention(last_hidden_state=final_hidden_state, hidden_states=all_hidden_states)
895
+
896
+ def train(self, mode=True):
897
+ """Convert the model into training mode while keep layers freezed."""
898
+ super(OmnivoreSwinTransformer3DModel, self).train(mode)
899
+ self._freeze_stages()
900
+
901
+
902
+ class OmnivoreImageClassificationHead(nn.Module):
903
+ def __init__(self, in_features=1024, out_features=1000, bias=True):
904
+ super().__init__()
905
+ self.image_head = nn.Linear(in_features, out_features, bias)
906
+
907
+ def forward(self, hidden_state):
908
+ logits = self.image_head(hidden_state)
909
+ return logits
910
+
911
+
912
+ class OmnivoreVideoClassificationHead(nn.Module):
913
+ def __init__(self, in_features=1024, out_features=400, bias=True):
914
+ super().__init__()
915
+ self.video_head = nn.Linear(in_features, out_features, bias)
916
+ self.dropout = nn.Dropout(p=0.5)
917
+
918
+ def forward(self, hidden_state):
919
+ logits = self.video_head(hidden_state)
920
+ logits = self.dropout(logits)
921
+ return logits
922
+
923
+
924
+ class OmnivoreRGBDClassificationHead(nn.Module):
925
+ def __init__(self, in_features=1024, out_features=19, bias=True):
926
+ super().__init__()
927
+ self.rgbd_head = nn.Linear(in_features, out_features, bias)
928
+
929
+ def forward(self, hidden_state):
930
+ logits = self.rgbd_head(hidden_state)
931
+ return logits
932
+
933
+
934
+ class OmnivorePreTrainedModel(PreTrainedModel):
935
+ """
936
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
937
+ models.
938
+ """
939
+
940
+ config_class = OmnivoreConfig
941
+ base_model_prefix = "omnivore"
942
+ main_input_name = "pixel_values"
943
+ supports_gradient_checkpointing = True
944
+
945
+ def _init_weights(self, module):
946
+ """Initialize the weights"""
947
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
948
+ # Slightly different from the TF version which uses truncated_normal for initialization
949
+ # cf https://github.com/pytorch/pytorch/pull/5617
950
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
951
+ if module.bias is not None:
952
+ module.bias.data.zero_()
953
+ elif isinstance(module, nn.LayerNorm):
954
+ module.bias.data.zero_()
955
+ module.weight.data.fill_(1.0)
956
+
957
+ def _set_gradient_checkpointing(self, module, value=False):
958
+ if isinstance(module, OmnivoreModel):
959
+ module.gradient_checkpointing = value
960
+
961
+
962
+ OMNIVORE_START_DOCSTRING = r"""
963
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
964
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
965
+ behavior.
966
+
967
+ Parameters:
968
+ config ([`OmnivoreConfig`]): Model configuration class with all the parameters of the model.
969
+ Initializing with a config file does not load the weights associated with the model, only the
970
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
971
+ """
972
+
973
+ OMNIVORE_INPUTS_DOCSTRING = r"""
974
+ Args:
975
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
976
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
977
+ [`AutoFeatureExtractor.__call__`] for details.
978
+
979
+ output_hidden_states (`bool`, *optional*):
980
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
981
+ more detail.
982
+ return_dict (`bool`, *optional*):
983
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
984
+ """
985
+
986
+
987
+ @add_start_docstrings(
988
+ "The bare Omnivore model outputting raw features without any specific head on top.",
989
+ OMNIVORE_START_DOCSTRING,
990
+ )
991
+ class OmnivoreModel(OmnivorePreTrainedModel):
992
+ def __init__(self, config):
993
+ super().__init__(config)
994
+ self.config = config
995
+ self.model = OmnivoreSwinTransformer3DModel(config)
996
+ self.post_init()
997
+
998
+ @add_start_docstrings_to_model_forward(OMNIVORE_INPUTS_DOCSTRING)
999
+ @add_code_sample_docstrings(
1000
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
1001
+ checkpoint=_CHECKPOINT_FOR_DOC,
1002
+ output_type=BaseModelOutputWithPoolingAndNoAttention,
1003
+ config_class=_CONFIG_FOR_DOC,
1004
+ modality="vision",
1005
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1006
+ )
1007
+ def forward(
1008
+ self,
1009
+ pixel_values: torch.FloatTensor = None,
1010
+ output_hidden_states: Optional[bool] = None,
1011
+ return_dict: Optional[bool] = None,
1012
+ ):
1013
+ output_hidden_states = (
1014
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1015
+ )
1016
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1017
+
1018
+ if pixel_values is None:
1019
+ raise ValueError("You have to specify pixel_values")
1020
+
1021
+ outputs = self.model(pixel_values)
1022
+ last_hidden_state = outputs[0]
1023
+ # global average pooling, (N, C, D, H, W) -> (N, C)
1024
+ pooled_output = last_hidden_state.mean([-1])
1025
+
1026
+ if not return_dict:
1027
+ return (last_hidden_state, pooled_output) + outputs[1:]
1028
+
1029
+ return BaseModelOutputWithPoolingAndNoAttention(
1030
+ last_hidden_state=last_hidden_state,
1031
+ pooler_output=pooled_output,
1032
+ hidden_states=outputs.hidden_states,
1033
+ )
1034
+
1035
+
1036
+ @add_start_docstrings(
1037
+ """
1038
+ Omnivore Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
1039
+ ImageNet.
1040
+ """,
1041
+ OMNIVORE_START_DOCSTRING,
1042
+ )
1043
+ class OmnivoreForImageClassification(OmnivorePreTrainedModel):
1044
+ def __init__(self, config):
1045
+ super().__init__(config)
1046
+
1047
+ self.num_image_labels = config.num_image_labels or config.num_labels
1048
+ self.num_video_labels = config.num_video_labels or config.num_labels
1049
+ self.num_rgbd_labels = config.num_rgbd_labels or config.num_labels
1050
+ self.omnivore = OmnivoreModel(config)
1051
+ self.image_classifier = OmnivoreImageClassificationHead(config.head_dim_in, self.num_image_labels)
1052
+ self.rgbd_classifier = OmnivoreRGBDClassificationHead(config.head_dim_in, self.num_rgbd_labels)
1053
+ self.video_classifier = OmnivoreVideoClassificationHead(config.head_dim_in, self.num_video_labels)
1054
+ # Initialize weights and apply final processing
1055
+ self.post_init()
1056
+
1057
+ @add_start_docstrings_to_model_forward(OMNIVORE_INPUTS_DOCSTRING)
1058
+ @add_code_sample_docstrings(
1059
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
1060
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1061
+ output_type=ImageClassifierOutputWithNoAttention,
1062
+ config_class=_CONFIG_FOR_DOC,
1063
+ modality="vision",
1064
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1065
+ )
1066
+ def forward(
1067
+ self,
1068
+ pixel_values: torch.FloatTensor = None,
1069
+ pixel_input_type: str = None,
1070
+ labels: Optional[torch.LongTensor] = None,
1071
+ output_hidden_states: Optional[bool] = None,
1072
+ return_dict: Optional[bool] = None,
1073
+ ):
1074
+ r"""
1075
+ pixel_input_type (`str`):
1076
+ Which classification head to use for the classification of given pixel_values
1077
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1078
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1079
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1080
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1081
+
1082
+ Returns:
1083
+
1084
+ ```"""
1085
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1086
+
1087
+ outputs = self.omnivore(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
1088
+ sequence_output = outputs[0]
1089
+
1090
+ logits = None
1091
+ if pixel_input_type == "image":
1092
+ logits = self.image_classifier(sequence_output)
1093
+
1094
+ if pixel_input_type == "video":
1095
+ logits = self.video_classifier(sequence_output)
1096
+
1097
+ if pixel_input_type == "rgbd":
1098
+ logits = self.rgbd_classifier(sequence_output)
1099
+
1100
+ loss = None
1101
+ if labels is not None:
1102
+ if self.config.problem_type is None:
1103
+ if self.num_labels == 1:
1104
+ self.config.problem_type = "regression"
1105
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1106
+ self.config.problem_type = "single_label_classification"
1107
+ else:
1108
+ self.config.problem_type = "multi_label_classification"
1109
+
1110
+ if self.config.problem_type == "regression":
1111
+ loss_fct = MSELoss()
1112
+ if self.num_labels == 1:
1113
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1114
+ else:
1115
+ loss = loss_fct(logits, labels)
1116
+ elif self.config.problem_type == "single_label_classification":
1117
+ loss_fct = CrossEntropyLoss()
1118
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1119
+ elif self.config.problem_type == "multi_label_classification":
1120
+ loss_fct = BCEWithLogitsLoss()
1121
+ loss = loss_fct(logits, labels)
1122
+ if not return_dict:
1123
+ output = (logits,) + outputs[2:]
1124
+ return ((loss,) + output) if loss is not None else output
1125
+
1126
+ return ImageClassifierOutputWithNoAttention(
1127
+ loss=loss,
1128
+ logits=logits,
1129
+ hidden_states=outputs.hidden_states,
1130
+ )