import torch from torch import nn from transformers import Wav2Vec2Processor, HubertModel if torch.cuda.is_available(): # Set the device to CUDA device = "cuda" else: # Set the device to CPU device = "cpu" class HubertXCNNEnoder(nn.Module): def __init__(self, audio_enc_dim, llm_dim, finetune=False): super().__init__() self.encoder = HubertModel.from_pretrained('facebook/hubert-xlarge-ll60k').to(device) for param in self.encoder.parameters(): param.requires_grad = False self.cnn = nn.Sequential( nn.ReLU(), nn.Conv1d(audio_enc_dim, llm_dim//2, kernel_size=5, stride=1, padding=0), nn.ReLU(), nn.Conv1d(llm_dim//2, llm_dim, kernel_size=5, stride=2, padding=0), nn.ReLU(), nn.Conv1d(llm_dim, llm_dim, kernel_size=3, stride=1, padding=0), ) def forward(self, x): x = self.encoder(x).last_hidden_state x = self.cnn(x.transpose(1,2)).transpose(1,2) return x