shangeth's picture
quantization added
ec9a712
raw
history blame
1.1 kB
import torch
from torch import nn
from transformers import Wav2Vec2Processor, HubertModel
if torch.cuda.is_available():
# Set the device to CUDA
device = "cuda"
else:
# Set the device to CPU
device = "cpu"
class HubertXCNNEnoder(nn.Module):
def __init__(self, audio_enc_dim, llm_dim, finetune=False):
super().__init__()
self.encoder = HubertModel.from_pretrained('facebook/hubert-xlarge-ll60k').to(device)
for param in self.encoder.parameters():
param.requires_grad = False
self.cnn = nn.Sequential(
nn.ReLU(),
nn.Conv1d(audio_enc_dim, llm_dim//2, kernel_size=5,
stride=1, padding=0),
nn.ReLU(),
nn.Conv1d(llm_dim//2, llm_dim, kernel_size=5,
stride=2, padding=0),
nn.ReLU(),
nn.Conv1d(llm_dim, llm_dim, kernel_size=3,
stride=1, padding=0),
)
def forward(self, x):
x = self.encoder(x).last_hidden_state
x = self.cnn(x.transpose(1,2)).transpose(1,2)
return x