vineetshukla.work@gmail.com
final commit
c5c9261
"""
Classification Heads — Pooling strategies and MLP classifier for deepfake detection.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class AttentiveStatsPooling(nn.Module):
"""
Attentive Statistics Pooling.
Learns which frames are most important, then computes weighted mean + std.
Used in ECAPA-TDNN and top speaker verification systems.
"""
def __init__(self, hidden_size: int, attention_dim: int = 128):
super().__init__()
self.attention = nn.Sequential(
nn.Linear(hidden_size, attention_dim),
nn.Tanh(),
nn.Linear(attention_dim, 1),
)
self.output_size = hidden_size * 2 # mean + std concatenated
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Args:
x: (batch, time, hidden_size)
mask: optional (batch, time) boolean mask
Returns:
(batch, hidden_size * 2) — weighted mean and std
"""
# Compute attention weights
attn_weights = self.attention(x).squeeze(-1) # (batch, time)
if mask is not None:
attn_weights = attn_weights.masked_fill(~mask, float("-inf"))
attn_weights = F.softmax(attn_weights, dim=-1).unsqueeze(-1) # (batch, time, 1)
# Weighted mean
mean = torch.sum(x * attn_weights, dim=1) # (batch, hidden)
# Weighted std
var = torch.sum(attn_weights * (x - mean.unsqueeze(1)) ** 2, dim=1)
std = torch.sqrt(var.clamp(min=1e-6))
return torch.cat([mean, std], dim=-1) # (batch, hidden*2)
class MultiHeadAttentionPooling(nn.Module):
"""
Multi-Head Attention Pooling.
Applies multi-head self-attention then pools via learned query vector.
"""
def __init__(self, hidden_size: int, num_heads: int = 4):
super().__init__()
self.num_heads = num_heads
self.query = nn.Parameter(torch.randn(1, 1, hidden_size))
self.mha = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
self.output_size = hidden_size
nn.init.xavier_uniform_(self.query)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Args:
x: (batch, time, hidden_size)
Returns:
(batch, hidden_size)
"""
batch_size = x.size(0)
query = self.query.expand(batch_size, -1, -1) # (batch, 1, hidden)
out, _ = self.mha(query, x, x) # (batch, 1, hidden)
return out.squeeze(1) # (batch, hidden)
class MeanPooling(nn.Module):
"""Simple mean pooling over the time axis."""
def __init__(self, hidden_size: int):
super().__init__()
self.output_size = hidden_size
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
if mask is not None:
x = x * mask.unsqueeze(-1).float()
return x.sum(dim=1) / mask.sum(dim=1, keepdim=True).float()
return x.mean(dim=1)
class DeepfakeClassifier(nn.Module):
"""
Full classification model = Backbone + Pooling + MLP Head.
"""
def __init__(self, backbone: nn.Module, hidden_size: int,
num_labels: int = 2, classifier_hidden: int = 256,
dropout: float = 0.3, pooling_type: str = "attentive_stats"):
super().__init__()
self.backbone = backbone
# Select pooling strategy
if pooling_type == "attentive_stats":
self.pooling = AttentiveStatsPooling(hidden_size)
elif pooling_type == "multi_head":
self.pooling = MultiHeadAttentionPooling(hidden_size)
elif pooling_type == "mean":
self.pooling = MeanPooling(hidden_size)
else:
raise ValueError(f"Unknown pooling: {pooling_type}")
pool_output_size = self.pooling.output_size
# MLP classification head with batch norm
self.classifier = nn.Sequential(
nn.Linear(pool_output_size, classifier_hidden),
nn.BatchNorm1d(classifier_hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(classifier_hidden, classifier_hidden // 2),
nn.BatchNorm1d(classifier_hidden // 2),
nn.ReLU(),
nn.Dropout(dropout / 2),
nn.Linear(classifier_hidden // 2, num_labels),
)
# Initialize weights
self._init_weights()
def _init_weights(self):
for m in self.classifier:
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
nn.init.zeros_(m.bias)
def forward(self, input_values: torch.Tensor,
attention_mask: torch.Tensor = None) -> torch.Tensor:
"""
Args:
input_values: (batch, time) raw waveform
attention_mask: (batch, time) attention mask
Returns:
logits: (batch, num_labels)
"""
# Extract features from backbone
outputs = self.backbone(input_values, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state # (batch, seq_len, hidden)
# Pool across time
pooled = self.pooling(hidden_states) # (batch, pool_dim)
# Classify
logits = self.classifier(pooled) # (batch, num_labels)
return logits
def extract_embeddings(self, input_values: torch.Tensor) -> torch.Tensor:
"""Extract embeddings (before classification head) for analysis."""
outputs = self.backbone(input_values)
hidden_states = outputs.last_hidden_state
return self.pooling(hidden_states)