File size: 2,192 Bytes
0db0a20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import ASTModel, ViTModel, PretrainedConfig, PreTrainedModel
import numpy as np
import torch
import torch.nn as nn
from einops import reduce


class MuVis(nn.Module):
  def __init__(self, embed_dims=768, latent_dims=128, sampling_rate=16000):
    super(MuVis, self).__init__()
    self.sampling_rate = sampling_rate
    self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", low_cpu_mem_usage=True)
    self.wav_lin = nn.Linear(embed_dims, latent_dims)

    self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k", low_cpu_mem_usage=True)
    self.img_lin = nn.Linear(embed_dims, latent_dims)


  def forward(self, wav=None, img=None):
    wav_out = None
    img_out = None

    if wav is not None:
      wav_out = self.ast(**wav)["last_hidden_state"]
      wav_out = self.wav_lin(wav_out)
      wav_out = reduce(wav_out, "b n d -> b d", "mean")
      wav_out = wav_out / wav_out.norm(dim=-1, keepdim=True)

    if img is not None:
      img_out = self.vit(**img)["last_hidden_state"]
      img_out = self.img_lin(img_out)
      img_out = reduce(img_out, "b n d -> b d", "mean")
      img_out = img_out / img_out.norm(dim=-1, keepdim=True)


    assert wav_out is not None or img_out is not None

    if wav_out is None or img_out is None:
      return wav_out if img_out is None else img_out
    return (wav_out, img_out)


class MuVisConfig(PretrainedConfig):
    model_type = "muvis"

    def __init__(
        self,
        embed_dims=768,
        latent_dims=128,
        sampling_rate=16000,
        **kwargs,
    ):
      self.embed_dims = embed_dims
      self.latent_dims = latent_dims
      self.sampling_rate = sampling_rate
      super().__init__(**kwargs)


class MuVisModel(PreTrainedModel):
    config_class = MuVisConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = MuVis(
            embed_dims=config.embed_dims,
            latent_dims=config.latent_dims,
            sampling_rate=config.sampling_rate,
        )

    def forward(self, wav=None, img=None):
        return self.model(wav=wav, img=img)