Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms as T | |
import timm | |
# Add species list | |
SPECIES_LIST = [ | |
'Acadian Flycatcher', 'American Bittern', 'American Crow', 'American Goldfinch', | |
'American Kestrel', 'American Redstart', 'American Robin', 'American Woodcock', | |
'Baltimore Oriole', 'Barn Swallow', 'Bay-breasted Warbler', 'Belted Kingfisher', | |
'Black-and-white Warbler', 'Black-billed Cuckoo', 'Blackburnian Warbler', | |
'Black-capped Chickadee', 'Black-throated Blue Warbler', 'Black-throated Green Warbler', | |
'Blue Grosbeak', 'Blue Jay', 'Blue-gray Gnatcatcher', 'Blue-headed Vireo', | |
'Bobolink', 'Brown Creeper', 'Brown Thrasher', 'Brown-headed Cowbird', | |
'Canada Warbler', 'Cape May Warbler', 'Carolina Chickadee', 'Carolina Wren', | |
'Cedar Waxwing', 'Chestnut-sided Warbler', 'Chimney Swift', 'Chipping Sparrow', | |
'Clapper Rail', 'Common Grackle', 'Common Yellowthroat', 'Connecticut Warbler', | |
"Cooper's Hawk", 'Dark-eyed Junco', 'Downy Woodpecker', 'Eastern Bluebird', | |
'Eastern Kingbird', 'Eastern Meadowlark', 'Eastern Phoebe', 'Eastern Screech-Owl', | |
'Eastern Towhee', 'Eastern Wood-Pewee', 'European Starling', 'Field Sparrow', | |
'Fox Sparrow', 'Golden-crowned Kinglet', 'Golden-winged Warbler', 'Grasshopper Sparrow', | |
'Gray Catbird', 'Gray-cheeked Thrush', 'Great Crested Flycatcher', 'Great Horned Owl', | |
'Hairy Woodpecker', 'Hermit Thrush', 'Hooded Warbler', 'House Finch', | |
'House Sparrow', 'House Wren', 'Ruby-throated Hummingbird', 'Indigo Bunting', | |
'Kentucky Warbler', 'Least Flycatcher', 'Lincoln\'s Sparrow', 'Louisiana Waterthrush', | |
'Magnolia Warbler', 'Mallard', 'Marsh Wren', 'Mourning Dove', | |
'Mourning Warbler', 'Nashville Warbler', 'Northern Cardinal', 'Northern Flicker', | |
'Northern Mockingbird', 'Northern Parula', 'Northern Saw-whet Owl', 'Northern Waterthrush', | |
'Orange-crowned Warbler', 'Ovenbird', 'Palm Warbler', 'Pileated Woodpecker', | |
'Pine Siskin', 'Pine Warbler', 'Purple Finch', 'Purple Martin', | |
'Red-eyed Vireo', 'Red-bellied Woodpecker', 'Red-breasted Nuthatch', 'Red-headed Woodpecker', | |
'Red-shouldered Hawk', 'Red-tailed Hawk', 'Rose-breasted Grosbeak', 'Ruby-crowned Kinglet', | |
'Savannah Sparrow', 'Scarlet Tanager', 'Sharp-shinned Hawk', 'Song Sparrow', | |
'Sora', 'Summer Tanager', 'Swainson\'s Thrush', 'Swamp Sparrow', | |
'Tennessee Warbler', 'Tufted Titmouse', 'Veery', 'Virginia Rail', | |
'Whip-poor-will', 'White-breasted Nuthatch', 'White-crowned Sparrow', 'White-throated Sparrow', | |
'Willow Flycatcher', 'Wilson\'s Warbler', 'Winter Wren', 'Wood Duck', | |
'Wood Thrush', 'Worm-eating Warbler', 'Yellow-bellied Sapsucker', 'Yellow-billed Cuckoo', | |
'Yellow-breasted Chat', 'Yellow Rail', 'Yellow Warbler', 'Yellow-bellied Flycatcher', | |
'Yellow-rumped Warbler' | |
] | |
class AttentionBlock(nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
self.attention = nn.Sequential( | |
nn.Conv2d(channels, channels // 4, 1), | |
nn.ReLU(), | |
nn.Conv2d(channels // 4, channels, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
return x * self.attention(x) | |
class BirdClassifier(nn.Module): | |
def __init__(self, num_classes): | |
super().__init__() | |
self.base_model = timm.create_model( | |
'efficientnet_b0', | |
pretrained=False, | |
num_classes=0 | |
) | |
feature_dim = self.base_model.num_features | |
self.attention = AttentionBlock(feature_dim) | |
self.classifier = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Flatten(), | |
nn.Dropout(0.2), | |
nn.Linear(feature_dim, num_classes) | |
) | |
self.projector = nn.Sequential( | |
nn.Linear(feature_dim, 256), | |
nn.ReLU() | |
) | |
def forward(self, x): | |
features = self.base_model.forward_features(x) | |
features = self.attention(features) | |
return self.classifier(features) | |
# Initialize model and load weights | |
model = BirdClassifier(num_classes=127) | |
checkpoint = torch.load('model/bird_model_final_91.23.pth', map_location='cpu') | |
model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
model.eval() | |
# Preprocessing | |
transform = T.Compose([ | |
T.Resize((512, 512)), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
def predict(image): | |
if image is None: | |
return "Please upload an image" | |
# Process image | |
img_tensor = transform(image).unsqueeze(0) | |
# Get predictions | |
with torch.no_grad(): | |
output = model(img_tensor) | |
probs = F.softmax(output[0], dim=0) | |
# Get top 2 predictions only | |
top2_prob, top2_idx = torch.topk(probs, 2) | |
# Format results with markdown for bold | |
results = ( | |
f"**Most likely: {SPECIES_LIST[top2_idx[0]]}: {top2_prob[0]*100:.1f}%**\n" | |
f"Second possibility: {SPECIES_LIST[top2_idx[1]]}: {top2_prob[1]*100:.1f}%" | |
) | |
return results | |
# Create interface with updated UI and layout | |
with gr.Blocks(theme="huggingface") as demo: | |
gr.Markdown("# DMV Bird Fatality Identifier") | |
gr.Markdown(""" | |
This tool helps identify bird fatalities from window collisions and cat attacks in the DMV area. | |
Currently covers 127 species documented in the OneStone Bird Map project (onestonebirdmap.xyz). | |
Future Development: | |
Once better computational resources are available, this model will be expanded to cover all native | |
and migratory birds in the DMV area, significantly increasing its utility for fatality reporting. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Upload photo of deceased bird") | |
predict_btn = gr.Button("Identify Bird") | |
output_text = gr.Markdown(label="Possible Species Identification") | |
gr.Markdown(""" | |
Usage tips: | |
- Take photo in good lighting | |
- Include the whole bird if possible | |
- Multiple angles can help with identification | |
- This is an aid for identification, please consult with experts for confirmation | |
After identification, please report the fatality at onestonebirdmap.xyz | |
""") | |
with gr.Accordion("Click to see all species this tool can identify"): | |
gr.Markdown("\n".join([f"- {species}" for species in sorted(SPECIES_LIST)])) | |
predict_btn.click( | |
fn=predict, | |
inputs=input_image, | |
outputs=output_text | |
) | |
# Launch with public sharing enabled | |
demo.launch(share=True) | |