|
import torch |
|
from torch import nn |
|
|
|
|
|
|
|
class AudioClassifier(nn.Module): |
|
def __init__( |
|
self, |
|
label2id: dict, |
|
feature_dim=256, |
|
hidden_dim=256, |
|
device="cpu", |
|
dropout_rate=0.5, |
|
num_hidden_layers=2, |
|
): |
|
super(AudioClassifier, self).__init__() |
|
self.num_classes = len(label2id) |
|
self.device = device |
|
self.label2id = label2id |
|
self.id2label = {v: k for k, v in self.label2id.items()} |
|
|
|
self.fc1 = nn.Sequential( |
|
nn.Linear(feature_dim, hidden_dim), |
|
nn.BatchNorm1d(hidden_dim), |
|
nn.Mish(), |
|
nn.Dropout(dropout_rate), |
|
) |
|
|
|
self.hidden_layers = nn.ModuleList() |
|
for _ in range(num_hidden_layers): |
|
layer = nn.Sequential( |
|
nn.Linear(hidden_dim, hidden_dim), |
|
nn.BatchNorm1d(hidden_dim), |
|
nn.Mish(), |
|
nn.Dropout(dropout_rate), |
|
) |
|
self.hidden_layers.append(layer) |
|
|
|
self.fc_last = nn.Linear(hidden_dim, self.num_classes) |
|
|
|
def forward(self, x): |
|
|
|
x = self.fc1(x) |
|
|
|
|
|
for layer in self.hidden_layers: |
|
x = layer(x) |
|
|
|
|
|
x = self.fc_last(x) |
|
return x |
|
|
|
def infer_from_features(self, features): |
|
|
|
features = ( |
|
torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device) |
|
) |
|
|
|
|
|
self.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
output = self.forward(features) |
|
|
|
|
|
probs = torch.softmax(output, dim=1) |
|
|
|
|
|
probs, indices = torch.sort(probs, descending=True) |
|
probs = probs.cpu().numpy().squeeze() |
|
indices = indices.cpu().numpy().squeeze() |
|
return [(self.id2label[i], p) for i, p in zip(indices, probs)] |
|
|
|
def infer_from_file(self, file_path): |
|
feature = extract_features(file_path, device=self.device) |
|
return self.infer_from_features(feature) |
|
|
|
|
|
from pyannote.audio import Inference, Model |
|
|
|
emb_model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") |
|
inference = Inference(emb_model, window="whole") |
|
|
|
|
|
def extract_features(file_path, device="cpu"): |
|
inference.to(torch.device(device)) |
|
return inference(file_path) |
|
|