juliagsy commited on
Commit
0db0a20
1 Parent(s): 89d7700

Create modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +73 -0
modeling.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ASTModel, ViTModel, PretrainedConfig, PreTrainedModel
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import reduce
6
+
7
+
8
+ class MuVis(nn.Module):
9
+ def __init__(self, embed_dims=768, latent_dims=128, sampling_rate=16000):
10
+ super(MuVis, self).__init__()
11
+ self.sampling_rate = sampling_rate
12
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
13
+
14
+ self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", low_cpu_mem_usage=True)
15
+ self.wav_lin = nn.Linear(embed_dims, latent_dims)
16
+
17
+ self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k", low_cpu_mem_usage=True)
18
+ self.img_lin = nn.Linear(embed_dims, latent_dims)
19
+
20
+
21
+ def forward(self, wav=None, img=None):
22
+ wav_out = None
23
+ img_out = None
24
+
25
+ if wav is not None:
26
+ wav_out = self.ast(**wav)["last_hidden_state"]
27
+ wav_out = self.wav_lin(wav_out)
28
+ wav_out = reduce(wav_out, "b n d -> b d", "mean")
29
+ wav_out = wav_out / wav_out.norm(dim=-1, keepdim=True)
30
+
31
+ if img is not None:
32
+ img_out = self.vit(**img)["last_hidden_state"]
33
+ img_out = self.img_lin(img_out)
34
+ img_out = reduce(img_out, "b n d -> b d", "mean")
35
+ img_out = img_out / img_out.norm(dim=-1, keepdim=True)
36
+
37
+
38
+ assert wav_out is not None or img_out is not None
39
+
40
+ if wav_out is None or img_out is None:
41
+ return wav_out if img_out is None else img_out
42
+ return (wav_out, img_out)
43
+
44
+
45
+ class MuVisConfig(PretrainedConfig):
46
+ model_type = "muvis"
47
+
48
+ def __init__(
49
+ self,
50
+ embed_dims=768,
51
+ latent_dims=128,
52
+ sampling_rate=16000,
53
+ **kwargs,
54
+ ):
55
+ self.embed_dims = embed_dims
56
+ self.latent_dims = latent_dims
57
+ self.sampling_rate = sampling_rate
58
+ super().__init__(**kwargs)
59
+
60
+
61
+ class MuVisModel(PreTrainedModel):
62
+ config_class = MuVisConfig
63
+
64
+ def __init__(self, config):
65
+ super().__init__(config)
66
+ self.model = MuVis(
67
+ embed_dims=config.embed_dims,
68
+ latent_dims=config.latent_dims,
69
+ sampling_rate=config.sampling_rate,
70
+ )
71
+
72
+ def forward(self, wav=None, img=None):
73
+ return self.model(wav=wav, img=img)