Spaces:
Sleeping
Sleeping
| #!/usr/bin/python | |
| #-*- coding: utf-8 -*- | |
| """ | |
| Transfer Learning Implementation for SyncNet | |
| This module provides pre-trained backbone integration for improved performance. | |
| Supported backbones: | |
| - Video: 3D ResNet (Kinetics), I3D, SlowFast, X3D | |
| - Audio: VGGish (AudioSet), wav2vec 2.0, HuBERT | |
| Author: Enhanced version | |
| Date: 2025-11-22 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # ==================== VIDEO BACKBONES ==================== | |
| class ResNet3D_Backbone(nn.Module): | |
| """ | |
| 3D ResNet backbone pre-trained on Kinetics-400. | |
| Uses torchvision's video models. | |
| """ | |
| def __init__(self, embedding_dim=512, pretrained=True, model_type='r3d_18'): | |
| super(ResNet3D_Backbone, self).__init__() | |
| try: | |
| import torchvision.models.video as video_models | |
| # Load pre-trained model | |
| if model_type == 'r3d_18': | |
| backbone = video_models.r3d_18(pretrained=pretrained) | |
| elif model_type == 'mc3_18': | |
| backbone = video_models.mc3_18(pretrained=pretrained) | |
| elif model_type == 'r2plus1d_18': | |
| backbone = video_models.r2plus1d_18(pretrained=pretrained) | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| # Remove final FC and pooling layers | |
| self.features = nn.Sequential(*list(backbone.children())[:-2]) | |
| # Add custom head | |
| self.conv_head = nn.Sequential( | |
| nn.Conv3d(512, embedding_dim, kernel_size=1), | |
| nn.BatchNorm3d(embedding_dim), | |
| nn.ReLU(inplace=True), | |
| ) | |
| print(f"Loaded {model_type} with pretrained={pretrained}") | |
| except ImportError: | |
| print("Warning: torchvision not found. Using random initialization.") | |
| self.features = self._build_simple_3dcnn() | |
| self.conv_head = nn.Conv3d(512, embedding_dim, 1) | |
| def _build_simple_3dcnn(self): | |
| """Fallback if torchvision not available.""" | |
| return nn.Sequential( | |
| nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3)), | |
| nn.BatchNorm3d(64), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), | |
| nn.Conv3d(64, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm3d(128), | |
| nn.ReLU(inplace=True), | |
| nn.Conv3d(128, 256, kernel_size=3, padding=1), | |
| nn.BatchNorm3d(256), | |
| nn.ReLU(inplace=True), | |
| nn.Conv3d(256, 512, kernel_size=3, padding=1), | |
| nn.BatchNorm3d(512), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: [B, 3, T, H, W] | |
| Returns: | |
| features: [B, C, T', H', W'] | |
| """ | |
| x = self.features(x) | |
| x = self.conv_head(x) | |
| return x | |
| class I3D_Backbone(nn.Module): | |
| """ | |
| Inflated 3D ConvNet (I3D) backbone. | |
| Requires external I3D implementation. | |
| """ | |
| def __init__(self, embedding_dim=512, pretrained=True): | |
| super(I3D_Backbone, self).__init__() | |
| try: | |
| # Try to import I3D (needs to be installed separately) | |
| from i3d import InceptionI3d | |
| self.i3d = InceptionI3d(400, in_channels=3) | |
| if pretrained: | |
| # Load pre-trained weights | |
| state_dict = torch.load('models/rgb_imagenet.pt', map_location='cpu') | |
| self.i3d.load_state_dict(state_dict) | |
| print("Loaded I3D with ImageNet+Kinetics pre-training") | |
| # Adaptation layer | |
| self.adapt = nn.Conv3d(1024, embedding_dim, kernel_size=1) | |
| except: | |
| print("Warning: I3D not available. Install from: https://github.com/piergiaj/pytorch-i3d") | |
| # Fallback to simple 3D CNN | |
| self.i3d = self._build_fallback() | |
| self.adapt = nn.Conv3d(512, embedding_dim, 1) | |
| def _build_fallback(self): | |
| return nn.Sequential( | |
| nn.Conv3d(3, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3)), | |
| nn.BatchNorm3d(64), | |
| nn.ReLU(inplace=True), | |
| nn.Conv3d(64, 512, kernel_size=3, padding=1), | |
| nn.BatchNorm3d(512), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| features = self.i3d.extract_features(x) if hasattr(self.i3d, 'extract_features') else self.i3d(x) | |
| features = self.adapt(features) | |
| return features | |
| # ==================== AUDIO BACKBONES ==================== | |
| class VGGish_Backbone(nn.Module): | |
| """ | |
| VGGish audio encoder pre-trained on AudioSet. | |
| Processes log-mel spectrograms. | |
| """ | |
| def __init__(self, embedding_dim=512, pretrained=True): | |
| super(VGGish_Backbone, self).__init__() | |
| try: | |
| import torchvggish | |
| # Load VGGish | |
| self.vggish = torchvggish.vggish() | |
| if pretrained: | |
| # Download and load pre-trained weights | |
| self.vggish.load_state_dict( | |
| torch.hub.load_state_dict_from_url( | |
| 'https://github.com/harritaylor/torchvggish/releases/download/v0.1/vggish-10086976.pth', | |
| map_location='cpu' | |
| ) | |
| ) | |
| print("Loaded VGGish pre-trained on AudioSet") | |
| # Use convolutional part only | |
| self.features = self.vggish.features | |
| # Adaptation layer | |
| self.adapt = nn.Sequential( | |
| nn.Conv2d(512, embedding_dim, kernel_size=1), | |
| nn.BatchNorm2d(embedding_dim), | |
| nn.ReLU(inplace=True), | |
| ) | |
| except ImportError: | |
| print("Warning: torchvggish not found. Install: pip install torchvggish") | |
| self.features = self._build_fallback() | |
| self.adapt = nn.Conv2d(512, embedding_dim, 1) | |
| def _build_fallback(self): | |
| """Simple audio CNN if VGGish unavailable.""" | |
| return nn.Sequential( | |
| nn.Conv2d(1, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(128, 256, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256, 512, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: [B, 1, F, T] or [B, 1, 96, T] (log-mel spectrogram) | |
| Returns: | |
| features: [B, C, F', T'] | |
| """ | |
| x = self.features(x) | |
| x = self.adapt(x) | |
| return x | |
| class Wav2Vec_Backbone(nn.Module): | |
| """ | |
| wav2vec 2.0 backbone for speech representation. | |
| Processes raw waveforms. | |
| """ | |
| def __init__(self, embedding_dim=512, pretrained=True, model_name='facebook/wav2vec2-base'): | |
| super(Wav2Vec_Backbone, self).__init__() | |
| try: | |
| from transformers import Wav2Vec2Model | |
| if pretrained: | |
| self.wav2vec = Wav2Vec2Model.from_pretrained(model_name) | |
| print(f"Loaded {model_name} from HuggingFace") | |
| else: | |
| from transformers import Wav2Vec2Config | |
| config = Wav2Vec2Config() | |
| self.wav2vec = Wav2Vec2Model(config) | |
| # Freeze early layers for fine-tuning | |
| self._freeze_layers(num_layers_to_freeze=6) | |
| # Adaptation layer | |
| wav2vec_dim = self.wav2vec.config.hidden_size | |
| self.adapt = nn.Sequential( | |
| nn.Linear(wav2vec_dim, embedding_dim), | |
| nn.LayerNorm(embedding_dim), | |
| nn.ReLU(), | |
| ) | |
| except ImportError: | |
| print("Warning: transformers not found. Install: pip install transformers") | |
| raise | |
| def _freeze_layers(self, num_layers_to_freeze): | |
| """Freeze early transformer layers.""" | |
| for param in self.wav2vec.feature_extractor.parameters(): | |
| param.requires_grad = False | |
| for i, layer in enumerate(self.wav2vec.encoder.layers): | |
| if i < num_layers_to_freeze: | |
| for param in layer.parameters(): | |
| param.requires_grad = False | |
| def forward(self, waveform): | |
| """ | |
| Args: | |
| waveform: [B, T] - raw audio waveform (16kHz) | |
| Returns: | |
| features: [B, C, T'] - temporal features | |
| """ | |
| # Extract features from wav2vec | |
| outputs = self.wav2vec(waveform, output_hidden_states=True) | |
| features = outputs.last_hidden_state # [B, T', D] | |
| # Adapt to target dimension | |
| features = self.adapt(features) # [B, T', embedding_dim] | |
| # Reshape to [B, C, T'] | |
| features = features.transpose(1, 2) | |
| return features | |
| # ==================== INTEGRATED SYNCNET WITH TRANSFER LEARNING ==================== | |
| class SyncNet_TransferLearning(nn.Module): | |
| """ | |
| SyncNet with transfer learning from pre-trained backbones. | |
| Args: | |
| video_backbone: 'resnet3d', 'i3d', 'simple' | |
| audio_backbone: 'vggish', 'wav2vec', 'simple' | |
| embedding_dim: Dimension of shared embedding space | |
| max_offset: Maximum temporal offset to consider | |
| freeze_backbone: Whether to freeze backbone weights | |
| """ | |
| def __init__(self, | |
| video_backbone='resnet3d', | |
| audio_backbone='vggish', | |
| embedding_dim=512, | |
| max_offset=15, | |
| freeze_backbone=False): | |
| super(SyncNet_TransferLearning, self).__init__() | |
| self.embedding_dim = embedding_dim | |
| self.max_offset = max_offset | |
| # Initialize video encoder | |
| if video_backbone == 'resnet3d': | |
| self.video_encoder = ResNet3D_Backbone(embedding_dim, pretrained=True) | |
| elif video_backbone == 'i3d': | |
| self.video_encoder = I3D_Backbone(embedding_dim, pretrained=True) | |
| else: | |
| from SyncNetModel_FCN import FCN_VideoEncoder | |
| self.video_encoder = FCN_VideoEncoder(embedding_dim) | |
| # Initialize audio encoder | |
| if audio_backbone == 'vggish': | |
| self.audio_encoder = VGGish_Backbone(embedding_dim, pretrained=True) | |
| elif audio_backbone == 'wav2vec': | |
| self.audio_encoder = Wav2Vec_Backbone(embedding_dim, pretrained=True) | |
| else: | |
| from SyncNetModel_FCN import FCN_AudioEncoder | |
| self.audio_encoder = FCN_AudioEncoder(embedding_dim) | |
| # Freeze backbones if requested | |
| if freeze_backbone: | |
| self._freeze_backbones() | |
| # Temporal pooling to handle variable spatial/frequency dimensions | |
| self.video_temporal_pool = nn.AdaptiveAvgPool3d((None, 1, 1)) | |
| self.audio_temporal_pool = nn.AdaptiveAvgPool2d((1, None)) | |
| # Correlation and sync prediction (from FCN model) | |
| from SyncNetModel_FCN import TemporalCorrelation | |
| self.correlation = TemporalCorrelation(max_displacement=max_offset) | |
| self.sync_predictor = nn.Sequential( | |
| nn.Conv1d(2*max_offset+1, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(inplace=True), | |
| nn.Conv1d(128, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(64), | |
| nn.ReLU(inplace=True), | |
| nn.Conv1d(64, 2*max_offset+1, kernel_size=1), | |
| ) | |
| def _freeze_backbones(self): | |
| """Freeze backbone parameters for fine-tuning only the head.""" | |
| for param in self.video_encoder.parameters(): | |
| param.requires_grad = False | |
| for param in self.audio_encoder.parameters(): | |
| param.requires_grad = False | |
| print("Backbones frozen. Only training sync predictor.") | |
| def forward_video(self, video): | |
| """ | |
| Extract video features. | |
| Args: | |
| video: [B, 3, T, H, W] | |
| Returns: | |
| features: [B, C, T'] | |
| """ | |
| features = self.video_encoder(video) # [B, C, T', H', W'] | |
| features = self.video_temporal_pool(features) # [B, C, T', 1, 1] | |
| B, C, T, _, _ = features.shape | |
| features = features.view(B, C, T) # [B, C, T'] | |
| return features | |
| def forward_audio(self, audio): | |
| """ | |
| Extract audio features. | |
| Args: | |
| audio: [B, 1, F, T] or [B, T] (raw waveform for wav2vec) | |
| Returns: | |
| features: [B, C, T'] | |
| """ | |
| if isinstance(self.audio_encoder, Wav2Vec_Backbone): | |
| # wav2vec expects [B, T] | |
| if audio.dim() == 4: | |
| # Convert from spectrogram to waveform (placeholder - need actual audio) | |
| raise NotImplementedError("Need raw waveform for wav2vec") | |
| features = self.audio_encoder(audio) | |
| else: | |
| features = self.audio_encoder(audio) # [B, C, F', T'] | |
| features = self.audio_temporal_pool(features) # [B, C, 1, T'] | |
| B, C, _, T = features.shape | |
| features = features.view(B, C, T) # [B, C, T'] | |
| return features | |
| def forward(self, audio, video): | |
| """ | |
| Full forward pass with sync prediction. | |
| Args: | |
| audio: [B, 1, F, T] - audio features | |
| video: [B, 3, T', H, W] - video frames | |
| Returns: | |
| sync_probs: [B, 2K+1, T''] - sync probabilities | |
| audio_features: [B, C, T_a] | |
| video_features: [B, C, T_v] | |
| """ | |
| # Extract features | |
| audio_features = self.forward_audio(audio) | |
| video_features = self.forward_video(video) | |
| # Align temporal dimensions | |
| min_time = min(audio_features.size(2), video_features.size(2)) | |
| audio_features = audio_features[:, :, :min_time] | |
| video_features = video_features[:, :, :min_time] | |
| # Compute correlation | |
| correlation = self.correlation(video_features, audio_features) | |
| # Predict sync probabilities | |
| sync_logits = self.sync_predictor(correlation) | |
| sync_probs = F.softmax(sync_logits, dim=1) | |
| return sync_probs, audio_features, video_features | |
| def compute_offset(self, sync_probs): | |
| """ | |
| Compute offset from sync probability map. | |
| Args: | |
| sync_probs: [B, 2K+1, T] - sync probabilities | |
| Returns: | |
| offsets: [B, T] - predicted offset for each frame | |
| confidences: [B, T] - confidence scores | |
| """ | |
| max_probs, max_indices = torch.max(sync_probs, dim=1) | |
| offsets = self.max_offset - max_indices | |
| median_probs = torch.median(sync_probs, dim=1)[0] | |
| confidences = max_probs - median_probs | |
| return offsets, confidences | |
| # ==================== TRAINING UTILITIES ==================== | |
| def fine_tune_with_transfer_learning(model, | |
| train_loader, | |
| val_loader, | |
| num_epochs=10, | |
| lr=1e-4, | |
| device='cuda'): | |
| """ | |
| Fine-tune pre-trained model on SyncNet task. | |
| Strategy: | |
| 1. Freeze backbones, train head (2-3 epochs) | |
| 2. Unfreeze last layers, train with small lr (5 epochs) | |
| 3. Unfreeze all, train with very small lr (2-3 epochs) | |
| """ | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs) | |
| for epoch in range(num_epochs): | |
| # Phase 1: Freeze backbones | |
| if epoch < 3: | |
| model._freeze_backbones() | |
| current_lr = lr | |
| # Phase 2: Unfreeze | |
| elif epoch == 3: | |
| for param in model.parameters(): | |
| param.requires_grad = True | |
| current_lr = lr / 10 | |
| optimizer = torch.optim.Adam(model.parameters(), lr=current_lr) | |
| model.train() | |
| total_loss = 0 | |
| for batch_idx, (audio, video, labels) in enumerate(train_loader): | |
| audio, video = audio.to(device), video.to(device) | |
| labels = labels.to(device) | |
| # Forward pass | |
| sync_probs, _, _ = model(audio, video) | |
| # Loss (cross-entropy on offset prediction) | |
| loss = F.cross_entropy( | |
| sync_probs.view(-1, sync_probs.size(1)), | |
| labels.view(-1) | |
| ) | |
| # Backward pass | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| # Validation | |
| model.eval() | |
| val_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for audio, video, labels in val_loader: | |
| audio, video = audio.to(device), video.to(device) | |
| labels = labels.to(device) | |
| sync_probs, _, _ = model(audio, video) | |
| val_loss += F.cross_entropy( | |
| sync_probs.view(-1, sync_probs.size(1)), | |
| labels.view(-1) | |
| ).item() | |
| offsets, _ = model.compute_offset(sync_probs) | |
| correct += (offsets.round() == labels).sum().item() | |
| total += labels.numel() | |
| scheduler.step() | |
| print(f"Epoch {epoch+1}/{num_epochs}") | |
| print(f" Train Loss: {total_loss/len(train_loader):.4f}") | |
| print(f" Val Loss: {val_loss/len(val_loader):.4f}") | |
| print(f" Val Accuracy: {100*correct/total:.2f}%") | |
| # ==================== EXAMPLE USAGE ==================== | |
| if __name__ == "__main__": | |
| print("Testing Transfer Learning SyncNet...") | |
| # Create model with pre-trained backbones | |
| model = SyncNet_TransferLearning( | |
| video_backbone='resnet3d', # or 'i3d' | |
| audio_backbone='vggish', # or 'wav2vec' | |
| embedding_dim=512, | |
| max_offset=15, | |
| freeze_backbone=False | |
| ) | |
| print(f"\nModel architecture:") | |
| print(f" Video encoder: {type(model.video_encoder).__name__}") | |
| print(f" Audio encoder: {type(model.audio_encoder).__name__}") | |
| # Test forward pass | |
| dummy_audio = torch.randn(2, 1, 13, 100) | |
| dummy_video = torch.randn(2, 3, 25, 112, 112) | |
| try: | |
| sync_probs, audio_feat, video_feat = model(dummy_audio, dummy_video) | |
| print(f"\nForward pass successful!") | |
| print(f" Sync probs: {sync_probs.shape}") | |
| print(f" Audio features: {audio_feat.shape}") | |
| print(f" Video features: {video_feat.shape}") | |
| offsets, confidences = model.compute_offset(sync_probs) | |
| print(f" Offsets: {offsets.shape}") | |
| print(f" Confidences: {confidences.shape}") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # Count parameters | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"\nParameters:") | |
| print(f" Total: {total_params:,}") | |
| print(f" Trainable: {trainable_params:,}") | |