Spaces:
Running
on
Zero
Running
on
Zero
import math | |
from dataclasses import dataclass | |
from typing import Optional, TypeVar, Tuple, Any | |
T = TypeVar('T', bound='Module') | |
from einops import rearrange, repeat | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from transformers.utils import ModelOutput | |
from transformers.modeling_outputs import BaseModelOutputWithPooling | |
from diffusers.models.modeling_utils import ModelMixin | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from imagebind.models import imagebind_model | |
from imagebind.models.imagebind_model import ModalityType | |
class ImageBindSegmaskAudioEncoderOutput(ModelOutput): | |
""" | |
Args: | |
text_embeds(`torch.Tensor` of shape `(batch_size, output_dim`): | |
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`]. | |
image_embeds(`torch.Tensor` of shape `(batch_size, output_dim`): | |
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`]. | |
text_model_output(`BaseModelOutputWithPooling`): | |
The output of the [`CLIPTextModel`]. | |
vision_model_output(`BaseModelOutputWithPooling`): | |
The output of the [`CLIPVisionModel`]. | |
""" | |
audio_embeds: torch.Tensor = None | |
audio_encodings: torch.Tensor = None | |
audio_segment_masks: torch.BoolTensor = None | |
def to_tuple(self) -> Tuple[Any]: | |
return tuple(self[k] for k in self.keys()) | |
class ImageBindSegmaskAudioEncoder(ModelMixin, ConfigMixin): | |
def __init__(self, | |
n_segment=4, | |
pretrained_model_name="imagebind-huge" | |
): | |
super().__init__() | |
self.n_segment = n_segment | |
self.pretrained_model_name = pretrained_model_name | |
if pretrained_model_name == "imagebind-huge": | |
pretrained_model = imagebind_model.imagebind_huge(pretrained=True) | |
self.preprocessor = pretrained_model.modality_preprocessors[ModalityType.AUDIO] | |
self.trunk = pretrained_model.modality_trunks[ModalityType.AUDIO] | |
self.head = pretrained_model.modality_heads[ModalityType.AUDIO] | |
self.postprocessor = pretrained_model.modality_postprocessors[ModalityType.AUDIO] | |
self.final_layer_norm = nn.LayerNorm(normalized_shape=768, eps=1e-6) | |
def _auto_split(self, n, n_chunk): | |
''' | |
automatically split into chunks with n_ele no differ by 1 | |
if n is not dividible by n_chunk, extra one's will be added to the heading chunks | |
''' | |
chunk_size = int(math.ceil(n / n_chunk)) | |
assert chunk_size >= 1, chunk_size | |
chunk_start_indices = np.round(np.linspace(0, n - chunk_size, n_chunk, endpoint=True)).astype(np.int32) | |
mask = torch.zeros(n_chunk, n).bool() | |
for chunk_index, chunk_start_index in enumerate(chunk_start_indices): | |
mask[chunk_index, chunk_start_index:chunk_start_index + chunk_size] = 1 | |
mask = mask.contiguous() | |
assert mask.long().sum() == chunk_size * n_chunk, mask.long().sum() | |
return mask | |
def forward(self, | |
input_features: Optional[torch.Tensor], | |
normalize: bool = False, | |
return_dict: Optional[bool] = None): | |
n_segment = self.n_segment | |
# 1. reshape to imagebind input | |
batchsize = input_features.size(0) | |
# 2. patchify images and add positional embedding and | |
audio_inputs = self.preprocessor(input_features) | |
trunk_inputs = audio_inputs["trunk"] # dict of {"tokens": (b, l, d)} | |
# 3. get audio encoder output | |
audio_encodings = self.trunk(**trunk_inputs) # w/o layer norm (b, seq_len, c) | |
head_inputs = audio_inputs["head"] | |
cls_embeds = self.head(audio_encodings, **head_inputs) | |
# normalize and logit scaling | |
if normalize: | |
cls_embeds = self.postprocessor(cls_embeds) # (b, c) | |
audio_encodings = self.final_layer_norm(audio_encodings) | |
# 4. get segment masks | |
n, t = 12, 19 # hard code | |
segment_mask = self._auto_split(t, n_segment).unsqueeze(1).expand(n_segment, n, t).contiguous() # (s, n, t) | |
segment_mask = rearrange( | |
segment_mask, "s n t -> s (n t)" | |
) | |
segment_mask = torch.cat([ | |
torch.ones(n_segment, 1).bool(), | |
segment_mask | |
], dim=1) # (s, 1+n*t) | |
segment_masks = repeat(segment_mask, "n s -> b n s", b=batchsize).contiguous().bool().to(self.device) | |
if not return_dict: | |
return cls_embeds, audio_encodings, segment_masks | |
return ImageBindSegmaskAudioEncoderOutput( | |
audio_embeds=cls_embeds, | |
audio_encodings=audio_encodings, | |
audio_segment_masks=segment_masks | |
) | |