XuyaoWang commited on
Commit
d458a68
·
verified ·
1 Parent(s): 49ce301

Update model files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ .bin filter=lfs diff=lfs merge=lfs -text
2
+ .pth filter=lfs diff=lfs merge=lfs -text
3
+ imagebind/imagebind_huge.pth filter=lfs diff=lfs merge=lfs -text
4
+ pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1 +1,17 @@
1
  # AnyRewardModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # AnyRewardModel
2
+
3
+ ## Usage
4
+ ```python
5
+ from transformers import AutoModel, AutoProcessor
6
+
7
+ model = AutoModel.from_pretrained("PKU-Alignment/AnyRewardModel", trust_remote_code=True)
8
+ processor = AutoProcessor.from_pretrained("PKU-Alignment/AnyRewardModel", trust_remote_code=True)
9
+ ```
10
+
11
+ ## Note:
12
+
13
+ If you encounter the following error:
14
+ ```
15
+ ModuleNotFoundError: No module named 'torchvision.transforms.functional_tensor'
16
+ ```
17
+ Please refer to guide at [blog](https://blog.csdn.net/lanxing147/article/details/136625264) for detailed resolution steps.
any_model.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Union, Tuple
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+
10
+ from transformers import AutoModelForCausalLM
11
+ from transformers.models.auto import CONFIG_MAPPING
12
+ from transformers.activations import ACT2FN
13
+ from transformers.cache_utils import Cache
14
+ from transformers.processing_utils import ProcessorMixin
15
+ from transformers.configuration_utils import PretrainedConfig
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.modeling_outputs import ModelOutput
18
+ from transformers.feature_extraction_utils import BatchFeature
19
+ from transformers.tokenization_utils_base import (
20
+ TextInput,
21
+ TensorType,
22
+ PaddingStrategy,
23
+ PreTokenizedInput,
24
+ TruncationStrategy
25
+ )
26
+ from transformers.utils import (
27
+ add_start_docstrings,
28
+ add_start_docstrings_to_model_forward,
29
+ logging,
30
+ replace_return_docstrings,
31
+ )
32
+
33
+ from .processor_mm import (
34
+ load_and_transform_image_data,
35
+ load_and_transform_video_data,
36
+ load_and_transform_audio_data
37
+ )
38
+ from .imagebind_model import *
39
+ from .helpers import *
40
+ from .multimodal_preprocessors import *
41
+ from .transformer import *
42
+
43
+ class ModalityType(Enum):
44
+ TEXT = "text"
45
+ IMAGE = "image"
46
+ VIDEO = "video"
47
+ AUDIO = "audio"
48
+ VISION = "vision" # For Imagebind
49
+
50
+ def __str__(self):
51
+ return self.value
52
+
53
+ def __eq__(self, other):
54
+ if isinstance(other, ModalityType):
55
+ return self.value == other.value
56
+ elif isinstance(other, str):
57
+ return self.value == other
58
+ return False
59
+
60
+ def __hash__(self):
61
+ return hash(self.value)
62
+
63
+ _CONFIG_FOR_DOC = "AnyModelConfig"
64
+
65
+ class AnyModelConfig(PretrainedConfig):
66
+ model_type = "any_model"
67
+ keys_to_ignore_at_inference = ["past_key_values"]
68
+
69
+ def __init__(
70
+ self,
71
+ modality_config=None,
72
+ text_config=None,
73
+ ignore_index=-100,
74
+ image_token_index=128256,
75
+ video_token_index=128257,
76
+ audio_token_index=128258,
77
+ projector_hidden_act="gelu",
78
+ **kwargs,
79
+ ):
80
+
81
+ if isinstance(text_config, dict):
82
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
83
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
84
+ elif text_config is None:
85
+ text_config = CONFIG_MAPPING["llama"]()
86
+
87
+ self.modality_config = modality_config
88
+ self.text_config = text_config
89
+ self.ignore_index = ignore_index
90
+ self.image_token_index = image_token_index
91
+ self.video_token_index = video_token_index
92
+ self.audio_token_index = audio_token_index
93
+ self.projector_hidden_act = projector_hidden_act
94
+
95
+ super().__init__(
96
+ **kwargs,
97
+ )
98
+
99
+ class AnyModelProcessor(ProcessorMixin):
100
+ # TODO: Add support for any_model_processor
101
+ # attributes = ["any_model_processor", "tokenizer"]
102
+ attributes = ["tokenizer"]
103
+ valid_kwargs = ["chat_template"]
104
+ any_model_processor_class = "AnyModelProcessor"
105
+ tokenizer_class = "AutoTokenizer"
106
+
107
+ def __init__(self, tokenizer=None, **kwargs):
108
+ super().__init__(tokenizer, **kwargs)
109
+ if self.tokenizer is not None:
110
+ self.tokenizer.add_special_tokens({"additional_special_tokens": ["<image>", "<video>", "<audio>"]})
111
+
112
+ def __call__(
113
+ self,
114
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
115
+ data_paths: Union[str, List[str]] = None,
116
+ modality: Optional[Union[ModalityType, List[ModalityType]]] = None,
117
+ padding: Union[bool, str, PaddingStrategy] = False,
118
+ truncation: Union[bool, str, TruncationStrategy] = None,
119
+ max_length=None,
120
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
121
+ ) -> BatchFeature:
122
+
123
+ if data_paths is not None:
124
+ if modality is None:
125
+ raise ValueError("modality must be specified when data_paths is provided")
126
+ if isinstance(modality, list):
127
+ assert len(set(modality)) == 1, "only one kind modality can be provided in a batch"
128
+ modality = modality[0]
129
+
130
+ proceesor_func = None
131
+ if modality == ModalityType.IMAGE:
132
+ proceesor_func = load_and_transform_image_data
133
+ elif modality == ModalityType.VIDEO:
134
+ proceesor_func = load_and_transform_video_data
135
+ elif modality == ModalityType.AUDIO:
136
+ proceesor_func = load_and_transform_audio_data
137
+ else:
138
+ raise ValueError("modality must be one of ModalityType.IMAGE, ModalityType.VIDEO, ModalityType.AUDIO")
139
+
140
+ if isinstance(data_paths, str):
141
+ pixel_values = proceesor_func(data_paths)
142
+ else:
143
+ pixel_values = torch.stack([proceesor_func(data_path) for data_path in data_paths], dim=0)
144
+ else:
145
+ pixel_values = None
146
+ if text is None:
147
+ text_inputs = {}
148
+ else:
149
+ text_inputs = self.tokenizer(
150
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
151
+ )
152
+
153
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values, "modality": modality})
154
+
155
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
156
+ def batch_decode(self, *args, **kwargs):
157
+ """
158
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
159
+ refer to the docstring of this method for more information.
160
+ """
161
+ return self.tokenizer.batch_decode(*args, **kwargs)
162
+
163
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
164
+ def decode(self, *args, **kwargs):
165
+ """
166
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
167
+ the docstring of this method for more information.
168
+ """
169
+ return self.tokenizer.decode(*args, **kwargs)
170
+
171
+ @property
172
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
173
+ def model_input_names(self):
174
+ tokenizer_input_names = self.tokenizer.model_input_names
175
+ feature_extractor_class_input_names = self.feature_extractor_class.model_input_names
176
+ return list(dict.fromkeys(tokenizer_input_names + feature_extractor_class_input_names))
177
+
178
+ @dataclass
179
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->AnyModel
180
+ class AnyModelCausalLMOutputWithPast(ModelOutput):
181
+ """
182
+ Base class for AnyModel causal language model (or autoregressive) outputs.
183
+
184
+ Args:
185
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
186
+ Language modeling loss (for next-token prediction).
187
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
188
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
189
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
190
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
191
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
192
+
193
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
194
+ `past_key_values` input) to speed up sequential decoding.
195
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
196
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
197
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
198
+
199
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
200
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
201
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
202
+ sequence_length)`.
203
+
204
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
205
+ heads.
206
+ modality_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
207
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
208
+ sequence_length, hidden_size)`.
209
+
210
+ modality_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
211
+ """
212
+
213
+ loss: Optional[torch.FloatTensor] = None
214
+ logits: torch.FloatTensor = None
215
+ past_key_values: Optional[List[torch.FloatTensor]] = None
216
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
217
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
218
+ modality_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
219
+ modality: Optional[ModalityType] = None
220
+
221
+
222
+ class AnyModelMultiModalProjector(nn.Module):
223
+ def __init__(self, config: AnyModelConfig):
224
+ super().__init__()
225
+
226
+ self.linear_1 = nn.Linear(config.modality_config["hidden_size"], config.text_config.hidden_size, bias=True)
227
+ self.act = ACT2FN[config.projector_hidden_act]
228
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
229
+
230
+ def forward(self, modality_features):
231
+ hidden_states = self.linear_1(modality_features)
232
+ hidden_states = self.act(hidden_states)
233
+ hidden_states = self.linear_2(hidden_states)
234
+ return hidden_states
235
+
236
+ class AnyModelPreTrainedModel(PreTrainedModel):
237
+ config_class = AnyModelConfig
238
+ base_model_prefix = "model"
239
+ supports_gradient_checkpointing = True
240
+ _no_split_modules = ["AnyModelAttention"]
241
+ _skip_keys_device_placement = "past_key_values"
242
+ _supports_flash_attn_2 = True
243
+
244
+ def __init__(self, config: AnyModelConfig):
245
+ self.config = config
246
+ super().__init__(config)
247
+
248
+
249
+ def _init_weights(self, module):
250
+ # important: this ported version of AnyModel isn't meant for training from scratch - only
251
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
252
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
253
+ std = (
254
+ self.config.initializer_range
255
+ if hasattr(self.config, "initializer_range")
256
+ else self.config.text_config.initializer_range
257
+ )
258
+
259
+ if hasattr(module, "class_embedding"):
260
+ module.class_embedding.data.normal_(mean=0.0, std=std)
261
+
262
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
263
+ module.weight.data.normal_(mean=0.0, std=std)
264
+ if module.bias is not None:
265
+ module.bias.data.zero_()
266
+ elif isinstance(module, nn.Embedding):
267
+ module.weight.data.normal_(mean=0.0, std=std)
268
+ if module.padding_idx is not None:
269
+ module.weight.data[module.padding_idx].zero_()
270
+
271
+ @property
272
+ def _supports_sdpa(self):
273
+ """
274
+ Retrieve language_model's attribute to check whether the model supports
275
+ SDPA or not.
276
+ """
277
+ return self.language_model._supports_sdpa
278
+
279
+ ANYMODEL_INPUTS_DOCSTRING = r"""
280
+ Args:
281
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
282
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
283
+ it.
284
+
285
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
286
+ [`PreTrainedTokenizer.__call__`] for details.
287
+
288
+ [What are input IDs?](../glossary#input-ids)
289
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
290
+ The tensors corresponding to the input images. Pixel values can be obtained using
291
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`AnyModelProcessor`] uses
292
+ [`CLIPImageProcessor`] for processing images).
293
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
294
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
295
+
296
+ - 1 for tokens that are **not masked**,
297
+ - 0 for tokens that are **masked**.
298
+
299
+ [What are attention masks?](../glossary#attention-mask)
300
+
301
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
302
+ [`PreTrainedTokenizer.__call__`] for details.
303
+
304
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
305
+ `past_key_values`).
306
+
307
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
308
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
309
+ information on the default strategy.
310
+
311
+ - 1 indicates the head is **not masked**,
312
+ - 0 indicates the head is **masked**.
313
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
314
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
315
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
316
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
317
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
318
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
319
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
320
+
321
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
322
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
323
+
324
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
325
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
326
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
327
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
328
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
329
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
330
+ model's internal embedding lookup matrix.
331
+ vision_feature_layer (`int`, *optional*, defaults to -2):
332
+ The index of the layer to select the vision feature.
333
+ use_cache (`bool`, *optional*):
334
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
335
+ `past_key_values`).
336
+ output_attentions (`bool`, *optional*):
337
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
338
+ tensors for more detail.
339
+ output_hidden_states (`bool`, *optional*):
340
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
341
+ more detail.
342
+ return_dict (`bool`, *optional*):
343
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
344
+ """
345
+
346
+ class AnyModelForConditionalGeneration(AnyModelPreTrainedModel):
347
+ def __init__(self, config: AnyModelConfig):
348
+ super().__init__(config)
349
+
350
+ self.image_projector = AnyModelMultiModalProjector(config)
351
+ self.video_projector = AnyModelMultiModalProjector(config)
352
+ self.audio_projector = AnyModelMultiModalProjector(config)
353
+ self.language_model = AutoModelForCausalLM.from_config(
354
+ config.text_config, attn_implementation=config._attn_implementation
355
+ )
356
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
357
+
358
+ self.modality_tower, _ = \
359
+ imagebind_huge(pretrained=True, store_path=os.path.join(config._name_or_path, config.modality_config["imagebind_ckpt_path"]))
360
+ self.modality_tower = self.modality_tower.to(self.language_model.device)
361
+ self.modality_tower = self.modality_tower.to(self.language_model.dtype)
362
+
363
+ self.post_init()
364
+
365
+ def get_input_embeddings(self):
366
+ return self.language_model.get_input_embeddings()
367
+
368
+ def set_input_embeddings(self, value):
369
+ self.language_model.set_input_embeddings(value)
370
+
371
+ def get_output_embeddings(self):
372
+ return self.language_model.get_output_embeddings()
373
+
374
+ def set_output_embeddings(self, new_embeddings):
375
+ self.language_model.set_output_embeddings(new_embeddings)
376
+
377
+ def set_decoder(self, decoder):
378
+ self.language_model.set_decoder(decoder)
379
+
380
+ def get_decoder(self):
381
+ return self.language_model.get_decoder()
382
+
383
+ def tie_weights(self):
384
+ return self.language_model.tie_weights()
385
+
386
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
387
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
388
+ # update vocab size
389
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
390
+ self.vocab_size = model_embeds.num_embeddings
391
+ return model_embeds
392
+
393
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
394
+ num_images, num_image_patches, embed_dim = image_features.shape
395
+ batch_size, sequence_length = input_ids.shape
396
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
397
+ # 1. Create a mask to know where special image tokens are
398
+ special_image_token_mask = input_ids == self.config.image_token_index
399
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
400
+ # Compute the maximum embed dimension
401
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
402
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
403
+
404
+ # 2. Compute the positions where text should be written
405
+ # Calculate new positions for text tokens in merged image-text sequence.
406
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
407
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
408
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
409
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
410
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
411
+ if left_padding:
412
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
413
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
414
+
415
+ # 3. Create the full embedding, already padded to the maximum position
416
+ final_embedding = torch.zeros(
417
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
418
+ )
419
+ final_attention_mask = torch.zeros(
420
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
421
+ )
422
+ if labels is not None:
423
+ final_labels = torch.full(
424
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
425
+ )
426
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
427
+ # set the corresponding tensors into their correct target device.
428
+ target_device = inputs_embeds.device
429
+ batch_indices, non_image_indices, text_to_overwrite = (
430
+ batch_indices.to(target_device),
431
+ non_image_indices.to(target_device),
432
+ text_to_overwrite.to(target_device),
433
+ )
434
+ attention_mask = attention_mask.to(target_device)
435
+
436
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
437
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
438
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
439
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
440
+ if labels is not None:
441
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
442
+
443
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
444
+ image_to_overwrite = torch.full(
445
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
446
+ )
447
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
448
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
449
+
450
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
451
+ raise ValueError(
452
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
453
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
454
+ )
455
+
456
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
457
+ final_attention_mask |= image_to_overwrite
458
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
459
+
460
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
461
+ batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
462
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
463
+
464
+ final_embedding[batch_indices, indices_to_mask] = 0
465
+
466
+ if labels is None:
467
+ final_labels = None
468
+
469
+ return final_embedding, final_attention_mask, final_labels, position_ids
470
+
471
+ def _merge_input_ids_with_video_features(self, video_features, inputs_embeds, input_ids, attention_mask, labels):
472
+ num_videos, num_video_patches, embed_dim = video_features.shape
473
+ batch_size, sequence_length = input_ids.shape
474
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
475
+ # 1. Create a mask to know where special video tokens are
476
+ special_video_token_mask = input_ids == self.config.video_token_index
477
+ num_special_video_tokens = torch.sum(special_video_token_mask, dim=-1)
478
+ # Compute the maximum embed dimension
479
+ max_embed_dim = (num_special_video_tokens.max() * (num_video_patches - 1)) + sequence_length
480
+ batch_indices, non_video_indices = torch.where(input_ids != self.config.video_token_index)
481
+
482
+ # 2. Compute the positions where text should be written
483
+ # Calculate new positions for text tokens in merged video-text sequence.
484
+ # `special_video_token_mask` identifies video tokens. Each video token will be replaced by `nb_text_tokens_per_videos - 1` text tokens.
485
+ # `torch.cumsum` computes how each video token shifts subsequent text token positions.
486
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
487
+ new_token_positions = torch.cumsum((special_video_token_mask * (num_video_patches - 1) + 1), -1) - 1
488
+ nb_video_pad = max_embed_dim - 1 - new_token_positions[:, -1]
489
+ if left_padding:
490
+ new_token_positions += nb_video_pad[:, None] # offset for left padding
491
+ text_to_overwrite = new_token_positions[batch_indices, non_video_indices]
492
+
493
+ # 3. Create the full embedding, already padded to the maximum position
494
+ final_embedding = torch.zeros(
495
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
496
+ )
497
+ final_attention_mask = torch.zeros(
498
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
499
+ )
500
+ if labels is not None:
501
+ final_labels = torch.full(
502
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
503
+ )
504
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
505
+ # set the corresponding tensors into their correct target device.
506
+ target_device = inputs_embeds.device
507
+ batch_indices, non_video_indices, text_to_overwrite = (
508
+ batch_indices.to(target_device),
509
+ non_video_indices.to(target_device),
510
+ text_to_overwrite.to(target_device),
511
+ )
512
+ attention_mask = attention_mask.to(target_device)
513
+
514
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<video>", "how", "are"]
515
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the video features
516
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_video_indices]
517
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_video_indices]
518
+ if labels is not None:
519
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_video_indices]
520
+
521
+ # 5. Fill the embeddings corresponding to the videos. Anything that is not `text_positions` needs filling (#29835)
522
+ video_to_overwrite = torch.full(
523
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
524
+ )
525
+ video_to_overwrite[batch_indices, text_to_overwrite] = False
526
+ video_to_overwrite &= video_to_overwrite.cumsum(-1) - 1 >= nb_video_pad[:, None].to(target_device)
527
+
528
+ if video_to_overwrite.sum() != video_features.shape[:-1].numel():
529
+ raise ValueError(
530
+ f"The input provided to the model are wrong. The number of video tokens is {torch.sum(special_video_token_mask)} while"
531
+ f" the number of video given to the model is {num_videos}. This prevents correct indexing and breaks batch generation."
532
+ )
533
+
534
+ final_embedding[video_to_overwrite] = video_features.contiguous().reshape(-1, embed_dim).to(target_device)
535
+ final_attention_mask |= video_to_overwrite
536
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
537
+
538
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
539
+ batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
540
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
541
+
542
+ final_embedding[batch_indices, indices_to_mask] = 0
543
+
544
+ if labels is None:
545
+ final_labels = None
546
+
547
+ return final_embedding, final_attention_mask, final_labels, position_ids
548
+
549
+ def _merge_input_ids_with_audio_features(self, audio_features, inputs_embeds, input_ids, attention_mask, labels):
550
+ num_audios, num_audio_patches, embed_dim = audio_features.shape
551
+ batch_size, sequence_length = input_ids.shape
552
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
553
+ # 1. Create a mask to know where special audio tokens are
554
+ special_audio_token_mask = input_ids == self.config.audio_token_index
555
+ num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1)
556
+ # Compute the maximum embed dimension
557
+ max_embed_dim = (num_special_audio_tokens.max() * (num_audio_patches - 1)) + sequence_length
558
+ batch_indices, non_audio_indices = torch.where(input_ids != self.config.audio_token_index)
559
+
560
+ # 2. Compute the positions where text should be written
561
+ # Calculate new positions for text tokens in merged audio-text sequence.
562
+ # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `nb_text_tokens_per_audios - 1` text tokens.
563
+ # `torch.cumsum` computes how each audio token shifts subsequent text token positions.
564
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
565
+ new_token_positions = torch.cumsum((special_audio_token_mask * (num_audio_patches - 1) + 1), -1) - 1
566
+ nb_audio_pad = max_embed_dim - 1 - new_token_positions[:, -1]
567
+ if left_padding:
568
+ new_token_positions += nb_audio_pad[:, None] # offset for left padding
569
+ text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
570
+
571
+ # 3. Create the full embedding, already padded to the maximum position
572
+ final_embedding = torch.zeros(
573
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
574
+ )
575
+ final_attention_mask = torch.zeros(
576
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
577
+ )
578
+ if labels is not None:
579
+ final_labels = torch.full(
580
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
581
+ )
582
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
583
+ # set the corresponding tensors into their correct target device.
584
+ target_device = inputs_embeds.device
585
+ batch_indices, non_audio_indices, text_to_overwrite = (
586
+ batch_indices.to(target_device),
587
+ non_audio_indices.to(target_device),
588
+ text_to_overwrite.to(target_device),
589
+ )
590
+ attention_mask = attention_mask.to(target_device)
591
+
592
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<audio>", "how", "are"]
593
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the audio features
594
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
595
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
596
+ if labels is not None:
597
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_audio_indices]
598
+
599
+ # 5. Fill the embeddings corresponding to the audios. Anything that is not `text_positions` needs filling (#29835)
600
+ audio_to_overwrite = torch.full(
601
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
602
+ )
603
+ audio_to_overwrite[batch_indices, text_to_overwrite] = False
604
+ audio_to_overwrite &= audio_to_overwrite.cumsum(-1) - 1 >= nb_audio_pad[:, None].to(target_device)
605
+
606
+ if audio_to_overwrite.sum() != audio_features.shape[:-1].numel():
607
+ raise ValueError(
608
+ f"The input provided to the model are wrong. The number of audio tokens is {torch.sum(special_audio_token_mask)} while"
609
+ f" the number of audio given to the model is {num_audios}. This prevents correct indexing and breaks batch generation."
610
+ )
611
+
612
+ final_embedding[audio_to_overwrite] = audio_features.contiguous().reshape(-1, embed_dim).to(target_device)
613
+ final_attention_mask |= audio_to_overwrite
614
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
615
+
616
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
617
+ batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
618
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
619
+
620
+ final_embedding[batch_indices, indices_to_mask] = 0
621
+
622
+ if labels is None:
623
+ final_labels = None
624
+
625
+ return final_embedding, final_attention_mask, final_labels, position_ids
626
+
627
+ @add_start_docstrings_to_model_forward(ANYMODEL_INPUTS_DOCSTRING)
628
+ @replace_return_docstrings(output_type=AnyModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
629
+ def forward(
630
+ self,
631
+ input_ids: torch.LongTensor = None,
632
+ pixel_values_1: torch.FloatTensor = None,
633
+ pixel_values_2: torch.FloatTensor = None,
634
+ attention_mask: Optional[torch.Tensor] = None,
635
+ position_ids: Optional[torch.LongTensor] = None,
636
+ modality: Optional[ModalityType] = None,
637
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
638
+ inputs_embeds: Optional[torch.FloatTensor] = None,
639
+ vision_feature_layer: Optional[int] = None,
640
+ labels: Optional[torch.LongTensor] = None,
641
+ use_cache: Optional[bool] = None,
642
+ output_attentions: Optional[bool] = None,
643
+ output_hidden_states: Optional[bool] = None,
644
+ return_dict: Optional[bool] = None,
645
+ ) -> Union[Tuple, AnyModelCausalLMOutputWithPast]:
646
+ r"""
647
+ Args:
648
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
649
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
650
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
651
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
652
+
653
+ Returns:
654
+ ```"""
655
+
656
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
657
+ output_hidden_states = (
658
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
659
+ )
660
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
661
+
662
+ if inputs_embeds is None:
663
+ # 1. Extra the input embeddings
664
+ inputs_embeds = self.get_input_embeddings()(input_ids)
665
+
666
+ # 2. Merge text and images
667
+ if pixel_values_1 is not None and pixel_values_1 is not None and input_ids.shape[1] != 1:
668
+ assert modality is not None, "modality must be provided when pixel_values is not None"
669
+ '''
670
+ if isinstance(modality, list):
671
+ assert len(set(modality)) == 1, "only one kind modality can be provided in a batch"
672
+ modality = modality[0]
673
+ '''
674
+ for i in range(2):
675
+ pixel_values = pixel_values_1 if i == 0 else pixel_values_2
676
+ if modality[0][i] == ModalityType.IMAGE:
677
+ modality_outputs = self.modality_tower({
678
+ str(ModalityType.VISION): pixel_values
679
+ })[str(ModalityType.VISION)] # size = (b, h)
680
+ features = self.image_projector(modality_outputs).unsqueeze(1) # size = (b, 1, h)
681
+ self.merge_input_ids_with_other_features = self._merge_input_ids_with_image_features
682
+ elif modality[0][i] == ModalityType.VIDEO:
683
+ modality_outputs = self.modality_tower({
684
+ str(ModalityType.VISION): pixel_values
685
+ })[str(ModalityType.VISION)] # size = (b, h)
686
+ features = self.video_projector(modality_outputs).unsqueeze(1) # size = (b, 1, h)
687
+ self.merge_input_ids_with_other_features = self._merge_input_ids_with_video_features
688
+ elif modality[0][i] == ModalityType.AUDIO:
689
+ modality_outputs = self.modality_tower({
690
+ str(ModalityType.AUDIO): pixel_values
691
+ })[str(ModalityType.AUDIO)] # size = (b, h)
692
+ features = self.audio_projector(modality_outputs).unsqueeze(1) # size = (b, 1, h)
693
+ self.merge_input_ids_with_other_features = self._merge_input_ids_with_audio_features
694
+ elif modality[0][i] == ModalityType.TEXT:
695
+ continue
696
+ else:
697
+ raise ValueError(f"modality {modality[i]} is not supported")
698
+
699
+ inputs_embeds = inputs_embeds.to(features.dtype)
700
+ '''
701
+ print('+++'*10)
702
+ print(input_ids)
703
+ print(torch.sum(input_ids == self.config.audio_token_index, dim=-1))
704
+ print('+++'*10)
705
+ '''
706
+ inputs_embeds, attention_mask, labels, position_ids = self.merge_input_ids_with_other_features(
707
+ features, inputs_embeds, input_ids, attention_mask, labels
708
+ )
709
+
710
+ position_ids = (attention_mask.cumsum(-1) - 1).masked_fill_((attention_mask == 0), 1)
711
+
712
+ outputs = self.language_model(
713
+ attention_mask=attention_mask,
714
+ position_ids=position_ids,
715
+ past_key_values=past_key_values,
716
+ inputs_embeds=inputs_embeds,
717
+ use_cache=use_cache,
718
+ output_attentions=output_attentions,
719
+ output_hidden_states=output_hidden_states,
720
+ return_dict=return_dict,
721
+ )
722
+
723
+ logits = outputs[0]
724
+
725
+ loss = None
726
+ if labels is not None:
727
+ # Shift so that tokens < n predict n
728
+ if attention_mask is not None:
729
+ shift_attention_mask = attention_mask[..., 1:]
730
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
731
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
732
+ else:
733
+ shift_logits = logits[..., :-1, :].contiguous()
734
+ shift_labels = labels[..., 1:].contiguous()
735
+ # Flatten the tokens
736
+ loss_fct = nn.CrossEntropyLoss()
737
+ loss = loss_fct(
738
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
739
+ )
740
+
741
+ if not return_dict:
742
+ output = (logits,) + outputs[1:]
743
+ return (loss,) + output if loss is not None else output
744
+
745
+ return AnyModelCausalLMOutputWithPast(
746
+ loss=loss,
747
+ logits=logits,
748
+ past_key_values=outputs.past_key_values,
749
+ hidden_states=outputs.hidden_states,
750
+ attentions=outputs.attentions,
751
+ )
752
+
753
+ def prepare_inputs_for_generation(
754
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
755
+ ):
756
+ if past_key_values is not None:
757
+ if isinstance(past_key_values, Cache):
758
+ cache_length = past_key_values.get_seq_length()
759
+ past_length = past_key_values.seen_tokens
760
+ else:
761
+ cache_length = past_length = past_key_values[0][0].shape[2]
762
+
763
+ # Keep only the unprocessed tokens:
764
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
765
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
766
+ # input)
767
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
768
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
769
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
770
+ # input_ids based on the past_length.
771
+ elif past_length < input_ids.shape[1]:
772
+ input_ids = input_ids[:, past_length:]
773
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
774
+ elif self.config.image_token_index in input_ids:
775
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
776
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
777
+ # older attention values, as their corresponding values are not part of the input.
778
+ if cache_length < past_length and attention_mask is not None:
779
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
780
+
781
+ position_ids = kwargs.get("position_ids", None)
782
+ if attention_mask is not None and position_ids is None:
783
+ # create position_ids on the fly for batch generation
784
+ position_ids = attention_mask.long().cumsum(-1) - 1
785
+ position_ids.masked_fill_(attention_mask == 0, 1)
786
+ if past_key_values:
787
+ position_ids = position_ids[:, -input_ids.shape[1] :]
788
+
789
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
790
+ if inputs_embeds is not None and past_key_values is None:
791
+ model_inputs = {"inputs_embeds": inputs_embeds}
792
+ else:
793
+ model_inputs = {"input_ids": input_ids}
794
+
795
+ model_inputs.update(
796
+ {
797
+ "position_ids": position_ids,
798
+ "past_key_values": past_key_values,
799
+ "use_cache": kwargs.get("use_cache"),
800
+ "attention_mask": attention_mask,
801
+ "pixel_values": pixel_values,
802
+ }
803
+ )
804
+ return model_inputs
805
+
806
+ def _reorder_cache(self, *args, **kwargs):
807
+ return self.language_model._reorder_cache(*args, **kwargs)
808
+
809
+ @dataclass
810
+ class ScoreModelOutput(ModelOutput):
811
+ """Output of the score model."""
812
+
813
+ scores: torch.FloatTensor | None = None # size = (B, L, D)
814
+ clipped_scores: torch.FloatTensor | None = None # size = (B, L-I, D)
815
+ end_scores: torch.FloatTensor | None = None # size = (B, D)
816
+ last_hidden_state: torch.FloatTensor | None = None # size = (B, L, E)
817
+ clipped_states: torch.FloatTensor | None = None # size = (B, L-I, D)
818
+ end_last_hidden_state: torch.FloatTensor | None = None # size = (B, E)
819
+ end_index: torch.LongTensor | None = None # size = (B,)
820
+
821
+ class AnyRewardModel(AnyModelForConditionalGeneration):
822
+ supports_gradient_checkpointing = True
823
+
824
+ def __init__(self, config: AnyModelConfig):
825
+ super().__init__(config)
826
+ self.score_head = nn.Linear(4096, 1, bias=False)
827
+
828
+ def forward(
829
+ self,
830
+ input_ids: torch.LongTensor | None = None,
831
+ attention_mask: torch.Tensor | None = None,
832
+ **kwargs,
833
+ ) -> torch.Tensor:
834
+ outputs = self.model(
835
+ input_ids,
836
+ attention_mask=attention_mask,
837
+ output_hidden_states=True,
838
+ **kwargs,
839
+ )
840
+
841
+ last_hidden_state = outputs.hidden_states[-1]
842
+ scores = self.score_head(last_hidden_state).float()
843
+ B, _, _ = scores.size()
844
+
845
+ end_index = -torch.ones((B,)) # size = (B,)
846
+ end_last_hidden_state = last_hidden_state[:, -1, :].unsqueeze(1)
847
+ end_scores = self.score_head(end_last_hidden_state).float()
848
+ end_last_hidden_state = end_last_hidden_state.squeeze(dim=1) # size = (B, E)
849
+ end_scores = end_scores.squeeze(dim=1) # size = (B, D)
850
+
851
+ return ScoreModelOutput(
852
+ scores=scores, # size = (B, L, D)
853
+ end_scores=end_scores, # size = (B, D)
854
+ last_hidden_state=last_hidden_state, # size = (B, L, E)
855
+ end_last_hidden_state=end_last_hidden_state, # size = (B, E)
856
+ end_index=end_index, # size = (B,)
857
+ )
858
+
859
+ from transformers import AutoConfig, AutoModel
860
+
861
+ AutoConfig.register("any_model", AnyModelConfig)
862
+ AutoModel.register(AnyModelConfig, AnyModelForConditionalGeneration)
863
+ AutoModel.register(AnyModelConfig, AnyRewardModel)
config.json CHANGED
@@ -2,6 +2,11 @@
2
  "architectures": [
3
  "AnyModelForConditionalGeneration"
4
  ],
 
 
 
 
 
5
  "audio_token_index": 128258,
6
  "bos_token_id": 128000,
7
  "eos_token_id": 128001,
 
2
  "architectures": [
3
  "AnyModelForConditionalGeneration"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "any_model.AnyModelConfig",
7
+ "AutoModel": "any_model.AnyRewardModel",
8
+ "AutoModelForCausalLM": "any_model.AnyModelForConditionalGeneration"
9
+ },
10
  "audio_token_index": 128258,
11
  "bos_token_id": 128000,
12
  "eos_token_id": 128001,
helpers.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import math
9
+
10
+ import einops
11
+ import numpy as np
12
+ import torch
13
+
14
+ import torch.nn as nn
15
+
16
+
17
+ class Normalize(nn.Module):
18
+ def __init__(self, dim: int) -> None:
19
+ super().__init__()
20
+ self.dim = dim
21
+
22
+ def forward(self, x):
23
+ return torch.nn.functional.normalize(x, dim=self.dim, p=2)
24
+
25
+
26
+ class LearnableLogitScaling(nn.Module):
27
+ def __init__(
28
+ self,
29
+ logit_scale_init: float = 1 / 0.07,
30
+ learnable: bool = True,
31
+ max_logit_scale: float = 100,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.max_logit_scale = max_logit_scale
35
+ self.logit_scale_init = logit_scale_init
36
+ self.learnable = learnable
37
+ log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
38
+ if learnable:
39
+ self.log_logit_scale = nn.Parameter(log_logit_scale)
40
+ else:
41
+ self.register_buffer("log_logit_scale", log_logit_scale)
42
+
43
+ def forward(self, x):
44
+ return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
45
+
46
+ def extra_repr(self):
47
+ st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
48
+ return st
49
+
50
+
51
+ class EinOpsRearrange(nn.Module):
52
+ def __init__(self, rearrange_expr: str, **kwargs) -> None:
53
+ super().__init__()
54
+ self.rearrange_expr = rearrange_expr
55
+ self.kwargs = kwargs
56
+
57
+ def forward(self, x):
58
+ assert isinstance(x, torch.Tensor)
59
+ return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
60
+
61
+
62
+ class VerboseNNModule(nn.Module):
63
+ """
64
+ Wrapper around nn.Module that prints registered buffers and parameter names.
65
+ """
66
+
67
+ @staticmethod
68
+ def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
69
+ st = (
70
+ "("
71
+ + name
72
+ + "): "
73
+ + "tensor("
74
+ + str(tuple(tensor[1].shape))
75
+ + ", requires_grad="
76
+ + str(tensor[1].requires_grad)
77
+ + ")\n"
78
+ )
79
+ return st
80
+
81
+ def extra_repr(self) -> str:
82
+ named_modules = set()
83
+ for p in self.named_modules():
84
+ named_modules.update([p[0]])
85
+ named_modules = list(named_modules)
86
+
87
+ string_repr = ""
88
+ for p in self.named_parameters():
89
+ name = p[0].split(".")[0]
90
+ if name not in named_modules:
91
+ string_repr += self.get_readable_tensor_repr(name, p)
92
+
93
+ for p in self.named_buffers():
94
+ name = p[0].split(".")[0]
95
+ string_repr += self.get_readable_tensor_repr(name, p)
96
+
97
+ return string_repr
98
+
99
+
100
+ def cast_if_src_dtype(
101
+ tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
102
+ ):
103
+ updated = False
104
+ if tensor.dtype == src_dtype:
105
+ tensor = tensor.to(dtype=tgt_dtype)
106
+ updated = True
107
+ return tensor, updated
108
+
109
+
110
+ class QuickGELU(nn.Module):
111
+ # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
112
+ def forward(self, x: torch.Tensor):
113
+ return x * torch.sigmoid(1.702 * x)
114
+
115
+
116
+ class SelectElement(nn.Module):
117
+ def __init__(self, index) -> None:
118
+ super().__init__()
119
+ self.index = index
120
+
121
+ def forward(self, x):
122
+ assert x.ndim >= 3
123
+ return x[:, self.index, ...]
124
+
125
+
126
+ class SelectEOSAndProject(nn.Module):
127
+ """
128
+ Text Pooling used in OpenCLIP
129
+ """
130
+
131
+ def __init__(self, proj: nn.Module) -> None:
132
+ super().__init__()
133
+ self.proj = proj
134
+
135
+ def forward(self, x, seq_len):
136
+ assert x.ndim == 3
137
+ # x is of shape B x L x D
138
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
139
+ x = x[torch.arange(x.shape[0]), seq_len]
140
+ x = self.proj(x)
141
+ return x
imagebind_model.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import os
10
+ import urllib
11
+ from functools import partial
12
+ from types import SimpleNamespace
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from .helpers import (
18
+ EinOpsRearrange,
19
+ LearnableLogitScaling,
20
+ Normalize,
21
+ SelectElement,
22
+ SelectEOSAndProject,
23
+ )
24
+ from .multimodal_preprocessors import (
25
+ AudioPreprocessor,
26
+ IMUPreprocessor,
27
+ PadIm2Video,
28
+ PatchEmbedGeneric,
29
+ RGBDTPreprocessor,
30
+ SpatioTemporalPosEmbeddingHelper,
31
+ TextPreprocessor,
32
+ ThermalPreprocessor,
33
+ )
34
+
35
+ from .transformer import MultiheadAttention, SimpleTransformer
36
+
37
+
38
+ ModalityType = SimpleNamespace(
39
+ VISION="vision",
40
+ TEXT="text",
41
+ AUDIO="audio",
42
+ THERMAL="thermal",
43
+ DEPTH="depth",
44
+ IMU="imu",
45
+ )
46
+
47
+
48
+ class ImageBindModel(nn.Module):
49
+ def __init__(
50
+ self,
51
+ video_frames=2,
52
+ kernel_size=(2, 14, 14),
53
+ audio_kernel_size=16,
54
+ audio_stride=10,
55
+ out_embed_dim=768,
56
+ vision_embed_dim=1024,
57
+ vision_num_blocks=24,
58
+ vision_num_heads=16,
59
+ audio_embed_dim=768,
60
+ audio_num_blocks=12,
61
+ audio_num_heads=12,
62
+ audio_num_mel_bins=128,
63
+ audio_target_len=204,
64
+ audio_drop_path=0.1,
65
+ text_embed_dim=768,
66
+ text_num_blocks=12,
67
+ text_num_heads=12,
68
+ depth_embed_dim=384,
69
+ depth_kernel_size=16,
70
+ depth_num_blocks=12,
71
+ depth_num_heads=8,
72
+ depth_drop_path=0.0,
73
+ thermal_embed_dim=768,
74
+ thermal_kernel_size=16,
75
+ thermal_num_blocks=12,
76
+ thermal_num_heads=12,
77
+ thermal_drop_path=0.0,
78
+ imu_embed_dim=512,
79
+ imu_kernel_size=8,
80
+ imu_num_blocks=6,
81
+ imu_num_heads=8,
82
+ imu_drop_path=0.7,
83
+ ):
84
+ super().__init__()
85
+
86
+ self.modality_preprocessors = self._create_modality_preprocessors(
87
+ video_frames,
88
+ vision_embed_dim,
89
+ kernel_size,
90
+ text_embed_dim,
91
+ audio_embed_dim,
92
+ audio_kernel_size,
93
+ audio_stride,
94
+ audio_num_mel_bins,
95
+ audio_target_len,
96
+ depth_embed_dim,
97
+ depth_kernel_size,
98
+ thermal_embed_dim,
99
+ thermal_kernel_size,
100
+ imu_embed_dim,
101
+ )
102
+
103
+ self.modality_trunks = self._create_modality_trunks(
104
+ vision_embed_dim,
105
+ vision_num_blocks,
106
+ vision_num_heads,
107
+ text_embed_dim,
108
+ text_num_blocks,
109
+ text_num_heads,
110
+ audio_embed_dim,
111
+ audio_num_blocks,
112
+ audio_num_heads,
113
+ audio_drop_path,
114
+ depth_embed_dim,
115
+ depth_num_blocks,
116
+ depth_num_heads,
117
+ depth_drop_path,
118
+ thermal_embed_dim,
119
+ thermal_num_blocks,
120
+ thermal_num_heads,
121
+ thermal_drop_path,
122
+ imu_embed_dim,
123
+ imu_num_blocks,
124
+ imu_num_heads,
125
+ imu_drop_path,
126
+ )
127
+
128
+ self.modality_heads = self._create_modality_heads(
129
+ out_embed_dim,
130
+ vision_embed_dim,
131
+ text_embed_dim,
132
+ audio_embed_dim,
133
+ depth_embed_dim,
134
+ thermal_embed_dim,
135
+ imu_embed_dim,
136
+ )
137
+
138
+ self.modality_postprocessors = self._create_modality_postprocessors(
139
+ out_embed_dim
140
+ )
141
+
142
+ def _create_modality_preprocessors(
143
+ self,
144
+ video_frames=2,
145
+ vision_embed_dim=1024,
146
+ kernel_size=(2, 14, 14),
147
+ text_embed_dim=768,
148
+ audio_embed_dim=768,
149
+ audio_kernel_size=16,
150
+ audio_stride=10,
151
+ audio_num_mel_bins=128,
152
+ audio_target_len=204,
153
+ depth_embed_dim=768,
154
+ depth_kernel_size=16,
155
+ thermal_embed_dim=768,
156
+ thermal_kernel_size=16,
157
+ imu_embed_dim=512,
158
+ ):
159
+ rgbt_stem = PatchEmbedGeneric(
160
+ proj_stem=[
161
+ PadIm2Video(pad_type="repeat", ntimes=2),
162
+ nn.Conv3d(
163
+ in_channels=3,
164
+ kernel_size=kernel_size,
165
+ out_channels=vision_embed_dim,
166
+ stride=kernel_size,
167
+ bias=False,
168
+ ),
169
+ ]
170
+ )
171
+ rgbt_preprocessor = RGBDTPreprocessor(
172
+ img_size=[3, video_frames, 224, 224],
173
+ num_cls_tokens=1,
174
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
175
+ rgbt_stem=rgbt_stem,
176
+ depth_stem=None,
177
+ )
178
+
179
+ text_preprocessor = TextPreprocessor(
180
+ context_length=77,
181
+ vocab_size=49408,
182
+ embed_dim=text_embed_dim,
183
+ causal_masking=True,
184
+ )
185
+
186
+ audio_stem = PatchEmbedGeneric(
187
+ proj_stem=[
188
+ nn.Conv2d(
189
+ in_channels=1,
190
+ kernel_size=audio_kernel_size,
191
+ stride=audio_stride,
192
+ out_channels=audio_embed_dim,
193
+ bias=False,
194
+ ),
195
+ ],
196
+ norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
197
+ )
198
+ audio_preprocessor = AudioPreprocessor(
199
+ img_size=[1, audio_num_mel_bins, audio_target_len],
200
+ num_cls_tokens=1,
201
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
202
+ audio_stem=audio_stem,
203
+ )
204
+
205
+ depth_stem = PatchEmbedGeneric(
206
+ [
207
+ nn.Conv2d(
208
+ kernel_size=depth_kernel_size,
209
+ in_channels=1,
210
+ out_channels=depth_embed_dim,
211
+ stride=depth_kernel_size,
212
+ bias=False,
213
+ ),
214
+ ],
215
+ norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
216
+ )
217
+
218
+ depth_preprocessor = RGBDTPreprocessor(
219
+ img_size=[1, 224, 224],
220
+ num_cls_tokens=1,
221
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
222
+ rgbt_stem=None,
223
+ depth_stem=depth_stem,
224
+ )
225
+
226
+ thermal_stem = PatchEmbedGeneric(
227
+ [
228
+ nn.Conv2d(
229
+ kernel_size=thermal_kernel_size,
230
+ in_channels=1,
231
+ out_channels=thermal_embed_dim,
232
+ stride=thermal_kernel_size,
233
+ bias=False,
234
+ ),
235
+ ],
236
+ norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
237
+ )
238
+ thermal_preprocessor = ThermalPreprocessor(
239
+ img_size=[1, 224, 224],
240
+ num_cls_tokens=1,
241
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
242
+ thermal_stem=thermal_stem,
243
+ )
244
+
245
+ imu_stem = PatchEmbedGeneric(
246
+ [
247
+ nn.Linear(
248
+ in_features=48,
249
+ out_features=imu_embed_dim,
250
+ bias=False,
251
+ ),
252
+ ],
253
+ norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
254
+ )
255
+
256
+ imu_preprocessor = IMUPreprocessor(
257
+ img_size=[6, 2000],
258
+ num_cls_tokens=1,
259
+ kernel_size=8,
260
+ embed_dim=imu_embed_dim,
261
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
262
+ imu_stem=imu_stem,
263
+ )
264
+
265
+ modality_preprocessors = {
266
+ ModalityType.VISION: rgbt_preprocessor,
267
+ ModalityType.TEXT: text_preprocessor,
268
+ ModalityType.AUDIO: audio_preprocessor,
269
+ ModalityType.DEPTH: depth_preprocessor,
270
+ ModalityType.THERMAL: thermal_preprocessor,
271
+ ModalityType.IMU: imu_preprocessor,
272
+ }
273
+
274
+ return nn.ModuleDict(modality_preprocessors)
275
+
276
+ def _create_modality_trunks(
277
+ self,
278
+ vision_embed_dim=1024,
279
+ vision_num_blocks=24,
280
+ vision_num_heads=16,
281
+ text_embed_dim=768,
282
+ text_num_blocks=12,
283
+ text_num_heads=12,
284
+ audio_embed_dim=768,
285
+ audio_num_blocks=12,
286
+ audio_num_heads=12,
287
+ audio_drop_path=0.0,
288
+ depth_embed_dim=768,
289
+ depth_num_blocks=12,
290
+ depth_num_heads=12,
291
+ depth_drop_path=0.0,
292
+ thermal_embed_dim=768,
293
+ thermal_num_blocks=12,
294
+ thermal_num_heads=12,
295
+ thermal_drop_path=0.0,
296
+ imu_embed_dim=512,
297
+ imu_num_blocks=6,
298
+ imu_num_heads=8,
299
+ imu_drop_path=0.7,
300
+ ):
301
+ def instantiate_trunk(
302
+ embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
303
+ ):
304
+ return SimpleTransformer(
305
+ embed_dim=embed_dim,
306
+ num_blocks=num_blocks,
307
+ ffn_dropout_rate=0.0,
308
+ drop_path_rate=drop_path,
309
+ attn_target=partial(
310
+ MultiheadAttention,
311
+ embed_dim=embed_dim,
312
+ num_heads=num_heads,
313
+ bias=True,
314
+ add_bias_kv=add_bias_kv,
315
+ ),
316
+ pre_transformer_layer=nn.Sequential(
317
+ nn.LayerNorm(embed_dim, eps=1e-6)
318
+ if pre_transformer_ln
319
+ else nn.Identity(),
320
+ EinOpsRearrange("b l d -> l b d"),
321
+ ),
322
+ post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
323
+ )
324
+
325
+ modality_trunks = {}
326
+ modality_trunks[ModalityType.VISION] = instantiate_trunk(
327
+ vision_embed_dim,
328
+ vision_num_blocks,
329
+ vision_num_heads,
330
+ pre_transformer_ln=True,
331
+ add_bias_kv=False,
332
+ drop_path=0.0,
333
+ )
334
+ modality_trunks[ModalityType.TEXT] = instantiate_trunk(
335
+ text_embed_dim,
336
+ text_num_blocks,
337
+ text_num_heads,
338
+ pre_transformer_ln=False,
339
+ add_bias_kv=False,
340
+ drop_path=0.0,
341
+ )
342
+ modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
343
+ audio_embed_dim,
344
+ audio_num_blocks,
345
+ audio_num_heads,
346
+ pre_transformer_ln=False,
347
+ add_bias_kv=True,
348
+ drop_path=audio_drop_path,
349
+ )
350
+ modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
351
+ depth_embed_dim,
352
+ depth_num_blocks,
353
+ depth_num_heads,
354
+ pre_transformer_ln=False,
355
+ add_bias_kv=True,
356
+ drop_path=depth_drop_path,
357
+ )
358
+ modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
359
+ thermal_embed_dim,
360
+ thermal_num_blocks,
361
+ thermal_num_heads,
362
+ pre_transformer_ln=False,
363
+ add_bias_kv=True,
364
+ drop_path=thermal_drop_path,
365
+ )
366
+ modality_trunks[ModalityType.IMU] = instantiate_trunk(
367
+ imu_embed_dim,
368
+ imu_num_blocks,
369
+ imu_num_heads,
370
+ pre_transformer_ln=False,
371
+ add_bias_kv=True,
372
+ drop_path=imu_drop_path,
373
+ )
374
+
375
+ return nn.ModuleDict(modality_trunks)
376
+
377
+ def _create_modality_heads(
378
+ self,
379
+ out_embed_dim,
380
+ vision_embed_dim,
381
+ text_embed_dim,
382
+ audio_embed_dim,
383
+ depth_embed_dim,
384
+ thermal_embed_dim,
385
+ imu_embed_dim,
386
+ ):
387
+ modality_heads = {}
388
+
389
+ modality_heads[ModalityType.VISION] = nn.Sequential(
390
+ nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
391
+ SelectElement(index=0),
392
+ nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
393
+ )
394
+
395
+ modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
396
+ proj=nn.Sequential(
397
+ nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
398
+ nn.Linear(text_embed_dim, out_embed_dim, bias=False),
399
+ )
400
+ )
401
+
402
+ modality_heads[ModalityType.AUDIO] = nn.Sequential(
403
+ nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
404
+ SelectElement(index=0),
405
+ nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
406
+ )
407
+
408
+ modality_heads[ModalityType.DEPTH] = nn.Sequential(
409
+ nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
410
+ SelectElement(index=0),
411
+ nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
412
+ )
413
+
414
+ modality_heads[ModalityType.THERMAL] = nn.Sequential(
415
+ nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
416
+ SelectElement(index=0),
417
+ nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
418
+ )
419
+
420
+ modality_heads[ModalityType.IMU] = nn.Sequential(
421
+ nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
422
+ SelectElement(index=0),
423
+ nn.Dropout(p=0.5),
424
+ nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
425
+ )
426
+
427
+ return nn.ModuleDict(modality_heads)
428
+
429
+ def _create_modality_postprocessors(self, out_embed_dim):
430
+ modality_postprocessors = {}
431
+
432
+ modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
433
+ modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
434
+ Normalize(dim=-1), LearnableLogitScaling(learnable=True)
435
+ )
436
+ modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
437
+ Normalize(dim=-1),
438
+ LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
439
+ )
440
+ modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
441
+ Normalize(dim=-1),
442
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
443
+ )
444
+ modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
445
+ Normalize(dim=-1),
446
+ LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
447
+ )
448
+ modality_postprocessors[ModalityType.IMU] = nn.Sequential(
449
+ Normalize(dim=-1),
450
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
451
+ )
452
+ return nn.ModuleDict(modality_postprocessors)
453
+
454
+ def forward(self, inputs):
455
+ outputs = {}
456
+ for modality_key, modality_value in inputs.items():
457
+ reduce_list = (
458
+ modality_value.ndim >= 5
459
+ ) # Audio and Video inputs consist of multiple clips
460
+ if reduce_list:
461
+ B, S = modality_value.shape[:2]
462
+ modality_value = modality_value.reshape(
463
+ B * S, *modality_value.shape[2:]
464
+ )
465
+
466
+ if modality_value is not None:
467
+ modality_value = self.modality_preprocessors[modality_key](
468
+ **{modality_key: modality_value}
469
+ )
470
+ trunk_inputs = modality_value["trunk"]
471
+ head_inputs = modality_value["head"]
472
+ modality_value = self.modality_trunks[modality_key](**trunk_inputs)
473
+ modality_value = self.modality_heads[modality_key](
474
+ modality_value, **head_inputs
475
+ )
476
+ if modality_key in [ModalityType.AUDIO]:
477
+ modality_value = self.modality_postprocessors[modality_key][0](
478
+ modality_value
479
+ )
480
+ else:
481
+ modality_value = self.modality_postprocessors[modality_key](
482
+ modality_value
483
+ )
484
+
485
+ if reduce_list:
486
+ modality_value = modality_value.reshape(B, S, -1)
487
+ modality_value = modality_value.mean(dim=1)
488
+
489
+ outputs[modality_key] = modality_value
490
+
491
+ return outputs
492
+
493
+
494
+ def imagebind_huge(pretrained=False, store_path=r'.checkpoints'):
495
+ model = ImageBindModel(
496
+ vision_embed_dim=1280,
497
+ vision_num_blocks=32,
498
+ vision_num_heads=16,
499
+ text_embed_dim=1024,
500
+ text_num_blocks=24,
501
+ text_num_heads=16,
502
+ out_embed_dim=1024,
503
+ audio_drop_path=0.1,
504
+ imu_drop_path=0.7,
505
+ )
506
+
507
+ if pretrained:
508
+ if not os.path.exists("{}/imagebind_huge.pth".format(store_path)):
509
+ print(
510
+ "Downloading imagebind weights to {}/imagebind_huge.pth ...".format(store_path)
511
+ )
512
+ os.makedirs(store_path, exist_ok=True)
513
+ torch.hub.download_url_to_file(
514
+ "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
515
+ "{}/imagebind_huge.pth".format(store_path),
516
+ progress=True,
517
+ )
518
+ print("Loading imagebind weights from {}/imagebind_huge.pth ...".format(store_path))
519
+ model.load_state_dict(torch.load("{}/imagebind_huge.pth".format(store_path)))
520
+
521
+ return model, 1024
multimodal_preprocessors.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import gzip
9
+ import html
10
+ import io
11
+ import math
12
+ from functools import lru_cache
13
+ from typing import Callable, List, Optional
14
+
15
+ import ftfy
16
+
17
+ import numpy as np
18
+ import regex as re
19
+ import torch
20
+ import torch.nn as nn
21
+ from iopath.common.file_io import g_pathmgr
22
+ from timm.models.layers import trunc_normal_
23
+
24
+ from .helpers import cast_if_src_dtype, VerboseNNModule
25
+
26
+
27
+ def get_sinusoid_encoding_table(n_position, d_hid):
28
+ """Sinusoid position encoding table"""
29
+
30
+ # TODO: make it with torch instead of numpy
31
+ def get_position_angle_vec(position):
32
+ return [
33
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
34
+ for hid_j in range(d_hid)
35
+ ]
36
+
37
+ sinusoid_table = np.array(
38
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
39
+ )
40
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
41
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
42
+
43
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
44
+
45
+
46
+ def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
47
+ N = pos_embed.shape[1]
48
+ if N == target_spatial_size:
49
+ return pos_embed
50
+ dim = pos_embed.shape[-1]
51
+ # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
52
+ pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
53
+ pos_embed = nn.functional.interpolate(
54
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
55
+ 0, 3, 1, 2
56
+ ),
57
+ scale_factor=math.sqrt(target_spatial_size / N),
58
+ mode="bicubic",
59
+ )
60
+ if updated:
61
+ pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
62
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
63
+ return pos_embed
64
+
65
+
66
+ def interpolate_pos_encoding(
67
+ npatch_per_img,
68
+ pos_embed,
69
+ patches_layout,
70
+ input_shape=None,
71
+ first_patch_idx=1,
72
+ ):
73
+ assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
74
+ N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
75
+ if npatch_per_img == N:
76
+ return pos_embed
77
+
78
+ assert (
79
+ patches_layout[-1] == patches_layout[-2]
80
+ ), "Interpolation of pos embed not supported for non-square layouts"
81
+
82
+ class_emb = pos_embed[:, :first_patch_idx]
83
+ pos_embed = pos_embed[:, first_patch_idx:]
84
+
85
+ if input_shape is None or patches_layout[0] == 1:
86
+ # simple 2D pos embedding, no temporal component
87
+ pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
88
+ elif patches_layout[0] > 1:
89
+ # pos embed has a temporal component
90
+ assert len(input_shape) == 4, "temporal interpolation not supported"
91
+ # we only support 2D interpolation in this case
92
+ num_frames = patches_layout[0]
93
+ num_spatial_tokens = patches_layout[1] * patches_layout[2]
94
+ pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
95
+ # interpolate embedding for zeroth frame
96
+ pos_embed = interpolate_pos_encoding_2d(
97
+ npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
98
+ )
99
+ else:
100
+ raise ValueError("This type of interpolation isn't implemented")
101
+
102
+ return torch.cat((class_emb, pos_embed), dim=1)
103
+
104
+
105
+ def _get_pos_embedding(
106
+ npatch_per_img,
107
+ pos_embed,
108
+ patches_layout,
109
+ input_shape,
110
+ first_patch_idx=1,
111
+ ):
112
+ pos_embed = interpolate_pos_encoding(
113
+ npatch_per_img,
114
+ pos_embed,
115
+ patches_layout,
116
+ input_shape=input_shape,
117
+ first_patch_idx=first_patch_idx,
118
+ )
119
+ return pos_embed
120
+
121
+
122
+ class PatchEmbedGeneric(nn.Module):
123
+ """
124
+ PatchEmbed from Hydra
125
+ """
126
+
127
+ def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
128
+ super().__init__()
129
+
130
+ if len(proj_stem) > 1:
131
+ self.proj = nn.Sequential(*proj_stem)
132
+ else:
133
+ # Special case to be able to load pre-trained models that were
134
+ # trained with a standard stem
135
+ self.proj = proj_stem[0]
136
+ self.norm_layer = norm_layer
137
+
138
+ def get_patch_layout(self, img_size):
139
+ with torch.no_grad():
140
+ dummy_img = torch.zeros(
141
+ [
142
+ 1,
143
+ ]
144
+ + img_size
145
+ )
146
+ dummy_out = self.proj(dummy_img)
147
+ embed_dim = dummy_out.shape[1]
148
+ patches_layout = tuple(dummy_out.shape[2:])
149
+ num_patches = np.prod(patches_layout)
150
+ return patches_layout, num_patches, embed_dim
151
+
152
+ def forward(self, x):
153
+ x = self.proj(x)
154
+ # B C (T_I_V_A.txt) H W -> B (T_I_V_A.txt) H W C
155
+ x = x.flatten(2).transpose(1, 2)
156
+ if self.norm_layer is not None:
157
+ x = self.norm_layer(x)
158
+ return x
159
+
160
+
161
+ class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
162
+ def __init__(
163
+ self,
164
+ patches_layout: List,
165
+ num_patches: int,
166
+ num_cls_tokens: int,
167
+ embed_dim: int,
168
+ learnable: bool,
169
+ ) -> None:
170
+ super().__init__()
171
+ self.num_cls_tokens = num_cls_tokens
172
+ self.patches_layout = patches_layout
173
+ self.num_patches = num_patches
174
+ self.num_tokens = num_cls_tokens + num_patches
175
+ self.learnable = learnable
176
+ if self.learnable:
177
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
178
+ trunc_normal_(self.pos_embed, std=0.02)
179
+ else:
180
+ self.register_buffer(
181
+ "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
182
+ )
183
+
184
+ def get_pos_embedding(self, vision_input, all_vision_tokens):
185
+ input_shape = vision_input.shape
186
+ pos_embed = _get_pos_embedding(
187
+ all_vision_tokens.size(1) - self.num_cls_tokens,
188
+ pos_embed=self.pos_embed,
189
+ patches_layout=self.patches_layout,
190
+ input_shape=input_shape,
191
+ first_patch_idx=self.num_cls_tokens,
192
+ )
193
+ return pos_embed
194
+
195
+
196
+ class RGBDTPreprocessor(VerboseNNModule):
197
+ def __init__(
198
+ self,
199
+ rgbt_stem: PatchEmbedGeneric,
200
+ depth_stem: PatchEmbedGeneric,
201
+ img_size: List = (3, 224, 224),
202
+ num_cls_tokens: int = 1,
203
+ pos_embed_fn: Callable = None,
204
+ use_type_embed: bool = False,
205
+ init_param_style: str = "openclip",
206
+ ) -> None:
207
+ super().__init__()
208
+ stem = rgbt_stem if rgbt_stem is not None else depth_stem
209
+ (
210
+ self.patches_layout,
211
+ self.num_patches,
212
+ self.embed_dim,
213
+ ) = stem.get_patch_layout(img_size)
214
+ self.rgbt_stem = rgbt_stem
215
+ self.depth_stem = depth_stem
216
+ self.use_pos_embed = pos_embed_fn is not None
217
+ self.use_type_embed = use_type_embed
218
+ self.num_cls_tokens = num_cls_tokens
219
+
220
+ if self.use_pos_embed:
221
+ self.pos_embedding_helper = pos_embed_fn(
222
+ patches_layout=self.patches_layout,
223
+ num_cls_tokens=num_cls_tokens,
224
+ num_patches=self.num_patches,
225
+ embed_dim=self.embed_dim,
226
+ )
227
+ if self.num_cls_tokens > 0:
228
+ self.cls_token = nn.Parameter(
229
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
230
+ )
231
+ if self.use_type_embed:
232
+ self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
233
+
234
+ self.init_parameters(init_param_style)
235
+
236
+ @torch.no_grad()
237
+ def init_parameters(self, init_param_style):
238
+ if init_param_style == "openclip":
239
+ # OpenCLIP style initialization
240
+ scale = self.embed_dim**-0.5
241
+ if self.use_pos_embed:
242
+ nn.init.normal_(self.pos_embedding_helper.pos_embed)
243
+ self.pos_embedding_helper.pos_embed *= scale
244
+
245
+ if self.num_cls_tokens > 0:
246
+ nn.init.normal_(self.cls_token)
247
+ self.cls_token *= scale
248
+ elif init_param_style == "vit":
249
+ self.cls_token.data.fill_(0)
250
+ else:
251
+ raise ValueError(f"Unknown init {init_param_style}")
252
+
253
+ if self.use_type_embed:
254
+ nn.init.normal_(self.type_embed)
255
+
256
+ def tokenize_input_and_cls_pos(self, input, stem, mask):
257
+ # tokens is of shape B x L x D
258
+ tokens = stem(input)
259
+ assert tokens.ndim == 3
260
+ assert tokens.shape[2] == self.embed_dim
261
+ B = tokens.shape[0]
262
+ if self.num_cls_tokens > 0:
263
+ class_tokens = self.cls_token.expand(
264
+ B, -1, -1
265
+ ) # stole class_tokens impl from Phil Wang, thanks
266
+ tokens = torch.cat((class_tokens, tokens), dim=1)
267
+ if self.use_pos_embed:
268
+ pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
269
+ tokens = tokens + pos_embed
270
+ if self.use_type_embed:
271
+ tokens = tokens + self.type_embed.expand(B, -1, -1)
272
+ return tokens
273
+
274
+ def forward(self, vision=None, depth=None, patch_mask=None):
275
+ if patch_mask is not None:
276
+ raise NotImplementedError()
277
+
278
+ if vision is not None:
279
+ vision_tokens = self.tokenize_input_and_cls_pos(
280
+ vision, self.rgbt_stem, patch_mask
281
+ )
282
+
283
+ if depth is not None:
284
+ depth_tokens = self.tokenize_input_and_cls_pos(
285
+ depth, self.depth_stem, patch_mask
286
+ )
287
+
288
+ # aggregate tokens
289
+ if vision is not None and depth is not None:
290
+ final_tokens = vision_tokens + depth_tokens
291
+ else:
292
+ final_tokens = vision_tokens if vision is not None else depth_tokens
293
+ return_dict = {
294
+ "trunk": {
295
+ "tokens": final_tokens,
296
+ },
297
+ "head": {},
298
+ }
299
+ return return_dict
300
+
301
+
302
+ class AudioPreprocessor(RGBDTPreprocessor):
303
+ def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
304
+ super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
305
+
306
+ def forward(self, audio=None):
307
+ return super().forward(vision=audio)
308
+
309
+
310
+ class ThermalPreprocessor(RGBDTPreprocessor):
311
+ def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
312
+ super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
313
+
314
+ def forward(self, thermal=None):
315
+ return super().forward(vision=thermal)
316
+
317
+
318
+ def build_causal_attention_mask(context_length):
319
+ # lazily create causal attention mask, with full attention between the vision tokens
320
+ # pytorch uses additive attention mask; fill with -inf
321
+ mask = torch.empty(context_length, context_length, requires_grad=False)
322
+ mask.fill_(float("-inf"))
323
+ mask.triu_(1) # zero out the lower diagonal
324
+ return mask
325
+
326
+
327
+ class TextPreprocessor(VerboseNNModule):
328
+ def __init__(
329
+ self,
330
+ vocab_size: int,
331
+ context_length: int,
332
+ embed_dim: int,
333
+ causal_masking: bool,
334
+ supply_seq_len_to_head: bool = True,
335
+ num_cls_tokens: int = 0,
336
+ init_param_style: str = "openclip",
337
+ ) -> None:
338
+ super().__init__()
339
+ self.vocab_size = vocab_size
340
+ self.context_length = context_length
341
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
342
+ self.pos_embed = nn.Parameter(
343
+ torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
344
+ )
345
+ self.causal_masking = causal_masking
346
+ if self.causal_masking:
347
+ mask = build_causal_attention_mask(self.context_length)
348
+ # register the mask as a buffer so it can be moved to the right device
349
+ self.register_buffer("mask", mask)
350
+
351
+ self.supply_seq_len_to_head = supply_seq_len_to_head
352
+ self.num_cls_tokens = num_cls_tokens
353
+ self.embed_dim = embed_dim
354
+ if num_cls_tokens > 0:
355
+ assert self.causal_masking is False, "Masking + CLS token isn't implemented"
356
+ self.cls_token = nn.Parameter(
357
+ torch.zeros(1, self.num_cls_tokens, embed_dim)
358
+ )
359
+
360
+ self.init_parameters(init_param_style)
361
+
362
+ @torch.no_grad()
363
+ def init_parameters(self, init_param_style="openclip"):
364
+ # OpenCLIP style initialization
365
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
366
+ nn.init.normal_(self.pos_embed, std=0.01)
367
+
368
+ if init_param_style == "openclip":
369
+ # OpenCLIP style initialization
370
+ scale = self.embed_dim**-0.5
371
+ if self.num_cls_tokens > 0:
372
+ nn.init.normal_(self.cls_token)
373
+ self.cls_token *= scale
374
+ elif init_param_style == "vit":
375
+ self.cls_token.data.fill_(0)
376
+ else:
377
+ raise ValueError(f"Unknown init {init_param_style}")
378
+
379
+ def forward(self, text):
380
+ # text tokens are of shape B x L x D
381
+ text_tokens = self.token_embedding(text)
382
+ # concat CLS tokens if any
383
+ if self.num_cls_tokens > 0:
384
+ B = text_tokens.shape[0]
385
+ class_tokens = self.cls_token.expand(
386
+ B, -1, -1
387
+ ) # stole class_tokens impl from Phil Wang, thanks
388
+ text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
389
+ text_tokens = text_tokens + self.pos_embed
390
+ return_dict = {
391
+ "trunk": {
392
+ "tokens": text_tokens,
393
+ },
394
+ "head": {},
395
+ }
396
+ # Compute sequence length after adding CLS tokens
397
+ if self.supply_seq_len_to_head:
398
+ text_lengths = text.argmax(dim=-1)
399
+ return_dict["head"] = {
400
+ "seq_len": text_lengths,
401
+ }
402
+ if self.causal_masking:
403
+ return_dict["trunk"].update({"attn_mask": self.mask})
404
+ return return_dict
405
+
406
+
407
+ class Im2Video(nn.Module):
408
+ """Convert an image into a trivial video."""
409
+
410
+ def __init__(self, time_dim=2):
411
+ super().__init__()
412
+ self.time_dim = time_dim
413
+
414
+ def forward(self, x):
415
+ if x.ndim == 4:
416
+ # B, C, H, W -> B, C, T_I_V_A.txt, H, W
417
+ return x.unsqueeze(self.time_dim)
418
+ elif x.ndim == 5:
419
+ return x
420
+ else:
421
+ raise ValueError(f"Dimension incorrect {x.shape}")
422
+
423
+
424
+ class PadIm2Video(Im2Video):
425
+ def __init__(self, ntimes, pad_type, time_dim=2):
426
+ super().__init__(time_dim=time_dim)
427
+ assert ntimes > 0
428
+ assert pad_type in ["zero", "repeat"]
429
+ self.ntimes = ntimes
430
+ self.pad_type = pad_type
431
+
432
+ def forward(self, x):
433
+ x = super().forward(x)
434
+ if x.shape[self.time_dim] == 1:
435
+ if self.pad_type == "repeat":
436
+ new_shape = [1] * len(x.shape)
437
+ new_shape[self.time_dim] = self.ntimes
438
+ x = x.repeat(new_shape)
439
+ elif self.pad_type == "zero":
440
+ padarg = [0, 0] * len(x.shape)
441
+ padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
442
+ x = nn.functional.pad(x, padarg)
443
+ return x
444
+
445
+
446
+ # Modified from github.com/openai/CLIP
447
+ @lru_cache()
448
+ def bytes_to_unicode():
449
+ """
450
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
451
+ The reversible bpe codes work on unicode strings.
452
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
453
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
454
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
455
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
456
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
457
+ """
458
+ bs = (
459
+ list(range(ord("!"), ord("~") + 1))
460
+ + list(range(ord("¡"), ord("¬") + 1))
461
+ + list(range(ord("®"), ord("ÿ") + 1))
462
+ )
463
+ cs = bs[:]
464
+ n = 0
465
+ for b in range(2**8):
466
+ if b not in bs:
467
+ bs.append(b)
468
+ cs.append(2**8 + n)
469
+ n += 1
470
+ cs = [chr(n) for n in cs]
471
+ return dict(zip(bs, cs))
472
+
473
+
474
+ def get_pairs(word):
475
+ """Return set of symbol pairs in a word.
476
+ Word is represented as tuple of symbols (symbols being variable-length strings).
477
+ """
478
+ pairs = set()
479
+ prev_char = word[0]
480
+ for char in word[1:]:
481
+ pairs.add((prev_char, char))
482
+ prev_char = char
483
+ return pairs
484
+
485
+
486
+ def basic_clean(text):
487
+ text = ftfy.fix_text(text)
488
+ text = html.unescape(html.unescape(text))
489
+ return text.strip()
490
+
491
+
492
+ def whitespace_clean(text):
493
+ text = re.sub(r"\s+", " ", text)
494
+ text = text.strip()
495
+ return text
496
+
497
+
498
+ class SimpleTokenizer(object):
499
+ def __init__(self, bpe_path: str, context_length=77):
500
+ self.byte_encoder = bytes_to_unicode()
501
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
502
+
503
+ with g_pathmgr.open(bpe_path, "rb") as fh:
504
+ bpe_bytes = io.BytesIO(fh.read())
505
+ merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
506
+ merges = merges[1 : 49152 - 256 - 2 + 1]
507
+ merges = [tuple(merge.split()) for merge in merges]
508
+ vocab = list(bytes_to_unicode().values())
509
+ vocab = vocab + [v + "</w>" for v in vocab]
510
+ for merge in merges:
511
+ vocab.append("".join(merge))
512
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
513
+ self.encoder = dict(zip(vocab, range(len(vocab))))
514
+ self.decoder = {v: k for k, v in self.encoder.items()}
515
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
516
+ self.cache = {
517
+ "<|startoftext|>": "<|startoftext|>",
518
+ "<|endoftext|>": "<|endoftext|>",
519
+ }
520
+ self.pat = re.compile(
521
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
522
+ re.IGNORECASE,
523
+ )
524
+ self.context_length = context_length
525
+
526
+ def bpe(self, token):
527
+ if token in self.cache:
528
+ return self.cache[token]
529
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
530
+ pairs = get_pairs(word)
531
+
532
+ if not pairs:
533
+ return token + "</w>"
534
+
535
+ while True:
536
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
537
+ if bigram not in self.bpe_ranks:
538
+ break
539
+ first, second = bigram
540
+ new_word = []
541
+ i = 0
542
+ while i < len(word):
543
+ try:
544
+ j = word.index(first, i)
545
+ new_word.extend(word[i:j])
546
+ i = j
547
+ except:
548
+ new_word.extend(word[i:])
549
+ break
550
+
551
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
552
+ new_word.append(first + second)
553
+ i += 2
554
+ else:
555
+ new_word.append(word[i])
556
+ i += 1
557
+ new_word = tuple(new_word)
558
+ word = new_word
559
+ if len(word) == 1:
560
+ break
561
+ else:
562
+ pairs = get_pairs(word)
563
+ word = " ".join(word)
564
+ self.cache[token] = word
565
+ return word
566
+
567
+ def encode(self, text):
568
+ bpe_tokens = []
569
+ text = whitespace_clean(basic_clean(text)).lower()
570
+ for token in re.findall(self.pat, text):
571
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
572
+ bpe_tokens.extend(
573
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
574
+ )
575
+ return bpe_tokens
576
+
577
+ def decode(self, tokens):
578
+ text = "".join([self.decoder[token] for token in tokens])
579
+ text = (
580
+ bytearray([self.byte_decoder[c] for c in text])
581
+ .decode("utf-8", errors="replace")
582
+ .replace("</w>", " ")
583
+ )
584
+ return text
585
+
586
+ def __call__(self, texts, context_length=None):
587
+ if not context_length:
588
+ context_length = self.context_length
589
+
590
+ if isinstance(texts, str):
591
+ texts = [texts]
592
+
593
+ sot_token = self.encoder["<|startoftext|>"]
594
+ eot_token = self.encoder["<|endoftext|>"]
595
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
596
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
597
+
598
+ for i, tokens in enumerate(all_tokens):
599
+ tokens = tokens[:context_length]
600
+ result[i, : len(tokens)] = torch.tensor(tokens)
601
+
602
+ if len(result) == 1:
603
+ return result[0]
604
+ return result
605
+
606
+
607
+ class IMUPreprocessor(VerboseNNModule):
608
+ def __init__(
609
+ self,
610
+ kernel_size: int,
611
+ imu_stem: PatchEmbedGeneric,
612
+ embed_dim: int,
613
+ img_size: List = (6, 2000),
614
+ num_cls_tokens: int = 1,
615
+ pos_embed_fn: Callable = None,
616
+ init_param_style: str = "openclip",
617
+ ) -> None:
618
+ super().__init__()
619
+ stem = imu_stem
620
+ self.imu_stem = imu_stem
621
+ self.embed_dim = embed_dim
622
+ self.use_pos_embed = pos_embed_fn is not None
623
+ self.num_cls_tokens = num_cls_tokens
624
+ self.kernel_size = kernel_size
625
+ self.pos_embed = nn.Parameter(
626
+ torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
627
+ )
628
+
629
+ if self.num_cls_tokens > 0:
630
+ self.cls_token = nn.Parameter(
631
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
632
+ )
633
+
634
+ self.init_parameters(init_param_style)
635
+
636
+ @torch.no_grad()
637
+ def init_parameters(self, init_param_style):
638
+ nn.init.normal_(self.pos_embed, std=0.01)
639
+
640
+ if init_param_style == "openclip":
641
+ # OpenCLIP style initialization
642
+ scale = self.embed_dim**-0.5
643
+
644
+ if self.num_cls_tokens > 0:
645
+ nn.init.normal_(self.cls_token)
646
+ self.cls_token *= scale
647
+ elif init_param_style == "vit":
648
+ self.cls_token.data.fill_(0)
649
+ else:
650
+ raise ValueError(f"Unknown init {init_param_style}")
651
+
652
+ def tokenize_input_and_cls_pos(self, input, stem):
653
+ # tokens is of shape B x L x D
654
+ tokens = stem.norm_layer(stem.proj(input))
655
+ assert tokens.ndim == 3
656
+ assert tokens.shape[2] == self.embed_dim
657
+ B = tokens.shape[0]
658
+ if self.num_cls_tokens > 0:
659
+ class_tokens = self.cls_token.expand(
660
+ B, -1, -1
661
+ ) # stole class_tokens impl from Phil Wang, thanks
662
+ tokens = torch.cat((class_tokens, tokens), dim=1)
663
+ if self.use_pos_embed:
664
+ tokens = tokens + self.pos_embed
665
+ return tokens
666
+
667
+ def forward(self, imu):
668
+ # Patchify
669
+ imu = imu.unfold(
670
+ -1,
671
+ self.kernel_size,
672
+ self.kernel_size,
673
+ ).permute(0, 2, 1, 3)
674
+ imu = imu.reshape(imu.size(0), imu.size(1), -1)
675
+
676
+ imu_tokens = self.tokenize_input_and_cls_pos(
677
+ imu,
678
+ self.imu_stem,
679
+ )
680
+
681
+ return_dict = {
682
+ "trunk": {
683
+ "tokens": imu_tokens,
684
+ },
685
+ "head": {},
686
+ }
687
+ return return_dict
processor_mm.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torchaudio
13
+ import logging
14
+
15
+ from PIL import Image
16
+ from pytorchvideo import transforms as pv_transforms
17
+ from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
18
+ from pytorchvideo.data.encoded_video import EncodedVideo
19
+
20
+ from torchvision import transforms
21
+ from torchvision.transforms._transforms_video import NormalizeVideo
22
+
23
+ DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
24
+
25
+ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
26
+ # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
27
+ waveform -= waveform.mean()
28
+ fbank = torchaudio.compliance.kaldi.fbank(
29
+ waveform,
30
+ htk_compat=True,
31
+ sample_frequency=sample_rate,
32
+ use_energy=False,
33
+ window_type="hanning",
34
+ num_mel_bins=num_mel_bins,
35
+ dither=0.0,
36
+ frame_length=25,
37
+ frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
38
+ )
39
+ # Convert to [mel_bins, num_frames] shape
40
+ fbank = fbank.transpose(0, 1)
41
+ # Pad to target_length
42
+ n_frames = fbank.size(1)
43
+ p = target_length - n_frames
44
+ # if p is too large (say >20%), flash a warning
45
+ if abs(p) / n_frames > 0.2:
46
+ logging.warning(
47
+ "Large gap between audio n_frames(%d) and "
48
+ "target_length (%d). Is the audio_target_length "
49
+ "setting correct?",
50
+ n_frames,
51
+ target_length,
52
+ )
53
+ # cut and pad
54
+ if p > 0:
55
+ fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
56
+ elif p < 0:
57
+ fbank = fbank[:, 0:target_length]
58
+ # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
59
+ # channel image
60
+ fbank = fbank.unsqueeze(0)
61
+ return fbank
62
+
63
+ def load_and_transform_image_data(image_path):
64
+ data_transform = transforms.Compose(
65
+ [
66
+ transforms.Resize(
67
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
68
+ ),
69
+ transforms.CenterCrop(224),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(
72
+ mean=(0.48145466, 0.4578275, 0.40821073),
73
+ std=(0.26862954, 0.26130258, 0.27577711),
74
+ ),
75
+ ]
76
+ )
77
+ if isinstance(image_path, Image.Image):
78
+ image = image_path
79
+ else:
80
+ with open(image_path, "rb") as fopen:
81
+ image = Image.open(fopen).convert("RGB")
82
+ return data_transform(image)
83
+
84
+ def load_and_transform_audio_data(
85
+ audio_path,
86
+ num_mel_bins=128,
87
+ target_length=204,
88
+ sample_rate=16000,
89
+ clip_duration=2,
90
+ clips_per_video=3,
91
+ mean=-4.268,
92
+ std=9.138,
93
+ ):
94
+ if audio_path is None:
95
+ return None
96
+
97
+ clip_sampler = ConstantClipsPerVideoSampler(
98
+ clip_duration=clip_duration, clips_per_video=clips_per_video
99
+ )
100
+
101
+ waveform, sr = torchaudio.load(audio_path)
102
+ if sample_rate != sr:
103
+ waveform = torchaudio.functional.resample(
104
+ waveform, orig_freq=sr, new_freq=sample_rate
105
+ )
106
+ all_clips_timepoints = get_clip_timepoints(
107
+ clip_sampler, waveform.size(1) / sample_rate
108
+ )
109
+ all_clips = []
110
+ for clip_timepoints in all_clips_timepoints:
111
+ waveform_clip = waveform[
112
+ :,
113
+ int(clip_timepoints[0] * sample_rate): int(
114
+ clip_timepoints[1] * sample_rate
115
+ ),
116
+ ]
117
+ waveform_melspec = waveform2melspec(
118
+ waveform_clip, sample_rate, num_mel_bins, target_length
119
+ )
120
+ all_clips.append(waveform_melspec)
121
+
122
+ normalize = transforms.Normalize(mean=mean, std=std)
123
+ all_clips = [normalize(ac) for ac in all_clips]
124
+ return torch.stack(all_clips, dim=0)
125
+
126
+
127
+ def get_clip_timepoints(clip_sampler, duration):
128
+ # Read out all clips in this video
129
+ all_clips_timepoints = []
130
+ is_last_clip = False
131
+ end = 0.0
132
+ while not is_last_clip:
133
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
134
+ all_clips_timepoints.append((start, end))
135
+ return all_clips_timepoints
136
+
137
+
138
+ def crop_boxes(boxes, x_offset, y_offset):
139
+ """
140
+ Perform crop on the bounding boxes given the offsets.
141
+ Args:
142
+ boxes (ndarray or None): bounding boxes to perform crop. The dimension
143
+ is `num boxes` x 4.
144
+ x_offset (int): cropping offset in the x axis.
145
+ y_offset (int): cropping offset in the y axis.
146
+ Returns:
147
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
148
+ `num boxes` x 4.
149
+ """
150
+ cropped_boxes = boxes.copy()
151
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
152
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
153
+
154
+ return cropped_boxes
155
+
156
+
157
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
158
+ """
159
+ Perform uniform spatial sampling on the images and corresponding boxes.
160
+ Args:
161
+ images (tensor): images to perform uniform crop. The dimension is
162
+ `num frames` x `channel` x `height` x `width`.
163
+ size (int): size of height and weight to crop the images.
164
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
165
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
166
+ crop if height is larger than width.
167
+ boxes (ndarray or None): optional. Corresponding boxes to images.
168
+ Dimension is `num boxes` x 4.
169
+ scale_size (int): optinal. If not None, resize the images to scale_size before
170
+ performing any crop.
171
+ Returns:
172
+ cropped (tensor): images with dimension of
173
+ `num frames` x `channel` x `size` x `size`.
174
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
175
+ `num boxes` x 4.
176
+ """
177
+ assert spatial_idx in [0, 1, 2]
178
+ ndim = len(images.shape)
179
+ if ndim == 3:
180
+ images = images.unsqueeze(0)
181
+ height = images.shape[2]
182
+ width = images.shape[3]
183
+
184
+ if scale_size is not None:
185
+ if width <= height:
186
+ width, height = scale_size, int(height / width * scale_size)
187
+ else:
188
+ width, height = int(width / height * scale_size), scale_size
189
+ images = torch.nn.functional.interpolate(
190
+ images,
191
+ size=(height, width),
192
+ mode="bilinear",
193
+ align_corners=False,
194
+ )
195
+
196
+ y_offset = int(math.ceil((height - size) / 2))
197
+ x_offset = int(math.ceil((width - size) / 2))
198
+
199
+ if height > width:
200
+ if spatial_idx == 0:
201
+ y_offset = 0
202
+ elif spatial_idx == 2:
203
+ y_offset = height - size
204
+ else:
205
+ if spatial_idx == 0:
206
+ x_offset = 0
207
+ elif spatial_idx == 2:
208
+ x_offset = width - size
209
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
210
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
211
+ if ndim == 3:
212
+ cropped = cropped.squeeze(0)
213
+ return cropped, cropped_boxes
214
+
215
+
216
+ class SpatialCrop(nn.Module):
217
+ """
218
+ Convert the video into 3 smaller clips spatially. Must be used after the
219
+ temporal crops to get spatial crops, and should be used with
220
+ -2 in the spatial crop at the slowfast augmentation stage (so full
221
+ frames are passed in here). Will return a larger list with the
222
+ 3x spatial crops as well.
223
+ """
224
+
225
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
226
+ super().__init__()
227
+ self.crop_size = crop_size
228
+ if num_crops == 3:
229
+ self.crops_to_ext = [0, 1, 2]
230
+ self.flipped_crops_to_ext = []
231
+ elif num_crops == 1:
232
+ self.crops_to_ext = [1]
233
+ self.flipped_crops_to_ext = []
234
+ else:
235
+ raise NotImplementedError("Nothing else supported yet")
236
+
237
+ def forward(self, videos):
238
+ """
239
+ Args:
240
+ videos: A list of C, T_I_V_A.txt, H, W videos.
241
+ Returns:
242
+ videos: A list with 3x the number of elements. Each video converted
243
+ to C, T_I_V_A.txt, H', W' by spatial cropping.
244
+ """
245
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
246
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T_I_V_A.txt,H,W)"
247
+ res = []
248
+ for video in videos:
249
+ for spatial_idx in self.crops_to_ext:
250
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
251
+ if not self.flipped_crops_to_ext:
252
+ continue
253
+ flipped_video = transforms.functional.hflip(video)
254
+ for spatial_idx in self.flipped_crops_to_ext:
255
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
256
+ return res
257
+
258
+
259
+ def load_and_transform_video_data(
260
+ video_path,
261
+ clip_duration=2,
262
+ clips_per_video=5,
263
+ sample_rate=16000,
264
+ ):
265
+ if video_path is None:
266
+ return None
267
+
268
+ video_transform = transforms.Compose(
269
+ [
270
+ pv_transforms.ShortSideScale(224),
271
+ NormalizeVideo(
272
+ mean=(0.48145466, 0.4578275, 0.40821073),
273
+ std=(0.26862954, 0.26130258, 0.27577711),
274
+ ),
275
+ ]
276
+ )
277
+
278
+ clip_sampler = ConstantClipsPerVideoSampler(
279
+ clip_duration=clip_duration, clips_per_video=clips_per_video
280
+ )
281
+ frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
282
+
283
+ video = EncodedVideo.from_path(
284
+ video_path,
285
+ decoder="decord",
286
+ decode_audio=False,
287
+ # **{"sample_rate": sample_rate},
288
+ )
289
+
290
+ all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
291
+
292
+ all_video = []
293
+ for clip_timepoints in all_clips_timepoints:
294
+ # Read the clip, get frames
295
+ clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
296
+ if clip is None:
297
+ raise ValueError("No clip found")
298
+ video_clip = frame_sampler(clip["video"])
299
+ video_clip = video_clip / 255.0 # since this is float, need 0-1
300
+
301
+ all_video.append(video_clip)
302
+
303
+ all_video = [video_transform(clip) for clip in all_video]
304
+ all_video = SpatialCrop(224, num_crops=3)(all_video)
305
+
306
+ return torch.stack(all_video, dim=0)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0330356010a6d02f34ff0c7af8f3e1e749bc61cd0408450b305e84e755c7dfd6
3
- size 18588600241
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76f23f3f18d7a2b44aa3f35395777426ff82293c470487ae6d58d123e750c9e9
3
+ size 16186598334
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ftfy
2
+ timm
3
+ regex
4
+ einops
5
+ fvcore
6
+ decord
7
+ torchaudio
8
+ torchvision
9
+ pytorchvideo
transformer.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # Code modified from
9
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ;
10
+ # https://github.com/facebookresearch/deit/blob/main/models.py
11
+ # and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py
12
+
13
+
14
+ import copy
15
+ import fnmatch
16
+ import logging
17
+ from functools import partial
18
+ from typing import Callable, List
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.utils.checkpoint as checkpoint
23
+
24
+ from timm.models.layers import DropPath, trunc_normal_
25
+
26
+
27
+ class Attention(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim,
31
+ num_heads=8,
32
+ qkv_bias=False,
33
+ qk_scale=None,
34
+ attn_drop=0.0,
35
+ proj_drop=0.0,
36
+ ):
37
+ super().__init__()
38
+ self.num_heads = num_heads
39
+ head_dim = dim // num_heads
40
+ # NOTE scale factor was wrong in my original version,
41
+ # can set manually to be compat with prev weights
42
+ self.scale = qk_scale or head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x):
50
+ B, N, C = x.shape
51
+ qkv = (
52
+ self.qkv(x)
53
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
54
+ .permute(2, 0, 3, 1, 4)
55
+ )
56
+ q, k, v = (
57
+ qkv[0],
58
+ qkv[1],
59
+ qkv[2],
60
+ ) # make torchscript happy (cannot use tensor as tuple)
61
+
62
+ attn = (q @ k.transpose(-2, -1)) * self.scale
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class Mlp(nn.Module):
73
+ def __init__(
74
+ self,
75
+ in_features,
76
+ hidden_features=None,
77
+ out_features=None,
78
+ act_layer=nn.GELU,
79
+ drop=0.0,
80
+ ):
81
+ super().__init__()
82
+ out_features = out_features or in_features
83
+ hidden_features = hidden_features or in_features
84
+ self.fc1 = nn.Linear(in_features, hidden_features)
85
+ self.act = act_layer()
86
+ self.fc2 = nn.Linear(hidden_features, out_features)
87
+ self.drop = nn.Dropout(drop)
88
+
89
+ def forward(self, x):
90
+ x = self.fc1(x)
91
+ x = self.act(x)
92
+ x = self.drop(x)
93
+ x = self.fc2(x)
94
+ x = self.drop(x)
95
+ return x
96
+
97
+
98
+ class MultiheadAttention(nn.MultiheadAttention):
99
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
100
+ return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
101
+
102
+
103
+ class ViTAttention(Attention):
104
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
105
+ assert attn_mask is None
106
+ return super().forward(x)
107
+
108
+
109
+ class BlockWithMasking(nn.Module):
110
+ def __init__(
111
+ self,
112
+ dim: int,
113
+ attn_target: Callable,
114
+ mlp_ratio: int = 4,
115
+ act_layer: Callable = nn.GELU,
116
+ norm_layer: Callable = nn.LayerNorm,
117
+ ffn_dropout_rate: float = 0.0,
118
+ drop_path: float = 0.0,
119
+ layer_scale_type: str = None,
120
+ layer_scale_init_value: float = 1e-4,
121
+ ):
122
+ super().__init__()
123
+
124
+ assert not isinstance(
125
+ attn_target, nn.Module
126
+ ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
127
+ self.attn = attn_target()
128
+ if drop_path > 0.0:
129
+ self.drop_path = DropPath(drop_path)
130
+ else:
131
+ self.drop_path = nn.Identity()
132
+ self.norm_1 = norm_layer(dim)
133
+ mlp_hidden_dim = int(mlp_ratio * dim)
134
+ self.mlp = Mlp(
135
+ in_features=dim,
136
+ hidden_features=mlp_hidden_dim,
137
+ act_layer=act_layer,
138
+ drop=ffn_dropout_rate,
139
+ )
140
+ self.norm_2 = norm_layer(dim)
141
+ self.layer_scale_type = layer_scale_type
142
+ if self.layer_scale_type is not None:
143
+ assert self.layer_scale_type in [
144
+ "per_channel",
145
+ "scalar",
146
+ ], f"Found Layer scale type {self.layer_scale_type}"
147
+ if self.layer_scale_type == "per_channel":
148
+ # one gamma value per channel
149
+ gamma_shape = [1, 1, dim]
150
+ elif self.layer_scale_type == "scalar":
151
+ # single gamma value for all channels
152
+ gamma_shape = [1, 1, 1]
153
+ # two gammas: for each part of the fwd in the encoder
154
+ self.layer_scale_gamma1 = nn.Parameter(
155
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
156
+ requires_grad=True,
157
+ )
158
+ self.layer_scale_gamma2 = nn.Parameter(
159
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
160
+ requires_grad=True,
161
+ )
162
+
163
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
164
+ if self.layer_scale_type is None:
165
+ x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
166
+ x = x + self.drop_path(self.mlp(self.norm_2(x)))
167
+ else:
168
+ x = (
169
+ x
170
+ + self.drop_path(self.attn(self.norm_1(x), attn_mask))
171
+ * self.layer_scale_gamma1
172
+ )
173
+ x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
174
+ return x
175
+
176
+
177
+ _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
178
+
179
+
180
+ class SimpleTransformer(nn.Module):
181
+ def __init__(
182
+ self,
183
+ attn_target: Callable,
184
+ embed_dim: int,
185
+ num_blocks: int,
186
+ block: Callable = BlockWithMasking,
187
+ pre_transformer_layer: Callable = None,
188
+ post_transformer_layer: Callable = None,
189
+ drop_path_rate: float = 0.0,
190
+ drop_path_type: str = "progressive",
191
+ norm_layer: Callable = _LAYER_NORM,
192
+ mlp_ratio: int = 4,
193
+ ffn_dropout_rate: float = 0.0,
194
+ layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar"
195
+ layer_scale_init_value: float = 1e-4, # from cait; float
196
+ weight_init_style: str = "jax", # possible values jax or pytorch
197
+ ):
198
+ """
199
+ Simple Transformer with the following features
200
+ 1. Supports masked attention
201
+ 2. Supports DropPath
202
+ 3. Supports LayerScale
203
+ 4. Supports Dropout in Attention and FFN
204
+ 5. Makes few assumptions about the input except that it is a Tensor
205
+ """
206
+ super().__init__()
207
+ self.pre_transformer_layer = pre_transformer_layer
208
+ if drop_path_type == "progressive":
209
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
210
+ elif drop_path_type == "uniform":
211
+ dpr = [drop_path_rate for i in range(num_blocks)]
212
+ else:
213
+ raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
214
+
215
+ self.blocks = nn.Sequential(
216
+ *[
217
+ block(
218
+ dim=embed_dim,
219
+ attn_target=attn_target,
220
+ mlp_ratio=mlp_ratio,
221
+ ffn_dropout_rate=ffn_dropout_rate,
222
+ drop_path=dpr[i],
223
+ norm_layer=norm_layer,
224
+ layer_scale_type=layer_scale_type,
225
+ layer_scale_init_value=layer_scale_init_value,
226
+ )
227
+ for i in range(num_blocks)
228
+ ]
229
+ )
230
+ self.post_transformer_layer = post_transformer_layer
231
+ self.weight_init_style = weight_init_style
232
+ self.apply(self._init_weights)
233
+
234
+ def _init_weights(self, m):
235
+ if isinstance(m, nn.Linear):
236
+ if self.weight_init_style == "jax":
237
+ # Based on MAE and official Jax ViT implementation
238
+ torch.nn.init.xavier_uniform_(m.weight)
239
+ elif self.weight_init_style == "pytorch":
240
+ # PyTorch ViT uses trunc_normal_
241
+ trunc_normal_(m.weight, std=0.02)
242
+
243
+ if m.bias is not None:
244
+ nn.init.constant_(m.bias, 0)
245
+ elif isinstance(m, (nn.LayerNorm)):
246
+ nn.init.constant_(m.bias, 0)
247
+ nn.init.constant_(m.weight, 1.0)
248
+
249
+ def forward(
250
+ self,
251
+ tokens: torch.Tensor,
252
+ attn_mask: torch.Tensor = None,
253
+ use_checkpoint: bool = False,
254
+ checkpoint_every_n: int = 1,
255
+ checkpoint_blk_ids: List[int] = None,
256
+ ):
257
+ """
258
+ Inputs
259
+ - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
260
+ - attn: mask of shape L x L
261
+
262
+ Output
263
+ - x: data of shape N x L x D (or L x N x D depending on the attention implementation)
264
+ """
265
+ if self.pre_transformer_layer:
266
+ tokens = self.pre_transformer_layer(tokens)
267
+ if use_checkpoint and checkpoint_blk_ids is None:
268
+ checkpoint_blk_ids = [
269
+ blk_id
270
+ for blk_id in range(len(self.blocks))
271
+ if blk_id % checkpoint_every_n == 0
272
+ ]
273
+ if checkpoint_blk_ids:
274
+ checkpoint_blk_ids = set(checkpoint_blk_ids)
275
+ for blk_id, blk in enumerate(self.blocks):
276
+ if use_checkpoint and blk_id in checkpoint_blk_ids:
277
+ tokens = checkpoint.checkpoint(
278
+ blk, tokens, attn_mask, use_reentrant=False
279
+ )
280
+ else:
281
+ tokens = blk(tokens, attn_mask=attn_mask)
282
+ if self.post_transformer_layer:
283
+ tokens = self.post_transformer_layer(tokens)
284
+ return tokens