import torch from torch import nn from transformers import Wav2Vec2Processor, HubertModel class HubertXCNNEnoder(nn.Module): def __init__(self, audio_enc_dim, llm_dim, finetune=False): super().__init__() self.encoder = HubertModel.from_pretrained('facebook/hubert-xlarge-ll60k') for param in self.encoder.parameters(): param.requires_grad = False self.cnn = nn.Sequential( nn.ReLU(), nn.Conv1d(audio_enc_dim, llm_dim//2, kernel_size=5, stride=1, padding=0), nn.ReLU(), nn.Conv1d(llm_dim//2, llm_dim, kernel_size=5, stride=2, padding=0), nn.ReLU(), nn.Conv1d(llm_dim, llm_dim, kernel_size=3, stride=1, padding=0), ) def forward(self, x): x = self.encoder(x).last_hidden_state x = self.cnn(x.transpose(1,2)).transpose(1,2) return x