import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function from torch import tensor from transformers import Wav2Vec2FeatureExtractor, WavLMModel import transformers.models.wavlm.modeling_wavlm as wavlm from huggingface_hub import PyTorchModelHubMixin from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks class RevGrad(Function): @staticmethod def forward(ctx, input_, alpha_): ctx.save_for_backward(input_, alpha_) return input_ @staticmethod def backward(ctx, grad_output): _, alpha_ = ctx.saved_tensors grad_input = -grad_output * alpha_ if ctx.needs_input_grad[0] else None return grad_input, None revgrad = RevGrad.apply class RevGradLayer(nn.Module): def __init__(self, alpha=1.): super().__init__() self._alpha = tensor(alpha, requires_grad=False) def forward(self, x): return revgrad(x, self._alpha) class WavLMEncoderLayer(nn.Module): def __init__(self, layer_idx, config, has_relative_position_bias: bool = True): super().__init__() self.attention = wavlm.WavLMAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, num_buckets=config.num_buckets, max_distance=config.max_bucket_distance, has_relative_position_bias=has_relative_position_bias, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = wavlm.WavLMFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.config = config def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0): attn_residual = hidden_states hidden_states, attn_weights, position_bias = self.attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, index=index, ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states, position_bias) if output_attentions: outputs += (attn_weights,) return outputs class WavLMEncoderLayerStableLayerNorm(nn.Module): def __init__(self, layer_idx, config, has_relative_position_bias: bool = True): super().__init__() self.attention = wavlm.WavLMAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, num_buckets=config.num_buckets, max_distance=config.max_bucket_distance, has_relative_position_bias=has_relative_position_bias, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = wavlm.WavLMFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.config = config def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False): attn_residual = hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states, attn_weights, position_bias = self.attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) outputs = (hidden_states, position_bias) if output_attentions: outputs += (attn_weights,) return outputs class WavLMWrapper(nn.Module, PyTorchModelHubMixin): def __init__( self, pretrain_model="wavlm_large", hidden_dim=256, freeze_params=True, output_class_num=4, use_conv_output=True, apply_reg=False ): super().__init__() self.pretrain_model = pretrain_model self.use_conv_output = use_conv_output # Load backbone if self.pretrain_model == "wavlm": self.backbone_model = WavLMModel.from_pretrained( "microsoft/wavlm-base-plus", output_hidden_states=True, ) elif self.pretrain_model == "wavlm_large": self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large') self.backbone_model = WavLMModel.from_pretrained( "microsoft/wavlm-large", output_hidden_states=True, ) # Keep original encoder layers (no LoRA) state_dict = self.backbone_model.state_dict() self.model_config = self.backbone_model.config if self.pretrain_model == "wavlm": self.backbone_model.encoder.layers = nn.ModuleList( [WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)] ) else: self.backbone_model.encoder.layers = nn.ModuleList( [WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)] ) self.backbone_model.load_state_dict(state_dict, strict=False) # Freeze weights if requested if freeze_params: for p in self.backbone_model.parameters(): p.requires_grad = False # Conv projection layers self.model_seq = nn.Sequential( nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1), nn.ReLU(), nn.Dropout(0.1), nn.Conv1d(hidden_dim, hidden_dim, 1), nn.ReLU(), nn.Dropout(0.1), nn.Conv1d(hidden_dim, hidden_dim, 1) ) # Layer weights num_layers = self.model_config.num_hidden_layers + 1 if use_conv_output else self.model_config.num_hidden_layers self.weights = nn.Parameter(torch.ones(num_layers)/num_layers) # Output heads if apply_reg: self.age_dist_layer = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), nn.Sigmoid() ) else: self.age_dist_layer = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 7) ) self.sex_layer = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 2) ) def forward(self, x, length=None, return_feature=False, pred="age_dist_sex"): # Feature extraction if self.pretrain_model == "wavlm_large": with torch.no_grad(): signal, attention_mask = [], [] if length is not None: attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device) else: attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device) for idx in range(len(x)): input_vals = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True) signal.append(input_vals["input_values"][0].to(x.device)) signal = torch.stack(signal) if length is not None: length = self.get_feat_extract_output_lengths(length.detach().cpu()).cuda() if self.pretrain_model == "wavlm": x = self.backbone_model(x, output_hidden_states=True).hidden_states else: x = self.backbone_model(signal, attention_mask=attention_mask, output_hidden_states=True).hidden_states # Weighted sum of layers stacked_feature = torch.stack(x, dim=0) if self.use_conv_output else torch.stack(x, dim=0)[1:] _, *origin_shape = stacked_feature.shape stacked_feature = stacked_feature.view(stacked_feature.shape[0], -1) norm_weights = F.softmax(self.weights, dim=-1) weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0) features = weighted_feature.view(*origin_shape) # Conv projection features = self.model_seq(features.transpose(1, 2)).transpose(1, 2) # Pooling if length is not None: mean = [] for snt_id in range(features.shape[0]): actual_size = length[snt_id] mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0)) features = torch.stack(mean) else: features = torch.mean(features, dim=1) # Predictions age_pred = self.age_dist_layer(features) sex_pred = self.sex_layer(features) if return_feature: return age_pred, sex_pred, features return age_pred, sex_pred # Huggingface conv output length helper def get_feat_extract_output_lengths(self, input_length): def _conv_out_length(input_length, kernel_size, stride): return (input_length - kernel_size) // stride + 1 for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride): input_length = _conv_out_length(input_length, kernel_size, stride) return input_length def age_gender(audio_waveform_np, model, device): #numpy2tensor if isinstance(audio_waveform_np, np.ndarray): tensor = torch.from_numpy(audio_waveform_np) elif isinstance(audio_waveform_np, torch.Tensor): tensor = audio_waveform_np if tensor.dim() == 1: tensor = tensor.unsqueeze(0) tensor = tensor.to(torch.device(device)) if tensor.dtype not in (torch.float32, torch.float16): tensor = tensor.float() with torch.no_grad(): wavlm_outputs, wavlm_sex_outputs = model(tensor) age_pred = wavlm_outputs.detach().cpu().numpy().flatten() * 100.0 sex_prob = F.softmax(wavlm_sex_outputs, dim=1) sex_labels_es = ["Femenino", "Masculino"] sex_idx = int(torch.argmax(sex_prob).detach().cpu().item()) sex_pred = sex_labels_es[sex_idx] try: age_value = int(round(float(age_pred[0]))) if age_value < 20: age_group = "joven (menor de 20)" elif age_value < 35: age_group = "adulto (20–35)" elif age_value < 60: age_group = "mediana edad (35–60)" else: age_group = "mayor (60+)" except Exception: age_value = None age_group = "desconocido" return str(age_value) if age_value is not None else "N/A", sex_pred, age_group