Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| import os | |
| import sys | |
| from config import * | |
| # Fix encoding issues | |
| if sys.platform.startswith('win'): | |
| import locale | |
| locale.setlocale(locale.LC_ALL, 'C') | |
| os.environ['PYTHONIOENCODING'] = 'utf-8' | |
| def download_model_weights(): | |
| """Download model weights with proper error handling""" | |
| try: | |
| # Set environment variable to handle progress bar encoding | |
| os.environ['KERAS_PROGRESS_MODE'] = 'plain' | |
| # Ensure the weights directory exists | |
| weights_dir = os.path.join(os.getcwd(), 'model_weights') | |
| os.makedirs(weights_dir, exist_ok=True) | |
| # Set Keras cache directory | |
| os.environ['KERAS_HOME'] = weights_dir | |
| print("Downloading model weights silently...") | |
| return True | |
| except Exception as e: | |
| print(f"Error setting up model download: {e}") | |
| return False | |
| class AttentionBlock(nn.Module): | |
| """Attention mechanism for focusing on important bird features""" | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.attention = nn.Sequential( | |
| nn.Conv2d(channels, channels // 4, 1), | |
| nn.ReLU(), | |
| nn.Conv2d(channels // 4, channels, 1), | |
| nn.Sigmoid() | |
| ) | |
| # Initialize Conv2d layers properly | |
| for m in self.attention.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| return x * self.attention(x) | |
| class BirdClassifier(nn.Module): | |
| def __init__(self, num_classes): | |
| super().__init__() | |
| self.base_model = timm.create_model( | |
| 'efficientnet_b0', | |
| pretrained=True, | |
| num_classes=0, | |
| drop_rate=0.2, | |
| drop_path_rate=0.1 | |
| ) | |
| feature_dim = self.base_model.num_features | |
| # Match the original attention block structure | |
| self.attention = AttentionBlock(feature_dim) | |
| # Memory-efficient classifier | |
| self.classifier = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Flatten(), | |
| nn.Dropout(0.2, inplace=True), | |
| nn.Linear(feature_dim, num_classes) | |
| ) | |
| # Match original projector dimensions (256 instead of 128) | |
| self.projector = nn.Sequential( | |
| nn.Linear(feature_dim, 256), | |
| nn.ReLU(inplace=True) | |
| ) | |
| # Add specialized attention layers for improved groups | |
| self.specialized_attention = nn.ModuleDict({ | |
| 'nuthatch': AttentionBlock(feature_dim), | |
| 'hawk': AttentionBlock(feature_dim), | |
| 'warbler': AttentionBlock(feature_dim), | |
| 'thrush': AttentionBlock(feature_dim), | |
| 'flycatcher': AttentionBlock(feature_dim) | |
| }) | |
| def forward(self, x, return_features=False): | |
| # Fix deprecated autocast | |
| with torch.amp.autocast(device_type='cuda'): | |
| features = self.base_model.forward_features(x) | |
| # Apply specialized attention for different bird groups | |
| attended_features = features | |
| for att_name, att_block in self.specialized_attention.items(): | |
| attended_features = attended_features + att_block(features) | |
| features = attended_features / (len(self.specialized_attention) + 1) | |
| features = self.attention(features) | |
| pooled = features.mean(dim=(2, 3)) | |
| if return_features: | |
| proj = F.normalize(self.projector(pooled), dim=1) | |
| logits = self.classifier(features) | |
| return proj, logits | |
| return self.classifier(features) | |
| def get_attention_maps(self, x): | |
| """Get attention maps for visualization""" | |
| features = self.base_model.forward_features(x) | |
| att1_map = self.attention(features) | |
| return att1_map | |
| def initialize_specialized_attention(self): | |
| """Initialize specialized attention layers properly""" | |
| for name, module in self.specialized_attention.items(): | |
| # Properly initialize each Conv2d layer in the attention block | |
| for m in module.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def create_model(num_classes): | |
| """Create model with specified number of classes""" | |
| if not isinstance(num_classes, int) or num_classes < 1: | |
| raise ValueError(f"Invalid num_classes: {num_classes}") | |
| print(f"Creating model with {num_classes} classes") | |
| return BirdClassifier(num_classes) | |