Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import timm | |
import os | |
import sys | |
from config import * | |
# Fix encoding issues | |
if sys.platform.startswith('win'): | |
import locale | |
locale.setlocale(locale.LC_ALL, 'C') | |
os.environ['PYTHONIOENCODING'] = 'utf-8' | |
def download_model_weights(): | |
"""Download model weights with proper error handling""" | |
try: | |
# Set environment variable to handle progress bar encoding | |
os.environ['KERAS_PROGRESS_MODE'] = 'plain' | |
# Ensure the weights directory exists | |
weights_dir = os.path.join(os.getcwd(), 'model_weights') | |
os.makedirs(weights_dir, exist_ok=True) | |
# Set Keras cache directory | |
os.environ['KERAS_HOME'] = weights_dir | |
print("Downloading model weights silently...") | |
return True | |
except Exception as e: | |
print(f"Error setting up model download: {e}") | |
return False | |
class AttentionBlock(nn.Module): | |
"""Attention mechanism for focusing on important bird features""" | |
def __init__(self, channels): | |
super().__init__() | |
self.attention = nn.Sequential( | |
nn.Conv2d(channels, channels // 4, 1), | |
nn.ReLU(), | |
nn.Conv2d(channels // 4, channels, 1), | |
nn.Sigmoid() | |
) | |
# Initialize Conv2d layers properly | |
for m in self.attention.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
return x * self.attention(x) | |
class BirdClassifier(nn.Module): | |
def __init__(self, num_classes): | |
super().__init__() | |
self.base_model = timm.create_model( | |
'efficientnet_b0', | |
pretrained=True, | |
num_classes=0, | |
drop_rate=0.2, | |
drop_path_rate=0.1 | |
) | |
feature_dim = self.base_model.num_features | |
# Match the original attention block structure | |
self.attention = AttentionBlock(feature_dim) | |
# Memory-efficient classifier | |
self.classifier = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Flatten(), | |
nn.Dropout(0.2, inplace=True), | |
nn.Linear(feature_dim, num_classes) | |
) | |
# Match original projector dimensions (256 instead of 128) | |
self.projector = nn.Sequential( | |
nn.Linear(feature_dim, 256), | |
nn.ReLU(inplace=True) | |
) | |
# Add specialized attention layers for improved groups | |
self.specialized_attention = nn.ModuleDict({ | |
'nuthatch': AttentionBlock(feature_dim), | |
'hawk': AttentionBlock(feature_dim), | |
'warbler': AttentionBlock(feature_dim), | |
'thrush': AttentionBlock(feature_dim), | |
'flycatcher': AttentionBlock(feature_dim) | |
}) | |
def forward(self, x, return_features=False): | |
# Fix deprecated autocast | |
with torch.amp.autocast(device_type='cuda'): | |
features = self.base_model.forward_features(x) | |
# Apply specialized attention for different bird groups | |
attended_features = features | |
for att_name, att_block in self.specialized_attention.items(): | |
attended_features = attended_features + att_block(features) | |
features = attended_features / (len(self.specialized_attention) + 1) | |
features = self.attention(features) | |
pooled = features.mean(dim=(2, 3)) | |
if return_features: | |
proj = F.normalize(self.projector(pooled), dim=1) | |
logits = self.classifier(features) | |
return proj, logits | |
return self.classifier(features) | |
def get_attention_maps(self, x): | |
"""Get attention maps for visualization""" | |
features = self.base_model.forward_features(x) | |
att1_map = self.attention(features) | |
return att1_map | |
def initialize_specialized_attention(self): | |
"""Initialize specialized attention layers properly""" | |
for name, module in self.specialized_attention.items(): | |
# Properly initialize each Conv2d layer in the attention block | |
for m in module.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def create_model(num_classes): | |
"""Create model with specified number of classes""" | |
if not isinstance(num_classes, int) or num_classes < 1: | |
raise ValueError(f"Invalid num_classes: {num_classes}") | |
print(f"Creating model with {num_classes} classes") | |
return BirdClassifier(num_classes) | |