muvis / modeling.py
juliagsy's picture
Create modeling.py
0db0a20 verified
raw
history blame
No virus
2.19 kB
from transformers import ASTModel, ViTModel, PretrainedConfig, PreTrainedModel
import numpy as np
import torch
import torch.nn as nn
from einops import reduce
class MuVis(nn.Module):
def __init__(self, embed_dims=768, latent_dims=128, sampling_rate=16000):
super(MuVis, self).__init__()
self.sampling_rate = sampling_rate
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", low_cpu_mem_usage=True)
self.wav_lin = nn.Linear(embed_dims, latent_dims)
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k", low_cpu_mem_usage=True)
self.img_lin = nn.Linear(embed_dims, latent_dims)
def forward(self, wav=None, img=None):
wav_out = None
img_out = None
if wav is not None:
wav_out = self.ast(**wav)["last_hidden_state"]
wav_out = self.wav_lin(wav_out)
wav_out = reduce(wav_out, "b n d -> b d", "mean")
wav_out = wav_out / wav_out.norm(dim=-1, keepdim=True)
if img is not None:
img_out = self.vit(**img)["last_hidden_state"]
img_out = self.img_lin(img_out)
img_out = reduce(img_out, "b n d -> b d", "mean")
img_out = img_out / img_out.norm(dim=-1, keepdim=True)
assert wav_out is not None or img_out is not None
if wav_out is None or img_out is None:
return wav_out if img_out is None else img_out
return (wav_out, img_out)
class MuVisConfig(PretrainedConfig):
model_type = "muvis"
def __init__(
self,
embed_dims=768,
latent_dims=128,
sampling_rate=16000,
**kwargs,
):
self.embed_dims = embed_dims
self.latent_dims = latent_dims
self.sampling_rate = sampling_rate
super().__init__(**kwargs)
class MuVisModel(PreTrainedModel):
config_class = MuVisConfig
def __init__(self, config):
super().__init__(config)
self.model = MuVis(
embed_dims=config.embed_dims,
latent_dims=config.latent_dims,
sampling_rate=config.sampling_rate,
)
def forward(self, wav=None, img=None):
return self.model(wav=wav, img=img)