import math from dataclasses import dataclass from typing import Optional, TypeVar, Tuple, Any T = TypeVar('T', bound='Module') from einops import rearrange, repeat import numpy as np import torch import torch.nn as nn from transformers.utils import ModelOutput from transformers.modeling_outputs import BaseModelOutputWithPooling from diffusers.models.modeling_utils import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config from imagebind.models import imagebind_model from imagebind.models.imagebind_model import ModalityType @dataclass class ImageBindSegmaskAudioEncoderOutput(ModelOutput): """ Args: text_embeds(`torch.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`]. image_embeds(`torch.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`]. text_model_output(`BaseModelOutputWithPooling`): The output of the [`CLIPTextModel`]. vision_model_output(`BaseModelOutputWithPooling`): The output of the [`CLIPVisionModel`]. """ audio_embeds: torch.Tensor = None audio_encodings: torch.Tensor = None audio_segment_masks: torch.BoolTensor = None def to_tuple(self) -> Tuple[Any]: return tuple(self[k] for k in self.keys()) class ImageBindSegmaskAudioEncoder(ModelMixin, ConfigMixin): @register_to_config def __init__(self, n_segment=4, pretrained_model_name="imagebind-huge" ): super().__init__() self.n_segment = n_segment self.pretrained_model_name = pretrained_model_name if pretrained_model_name == "imagebind-huge": pretrained_model = imagebind_model.imagebind_huge(pretrained=True) self.preprocessor = pretrained_model.modality_preprocessors[ModalityType.AUDIO] self.trunk = pretrained_model.modality_trunks[ModalityType.AUDIO] self.head = pretrained_model.modality_heads[ModalityType.AUDIO] self.postprocessor = pretrained_model.modality_postprocessors[ModalityType.AUDIO] self.final_layer_norm = nn.LayerNorm(normalized_shape=768, eps=1e-6) def _auto_split(self, n, n_chunk): ''' automatically split into chunks with n_ele no differ by 1 if n is not dividible by n_chunk, extra one's will be added to the heading chunks ''' chunk_size = int(math.ceil(n / n_chunk)) assert chunk_size >= 1, chunk_size chunk_start_indices = np.round(np.linspace(0, n - chunk_size, n_chunk, endpoint=True)).astype(np.int32) mask = torch.zeros(n_chunk, n).bool() for chunk_index, chunk_start_index in enumerate(chunk_start_indices): mask[chunk_index, chunk_start_index:chunk_start_index + chunk_size] = 1 mask = mask.contiguous() assert mask.long().sum() == chunk_size * n_chunk, mask.long().sum() return mask def forward(self, input_features: Optional[torch.Tensor], normalize: bool = False, return_dict: Optional[bool] = None): n_segment = self.n_segment # 1. reshape to imagebind input batchsize = input_features.size(0) # 2. patchify images and add positional embedding and audio_inputs = self.preprocessor(input_features) trunk_inputs = audio_inputs["trunk"] # dict of {"tokens": (b, l, d)} # 3. get audio encoder output audio_encodings = self.trunk(**trunk_inputs) # w/o layer norm (b, seq_len, c) head_inputs = audio_inputs["head"] cls_embeds = self.head(audio_encodings, **head_inputs) # normalize and logit scaling if normalize: cls_embeds = self.postprocessor(cls_embeds) # (b, c) audio_encodings = self.final_layer_norm(audio_encodings) # 4. get segment masks n, t = 12, 19 # hard code segment_mask = self._auto_split(t, n_segment).unsqueeze(1).expand(n_segment, n, t).contiguous() # (s, n, t) segment_mask = rearrange( segment_mask, "s n t -> s (n t)" ) segment_mask = torch.cat([ torch.ones(n_segment, 1).bool(), segment_mask ], dim=1) # (s, 1+n*t) segment_masks = repeat(segment_mask, "n s -> b n s", b=batchsize).contiguous().bool().to(self.device) if not return_dict: return cls_embeds, audio_encodings, segment_masks return ImageBindSegmaskAudioEncoderOutput( audio_embeds=cls_embeds, audio_encodings=audio_encodings, audio_segment_masks=segment_masks )