File size: 4,936 Bytes
9ff2b84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import numpy as np
import torch
from torchvision import transforms
import open_clip
import pymap3d as pm
import reverse_geocode
import json


def bounding_box_from_circle(lat_center, lon_center, radius = 1000,
    disable_latitude_compensation=False):
  '''
  radius is in meters determined at the equator

  warning: doesn't handle the poles or the 180th meridian very well, it might loop give a bad bounding box
   should probably define a check to make sure the radius isn't too big
  '''

  thetas = np.linspace(0,2*np.pi, 5)
  x, y = radius*np.cos(thetas), radius*np.sin(thetas)


  if not disable_latitude_compensation:
    # use tangent plane boxes, defined in meters at location
    lat, lon, alt = pm.enu2geodetic(x, y, 0, lat_center, lon_center, 0)
  else:
    # use lat-lon boxes, defined in meters at equator
    lat, lon, alt = pm.enu2geodetic(x, y, 0, 0, 0, 0)
    lat = lat + lat_center
    lon = lon + lon_center

  b,t = lat[3], lat[1]
  l,r = lon[2], lon[0]

  return l,b,r,t

#imgs = np.load("imgs.npy")
sat_embeds = np.load("sat_embeds_new.npy")
coordinates = np.load("coordinates_new.npy")
loc_embeds = np.load("loc_embeds_new.npy")
txt_emb = torch.from_numpy(np.load("txt_emb_species.npy", mmap_mode="r"))
txt_names_json = "txt_emb_species.json"
with open(txt_names_json) as f:
    txt_names = json.load(f)


transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
model, *_ = open_clip.create_model_and_transforms('hf-hub:MVRL/taxabind-vit-b-16')
model.eval()

def format_name(taxon, common):
    taxon = " ".join(taxon)
    if not common:
        return taxon
    return f"{taxon} ({common})"

def process(input_image):
    img_tensor = transform(input_image).unsqueeze(0)
    with torch.no_grad():
        img_embed = model(img_tensor)[0].detach().cpu()
    sims = torch.matmul(torch.tensor(sat_embeds), img_embed.t())
    sims_locs = torch.matmul(torch.nn.functional.normalize(torch.tensor(loc_embeds), dim=-1), img_embed.t())
    sims_txt = torch.matmul(torch.tensor(txt_emb).t(), img_embed.t())
    topk = torch.topk(sims, 5, dim=0)
    topk_locs = torch.topk(sims_locs, 5, dim=0)
    topk_txt = torch.topk(sims_txt, 5, dim=0)
    
    images = []
    d = {}
    d_species = {}
    for i in range(5):
        lat, lon = coordinates[topk.indices[i]]
        l,b,r,t = bounding_box_from_circle(float(lat),float(lon),1280,
            disable_latitude_compensation=True)
        image_url = f"https://tiles.maps.eox.at/wms?service=WMS&version=1.1.1&request=GetMap&layers=s2cloudless-2020&styles=&width=256&height=256&srs=EPSG:4326&bbox={l},{b},{r},{t}&format=image/png"
        images.append(image_url)
        code = reverse_geocode.get([lat, lon])
        d.update({f"{code['city']}, {code['country']} ({lat:.4f}, {lon:.4f})": topk.values[i].item()})
        d_species.update({f"{format_name(*txt_names[topk_txt.indices[i]])}": topk_txt.values[i].item()})
    return d_species, d, [(np.array(input_image), "Query Image")] + [(images[i], f"Result {i+1}") for i in range(5)]


block = gr.Blocks().queue()
with block:
    with gr.Row():
        gr.Markdown(
        """
        <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
          <div>
            <h1>TaxaBind</h1>
            <span>A Unified Embedding Space for Ecological Applications</span>
            <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
                <a href="https://vishu26.github.io/">Srikumar Sastry</a>,
                <a href="https://subash-khanal.github.io/">Subash Khanal</a>,
                <a href="https://sites.wustl.edu/aayush/">Aayush Dhakal</a>,
                <a href="https://adealgis.wixsite.com/adeel-ahmad-geog">Adeel Ahmad</a>,
                <a href="https://jacobsn.github.io/">Nathan Jacobs</a>
            </h2>
            <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
          </div>
        </div>
        """
        )
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(sources='upload', type="pil", height=400)
            run_button = gr.Button(value="Run")
        with gr.Column():
            species = gr.Label(label="Species Classification", num_top_classes=5, show_label=True)
    with gr.Row():
        with gr.Column(): 
            coords = gr.Label(label="Image -> Location Retrieval (Top 5)", num_top_classes=5, show_label=True)
        with gr.Column():
            result_gallery = gr.Gallery(label='Image -> Satellite Image Retrieval (Top 5)', elem_id="gallery", object_fit="contain", height="auto", columns=[2], rows=[3])
    ips = [input_image]
    run_button.click(fn=process, inputs=ips, outputs=[species, coords, result_gallery])

block.launch(share=True)