File size: 6,663 Bytes
92e1bed
 
baf423e
92e1bed
 
baf423e
 
6815f9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baf423e
 
 
 
 
 
 
 
 
92e1bed
baf423e
 
 
 
89305b2
 
 
 
 
baf423e
89305b2
baf423e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89305b2
baf423e
 
 
89305b2
baf423e
89305b2
9d2e0bc
baf423e
92e1bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d987a0
 
92e1bed
5d987a0
 
 
 
 
92e1bed
 
4e296fb
 
 
 
 
5d987a0
db699b3
 
 
 
 
4e296fb
 
 
 
 
 
 
5d987a0
4e296fb
5d987a0
 
 
 
 
 
 
4e296fb
db699b3
4e296fb
 
db699b3
4e296fb
 
 
 
 
92e1bed
db699b3
4e296fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as T
import timm

# Add species list
SPECIES_LIST = [
    'Acadian Flycatcher', 'American Bittern', 'American Crow', 'American Goldfinch',
    'American Kestrel', 'American Redstart', 'American Robin', 'American Woodcock',
    'Baltimore Oriole', 'Barn Swallow', 'Bay-breasted Warbler', 'Belted Kingfisher',
    'Black-and-white Warbler', 'Black-billed Cuckoo', 'Blackburnian Warbler',
    'Black-capped Chickadee', 'Black-throated Blue Warbler', 'Black-throated Green Warbler',
    'Blue Grosbeak', 'Blue Jay', 'Blue-gray Gnatcatcher', 'Blue-headed Vireo',
    'Bobolink', 'Brown Creeper', 'Brown Thrasher', 'Brown-headed Cowbird',
    'Canada Warbler', 'Cape May Warbler', 'Carolina Chickadee', 'Carolina Wren',
    'Cedar Waxwing', 'Chestnut-sided Warbler', 'Chimney Swift', 'Chipping Sparrow',
    'Clapper Rail', 'Common Grackle', 'Common Yellowthroat', 'Connecticut Warbler',
    "Cooper's Hawk", 'Dark-eyed Junco', 'Downy Woodpecker', 'Eastern Bluebird',
    'Eastern Kingbird', 'Eastern Meadowlark', 'Eastern Phoebe', 'Eastern Screech-Owl',
    'Eastern Towhee', 'Eastern Wood-Pewee', 'European Starling', 'Field Sparrow',
    'Fox Sparrow', 'Golden-crowned Kinglet', 'Golden-winged Warbler', 'Grasshopper Sparrow',
    'Gray Catbird', 'Gray-cheeked Thrush', 'Great Crested Flycatcher', 'Great Horned Owl',
    'Hairy Woodpecker', 'Hermit Thrush', 'Hooded Warbler', 'House Finch',
    'House Sparrow', 'House Wren', 'Ruby-throated Hummingbird', 'Indigo Bunting',
    'Kentucky Warbler', 'Least Flycatcher', 'Lincoln\'s Sparrow', 'Louisiana Waterthrush',
    'Magnolia Warbler', 'Mallard', 'Marsh Wren', 'Mourning Dove',
    'Mourning Warbler', 'Nashville Warbler', 'Northern Cardinal', 'Northern Flicker',
    'Northern Mockingbird', 'Northern Parula', 'Northern Saw-whet Owl', 'Northern Waterthrush',
    'Orange-crowned Warbler', 'Ovenbird', 'Palm Warbler', 'Pileated Woodpecker',
    'Pine Siskin', 'Pine Warbler', 'Purple Finch', 'Purple Martin',
    'Red-eyed Vireo', 'Red-bellied Woodpecker', 'Red-breasted Nuthatch', 'Red-headed Woodpecker',
    'Red-shouldered Hawk', 'Red-tailed Hawk', 'Rose-breasted Grosbeak', 'Ruby-crowned Kinglet',
    'Savannah Sparrow', 'Scarlet Tanager', 'Sharp-shinned Hawk', 'Song Sparrow',
    'Sora', 'Summer Tanager', 'Swainson\'s Thrush', 'Swamp Sparrow',
    'Tennessee Warbler', 'Tufted Titmouse', 'Veery', 'Virginia Rail',
    'Whip-poor-will', 'White-breasted Nuthatch', 'White-crowned Sparrow', 'White-throated Sparrow',
    'Willow Flycatcher', 'Wilson\'s Warbler', 'Winter Wren', 'Wood Duck',
    'Wood Thrush', 'Worm-eating Warbler', 'Yellow-bellied Sapsucker', 'Yellow-billed Cuckoo',
    'Yellow-breasted Chat', 'Yellow Rail', 'Yellow Warbler', 'Yellow-bellied Flycatcher',
    'Yellow-rumped Warbler'
]

class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(channels, channels // 4, 1),
            nn.ReLU(),
            nn.Conv2d(channels // 4, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.attention(x)

class BirdClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.base_model = timm.create_model(
            'efficientnet_b0',
            pretrained=False,
            num_classes=0
        )
        
        feature_dim = self.base_model.num_features
        self.attention = AttentionBlock(feature_dim)
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.2),
            nn.Linear(feature_dim, num_classes)
        )
        
        self.projector = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU()
        )

    def forward(self, x):
        features = self.base_model.forward_features(x)
        features = self.attention(features)
        return self.classifier(features)

# Initialize model and load weights
model = BirdClassifier(num_classes=127)
checkpoint = torch.load('model/bird_model_final_91.23.pth', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.eval()

# Preprocessing
transform = T.Compose([
    T.Resize((512, 512)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict(image):
    if image is None:
        return "Please upload an image"
    
    # Process image
    img_tensor = transform(image).unsqueeze(0)
    
    # Get predictions
    with torch.no_grad():
        output = model(img_tensor)
        probs = F.softmax(output[0], dim=0)
        
    # Get top 2 predictions only
    top2_prob, top2_idx = torch.topk(probs, 2)
    
    # Format results with markdown for bold
    results = (
        f"**Most likely: {SPECIES_LIST[top2_idx[0]]}: {top2_prob[0]*100:.1f}%**\n"
        f"Second possibility: {SPECIES_LIST[top2_idx[1]]}: {top2_prob[1]*100:.1f}%"
    )
    return results

# Create interface with updated UI and layout
with gr.Blocks(theme="huggingface") as demo:
    gr.Markdown("# DMV Bird Fatality Identifier")
    
    gr.Markdown("""
    This tool helps identify bird fatalities from window collisions and cat attacks in the DMV area.
    Currently covers 127 species documented in the OneStone Bird Map project (onestonebirdmap.xyz).
    
    Future Development:
    Once better computational resources are available, this model will be expanded to cover all native
    and migratory birds in the DMV area, significantly increasing its utility for fatality reporting.
    """)
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Upload photo of deceased bird")
            predict_btn = gr.Button("Identify Bird")
            output_text = gr.Markdown(label="Possible Species Identification")
    
    gr.Markdown("""
    Usage tips:
    - Take photo in good lighting
    - Include the whole bird if possible
    - Multiple angles can help with identification
    - This is an aid for identification, please consult with experts for confirmation
    
    After identification, please report the fatality at onestonebirdmap.xyz
    """)
    
    with gr.Accordion("Click to see all species this tool can identify"):
        gr.Markdown("\n".join([f"- {species}" for species in sorted(SPECIES_LIST)]))
    
    predict_btn.click(
        fn=predict,
        inputs=input_image,
        outputs=output_text
    )

# Launch with public sharing enabled
demo.launch(share=True)