OneStoneBirdID / model.py
jandreanalytics's picture
Upload model.py
22f8ad1 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import os
import sys
from config import *
# Fix encoding issues
if sys.platform.startswith('win'):
import locale
locale.setlocale(locale.LC_ALL, 'C')
os.environ['PYTHONIOENCODING'] = 'utf-8'
def download_model_weights():
"""Download model weights with proper error handling"""
try:
# Set environment variable to handle progress bar encoding
os.environ['KERAS_PROGRESS_MODE'] = 'plain'
# Ensure the weights directory exists
weights_dir = os.path.join(os.getcwd(), 'model_weights')
os.makedirs(weights_dir, exist_ok=True)
# Set Keras cache directory
os.environ['KERAS_HOME'] = weights_dir
print("Downloading model weights silently...")
return True
except Exception as e:
print(f"Error setting up model download: {e}")
return False
class AttentionBlock(nn.Module):
"""Attention mechanism for focusing on important bird features"""
def __init__(self, channels):
super().__init__()
self.attention = nn.Sequential(
nn.Conv2d(channels, channels // 4, 1),
nn.ReLU(),
nn.Conv2d(channels // 4, channels, 1),
nn.Sigmoid()
)
# Initialize Conv2d layers properly
for m in self.attention.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
return x * self.attention(x)
class BirdClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.base_model = timm.create_model(
'efficientnet_b0',
pretrained=True,
num_classes=0,
drop_rate=0.2,
drop_path_rate=0.1
)
feature_dim = self.base_model.num_features
# Match the original attention block structure
self.attention = AttentionBlock(feature_dim)
# Memory-efficient classifier
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(0.2, inplace=True),
nn.Linear(feature_dim, num_classes)
)
# Match original projector dimensions (256 instead of 128)
self.projector = nn.Sequential(
nn.Linear(feature_dim, 256),
nn.ReLU(inplace=True)
)
# Add specialized attention layers for improved groups
self.specialized_attention = nn.ModuleDict({
'nuthatch': AttentionBlock(feature_dim),
'hawk': AttentionBlock(feature_dim),
'warbler': AttentionBlock(feature_dim),
'thrush': AttentionBlock(feature_dim),
'flycatcher': AttentionBlock(feature_dim)
})
def forward(self, x, return_features=False):
# Fix deprecated autocast
with torch.amp.autocast(device_type='cuda'):
features = self.base_model.forward_features(x)
# Apply specialized attention for different bird groups
attended_features = features
for att_name, att_block in self.specialized_attention.items():
attended_features = attended_features + att_block(features)
features = attended_features / (len(self.specialized_attention) + 1)
features = self.attention(features)
pooled = features.mean(dim=(2, 3))
if return_features:
proj = F.normalize(self.projector(pooled), dim=1)
logits = self.classifier(features)
return proj, logits
return self.classifier(features)
def get_attention_maps(self, x):
"""Get attention maps for visualization"""
features = self.base_model.forward_features(x)
att1_map = self.attention(features)
return att1_map
def initialize_specialized_attention(self):
"""Initialize specialized attention layers properly"""
for name, module in self.specialized_attention.items():
# Properly initialize each Conv2d layer in the attention block
for m in module.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def create_model(num_classes):
"""Create model with specified number of classes"""
if not isinstance(num_classes, int) or num_classes < 1:
raise ValueError(f"Invalid num_classes: {num_classes}")
print(f"Creating model with {num_classes} classes")
return BirdClassifier(num_classes)