|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
from transformers import AutoModel |
|
|
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
class BCEWithLogitsLossLS(nn.Module): |
|
def __init__(self, label_smoothing=0.1, pos_weight=None, reduction='mean'): |
|
super(BCEWithLogitsLossLS, self).__init__() |
|
assert 0 <= label_smoothing < 1, "label_smoothing value must be between 0 and 1." |
|
self.label_smoothing = label_smoothing |
|
self.reduction = reduction |
|
self.bce_with_logits = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=reduction) |
|
|
|
def forward(self, input, target): |
|
if self.label_smoothing > 0: |
|
positive_smoothed_labels = 1.0 - self.label_smoothing |
|
negative_smoothed_labels = self.label_smoothing |
|
target = target * positive_smoothed_labels + \ |
|
(1 - target) * negative_smoothed_labels |
|
|
|
loss = self.bce_with_logits(input, target) |
|
return loss |
|
|
|
class WavLMForEndpointing(nn.Module): |
|
def __init__(self, config, n_trainable_layers=6): |
|
super().__init__() |
|
self.wavlm = AutoModel.from_pretrained('microsoft/wavlm-base-plus', config=config) |
|
self.config = config |
|
self.n_trainable_layers = n_trainable_layers |
|
|
|
for param in self.wavlm.parameters(): |
|
param.requires_grad = False |
|
|
|
if self.n_trainable_layers > 0: |
|
for i in range(self.n_trainable_layers): |
|
for param in self.wavlm.encoder.layers[-(i+1)].parameters(): |
|
param.requires_grad = True |
|
|
|
self.pool_attention = nn.Sequential( |
|
nn.Linear(config.hidden_size, 256), |
|
nn.Tanh(), |
|
nn.Linear(256, 1) |
|
) |
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(config.hidden_size, 256), |
|
nn.LayerNorm(256), |
|
nn.GELU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(256, 64), |
|
nn.LayerNorm(64), |
|
nn.GELU(), |
|
nn.Linear(64, 1) |
|
) |
|
|
|
for module in self.classifier: |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=0.1) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
for module in self.pool_attention: |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=0.1) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
def attention_pool(self, hidden_states, attention_mask): |
|
attention_weights = self.pool_attention(hidden_states) |
|
|
|
if attention_mask is None: |
|
raise ValueError("attention_mask must be provided for attention pooling") |
|
|
|
attention_weights = attention_weights + ( |
|
(1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9 |
|
) |
|
|
|
attention_weights = F.softmax(attention_weights, dim=1) |
|
|
|
|
|
weighted_sum = torch.sum(hidden_states * attention_weights, dim=1) |
|
|
|
return weighted_sum |
|
|
|
def forward(self, input_values, attention_mask=None, labels=None): |
|
outputs = self.wavlm(input_values, attention_mask=attention_mask) |
|
hidden_states = outputs[0] |
|
|
|
if attention_mask is not None: |
|
input_length = attention_mask.size(1) |
|
hidden_length = hidden_states.size(1) |
|
ratio = input_length / hidden_length |
|
indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long() |
|
attention_mask = attention_mask[:, indices] |
|
attention_mask = attention_mask.bool() |
|
else: |
|
attention_mask = None |
|
|
|
pooled = self.attention_pool(hidden_states, attention_mask) |
|
|
|
logits = self.classifier(pooled) |
|
|
|
if torch.isnan(logits).any(): |
|
raise ValueError("NaN values detected in logits") |
|
|
|
if labels is not None: |
|
pos_weight = ((labels == 0).sum() / (labels == 1).sum()).clamp(min=0.1, max=10.0) |
|
loss_fct = BCEWithLogitsLossLS(pos_weight=pos_weight) |
|
labels = labels.float() |
|
loss = loss_fct(logits.view(-1), labels.view(-1)) |
|
|
|
l2_lambda = 0.01 |
|
l2_reg = torch.tensor(0., device=logits.device) |
|
for param in self.classifier.parameters(): |
|
l2_reg += torch.norm(param) |
|
loss += l2_lambda * l2_reg |
|
|
|
probs = torch.sigmoid(logits.detach()) |
|
return {"loss": loss, "logits": probs} |
|
|
|
probs = torch.sigmoid(logits) |
|
return {"logits": probs} |