audio_detect / model.py
NikiPshg's picture
Upload folder using huggingface_hub
b2231f4 verified
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn as nn
from transformers import AutoModel
import torch
from torch import nn
import torch.nn.functional as F
class BCEWithLogitsLossLS(nn.Module):
def __init__(self, label_smoothing=0.1, pos_weight=None, reduction='mean'):
super(BCEWithLogitsLossLS, self).__init__()
assert 0 <= label_smoothing < 1, "label_smoothing value must be between 0 and 1."
self.label_smoothing = label_smoothing
self.reduction = reduction
self.bce_with_logits = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=reduction)
def forward(self, input, target):
if self.label_smoothing > 0:
positive_smoothed_labels = 1.0 - self.label_smoothing
negative_smoothed_labels = self.label_smoothing
target = target * positive_smoothed_labels + \
(1 - target) * negative_smoothed_labels
loss = self.bce_with_logits(input, target)
return loss
class WavLMForEndpointing(nn.Module):
def __init__(self, config, n_trainable_layers=6):
super().__init__()
self.wavlm = AutoModel.from_pretrained('microsoft/wavlm-base-plus', config=config)
self.config = config
self.n_trainable_layers = n_trainable_layers
for param in self.wavlm.parameters():
param.requires_grad = False
if self.n_trainable_layers > 0:
for i in range(self.n_trainable_layers):
for param in self.wavlm.encoder.layers[-(i+1)].parameters():
param.requires_grad = True
self.pool_attention = nn.Sequential(
nn.Linear(config.hidden_size, 256),
nn.Tanh(),
nn.Linear(256, 1)
)
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(256, 64),
nn.LayerNorm(64),
nn.GELU(),
nn.Linear(64, 1)
)
for module in self.classifier:
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.1)
if module.bias is not None:
module.bias.data.zero_()
for module in self.pool_attention:
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.1)
if module.bias is not None:
module.bias.data.zero_()
def attention_pool(self, hidden_states, attention_mask):
attention_weights = self.pool_attention(hidden_states)
if attention_mask is None:
raise ValueError("attention_mask must be provided for attention pooling")
attention_weights = attention_weights + (
(1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9
)
attention_weights = F.softmax(attention_weights, dim=1)
# Apply attention to hidden states
weighted_sum = torch.sum(hidden_states * attention_weights, dim=1)
return weighted_sum
def forward(self, input_values, attention_mask=None, labels=None):
outputs = self.wavlm(input_values, attention_mask=attention_mask)
hidden_states = outputs[0]
if attention_mask is not None:
input_length = attention_mask.size(1)
hidden_length = hidden_states.size(1)
ratio = input_length / hidden_length
indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long()
attention_mask = attention_mask[:, indices]
attention_mask = attention_mask.bool()
else:
attention_mask = None
pooled = self.attention_pool(hidden_states, attention_mask)
logits = self.classifier(pooled)
if torch.isnan(logits).any():
raise ValueError("NaN values detected in logits")
if labels is not None:
pos_weight = ((labels == 0).sum() / (labels == 1).sum()).clamp(min=0.1, max=10.0)
loss_fct = BCEWithLogitsLossLS(pos_weight=pos_weight)
labels = labels.float()
loss = loss_fct(logits.view(-1), labels.view(-1))
l2_lambda = 0.01
l2_reg = torch.tensor(0., device=logits.device)
for param in self.classifier.parameters():
l2_reg += torch.norm(param)
loss += l2_lambda * l2_reg
probs = torch.sigmoid(logits.detach())
return {"loss": loss, "logits": probs}
probs = torch.sigmoid(logits)
return {"logits": probs}