OneStoneBirdID / app.py
jandreanalytics's picture
Update app.py
4e296fb verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as T
import timm
# Add species list
SPECIES_LIST = [
'Acadian Flycatcher', 'American Bittern', 'American Crow', 'American Goldfinch',
'American Kestrel', 'American Redstart', 'American Robin', 'American Woodcock',
'Baltimore Oriole', 'Barn Swallow', 'Bay-breasted Warbler', 'Belted Kingfisher',
'Black-and-white Warbler', 'Black-billed Cuckoo', 'Blackburnian Warbler',
'Black-capped Chickadee', 'Black-throated Blue Warbler', 'Black-throated Green Warbler',
'Blue Grosbeak', 'Blue Jay', 'Blue-gray Gnatcatcher', 'Blue-headed Vireo',
'Bobolink', 'Brown Creeper', 'Brown Thrasher', 'Brown-headed Cowbird',
'Canada Warbler', 'Cape May Warbler', 'Carolina Chickadee', 'Carolina Wren',
'Cedar Waxwing', 'Chestnut-sided Warbler', 'Chimney Swift', 'Chipping Sparrow',
'Clapper Rail', 'Common Grackle', 'Common Yellowthroat', 'Connecticut Warbler',
"Cooper's Hawk", 'Dark-eyed Junco', 'Downy Woodpecker', 'Eastern Bluebird',
'Eastern Kingbird', 'Eastern Meadowlark', 'Eastern Phoebe', 'Eastern Screech-Owl',
'Eastern Towhee', 'Eastern Wood-Pewee', 'European Starling', 'Field Sparrow',
'Fox Sparrow', 'Golden-crowned Kinglet', 'Golden-winged Warbler', 'Grasshopper Sparrow',
'Gray Catbird', 'Gray-cheeked Thrush', 'Great Crested Flycatcher', 'Great Horned Owl',
'Hairy Woodpecker', 'Hermit Thrush', 'Hooded Warbler', 'House Finch',
'House Sparrow', 'House Wren', 'Ruby-throated Hummingbird', 'Indigo Bunting',
'Kentucky Warbler', 'Least Flycatcher', 'Lincoln\'s Sparrow', 'Louisiana Waterthrush',
'Magnolia Warbler', 'Mallard', 'Marsh Wren', 'Mourning Dove',
'Mourning Warbler', 'Nashville Warbler', 'Northern Cardinal', 'Northern Flicker',
'Northern Mockingbird', 'Northern Parula', 'Northern Saw-whet Owl', 'Northern Waterthrush',
'Orange-crowned Warbler', 'Ovenbird', 'Palm Warbler', 'Pileated Woodpecker',
'Pine Siskin', 'Pine Warbler', 'Purple Finch', 'Purple Martin',
'Red-eyed Vireo', 'Red-bellied Woodpecker', 'Red-breasted Nuthatch', 'Red-headed Woodpecker',
'Red-shouldered Hawk', 'Red-tailed Hawk', 'Rose-breasted Grosbeak', 'Ruby-crowned Kinglet',
'Savannah Sparrow', 'Scarlet Tanager', 'Sharp-shinned Hawk', 'Song Sparrow',
'Sora', 'Summer Tanager', 'Swainson\'s Thrush', 'Swamp Sparrow',
'Tennessee Warbler', 'Tufted Titmouse', 'Veery', 'Virginia Rail',
'Whip-poor-will', 'White-breasted Nuthatch', 'White-crowned Sparrow', 'White-throated Sparrow',
'Willow Flycatcher', 'Wilson\'s Warbler', 'Winter Wren', 'Wood Duck',
'Wood Thrush', 'Worm-eating Warbler', 'Yellow-bellied Sapsucker', 'Yellow-billed Cuckoo',
'Yellow-breasted Chat', 'Yellow Rail', 'Yellow Warbler', 'Yellow-bellied Flycatcher',
'Yellow-rumped Warbler'
]
class AttentionBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.attention = nn.Sequential(
nn.Conv2d(channels, channels // 4, 1),
nn.ReLU(),
nn.Conv2d(channels // 4, channels, 1),
nn.Sigmoid()
)
def forward(self, x):
return x * self.attention(x)
class BirdClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.base_model = timm.create_model(
'efficientnet_b0',
pretrained=False,
num_classes=0
)
feature_dim = self.base_model.num_features
self.attention = AttentionBlock(feature_dim)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(0.2),
nn.Linear(feature_dim, num_classes)
)
self.projector = nn.Sequential(
nn.Linear(feature_dim, 256),
nn.ReLU()
)
def forward(self, x):
features = self.base_model.forward_features(x)
features = self.attention(features)
return self.classifier(features)
# Initialize model and load weights
model = BirdClassifier(num_classes=127)
checkpoint = torch.load('model/bird_model_final_91.23.pth', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.eval()
# Preprocessing
transform = T.Compose([
T.Resize((512, 512)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(image):
if image is None:
return "Please upload an image"
# Process image
img_tensor = transform(image).unsqueeze(0)
# Get predictions
with torch.no_grad():
output = model(img_tensor)
probs = F.softmax(output[0], dim=0)
# Get top 2 predictions only
top2_prob, top2_idx = torch.topk(probs, 2)
# Format results with markdown for bold
results = (
f"**Most likely: {SPECIES_LIST[top2_idx[0]]}: {top2_prob[0]*100:.1f}%**\n"
f"Second possibility: {SPECIES_LIST[top2_idx[1]]}: {top2_prob[1]*100:.1f}%"
)
return results
# Create interface with updated UI and layout
with gr.Blocks(theme="huggingface") as demo:
gr.Markdown("# DMV Bird Fatality Identifier")
gr.Markdown("""
This tool helps identify bird fatalities from window collisions and cat attacks in the DMV area.
Currently covers 127 species documented in the OneStone Bird Map project (onestonebirdmap.xyz).
Future Development:
Once better computational resources are available, this model will be expanded to cover all native
and migratory birds in the DMV area, significantly increasing its utility for fatality reporting.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload photo of deceased bird")
predict_btn = gr.Button("Identify Bird")
output_text = gr.Markdown(label="Possible Species Identification")
gr.Markdown("""
Usage tips:
- Take photo in good lighting
- Include the whole bird if possible
- Multiple angles can help with identification
- This is an aid for identification, please consult with experts for confirmation
After identification, please report the fatality at onestonebirdmap.xyz
""")
with gr.Accordion("Click to see all species this tool can identify"):
gr.Markdown("\n".join([f"- {species}" for species in sorted(SPECIES_LIST)]))
predict_btn.click(
fn=predict,
inputs=input_image,
outputs=output_text
)
# Launch with public sharing enabled
demo.launch(share=True)