Spaces:
Sleeping
Sleeping
File size: 5,126 Bytes
22f8ad1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import os
import sys
from config import *
# Fix encoding issues
if sys.platform.startswith('win'):
import locale
locale.setlocale(locale.LC_ALL, 'C')
os.environ['PYTHONIOENCODING'] = 'utf-8'
def download_model_weights():
"""Download model weights with proper error handling"""
try:
# Set environment variable to handle progress bar encoding
os.environ['KERAS_PROGRESS_MODE'] = 'plain'
# Ensure the weights directory exists
weights_dir = os.path.join(os.getcwd(), 'model_weights')
os.makedirs(weights_dir, exist_ok=True)
# Set Keras cache directory
os.environ['KERAS_HOME'] = weights_dir
print("Downloading model weights silently...")
return True
except Exception as e:
print(f"Error setting up model download: {e}")
return False
class AttentionBlock(nn.Module):
"""Attention mechanism for focusing on important bird features"""
def __init__(self, channels):
super().__init__()
self.attention = nn.Sequential(
nn.Conv2d(channels, channels // 4, 1),
nn.ReLU(),
nn.Conv2d(channels // 4, channels, 1),
nn.Sigmoid()
)
# Initialize Conv2d layers properly
for m in self.attention.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
return x * self.attention(x)
class BirdClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.base_model = timm.create_model(
'efficientnet_b0',
pretrained=True,
num_classes=0,
drop_rate=0.2,
drop_path_rate=0.1
)
feature_dim = self.base_model.num_features
# Match the original attention block structure
self.attention = AttentionBlock(feature_dim)
# Memory-efficient classifier
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(0.2, inplace=True),
nn.Linear(feature_dim, num_classes)
)
# Match original projector dimensions (256 instead of 128)
self.projector = nn.Sequential(
nn.Linear(feature_dim, 256),
nn.ReLU(inplace=True)
)
# Add specialized attention layers for improved groups
self.specialized_attention = nn.ModuleDict({
'nuthatch': AttentionBlock(feature_dim),
'hawk': AttentionBlock(feature_dim),
'warbler': AttentionBlock(feature_dim),
'thrush': AttentionBlock(feature_dim),
'flycatcher': AttentionBlock(feature_dim)
})
def forward(self, x, return_features=False):
# Fix deprecated autocast
with torch.amp.autocast(device_type='cuda'):
features = self.base_model.forward_features(x)
# Apply specialized attention for different bird groups
attended_features = features
for att_name, att_block in self.specialized_attention.items():
attended_features = attended_features + att_block(features)
features = attended_features / (len(self.specialized_attention) + 1)
features = self.attention(features)
pooled = features.mean(dim=(2, 3))
if return_features:
proj = F.normalize(self.projector(pooled), dim=1)
logits = self.classifier(features)
return proj, logits
return self.classifier(features)
def get_attention_maps(self, x):
"""Get attention maps for visualization"""
features = self.base_model.forward_features(x)
att1_map = self.attention(features)
return att1_map
def initialize_specialized_attention(self):
"""Initialize specialized attention layers properly"""
for name, module in self.specialized_attention.items():
# Properly initialize each Conv2d layer in the attention block
for m in module.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def create_model(num_classes):
"""Create model with specified number of classes"""
if not isinstance(num_classes, int) or num_classes < 1:
raise ValueError(f"Invalid num_classes: {num_classes}")
print(f"Creating model with {num_classes} classes")
return BirdClassifier(num_classes)
|