Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
import open_clip | |
import pymap3d as pm | |
import reverse_geocode | |
import json | |
def bounding_box_from_circle(lat_center, lon_center, radius = 1000, | |
disable_latitude_compensation=False): | |
''' | |
radius is in meters determined at the equator | |
warning: doesn't handle the poles or the 180th meridian very well, it might loop give a bad bounding box | |
should probably define a check to make sure the radius isn't too big | |
''' | |
thetas = np.linspace(0,2*np.pi, 5) | |
x, y = radius*np.cos(thetas), radius*np.sin(thetas) | |
if not disable_latitude_compensation: | |
# use tangent plane boxes, defined in meters at location | |
lat, lon, alt = pm.enu2geodetic(x, y, 0, lat_center, lon_center, 0) | |
else: | |
# use lat-lon boxes, defined in meters at equator | |
lat, lon, alt = pm.enu2geodetic(x, y, 0, 0, 0, 0) | |
lat = lat + lat_center | |
lon = lon + lon_center | |
b,t = lat[3], lat[1] | |
l,r = lon[2], lon[0] | |
return l,b,r,t | |
#imgs = np.load("imgs.npy") | |
sat_embeds = np.load("sat_embeds_new.npy") | |
coordinates = np.load("coordinates_new.npy") | |
loc_embeds = np.load("loc_embeds_new.npy") | |
txt_emb = torch.from_numpy(np.load("txt_emb_species.npy", mmap_mode="r")) | |
txt_names_json = "txt_emb_species.json" | |
with open(txt_names_json) as f: | |
txt_names = json.load(f) | |
transform = transforms.Compose([ | |
transforms.Resize((256,256)), | |
transforms.CenterCrop((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
model, *_ = open_clip.create_model_and_transforms('hf-hub:MVRL/taxabind-vit-b-16') | |
model.eval() | |
def format_name(taxon, common): | |
taxon = " ".join(taxon) | |
if not common: | |
return taxon | |
return f"{taxon} ({common})" | |
def process(input_image): | |
img_tensor = transform(input_image).unsqueeze(0) | |
with torch.no_grad(): | |
img_embed = model(img_tensor)[0].detach().cpu() | |
sims = torch.matmul(torch.tensor(sat_embeds), img_embed.t()) | |
sims_locs = torch.matmul(torch.nn.functional.normalize(torch.tensor(loc_embeds), dim=-1), img_embed.t()) | |
sims_txt = torch.matmul(torch.tensor(txt_emb).t(), img_embed.t()) | |
topk = torch.topk(sims, 5, dim=0) | |
topk_locs = torch.topk(sims_locs, 5, dim=0) | |
topk_txt = torch.topk(sims_txt, 5, dim=0) | |
images = [] | |
d = {} | |
d_species = {} | |
for i in range(5): | |
lat, lon = coordinates[topk.indices[i]] | |
l,b,r,t = bounding_box_from_circle(float(lat),float(lon),1280, | |
disable_latitude_compensation=True) | |
image_url = f"https://tiles.maps.eox.at/wms?service=WMS&version=1.1.1&request=GetMap&layers=s2cloudless-2020&styles=&width=256&height=256&srs=EPSG:4326&bbox={l},{b},{r},{t}&format=image/png" | |
images.append(image_url) | |
code = reverse_geocode.get([lat, lon]) | |
d.update({f"{code['city']}, {code['country']} ({lat:.4f}, {lon:.4f})": topk.values[i].item()}) | |
d_species.update({f"{format_name(*txt_names[topk_txt.indices[i]])}": topk_txt.values[i].item()}) | |
return d_species, d, [(np.array(input_image), "Query Image")] + [(images[i], f"Result {i+1}") for i in range(5)] | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<div> | |
<h1>TaxaBind</h1> | |
<span>A Unified Embedding Space for Ecological Applications</span> | |
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\ | |
<a href="https://vishu26.github.io/">Srikumar Sastry</a>, | |
<a href="https://subash-khanal.github.io/">Subash Khanal</a>, | |
<a href="https://sites.wustl.edu/aayush/">Aayush Dhakal</a>, | |
<a href="https://adealgis.wixsite.com/adeel-ahmad-geog">Adeel Ahmad</a>, | |
<a href="https://jacobsn.github.io/">Nathan Jacobs</a> | |
</h2> | |
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2> | |
</div> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(sources='upload', type="pil", height=400) | |
run_button = gr.Button(value="Run") | |
with gr.Column(): | |
species = gr.Label(label="Species Classification", num_top_classes=5, show_label=True) | |
with gr.Row(): | |
with gr.Column(): | |
coords = gr.Label(label="Image -> Location Retrieval (Top 5)", num_top_classes=5, show_label=True) | |
with gr.Column(): | |
result_gallery = gr.Gallery(label='Image -> Satellite Image Retrieval (Top 5)', elem_id="gallery", object_fit="contain", height="auto", columns=[2], rows=[3]) | |
ips = [input_image] | |
run_button.click(fn=process, inputs=ips, outputs=[species, coords, result_gallery]) | |
block.launch(share=True) | |