from transformers import ASTModel, ViTModel, PretrainedConfig, PreTrainedModel import numpy as np import torch import torch.nn as nn from einops import reduce class MuVis(nn.Module): def __init__(self, embed_dims=768, latent_dims=128, sampling_rate=16000): super(MuVis, self).__init__() self.sampling_rate = sampling_rate self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", low_cpu_mem_usage=True) self.wav_lin = nn.Linear(embed_dims, latent_dims) self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k", low_cpu_mem_usage=True) self.img_lin = nn.Linear(embed_dims, latent_dims) def forward(self, wav=None, img=None): wav_out = None img_out = None if wav is not None: wav_out = self.ast(**wav)["last_hidden_state"] wav_out = self.wav_lin(wav_out) wav_out = reduce(wav_out, "b n d -> b d", "mean") wav_out = wav_out / wav_out.norm(dim=-1, keepdim=True) if img is not None: img_out = self.vit(**img)["last_hidden_state"] img_out = self.img_lin(img_out) img_out = reduce(img_out, "b n d -> b d", "mean") img_out = img_out / img_out.norm(dim=-1, keepdim=True) assert wav_out is not None or img_out is not None if wav_out is None or img_out is None: return wav_out if img_out is None else img_out return (wav_out, img_out) class MuVisConfig(PretrainedConfig): model_type = "muvis" def __init__( self, embed_dims=768, latent_dims=128, sampling_rate=16000, **kwargs, ): self.embed_dims = embed_dims self.latent_dims = latent_dims self.sampling_rate = sampling_rate super().__init__(**kwargs) class MuVisModel(PreTrainedModel): config_class = MuVisConfig def __init__(self, config): super().__init__(config) self.model = MuVis( embed_dims=config.embed_dims, latent_dims=config.latent_dims, sampling_rate=config.sampling_rate, ) def forward(self, wav=None, img=None): return self.model(wav=wav, img=img)