Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchaudio | |
import wget | |
import json | |
import os | |
STT_FOLDER = "./STTModel" | |
STT_MODEL_NAME = "wav2vec2" | |
STT_MODEL_WEIGHTS = "pytorch_model.bin" | |
STT_CONFIG = "config.json" | |
STT_VOCAB = "vocab.json" | |
STT_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/pytorch_model.bin" | |
STT_CONFIG_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json" | |
STT_VOCAB_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json" | |
STT_FILES_URLS = [ | |
(STT_MODEL_WEIGHTS_URL, STT_MODEL_WEIGHTS), | |
(STT_CONFIG_URL, STT_CONFIG), | |
(STT_VOCAB_URL, STT_VOCAB), | |
] | |
def ensure_stt_files_exist(): | |
os.makedirs(STT_FOLDER, exist_ok=True) | |
for url, filename in STT_FILES_URLS: | |
filepath = os.path.join(STT_FOLDER, filename) | |
if not os.path.exists(filepath): | |
wget.download(url, out=filepath) | |
class Wav2Vec2ForCTC(nn.Module): | |
def __init__(self, num_classes): | |
super().__init__() | |
self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=2, padding=2) | |
self.relu1 = nn.ReLU() | |
self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2) | |
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1) | |
self.relu2 = nn.ReLU() | |
self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2) | |
self.fc = nn.Linear(32 * 39 * 40, num_classes) # Adjusted input size | |
def forward(self, x): | |
x = self.pool1(self.relu1(self.conv1(x))) | |
x = self.pool2(self.relu2(self.conv2(x))) | |
x = x.view(x.size(0), -1) | |
logits = self.fc(x) | |
return logits |