import gradio as gr import numpy as np import torch from torchvision import transforms import open_clip import pymap3d as pm import reverse_geocode import json def bounding_box_from_circle(lat_center, lon_center, radius = 1000, disable_latitude_compensation=False): ''' radius is in meters determined at the equator warning: doesn't handle the poles or the 180th meridian very well, it might loop give a bad bounding box should probably define a check to make sure the radius isn't too big ''' thetas = np.linspace(0,2*np.pi, 5) x, y = radius*np.cos(thetas), radius*np.sin(thetas) if not disable_latitude_compensation: # use tangent plane boxes, defined in meters at location lat, lon, alt = pm.enu2geodetic(x, y, 0, lat_center, lon_center, 0) else: # use lat-lon boxes, defined in meters at equator lat, lon, alt = pm.enu2geodetic(x, y, 0, 0, 0, 0) lat = lat + lat_center lon = lon + lon_center b,t = lat[3], lat[1] l,r = lon[2], lon[0] return l,b,r,t #imgs = np.load("imgs.npy") sat_embeds = np.load("sat_embeds_new.npy") coordinates = np.load("coordinates_new.npy") loc_embeds = np.load("loc_embeds_new.npy") txt_emb = torch.from_numpy(np.load("txt_emb_species.npy", mmap_mode="r")) txt_names_json = "txt_emb_species.json" with open(txt_names_json) as f: txt_names = json.load(f) transform = transforms.Compose([ transforms.Resize((256,256)), transforms.CenterCrop((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) model, *_ = open_clip.create_model_and_transforms('hf-hub:MVRL/taxabind-vit-b-16') model.eval() def format_name(taxon, common): taxon = " ".join(taxon) if not common: return taxon return f"{taxon} ({common})" def process(input_image): img_tensor = transform(input_image).unsqueeze(0) with torch.no_grad(): img_embed = model(img_tensor)[0].detach().cpu() sims = torch.matmul(torch.tensor(sat_embeds), img_embed.t()) sims_locs = torch.matmul(torch.nn.functional.normalize(torch.tensor(loc_embeds), dim=-1), img_embed.t()) sims_txt = torch.matmul(torch.tensor(txt_emb).t(), img_embed.t()) topk = torch.topk(sims, 5, dim=0) topk_locs = torch.topk(sims_locs, 5, dim=0) topk_txt = torch.topk(sims_txt, 5, dim=0) images = [] d = {} d_species = {} for i in range(5): lat, lon = coordinates[topk.indices[i]] l,b,r,t = bounding_box_from_circle(float(lat),float(lon),1280, disable_latitude_compensation=True) image_url = f"https://tiles.maps.eox.at/wms?service=WMS&version=1.1.1&request=GetMap&layers=s2cloudless-2020&styles=&width=256&height=256&srs=EPSG:4326&bbox={l},{b},{r},{t}&format=image/png" images.append(image_url) code = reverse_geocode.get([lat, lon]) d.update({f"{code['city']}, {code['country']} ({lat:.4f}, {lon:.4f})": topk.values[i].item()}) d_species.update({f"{format_name(*txt_names[topk_txt.indices[i]])}": topk_txt.values[i].item()}) return d_species, d, [(np.array(input_image), "Query Image")] + [(images[i], f"Result {i+1}") for i in range(5)] block = gr.Blocks().queue() with block: with gr.Row(): gr.Markdown( """

TaxaBind

A Unified Embedding Space for Ecological Applications

\ Srikumar Sastry, Subash Khanal, Aayush Dhakal, Adeel Ahmad, Nathan Jacobs

WACV 2025

""" ) with gr.Row(): with gr.Column(): input_image = gr.Image(sources='upload', type="pil", height=400) run_button = gr.Button(value="Run") with gr.Column(): species = gr.Label(label="Species Classification", num_top_classes=5, show_label=True) with gr.Row(): with gr.Column(): coords = gr.Label(label="Image -> Location Retrieval (Top 5)", num_top_classes=5, show_label=True) with gr.Column(): result_gallery = gr.Gallery(label='Image -> Satellite Image Retrieval (Top 5)', elem_id="gallery", object_fit="contain", height="auto", columns=[2], rows=[3]) ips = [input_image] run_button.click(fn=process, inputs=ips, outputs=[species, coords, result_gallery]) block.launch(share=True)