File size: 5,126 Bytes
22f8ad1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import os
import sys
from config import *

# Fix encoding issues
if sys.platform.startswith('win'):
    import locale
    locale.setlocale(locale.LC_ALL, 'C')
    os.environ['PYTHONIOENCODING'] = 'utf-8'

def download_model_weights():
    """Download model weights with proper error handling"""
    try:
        # Set environment variable to handle progress bar encoding
        os.environ['KERAS_PROGRESS_MODE'] = 'plain'
        
        # Ensure the weights directory exists
        weights_dir = os.path.join(os.getcwd(), 'model_weights')
        os.makedirs(weights_dir, exist_ok=True)
        
        # Set Keras cache directory
        os.environ['KERAS_HOME'] = weights_dir
        
        print("Downloading model weights silently...")
        return True
    except Exception as e:
        print(f"Error setting up model download: {e}")
        return False

class AttentionBlock(nn.Module):
    """Attention mechanism for focusing on important bird features"""
    def __init__(self, channels):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(channels, channels // 4, 1),
            nn.ReLU(),
            nn.Conv2d(channels // 4, channels, 1),
            nn.Sigmoid()
        )
        # Initialize Conv2d layers properly
        for m in self.attention.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return x * self.attention(x)

class BirdClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.base_model = timm.create_model(
            'efficientnet_b0',
            pretrained=True,
            num_classes=0,
            drop_rate=0.2,
            drop_path_rate=0.1
        )
        
        feature_dim = self.base_model.num_features
        
        # Match the original attention block structure
        self.attention = AttentionBlock(feature_dim)
        
        # Memory-efficient classifier
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.2, inplace=True),
            nn.Linear(feature_dim, num_classes)
        )
        
        # Match original projector dimensions (256 instead of 128)
        self.projector = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(inplace=True)
        )
        
        # Add specialized attention layers for improved groups
        self.specialized_attention = nn.ModuleDict({
            'nuthatch': AttentionBlock(feature_dim),
            'hawk': AttentionBlock(feature_dim),
            'warbler': AttentionBlock(feature_dim),
            'thrush': AttentionBlock(feature_dim),
            'flycatcher': AttentionBlock(feature_dim)
        })

    def forward(self, x, return_features=False):
        # Fix deprecated autocast
        with torch.amp.autocast(device_type='cuda'):
            features = self.base_model.forward_features(x)
            
            # Apply specialized attention for different bird groups
            attended_features = features
            for att_name, att_block in self.specialized_attention.items():
                attended_features = attended_features + att_block(features)
            
            features = attended_features / (len(self.specialized_attention) + 1)
            
            features = self.attention(features)
            
            pooled = features.mean(dim=(2, 3))
            
            if return_features:
                proj = F.normalize(self.projector(pooled), dim=1)
                logits = self.classifier(features)
                return proj, logits
                
            return self.classifier(features)

    def get_attention_maps(self, x):
        """Get attention maps for visualization"""
        features = self.base_model.forward_features(x)
        att1_map = self.attention(features)
        return att1_map

    def initialize_specialized_attention(self):
        """Initialize specialized attention layers properly"""
        for name, module in self.specialized_attention.items():
            # Properly initialize each Conv2d layer in the attention block
            for m in module.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)

def create_model(num_classes):
    """Create model with specified number of classes"""
    if not isinstance(num_classes, int) or num_classes < 1:
        raise ValueError(f"Invalid num_classes: {num_classes}")
    print(f"Creating model with {num_classes} classes")
    return BirdClassifier(num_classes)