Theivaprakasham Hari commited on
Commit
90f3f7a
1 Parent(s): 66634a1
.gitattributes CHANGED
@@ -6,6 +6,7 @@
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
 
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
  *.mlmodel filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.json filter=lfs diff=lfs merge=lfs -text
10
  *.joblib filter=lfs diff=lfs merge=lfs -text
11
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
  *.mlmodel filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import heapq
3
+ import json
4
+ import os
5
+ import logging
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from open_clip import create_model, get_tokenizer
12
+ from torchvision import transforms
13
+
14
+ from templates import openai_imagenet_template
15
+
16
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
17
+ logging.basicConfig(level=logging.INFO, format=log_format)
18
+ logger = logging.getLogger()
19
+
20
+
21
+ model_str = "hf-hub:imageomics/bioclip"
22
+ tokenizer_str = "ViT-B-16"
23
+
24
+ txt_emb_npy = r"txt_emb_species.npy"
25
+ txt_names_json = r"txt_emb_species.json"
26
+
27
+
28
+ min_prob = 1e-9
29
+ k = 5
30
+
31
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
+
33
+ preprocess_img = transforms.Compose(
34
+ [
35
+ transforms.ToTensor(),
36
+ transforms.Resize((224, 224), antialias=True),
37
+ transforms.Normalize(
38
+ mean=(0.48145466, 0.4578275, 0.40821073),
39
+ std=(0.26862954, 0.26130258, 0.27577711),
40
+ ),
41
+ ]
42
+ )
43
+
44
+ ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
45
+
46
+ open_domain_examples = [
47
+
48
+ ['example1_Pararge_aegeria.jpg', "Species"]
49
+
50
+ ]
51
+ zero_shot_examples = [
52
+ ['example1_Pararge_aegeria.jpg', "Pararge aegeria \nPieris brassicae \nSatyrium w-album \nDanaus chrysippus"]
53
+ ]
54
+
55
+
56
+ def indexed(lst, indices):
57
+ return [lst[i] for i in indices]
58
+
59
+
60
+ @torch.no_grad()
61
+ def get_txt_features(classnames, templates):
62
+ all_features = []
63
+ for classname in classnames:
64
+ txts = [template(classname) for template in templates]
65
+ txts = tokenizer(txts).to(device)
66
+ txt_features = model.encode_text(txts)
67
+ txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
68
+ txt_features /= txt_features.norm()
69
+ all_features.append(txt_features)
70
+ all_features = torch.stack(all_features, dim=1)
71
+ return all_features
72
+
73
+
74
+ @torch.no_grad()
75
+ def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
76
+ classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
77
+ txt_features = get_txt_features(classes, openai_imagenet_template)
78
+
79
+ img = preprocess_img(img).to(device)
80
+ img_features = model.encode_image(img.unsqueeze(0))
81
+ img_features = F.normalize(img_features, dim=-1)
82
+
83
+ logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
84
+ probs = F.softmax(logits, dim=0).to("cpu").tolist()
85
+ return {cls: prob for cls, prob in zip(classes, probs)}
86
+
87
+
88
+ def format_name(taxon, common):
89
+ taxon = " ".join(taxon)
90
+ if not common:
91
+ return taxon
92
+ return f"{taxon} ({common})"
93
+
94
+
95
+ @torch.no_grad()
96
+ def open_domain_classification(img, rank: int) -> dict[str, float]:
97
+ """
98
+ Predicts from the entire tree of life.
99
+ If targeting a higher rank than species, then this function predicts among all
100
+ species, then sums up species-level probabilities for the given rank.
101
+ """
102
+ img = preprocess_img(img).to(device)
103
+ img_features = model.encode_image(img.unsqueeze(0))
104
+ img_features = F.normalize(img_features, dim=-1)
105
+
106
+ logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
107
+ probs = F.softmax(logits, dim=0)
108
+
109
+ # If predicting species, no need to sum probabilities.
110
+ if rank + 1 == len(ranks):
111
+ topk = probs.topk(k)
112
+ return {
113
+ format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
114
+ }
115
+
116
+ # Sum up by the rank
117
+ output = collections.defaultdict(float)
118
+ for i in torch.nonzero(probs > min_prob).squeeze():
119
+ output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
120
+
121
+ topk_names = heapq.nlargest(k, output, key=output.get)
122
+
123
+ return {name: output[name] for name in topk_names}
124
+
125
+
126
+ def change_output(choice):
127
+ return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
128
+
129
+
130
+
131
+ js = """
132
+ function createGradioAnimation() {
133
+ var container = document.createElement('div');
134
+ container.id = 'gradio-animation';
135
+ container.style.fontSize = '2em';
136
+ container.style.fontWeight = 'bold';
137
+ container.style.textAlign = 'center';
138
+ container.style.marginBottom = '20px';
139
+
140
+ var text = 'Global Species Identifier: Powered by Artificial Intelligence';
141
+ for (var i = 0; i < text.length; i++) {
142
+ (function(i){
143
+ setTimeout(function(){
144
+ var letter = document.createElement('span');
145
+ letter.style.opacity = '0';
146
+ letter.style.transition = 'opacity 0.5s';
147
+ letter.innerText = text[i];
148
+
149
+ container.appendChild(letter);
150
+
151
+ setTimeout(function() {
152
+ letter.style.opacity = '1';
153
+ }, 50);
154
+ }, i * 50);
155
+ })(i);
156
+ }
157
+
158
+ var gradioContainer = document.querySelector('.gradio-container');
159
+ gradioContainer.insertBefore(container, gradioContainer.firstChild);
160
+
161
+ return 'Animation created';
162
+ }
163
+ """
164
+
165
+ if __name__ == "__main__":
166
+ logger.info("Starting.")
167
+ model = create_model(model_str, output_dict=True, require_pretrained=True)
168
+ model = model.to(device)
169
+ logger.info("Created model.")
170
+
171
+ # model = torch.compile(model)
172
+ logger.info("Compiled model.")
173
+
174
+ tokenizer = get_tokenizer(tokenizer_str)
175
+
176
+ txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device)
177
+ with open(txt_names_json) as fd:
178
+ txt_names = json.load(fd)
179
+
180
+ done = txt_emb.any(axis=0).sum().item()
181
+ total = txt_emb.shape[1]
182
+ status_msg = ""
183
+ if done != total:
184
+ status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
185
+
186
+ with gr.Blocks(title='Global Species Identifier: Powered by Artificial Intelligence', css="footer {visibility: hidden}", js=js) as app:
187
+
188
+ gr.Markdown(
189
+ """
190
+ Upload an image of any plant, animal, or other organism, and our Artificial Intelligence-powered tool will identify the species. Our database covers species from around the world, aiming to support biodiversity awareness and conservation efforts.
191
+
192
+ Features include:
193
+ - **Instant identification** of plants, animals, and other organisms.
194
+ - **Detailed information** on species, including habitat, distribution, and conservation status.
195
+ - An **interactive, user-friendly interface** designed for both experts and enthusiasts.
196
+ - **Continuous learning and improvement** of AI models to expand the app's knowledge base and accuracy.
197
+
198
+ Join us in exploring the diversity of life on Earth, powered by the intelligence of technology. Start your journey of discovery today!
199
+
200
+ """)
201
+ img_input = gr.Image()
202
+
203
+ with gr.Tab("Open-Ended"):
204
+ with gr.Row():
205
+ with gr.Column():
206
+ rank_dropdown = gr.Dropdown(
207
+ label="Taxonomic Rank",
208
+ info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
209
+ choices=ranks,
210
+ value="Species",
211
+ type="index",
212
+ )
213
+ open_domain_btn = gr.Button("Submit", variant="primary")
214
+ with gr.Column():
215
+ open_domain_output = gr.Label(
216
+ num_top_classes=k,
217
+ label="Prediction",
218
+ show_label=True,
219
+ value=None,
220
+ )
221
+
222
+ with gr.Row():
223
+ gr.Examples(
224
+ examples=open_domain_examples,
225
+ inputs=[img_input, rank_dropdown],
226
+ cache_examples=True,
227
+ fn=open_domain_classification,
228
+ outputs=[open_domain_output],
229
+ )
230
+
231
+
232
+ with gr.Tab("Zero-Shot"):
233
+ with gr.Row():
234
+ with gr.Column():
235
+ classes_txt = gr.Textbox(
236
+ placeholder= "Pararge aegeria \nPieris brassicae \nSatyrium w-album \nDanaus chrysippus\n...",
237
+ lines=3,
238
+ label="Classes",
239
+ show_label=True,
240
+ info="Use taxonomic names where possible; include common names if possible.",
241
+ )
242
+ zero_shot_btn = gr.Button("Submit", variant="primary")
243
+
244
+ with gr.Column():
245
+ zero_shot_output = gr.Label(
246
+ num_top_classes=k, label="Prediction", show_label=True
247
+ )
248
+
249
+ with gr.Row():
250
+ gr.Examples(
251
+ examples=zero_shot_examples,
252
+ inputs=[img_input, classes_txt],
253
+ cache_examples=True,
254
+ fn=zero_shot_classification,
255
+ outputs=[zero_shot_output],
256
+ )
257
+
258
+
259
+ rank_dropdown.change(
260
+ fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
261
+ )
262
+
263
+ open_domain_btn.click(
264
+ fn=open_domain_classification,
265
+ inputs=[img_input, rank_dropdown],
266
+ outputs=[open_domain_output],
267
+ )
268
+
269
+ zero_shot_btn.click(
270
+ fn=zero_shot_classification,
271
+ inputs=[img_input, classes_txt],
272
+ outputs=zero_shot_output,
273
+ )
274
+
275
+ app.queue(max_size=20)
276
+ app.launch(show_api=False)
embed_texts.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #SBATCH --nodes=1
3
+ #SBATCH --account=PAS2136
4
+ #SBATCH --gpus-per-node=1
5
+ #SBATCH --ntasks-per-node=10
6
+ #SBATCH --job-name=embed-treeoflife
7
+ #SBATCH --time=12:00:00
8
+ #SBATCH --partition=gpu
9
+
10
+ python make_txt_embedding.py \
11
+ --catalog-path /fs/ess/PAS2136/open_clip/data/evobio10m-v3.3/predicted-statistics.csv \
12
+ --out-path text_emb.bin
example1_Pararge_aegeria.jpg ADDED
gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+
37
+ *.json filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/
2
+ __pycache__/
lib.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mostly a TaxonomicTree class that implements a taxonomy and some helpers for easily
3
+ walking and looking in the tree.
4
+
5
+ A tree is an arrangement of TaxonomicNodes.
6
+
7
+
8
+ """
9
+
10
+
11
+ import itertools
12
+ import json
13
+
14
+
15
+ class TaxonomicNode:
16
+ __slots__ = ("name", "index", "root", "_children")
17
+
18
+ def __init__(self, name, index, root):
19
+ self.name = name
20
+ self.index = index
21
+ self.root = root
22
+ self._children = {}
23
+
24
+ def add(self, name):
25
+ added = 0
26
+ if not name:
27
+ return added
28
+
29
+ first, rest = name[0], name[1:]
30
+ if first not in self._children:
31
+ self._children[first] = TaxonomicNode(first, self.root.size, self.root)
32
+ self.root.size += 1
33
+
34
+ self._children[first].add(rest)
35
+
36
+ def children(self, name):
37
+ if not name:
38
+ return set((child.name, child.index) for child in self._children.values())
39
+
40
+ first, rest = name[0], name[1:]
41
+ if first not in self._children:
42
+ return set()
43
+
44
+ return self._children[first].children(rest)
45
+
46
+ def descendants(self, prefix=None):
47
+ """Iterates over all values in the subtree that match prefix."""
48
+
49
+ if not prefix:
50
+ yield (self.name,), self.index
51
+ for child in self._children.values():
52
+ for name, i in child.descendants():
53
+ yield (self.name, *name), i
54
+ return
55
+
56
+ first, rest = prefix[0], prefix[1:]
57
+ if first not in self._children:
58
+ return
59
+
60
+ for name, i in self._children[first].descendants(rest):
61
+ yield (self.name, *name), i
62
+
63
+ def values(self):
64
+ """Iterates over all (name, i) pairs in the tree."""
65
+ yield (self.name,), self.index
66
+
67
+ for child in self._children.values():
68
+ for name, index in child.values():
69
+ yield (self.name, *name), index
70
+
71
+ @classmethod
72
+ def from_dict(cls, dct, root):
73
+ node = cls(dct["name"], dct["index"], root)
74
+ node._children = {
75
+ child["name"]: cls.from_dict(child, root) for child in dct["children"]
76
+ }
77
+ return node
78
+
79
+
80
+ class TaxonomicTree:
81
+ """
82
+ Efficient structure for finding taxonomic names and their descendants.
83
+ Also returns an integer index i for each possible name.
84
+ """
85
+
86
+ def __init__(self):
87
+ self.kingdoms = {}
88
+ self.size = 0
89
+
90
+ def add(self, name: list[str]):
91
+ if not name:
92
+ return
93
+
94
+ first, rest = name[0], name[1:]
95
+ if first not in self.kingdoms:
96
+ self.kingdoms[first] = TaxonomicNode(first, self.size, self)
97
+ self.size += 1
98
+
99
+ self.kingdoms[first].add(rest)
100
+
101
+ def children(self, name=None):
102
+ if not name:
103
+ return set(
104
+ (kingdom.name, kingdom.index) for kingdom in self.kingdoms.values()
105
+ )
106
+
107
+ first, rest = name[0], name[1:]
108
+ if first not in self.kingdoms:
109
+ return set()
110
+
111
+ return self.kingdoms[first].children(rest)
112
+
113
+ def descendants(self, prefix=None):
114
+ """Iterates over all values in the tree that match prefix."""
115
+ if not prefix:
116
+ # Give them all the subnodes
117
+ for kingdom in self.kingdoms.values():
118
+ yield from kingdom.descendants()
119
+
120
+ return
121
+
122
+ first, rest = prefix[0], prefix[1:]
123
+ if first not in self.kingdoms:
124
+ return
125
+
126
+ yield from self.kingdoms[first].descendants(rest)
127
+
128
+ def values(self):
129
+ """Iterates over all (name, i) pairs in the tree."""
130
+ for kingdom in self.kingdoms.values():
131
+ yield from kingdom.values()
132
+
133
+ def __len__(self):
134
+ return self.size
135
+
136
+ @classmethod
137
+ def from_dict(cls, dct):
138
+ tree = cls()
139
+ tree.kingdoms = {
140
+ kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree)
141
+ for kingdom in dct["kingdoms"]
142
+ }
143
+ tree.size = dct["size"]
144
+ return tree
145
+
146
+
147
+ class TaxonomicJsonEncoder(json.JSONEncoder):
148
+ def default(self, obj):
149
+ if isinstance(obj, TaxonomicNode):
150
+ return {
151
+ "name": obj.name,
152
+ "index": obj.index,
153
+ "children": list(obj._children.values()),
154
+ }
155
+ elif isinstance(obj, TaxonomicTree):
156
+ return {
157
+ "kingdoms": list(obj.kingdoms.values()),
158
+ "size": obj.size,
159
+ }
160
+ else:
161
+ super().default(self, obj)
162
+
163
+
164
+ def batched(iterable, n):
165
+ # batched('ABCDEFG', 3) --> ABC DEF G
166
+ if n < 1:
167
+ raise ValueError("n must be at least one")
168
+ it = iter(iterable)
169
+ while batch := tuple(itertools.islice(it, n)):
170
+ yield zip(*batch)
make_txt_embedding.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Makes the entire set of text emebeddings for all possible names in the tree of life.
3
+ Uses the catalog.csv file from TreeOfLife-10M.
4
+ """
5
+ import argparse
6
+ import csv
7
+ import json
8
+ import os
9
+ import logging
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from open_clip import create_model, get_tokenizer
16
+ from tqdm import tqdm
17
+
18
+ import lib
19
+ from templates import openai_imagenet_template
20
+
21
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
22
+ logging.basicConfig(level=logging.INFO, format=log_format)
23
+ logger = logging.getLogger()
24
+
25
+ model_str = "hf-hub:imageomics/bioclip"
26
+ tokenizer_str = "ViT-B-16"
27
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
+
29
+ ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
30
+
31
+
32
+ @torch.no_grad()
33
+ def write_txt_features(name_lookup):
34
+ if os.path.isfile(args.out_path):
35
+ all_features = np.load(args.out_path)
36
+ else:
37
+ all_features = np.zeros((512, len(name_lookup)), dtype=np.float32)
38
+
39
+ batch_size = args.batch_size // len(openai_imagenet_template)
40
+ for batch, (names, indices) in enumerate(
41
+ tqdm(
42
+ lib.batched(name_lookup.values(), batch_size),
43
+ desc="txt feats",
44
+ total=len(name_lookup) // batch_size,
45
+ )
46
+ ):
47
+ # Skip if any non-zero elements
48
+ if all_features[:, indices].any():
49
+ logger.info(f"Skipping batch {batch}")
50
+ continue
51
+
52
+ txts = [
53
+ template(name) for name in names for template in openai_imagenet_template
54
+ ]
55
+ txts = tokenizer(txts).to(device)
56
+ txt_features = model.encode_text(txts)
57
+ txt_features = torch.reshape(
58
+ txt_features, (len(names), len(openai_imagenet_template), 512)
59
+ )
60
+ txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
61
+ txt_features /= txt_features.norm(dim=1, keepdim=True)
62
+ all_features[:, indices] = txt_features.T.cpu().numpy()
63
+
64
+ if batch % 100 == 0:
65
+ np.save(args.out_path, all_features)
66
+
67
+ np.save(args.out_path, all_features)
68
+
69
+
70
+ def convert_txt_features_to_avgs(name_lookup):
71
+ assert os.path.isfile(args.out_path)
72
+
73
+ # Put that big boy on the GPU. We're going fast.
74
+ all_features = torch.from_numpy(np.load(args.out_path)).to(device)
75
+ logger.info("Loaded text features from disk to %s.", device)
76
+
77
+ names_by_rank = [set() for rank in ranks]
78
+ for name, index in tqdm(name_lookup.values()):
79
+ i = len(name) - 1
80
+ names_by_rank[i].add((name, index))
81
+
82
+ zeroed = 0
83
+ for i, rank in reversed(list(enumerate(ranks))):
84
+ if rank == "Species":
85
+ continue
86
+ for name, index in tqdm(names_by_rank[i], desc=rank):
87
+ species = tuple(
88
+ zip(
89
+ *(
90
+ (d, i)
91
+ for d, i in name_lookup.descendants(prefix=name)
92
+ if len(d) >= 6
93
+ )
94
+ )
95
+ )
96
+ if not species:
97
+ logger.warning("No species for %s.", " ".join(name))
98
+ all_features[:, index] = 0.0
99
+ zeroed += 1
100
+ continue
101
+
102
+ values, indices = species
103
+ mean = all_features[:, indices].mean(dim=1)
104
+ all_features[:, index] = F.normalize(mean, dim=0)
105
+
106
+ out_path, ext = os.path.splitext(args.out_path)
107
+ np.save(f"{out_path}_avgs{ext}", all_features.cpu().numpy())
108
+ if zeroed:
109
+ logger.warning(
110
+ "Zeroed out %d nodes because they didn't have any genus or species-level labels.",
111
+ zeroed,
112
+ )
113
+
114
+
115
+ def convert_txt_features_to_species_only(name_lookup):
116
+ assert os.path.isfile(args.out_path)
117
+
118
+ all_features = np.load(args.out_path)
119
+ logger.info("Loaded text features from disk.")
120
+
121
+ species = [(d, i) for d, i in name_lookup.descendants() if len(d) == 7]
122
+ species_features = np.zeros((512, len(species)), dtype=np.float32)
123
+ species_names = [""] * len(species)
124
+
125
+ for new_i, (name, old_i) in enumerate(tqdm(species)):
126
+ species_features[:, new_i] = all_features[:, old_i]
127
+ species_names[new_i] = name
128
+
129
+ out_path, ext = os.path.splitext(args.out_path)
130
+ np.save(f"{out_path}_species{ext}", species_features)
131
+ with open(f"{out_path}_species.json", "w") as fd:
132
+ json.dump(species_names, fd, indent=2)
133
+
134
+
135
+ def get_name_lookup(catalog_path, cache_path):
136
+ if os.path.isfile(cache_path):
137
+ with open(cache_path) as fd:
138
+ lookup = lib.TaxonomicTree.from_dict(json.load(fd))
139
+ return lookup
140
+
141
+ lookup = lib.TaxonomicTree()
142
+
143
+ with open(catalog_path) as fd:
144
+ reader = csv.DictReader(fd)
145
+ for row in tqdm(reader, desc="catalog"):
146
+ name = [
147
+ row["kingdom"],
148
+ row["phylum"],
149
+ row["class"],
150
+ row["order"],
151
+ row["family"],
152
+ row["genus"],
153
+ row["species"],
154
+ ]
155
+ if any(not value for value in name):
156
+ name = name[: name.index("")]
157
+ lookup.add(name)
158
+
159
+ with open(args.name_cache_path, "w") as fd:
160
+ json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder)
161
+
162
+ return lookup
163
+
164
+
165
+ if __name__ == "__main__":
166
+ parser = argparse.ArgumentParser()
167
+ parser.add_argument(
168
+ "--catalog-path",
169
+ help="Path to the catalog.csv file from TreeOfLife-10M.",
170
+ required=True,
171
+ )
172
+ parser.add_argument("--out-path", help="Path to the output file.", required=True)
173
+ parser.add_argument(
174
+ "--name-cache-path",
175
+ help="Path to the name cache file.",
176
+ default="name_lookup.json",
177
+ )
178
+ parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int)
179
+ args = parser.parse_args()
180
+
181
+ name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
182
+ logger.info("Got name lookup.")
183
+
184
+ model = create_model(model_str, output_dict=True, require_pretrained=True)
185
+ model = model.to(device)
186
+ logger.info("Created model.")
187
+ model = torch.compile(model)
188
+ logger.info("Compiled model.")
189
+
190
+ tokenizer = get_tokenizer(tokenizer_str)
191
+ write_txt_features(name_lookup)
192
+ convert_txt_features_to_avgs(name_lookup)
193
+ convert_txt_features_to_species_only(name_lookup)
name_lookup.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20d731d9d901f1c17927187bc87e4a2513279845a1a6ba5982dbf779f2ac1434
3
+ size 26462858
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ open_clip_torch
2
+ torchvision
3
+ torch
4
+ gradio
templates.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai_imagenet_template = [
2
+ lambda c: f"a bad photo of a {c}.",
3
+ lambda c: f"a photo of many {c}.",
4
+ lambda c: f"a sculpture of a {c}.",
5
+ lambda c: f"a photo of the hard to see {c}.",
6
+ lambda c: f"a low resolution photo of the {c}.",
7
+ lambda c: f"a rendering of a {c}.",
8
+ lambda c: f"graffiti of a {c}.",
9
+ lambda c: f"a bad photo of the {c}.",
10
+ lambda c: f"a cropped photo of the {c}.",
11
+ lambda c: f"a tattoo of a {c}.",
12
+ lambda c: f"the embroidered {c}.",
13
+ lambda c: f"a photo of a hard to see {c}.",
14
+ lambda c: f"a bright photo of a {c}.",
15
+ lambda c: f"a photo of a clean {c}.",
16
+ lambda c: f"a photo of a dirty {c}.",
17
+ lambda c: f"a dark photo of the {c}.",
18
+ lambda c: f"a drawing of a {c}.",
19
+ lambda c: f"a photo of my {c}.",
20
+ lambda c: f"the plastic {c}.",
21
+ lambda c: f"a photo of the cool {c}.",
22
+ lambda c: f"a close-up photo of a {c}.",
23
+ lambda c: f"a black and white photo of the {c}.",
24
+ lambda c: f"a painting of the {c}.",
25
+ lambda c: f"a painting of a {c}.",
26
+ lambda c: f"a pixelated photo of the {c}.",
27
+ lambda c: f"a sculpture of the {c}.",
28
+ lambda c: f"a bright photo of the {c}.",
29
+ lambda c: f"a cropped photo of a {c}.",
30
+ lambda c: f"a plastic {c}.",
31
+ lambda c: f"a photo of the dirty {c}.",
32
+ lambda c: f"a jpeg corrupted photo of a {c}.",
33
+ lambda c: f"a blurry photo of the {c}.",
34
+ lambda c: f"a photo of the {c}.",
35
+ lambda c: f"a good photo of the {c}.",
36
+ lambda c: f"a rendering of the {c}.",
37
+ lambda c: f"a {c} in a video game.",
38
+ lambda c: f"a photo of one {c}.",
39
+ lambda c: f"a doodle of a {c}.",
40
+ lambda c: f"a close-up photo of the {c}.",
41
+ lambda c: f"a photo of a {c}.",
42
+ lambda c: f"the origami {c}.",
43
+ lambda c: f"the {c} in a video game.",
44
+ lambda c: f"a sketch of a {c}.",
45
+ lambda c: f"a doodle of the {c}.",
46
+ lambda c: f"a origami {c}.",
47
+ lambda c: f"a low resolution photo of a {c}.",
48
+ lambda c: f"the toy {c}.",
49
+ lambda c: f"a rendition of the {c}.",
50
+ lambda c: f"a photo of the clean {c}.",
51
+ lambda c: f"a photo of a large {c}.",
52
+ lambda c: f"a rendition of a {c}.",
53
+ lambda c: f"a photo of a nice {c}.",
54
+ lambda c: f"a photo of a weird {c}.",
55
+ lambda c: f"a blurry photo of a {c}.",
56
+ lambda c: f"a cartoon {c}.",
57
+ lambda c: f"art of a {c}.",
58
+ lambda c: f"a sketch of the {c}.",
59
+ lambda c: f"a embroidered {c}.",
60
+ lambda c: f"a pixelated photo of a {c}.",
61
+ lambda c: f"itap of the {c}.",
62
+ lambda c: f"a jpeg corrupted photo of the {c}.",
63
+ lambda c: f"a good photo of a {c}.",
64
+ lambda c: f"a plushie {c}.",
65
+ lambda c: f"a photo of the nice {c}.",
66
+ lambda c: f"a photo of the small {c}.",
67
+ lambda c: f"a photo of the weird {c}.",
68
+ lambda c: f"the cartoon {c}.",
69
+ lambda c: f"art of the {c}.",
70
+ lambda c: f"a drawing of the {c}.",
71
+ lambda c: f"a photo of the large {c}.",
72
+ lambda c: f"a black and white photo of a {c}.",
73
+ lambda c: f"the plushie {c}.",
74
+ lambda c: f"a dark photo of a {c}.",
75
+ lambda c: f"itap of a {c}.",
76
+ lambda c: f"graffiti of the {c}.",
77
+ lambda c: f"a toy {c}.",
78
+ lambda c: f"itap of my {c}.",
79
+ lambda c: f"a photo of a cool {c}.",
80
+ lambda c: f"a photo of a small {c}.",
81
+ lambda c: f"a tattoo of the {c}.",
82
+ ]
test_lib.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lib
2
+
3
+
4
+ def test_taxonomiclookup_empty():
5
+ lookup = lib.TaxonomicTree()
6
+ assert lookup.size == 0
7
+
8
+
9
+ def test_taxonomiclookup_kingdom_size():
10
+ lookup = lib.TaxonomicTree()
11
+
12
+ lookup.add(("Animalia",))
13
+
14
+ assert lookup.size == 1
15
+
16
+
17
+ def test_taxonomiclookup_genus_size():
18
+ lookup = lib.TaxonomicTree()
19
+
20
+ lookup.add(
21
+ (
22
+ "Animalia",
23
+ "Chordata",
24
+ "Aves",
25
+ "Accipitriformes",
26
+ "Accipitridae",
27
+ "Halieaeetus",
28
+ )
29
+ )
30
+
31
+ assert lookup.size == 6
32
+
33
+
34
+ def test_taxonomictree_kingdom_children():
35
+ lookup = lib.TaxonomicTree()
36
+
37
+ lookup.add(
38
+ (
39
+ "Animalia",
40
+ "Chordata",
41
+ "Aves",
42
+ "Accipitriformes",
43
+ "Accipitridae",
44
+ "Halieaeetus",
45
+ )
46
+ )
47
+
48
+ expected = set([("Animalia", 0)])
49
+ actual = lookup.children()
50
+ assert actual == expected
51
+
52
+
53
+ def test_taxonomiclookup_children_of_animal_only_birds():
54
+ lookup = lib.TaxonomicTree()
55
+
56
+ lookup.add(
57
+ (
58
+ "Animalia",
59
+ "Chordata",
60
+ "Aves",
61
+ "Accipitriformes",
62
+ "Accipitridae",
63
+ "Halieaeetus",
64
+ "leucocephalus",
65
+ )
66
+ )
67
+ lookup.add(
68
+ (
69
+ "Animalia",
70
+ "Chordata",
71
+ "Aves",
72
+ "Strigiformes",
73
+ "Strigidae",
74
+ "Ninox",
75
+ "scutulata",
76
+ )
77
+ )
78
+ lookup.add(
79
+ (
80
+ "Animalia",
81
+ "Chordata",
82
+ "Aves",
83
+ "Strigiformes",
84
+ "Strigidae",
85
+ "Ninox",
86
+ "plesseni",
87
+ )
88
+ )
89
+
90
+ actual = lookup.children(("Animalia",))
91
+ expected = set([("Chordata", 1)])
92
+ assert actual == expected
93
+
94
+
95
+ def test_taxonomiclookup_children_of_animal():
96
+ lookup = lib.TaxonomicTree()
97
+
98
+ lookup.add(
99
+ (
100
+ "Animalia",
101
+ "Chordata",
102
+ "Aves",
103
+ "Accipitriformes",
104
+ "Accipitridae",
105
+ "Halieaeetus",
106
+ "leucocephalus",
107
+ )
108
+ )
109
+ lookup.add(
110
+ (
111
+ "Animalia",
112
+ "Chordata",
113
+ "Aves",
114
+ "Strigiformes",
115
+ "Strigidae",
116
+ "Ninox",
117
+ "scutulata",
118
+ )
119
+ )
120
+ lookup.add(
121
+ (
122
+ "Animalia",
123
+ "Chordata",
124
+ "Aves",
125
+ "Strigiformes",
126
+ "Strigidae",
127
+ "Ninox",
128
+ "plesseni",
129
+ )
130
+ )
131
+ lookup.add(
132
+ (
133
+ "Animalia",
134
+ "Chordata",
135
+ "Mammalia",
136
+ "Primates",
137
+ "Hominidae",
138
+ "Gorilla",
139
+ "gorilla",
140
+ )
141
+ )
142
+ lookup.add(
143
+ (
144
+ "Animalia",
145
+ "Arthropoda",
146
+ "Insecta",
147
+ "Hymenoptera",
148
+ "Apidae",
149
+ "Bombus",
150
+ "balteatus",
151
+ )
152
+ )
153
+
154
+ actual = lookup.children(("Animalia",))
155
+ expected = set([("Chordata", 1), ("Arthropoda", 17)])
156
+ assert actual == expected
157
+
158
+
159
+ def test_taxonomiclookup_children_of_chordata():
160
+ lookup = lib.TaxonomicTree()
161
+
162
+ lookup.add(
163
+ (
164
+ "Animalia",
165
+ "Chordata",
166
+ "Aves",
167
+ "Accipitriformes",
168
+ "Accipitridae",
169
+ "Halieaeetus",
170
+ "leucocephalus",
171
+ )
172
+ )
173
+ lookup.add(
174
+ (
175
+ "Animalia",
176
+ "Chordata",
177
+ "Aves",
178
+ "Strigiformes",
179
+ "Strigidae",
180
+ "Ninox",
181
+ "scutulata",
182
+ )
183
+ )
184
+ lookup.add(
185
+ (
186
+ "Animalia",
187
+ "Chordata",
188
+ "Aves",
189
+ "Strigiformes",
190
+ "Strigidae",
191
+ "Ninox",
192
+ "plesseni",
193
+ )
194
+ )
195
+ lookup.add(
196
+ (
197
+ "Animalia",
198
+ "Chordata",
199
+ "Mammalia",
200
+ "Primates",
201
+ "Hominidae",
202
+ "Gorilla",
203
+ "gorilla",
204
+ )
205
+ )
206
+ lookup.add(
207
+ (
208
+ "Animalia",
209
+ "Arthropoda",
210
+ "Insecta",
211
+ "Hymenoptera",
212
+ "Apidae",
213
+ "Bombus",
214
+ "balteatus",
215
+ )
216
+ )
217
+
218
+ actual = lookup.children(("Animalia", "Chordata"))
219
+ expected = set([("Aves", 2), ("Mammalia", 12)])
220
+ assert actual == expected
221
+
222
+
223
+ def test_taxonomiclookup_children_of_strigiformes():
224
+ lookup = lib.TaxonomicTree()
225
+
226
+ lookup.add(
227
+ (
228
+ "Animalia",
229
+ "Chordata",
230
+ "Aves",
231
+ "Accipitriformes",
232
+ "Accipitridae",
233
+ "Halieaeetus",
234
+ "leucocephalus",
235
+ )
236
+ )
237
+ lookup.add(
238
+ (
239
+ "Animalia",
240
+ "Chordata",
241
+ "Aves",
242
+ "Strigiformes",
243
+ "Strigidae",
244
+ "Ninox",
245
+ "scutulata",
246
+ )
247
+ )
248
+ lookup.add(
249
+ (
250
+ "Animalia",
251
+ "Chordata",
252
+ "Aves",
253
+ "Strigiformes",
254
+ "Strigidae",
255
+ "Ninox",
256
+ "plesseni",
257
+ )
258
+ )
259
+ lookup.add(
260
+ (
261
+ "Animalia",
262
+ "Chordata",
263
+ "Mammalia",
264
+ "Primates",
265
+ "Hominidae",
266
+ "Gorilla",
267
+ "gorilla",
268
+ )
269
+ )
270
+ lookup.add(
271
+ (
272
+ "Animalia",
273
+ "Arthropoda",
274
+ "Insecta",
275
+ "Hymenoptera",
276
+ "Apidae",
277
+ "Bombus",
278
+ "balteatus",
279
+ )
280
+ )
281
+
282
+ actual = lookup.children(("Animalia", "Chordata", "Aves", "Strigiformes"))
283
+ expected = set([("Strigidae", 8)])
284
+ assert actual == expected
285
+
286
+
287
+ def test_taxonomiclookup_children_of_ninox():
288
+ lookup = lib.TaxonomicTree()
289
+
290
+ lookup.add(
291
+ (
292
+ "Animalia",
293
+ "Chordata",
294
+ "Aves",
295
+ "Accipitriformes",
296
+ "Accipitridae",
297
+ "Halieaeetus",
298
+ "leucocephalus",
299
+ )
300
+ )
301
+ lookup.add(
302
+ (
303
+ "Animalia",
304
+ "Chordata",
305
+ "Aves",
306
+ "Strigiformes",
307
+ "Strigidae",
308
+ "Ninox",
309
+ "scutulata",
310
+ )
311
+ )
312
+ lookup.add(
313
+ (
314
+ "Animalia",
315
+ "Chordata",
316
+ "Aves",
317
+ "Strigiformes",
318
+ "Strigidae",
319
+ "Ninox",
320
+ "plesseni",
321
+ )
322
+ )
323
+ lookup.add(
324
+ (
325
+ "Animalia",
326
+ "Chordata",
327
+ "Mammalia",
328
+ "Primates",
329
+ "Hominidae",
330
+ "Gorilla",
331
+ "gorilla",
332
+ )
333
+ )
334
+ lookup.add(
335
+ (
336
+ "Animalia",
337
+ "Arthropoda",
338
+ "Insecta",
339
+ "Hymenoptera",
340
+ "Apidae",
341
+ "Bombus",
342
+ "balteatus",
343
+ )
344
+ )
345
+
346
+ actual = lookup.children(
347
+ ("Animalia", "Chordata", "Aves", "Strigiformes", "Strigidae", "Ninox")
348
+ )
349
+ expected = set([("scutulata", 10), ("plesseni", 11)])
350
+ assert actual == expected
351
+
352
+
353
+ def test_taxonomiclookup_children_of_gorilla():
354
+ lookup = lib.TaxonomicTree()
355
+
356
+ lookup.add(
357
+ (
358
+ "Animalia",
359
+ "Chordata",
360
+ "Aves",
361
+ "Accipitriformes",
362
+ "Accipitridae",
363
+ "Halieaeetus",
364
+ "leucocephalus",
365
+ )
366
+ )
367
+ lookup.add(
368
+ (
369
+ "Animalia",
370
+ "Chordata",
371
+ "Aves",
372
+ "Strigiformes",
373
+ "Strigidae",
374
+ "Ninox",
375
+ "scutulata",
376
+ )
377
+ )
378
+ lookup.add(
379
+ (
380
+ "Animalia",
381
+ "Chordata",
382
+ "Aves",
383
+ "Strigiformes",
384
+ "Strigidae",
385
+ "Ninox",
386
+ "plesseni",
387
+ )
388
+ )
389
+ lookup.add(
390
+ (
391
+ "Animalia",
392
+ "Chordata",
393
+ "Mammalia",
394
+ "Primates",
395
+ "Hominidae",
396
+ "Gorilla",
397
+ "gorilla",
398
+ )
399
+ )
400
+ lookup.add(
401
+ (
402
+ "Animalia",
403
+ "Arthropoda",
404
+ "Insecta",
405
+ "Hymenoptera",
406
+ "Apidae",
407
+ "Bombus",
408
+ "balteatus",
409
+ )
410
+ )
411
+
412
+ actual = lookup.children(
413
+ (
414
+ "Animalia",
415
+ "Chordata",
416
+ "Mammalia",
417
+ "Primates",
418
+ "Hominidae",
419
+ "Gorilla",
420
+ "gorilla",
421
+ )
422
+ )
423
+ expected = set()
424
+ assert actual == expected
425
+
426
+
427
+ def test_taxonomictree_descendants_last():
428
+ lookup = lib.TaxonomicTree()
429
+
430
+ lookup.add(("A", "B", "C", "D", "E", "F", "G"))
431
+
432
+ actual = list(lookup.descendants(("A", "B", "C", "D", "E", "F", "G")))
433
+
434
+ expected = [
435
+ (("A", "B", "C", "D", "E", "F", "G"), 6),
436
+ ]
437
+ assert actual == expected
438
+
439
+
440
+ def test_taxonomictree_descendants_entire_tree():
441
+ lookup = lib.TaxonomicTree()
442
+
443
+ lookup.add(("A", "B"))
444
+
445
+ actual = list(lookup.descendants())
446
+
447
+ expected = [
448
+ (("A",), 0),
449
+ (("A", "B"), 1),
450
+ ]
451
+ assert actual == expected
452
+
453
+
454
+ def test_taxonomictree_descendants_entire_tree_with_prefix():
455
+ lookup = lib.TaxonomicTree()
456
+
457
+ lookup.add(("A", "B"))
458
+
459
+ actual = list(lookup.descendants(prefix=("A",)))
460
+
461
+ expected = [
462
+ (("A",), 0),
463
+ (("A", "B"), 1),
464
+ ]
465
+ assert actual == expected
466
+
467
+
468
+ def test_taxonomictree_descendants_general():
469
+ lookup = lib.TaxonomicTree()
470
+
471
+ lookup.add(("A", "B", "C", "D", "E", "F", "G"))
472
+
473
+ actual = list(lookup.descendants(("A", "B", "C", "D")))
474
+
475
+ expected = [
476
+ (("A", "B", "C", "D"), 3),
477
+ (("A", "B", "C", "D", "E"), 4),
478
+ (("A", "B", "C", "D", "E", "F"), 5),
479
+ (("A", "B", "C", "D", "E", "F", "G"), 6),
480
+ ]
481
+ assert actual == expected
txt_emb.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4a3c3412c3dae49cf92cc760aba5ee84227362adf1eb08f04dd50ee2a756e43
3
+ size 969818240
txt_emb_species.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:844e6fabc06cac072214d566b78f40825b154efa9479eb11285030ca038b2ece
3
+ size 65731052
txt_emb_species.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91ce02dff2433222e3138b8bf7eefa1dd74b30f4d406c16cd3301f66d65ab4ed
3
+ size 787435648