Spaces:
Running
Running
File size: 4,936 Bytes
9ff2b84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
import numpy as np
import torch
from torchvision import transforms
import open_clip
import pymap3d as pm
import reverse_geocode
import json
def bounding_box_from_circle(lat_center, lon_center, radius = 1000,
disable_latitude_compensation=False):
'''
radius is in meters determined at the equator
warning: doesn't handle the poles or the 180th meridian very well, it might loop give a bad bounding box
should probably define a check to make sure the radius isn't too big
'''
thetas = np.linspace(0,2*np.pi, 5)
x, y = radius*np.cos(thetas), radius*np.sin(thetas)
if not disable_latitude_compensation:
# use tangent plane boxes, defined in meters at location
lat, lon, alt = pm.enu2geodetic(x, y, 0, lat_center, lon_center, 0)
else:
# use lat-lon boxes, defined in meters at equator
lat, lon, alt = pm.enu2geodetic(x, y, 0, 0, 0, 0)
lat = lat + lat_center
lon = lon + lon_center
b,t = lat[3], lat[1]
l,r = lon[2], lon[0]
return l,b,r,t
#imgs = np.load("imgs.npy")
sat_embeds = np.load("sat_embeds_new.npy")
coordinates = np.load("coordinates_new.npy")
loc_embeds = np.load("loc_embeds_new.npy")
txt_emb = torch.from_numpy(np.load("txt_emb_species.npy", mmap_mode="r"))
txt_names_json = "txt_emb_species.json"
with open(txt_names_json) as f:
txt_names = json.load(f)
transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
model, *_ = open_clip.create_model_and_transforms('hf-hub:MVRL/taxabind-vit-b-16')
model.eval()
def format_name(taxon, common):
taxon = " ".join(taxon)
if not common:
return taxon
return f"{taxon} ({common})"
def process(input_image):
img_tensor = transform(input_image).unsqueeze(0)
with torch.no_grad():
img_embed = model(img_tensor)[0].detach().cpu()
sims = torch.matmul(torch.tensor(sat_embeds), img_embed.t())
sims_locs = torch.matmul(torch.nn.functional.normalize(torch.tensor(loc_embeds), dim=-1), img_embed.t())
sims_txt = torch.matmul(torch.tensor(txt_emb).t(), img_embed.t())
topk = torch.topk(sims, 5, dim=0)
topk_locs = torch.topk(sims_locs, 5, dim=0)
topk_txt = torch.topk(sims_txt, 5, dim=0)
images = []
d = {}
d_species = {}
for i in range(5):
lat, lon = coordinates[topk.indices[i]]
l,b,r,t = bounding_box_from_circle(float(lat),float(lon),1280,
disable_latitude_compensation=True)
image_url = f"https://tiles.maps.eox.at/wms?service=WMS&version=1.1.1&request=GetMap&layers=s2cloudless-2020&styles=&width=256&height=256&srs=EPSG:4326&bbox={l},{b},{r},{t}&format=image/png"
images.append(image_url)
code = reverse_geocode.get([lat, lon])
d.update({f"{code['city']}, {code['country']} ({lat:.4f}, {lon:.4f})": topk.values[i].item()})
d_species.update({f"{format_name(*txt_names[topk_txt.indices[i]])}": topk_txt.values[i].item()})
return d_species, d, [(np.array(input_image), "Query Image")] + [(images[i], f"Result {i+1}") for i in range(5)]
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>TaxaBind</h1>
<span>A Unified Embedding Space for Ecological Applications</span>
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
<a href="https://vishu26.github.io/">Srikumar Sastry</a>,
<a href="https://subash-khanal.github.io/">Subash Khanal</a>,
<a href="https://sites.wustl.edu/aayush/">Aayush Dhakal</a>,
<a href="https://adealgis.wixsite.com/adeel-ahmad-geog">Adeel Ahmad</a>,
<a href="https://jacobsn.github.io/">Nathan Jacobs</a>
</h2>
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
</div>
</div>
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(sources='upload', type="pil", height=400)
run_button = gr.Button(value="Run")
with gr.Column():
species = gr.Label(label="Species Classification", num_top_classes=5, show_label=True)
with gr.Row():
with gr.Column():
coords = gr.Label(label="Image -> Location Retrieval (Top 5)", num_top_classes=5, show_label=True)
with gr.Column():
result_gallery = gr.Gallery(label='Image -> Satellite Image Retrieval (Top 5)', elem_id="gallery", object_fit="contain", height="auto", columns=[2], rows=[3])
ips = [input_image]
run_button.click(fn=process, inputs=ips, outputs=[species, coords, result_gallery])
block.launch(share=True)
|