Vishu26 commited on
Commit
9ff2b84
0 Parent(s):

taxabind-demo

Browse files
.gitattributes ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ txt_emb_species.npy filter=lfs diff=lfs merge=lfs -text
37
+ txt_emb_species.json filter=lfs diff=lfs merge=lfs -text
38
+ coordinates_new.npy filter=lfs diff=lfs merge=lfs -text
39
+ loc_embeds_new.npy filter=lfs diff=lfs merge=lfs -text
40
+ sat_embeds_new.npy filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Taxabind Demo
3
+ emoji: 🔥
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.4.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: 'TaxaBind: A Unified Embedding Space for Ecological Applicati'
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from torchvision import transforms
5
+ import open_clip
6
+ import pymap3d as pm
7
+ import reverse_geocode
8
+ import json
9
+
10
+
11
+ def bounding_box_from_circle(lat_center, lon_center, radius = 1000,
12
+ disable_latitude_compensation=False):
13
+ '''
14
+ radius is in meters determined at the equator
15
+
16
+ warning: doesn't handle the poles or the 180th meridian very well, it might loop give a bad bounding box
17
+ should probably define a check to make sure the radius isn't too big
18
+ '''
19
+
20
+ thetas = np.linspace(0,2*np.pi, 5)
21
+ x, y = radius*np.cos(thetas), radius*np.sin(thetas)
22
+
23
+
24
+ if not disable_latitude_compensation:
25
+ # use tangent plane boxes, defined in meters at location
26
+ lat, lon, alt = pm.enu2geodetic(x, y, 0, lat_center, lon_center, 0)
27
+ else:
28
+ # use lat-lon boxes, defined in meters at equator
29
+ lat, lon, alt = pm.enu2geodetic(x, y, 0, 0, 0, 0)
30
+ lat = lat + lat_center
31
+ lon = lon + lon_center
32
+
33
+ b,t = lat[3], lat[1]
34
+ l,r = lon[2], lon[0]
35
+
36
+ return l,b,r,t
37
+
38
+ #imgs = np.load("imgs.npy")
39
+ sat_embeds = np.load("sat_embeds_new.npy")
40
+ coordinates = np.load("coordinates_new.npy")
41
+ loc_embeds = np.load("loc_embeds_new.npy")
42
+ txt_emb = torch.from_numpy(np.load("txt_emb_species.npy", mmap_mode="r"))
43
+ txt_names_json = "txt_emb_species.json"
44
+ with open(txt_names_json) as f:
45
+ txt_names = json.load(f)
46
+
47
+
48
+ transform = transforms.Compose([
49
+ transforms.Resize((256,256)),
50
+ transforms.CenterCrop((224, 224)),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
53
+ std=[0.229, 0.224, 0.225])
54
+ ])
55
+ model, *_ = open_clip.create_model_and_transforms('hf-hub:MVRL/taxabind-vit-b-16')
56
+ model.eval()
57
+
58
+ def format_name(taxon, common):
59
+ taxon = " ".join(taxon)
60
+ if not common:
61
+ return taxon
62
+ return f"{taxon} ({common})"
63
+
64
+ def process(input_image):
65
+ img_tensor = transform(input_image).unsqueeze(0)
66
+ with torch.no_grad():
67
+ img_embed = model(img_tensor)[0].detach().cpu()
68
+ sims = torch.matmul(torch.tensor(sat_embeds), img_embed.t())
69
+ sims_locs = torch.matmul(torch.nn.functional.normalize(torch.tensor(loc_embeds), dim=-1), img_embed.t())
70
+ sims_txt = torch.matmul(torch.tensor(txt_emb).t(), img_embed.t())
71
+ topk = torch.topk(sims, 5, dim=0)
72
+ topk_locs = torch.topk(sims_locs, 5, dim=0)
73
+ topk_txt = torch.topk(sims_txt, 5, dim=0)
74
+
75
+ images = []
76
+ d = {}
77
+ d_species = {}
78
+ for i in range(5):
79
+ lat, lon = coordinates[topk.indices[i]]
80
+ l,b,r,t = bounding_box_from_circle(float(lat),float(lon),1280,
81
+ disable_latitude_compensation=True)
82
+ image_url = f"https://tiles.maps.eox.at/wms?service=WMS&version=1.1.1&request=GetMap&layers=s2cloudless-2020&styles=&width=256&height=256&srs=EPSG:4326&bbox={l},{b},{r},{t}&format=image/png"
83
+ images.append(image_url)
84
+ code = reverse_geocode.get([lat, lon])
85
+ d.update({f"{code['city']}, {code['country']} ({lat:.4f}, {lon:.4f})": topk.values[i].item()})
86
+ d_species.update({f"{format_name(*txt_names[topk_txt.indices[i]])}": topk_txt.values[i].item()})
87
+ return d_species, d, [(np.array(input_image), "Query Image")] + [(images[i], f"Result {i+1}") for i in range(5)]
88
+
89
+
90
+ block = gr.Blocks().queue()
91
+ with block:
92
+ with gr.Row():
93
+ gr.Markdown(
94
+ """
95
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
96
+ <div>
97
+ <h1>TaxaBind</h1>
98
+ <span>A Unified Embedding Space for Ecological Applications</span>
99
+ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
100
+ <a href="https://vishu26.github.io/">Srikumar Sastry</a>,
101
+ <a href="https://subash-khanal.github.io/">Subash Khanal</a>,
102
+ <a href="https://sites.wustl.edu/aayush/">Aayush Dhakal</a>,
103
+ <a href="https://adealgis.wixsite.com/adeel-ahmad-geog">Adeel Ahmad</a>,
104
+ <a href="https://jacobsn.github.io/">Nathan Jacobs</a>
105
+ </h2>
106
+ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
107
+ </div>
108
+ </div>
109
+ """
110
+ )
111
+ with gr.Row():
112
+ with gr.Column():
113
+ input_image = gr.Image(sources='upload', type="pil", height=400)
114
+ run_button = gr.Button(value="Run")
115
+ with gr.Column():
116
+ species = gr.Label(label="Species Classification", num_top_classes=5, show_label=True)
117
+ with gr.Row():
118
+ with gr.Column():
119
+ coords = gr.Label(label="Image -> Location Retrieval (Top 5)", num_top_classes=5, show_label=True)
120
+ with gr.Column():
121
+ result_gallery = gr.Gallery(label='Image -> Satellite Image Retrieval (Top 5)', elem_id="gallery", object_fit="contain", height="auto", columns=[2], rows=[3])
122
+ ips = [input_image]
123
+ run_button.click(fn=process, inputs=ips, outputs=[species, coords, result_gallery])
124
+
125
+ block.launch(share=True)
coordinates_new.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac50785e868d4a98fb866b331f52d42bab559b976c4ee607e112fec5b6806307
3
+ size 797328
loc_embeds_new.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec310669f4e077ba614e476c36eff7934359b87a93d4c4f9e4f32e66f0023602
3
+ size 204083328
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ reverse_geocode
3
+ pymap3d
4
+ torch
5
+ open_clip
6
+ torchvision
sat_embeds_new.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:065752a12952c2f2fdb641466b2cc6722e3d1213cfa3f7a18c7dd97fe6d034f9
3
+ size 204083328
txt_emb_species.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:844e6fabc06cac072214d566b78f40825b154efa9479eb11285030ca038b2ece
3
+ size 65731052
txt_emb_species.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91ce02dff2433222e3138b8bf7eefa1dd74b30f4d406c16cd3301f66d65ab4ed
3
+ size 787435648