Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms as T | |
| import timm | |
| # Add species list | |
| SPECIES_LIST = [ | |
| 'Acadian Flycatcher', 'American Bittern', 'American Crow', 'American Goldfinch', | |
| 'American Kestrel', 'American Redstart', 'American Robin', 'American Woodcock', | |
| 'Baltimore Oriole', 'Barn Swallow', 'Bay-breasted Warbler', 'Belted Kingfisher', | |
| 'Black-and-white Warbler', 'Black-billed Cuckoo', 'Blackburnian Warbler', | |
| 'Black-capped Chickadee', 'Black-throated Blue Warbler', 'Black-throated Green Warbler', | |
| 'Blue Grosbeak', 'Blue Jay', 'Blue-gray Gnatcatcher', 'Blue-headed Vireo', | |
| 'Bobolink', 'Brown Creeper', 'Brown Thrasher', 'Brown-headed Cowbird', | |
| 'Canada Warbler', 'Cape May Warbler', 'Carolina Chickadee', 'Carolina Wren', | |
| 'Cedar Waxwing', 'Chestnut-sided Warbler', 'Chimney Swift', 'Chipping Sparrow', | |
| 'Clapper Rail', 'Common Grackle', 'Common Yellowthroat', 'Connecticut Warbler', | |
| "Cooper's Hawk", 'Dark-eyed Junco', 'Downy Woodpecker', 'Eastern Bluebird', | |
| 'Eastern Kingbird', 'Eastern Meadowlark', 'Eastern Phoebe', 'Eastern Screech-Owl', | |
| 'Eastern Towhee', 'Eastern Wood-Pewee', 'European Starling', 'Field Sparrow', | |
| 'Fox Sparrow', 'Golden-crowned Kinglet', 'Golden-winged Warbler', 'Grasshopper Sparrow', | |
| 'Gray Catbird', 'Gray-cheeked Thrush', 'Great Crested Flycatcher', 'Great Horned Owl', | |
| 'Hairy Woodpecker', 'Hermit Thrush', 'Hooded Warbler', 'House Finch', | |
| 'House Sparrow', 'House Wren', 'Ruby-throated Hummingbird', 'Indigo Bunting', | |
| 'Kentucky Warbler', 'Least Flycatcher', 'Lincoln\'s Sparrow', 'Louisiana Waterthrush', | |
| 'Magnolia Warbler', 'Mallard', 'Marsh Wren', 'Mourning Dove', | |
| 'Mourning Warbler', 'Nashville Warbler', 'Northern Cardinal', 'Northern Flicker', | |
| 'Northern Mockingbird', 'Northern Parula', 'Northern Saw-whet Owl', 'Northern Waterthrush', | |
| 'Orange-crowned Warbler', 'Ovenbird', 'Palm Warbler', 'Pileated Woodpecker', | |
| 'Pine Siskin', 'Pine Warbler', 'Purple Finch', 'Purple Martin', | |
| 'Red-eyed Vireo', 'Red-bellied Woodpecker', 'Red-breasted Nuthatch', 'Red-headed Woodpecker', | |
| 'Red-shouldered Hawk', 'Red-tailed Hawk', 'Rose-breasted Grosbeak', 'Ruby-crowned Kinglet', | |
| 'Savannah Sparrow', 'Scarlet Tanager', 'Sharp-shinned Hawk', 'Song Sparrow', | |
| 'Sora', 'Summer Tanager', 'Swainson\'s Thrush', 'Swamp Sparrow', | |
| 'Tennessee Warbler', 'Tufted Titmouse', 'Veery', 'Virginia Rail', | |
| 'Whip-poor-will', 'White-breasted Nuthatch', 'White-crowned Sparrow', 'White-throated Sparrow', | |
| 'Willow Flycatcher', 'Wilson\'s Warbler', 'Winter Wren', 'Wood Duck', | |
| 'Wood Thrush', 'Worm-eating Warbler', 'Yellow-bellied Sapsucker', 'Yellow-billed Cuckoo', | |
| 'Yellow-breasted Chat', 'Yellow Rail', 'Yellow Warbler', 'Yellow-bellied Flycatcher', | |
| 'Yellow-rumped Warbler' | |
| ] | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.attention = nn.Sequential( | |
| nn.Conv2d(channels, channels // 4, 1), | |
| nn.ReLU(), | |
| nn.Conv2d(channels // 4, channels, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| return x * self.attention(x) | |
| class BirdClassifier(nn.Module): | |
| def __init__(self, num_classes): | |
| super().__init__() | |
| self.base_model = timm.create_model( | |
| 'efficientnet_b0', | |
| pretrained=False, | |
| num_classes=0 | |
| ) | |
| feature_dim = self.base_model.num_features | |
| self.attention = AttentionBlock(feature_dim) | |
| self.classifier = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Flatten(), | |
| nn.Dropout(0.2), | |
| nn.Linear(feature_dim, num_classes) | |
| ) | |
| self.projector = nn.Sequential( | |
| nn.Linear(feature_dim, 256), | |
| nn.ReLU() | |
| ) | |
| def forward(self, x): | |
| features = self.base_model.forward_features(x) | |
| features = self.attention(features) | |
| return self.classifier(features) | |
| # Initialize model and load weights | |
| model = BirdClassifier(num_classes=127) | |
| checkpoint = torch.load('model/bird_model_final_91.23.pth', map_location='cpu') | |
| model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| model.eval() | |
| # Preprocessing | |
| transform = T.Compose([ | |
| T.Resize((512, 512)), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(image): | |
| if image is None: | |
| return "Please upload an image" | |
| # Process image | |
| img_tensor = transform(image).unsqueeze(0) | |
| # Get predictions | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| probs = F.softmax(output[0], dim=0) | |
| # Get top 2 predictions only | |
| top2_prob, top2_idx = torch.topk(probs, 2) | |
| # Format results with markdown for bold | |
| results = ( | |
| f"**Most likely: {SPECIES_LIST[top2_idx[0]]}: {top2_prob[0]*100:.1f}%**\n" | |
| f"Second possibility: {SPECIES_LIST[top2_idx[1]]}: {top2_prob[1]*100:.1f}%" | |
| ) | |
| return results | |
| # Create interface with updated UI and layout | |
| with gr.Blocks(theme="huggingface") as demo: | |
| gr.Markdown("# DMV Bird Fatality Identifier") | |
| gr.Markdown(""" | |
| This tool helps identify bird fatalities from window collisions and cat attacks in the DMV area. | |
| Currently covers 127 species documented in the OneStone Bird Map project (onestonebirdmap.xyz). | |
| Future Development: | |
| Once better computational resources are available, this model will be expanded to cover all native | |
| and migratory birds in the DMV area, significantly increasing its utility for fatality reporting. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Upload photo of deceased bird") | |
| predict_btn = gr.Button("Identify Bird") | |
| output_text = gr.Markdown(label="Possible Species Identification") | |
| gr.Markdown(""" | |
| Usage tips: | |
| - Take photo in good lighting | |
| - Include the whole bird if possible | |
| - Multiple angles can help with identification | |
| - This is an aid for identification, please consult with experts for confirmation | |
| After identification, please report the fatality at onestonebirdmap.xyz | |
| """) | |
| with gr.Accordion("Click to see all species this tool can identify"): | |
| gr.Markdown("\n".join([f"- {species}" for species in sorted(SPECIES_LIST)])) | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=input_image, | |
| outputs=output_text | |
| ) | |
| # Launch with public sharing enabled | |
| demo.launch(share=True) | |