import torch import torch.nn as nn import torch.nn.functional as F import timm import os import sys from config import * # Fix encoding issues if sys.platform.startswith('win'): import locale locale.setlocale(locale.LC_ALL, 'C') os.environ['PYTHONIOENCODING'] = 'utf-8' def download_model_weights(): """Download model weights with proper error handling""" try: # Set environment variable to handle progress bar encoding os.environ['KERAS_PROGRESS_MODE'] = 'plain' # Ensure the weights directory exists weights_dir = os.path.join(os.getcwd(), 'model_weights') os.makedirs(weights_dir, exist_ok=True) # Set Keras cache directory os.environ['KERAS_HOME'] = weights_dir print("Downloading model weights silently...") return True except Exception as e: print(f"Error setting up model download: {e}") return False class AttentionBlock(nn.Module): """Attention mechanism for focusing on important bird features""" def __init__(self, channels): super().__init__() self.attention = nn.Sequential( nn.Conv2d(channels, channels // 4, 1), nn.ReLU(), nn.Conv2d(channels // 4, channels, 1), nn.Sigmoid() ) # Initialize Conv2d layers properly for m in self.attention.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): return x * self.attention(x) class BirdClassifier(nn.Module): def __init__(self, num_classes): super().__init__() self.base_model = timm.create_model( 'efficientnet_b0', pretrained=True, num_classes=0, drop_rate=0.2, drop_path_rate=0.1 ) feature_dim = self.base_model.num_features # Match the original attention block structure self.attention = AttentionBlock(feature_dim) # Memory-efficient classifier self.classifier = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Dropout(0.2, inplace=True), nn.Linear(feature_dim, num_classes) ) # Match original projector dimensions (256 instead of 128) self.projector = nn.Sequential( nn.Linear(feature_dim, 256), nn.ReLU(inplace=True) ) # Add specialized attention layers for improved groups self.specialized_attention = nn.ModuleDict({ 'nuthatch': AttentionBlock(feature_dim), 'hawk': AttentionBlock(feature_dim), 'warbler': AttentionBlock(feature_dim), 'thrush': AttentionBlock(feature_dim), 'flycatcher': AttentionBlock(feature_dim) }) def forward(self, x, return_features=False): # Fix deprecated autocast with torch.amp.autocast(device_type='cuda'): features = self.base_model.forward_features(x) # Apply specialized attention for different bird groups attended_features = features for att_name, att_block in self.specialized_attention.items(): attended_features = attended_features + att_block(features) features = attended_features / (len(self.specialized_attention) + 1) features = self.attention(features) pooled = features.mean(dim=(2, 3)) if return_features: proj = F.normalize(self.projector(pooled), dim=1) logits = self.classifier(features) return proj, logits return self.classifier(features) def get_attention_maps(self, x): """Get attention maps for visualization""" features = self.base_model.forward_features(x) att1_map = self.attention(features) return att1_map def initialize_specialized_attention(self): """Initialize specialized attention layers properly""" for name, module in self.specialized_attention.items(): # Properly initialize each Conv2d layer in the attention block for m in module.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) def create_model(num_classes): """Create model with specified number of classes""" if not isinstance(num_classes, int) or num_classes < 1: raise ValueError(f"Invalid num_classes: {num_classes}") print(f"Creating model with {num_classes} classes") return BirdClassifier(num_classes)