Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from transformers import Wav2Vec2Processor, HubertModel | |
if torch.cuda.is_available(): | |
# Set the device to CUDA | |
device = "cuda" | |
else: | |
# Set the device to CPU | |
device = "cpu" | |
class HubertXCNNEnoder(nn.Module): | |
def __init__(self, audio_enc_dim, llm_dim, finetune=False): | |
super().__init__() | |
self.encoder = HubertModel.from_pretrained('facebook/hubert-xlarge-ll60k').to(device) | |
for param in self.encoder.parameters(): | |
param.requires_grad = False | |
self.cnn = nn.Sequential( | |
nn.ReLU(), | |
nn.Conv1d(audio_enc_dim, llm_dim//2, kernel_size=5, | |
stride=1, padding=0), | |
nn.ReLU(), | |
nn.Conv1d(llm_dim//2, llm_dim, kernel_size=5, | |
stride=2, padding=0), | |
nn.ReLU(), | |
nn.Conv1d(llm_dim, llm_dim, kernel_size=3, | |
stride=1, padding=0), | |
) | |
def forward(self, x): | |
x = self.encoder(x).last_hidden_state | |
x = self.cnn(x.transpose(1,2)).transpose(1,2) | |
return x | |