shangeth commited on
Commit
e94759d
1 Parent(s): c8fbf2f

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +27 -0
model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import Wav2Vec2Processor, HubertModel
4
+
5
+ class HubertXCNNEnoder(nn.Module):
6
+ def __init__(self, audio_enc_dim, llm_dim, finetune=False):
7
+ super().__init__()
8
+ self.encoder = HubertModel.from_pretrained('facebook/hubert-xlarge-ll60k')
9
+ for param in self.encoder.parameters():
10
+ param.requires_grad = False
11
+
12
+ self.cnn = nn.Sequential(
13
+ nn.ReLU(),
14
+ nn.Conv1d(audio_enc_dim, llm_dim//2, kernel_size=5,
15
+ stride=1, padding=0),
16
+ nn.ReLU(),
17
+ nn.Conv1d(llm_dim//2, llm_dim, kernel_size=5,
18
+ stride=2, padding=0),
19
+ nn.ReLU(),
20
+ nn.Conv1d(llm_dim, llm_dim, kernel_size=3,
21
+ stride=1, padding=0),
22
+ )
23
+
24
+ def forward(self, x):
25
+ x = self.encoder(x).last_hidden_state
26
+ x = self.cnn(x.transpose(1,2)).transpose(1,2)
27
+ return x