Samuel Stevens commited on
Commit
a33c93d
1 Parent(s): 6ee7e7c

try hierarchical averaging

Browse files
Files changed (5) hide show
  1. app.py +1 -0
  2. examples/Sarcoscypha-coccinea.jpeg +3 -0
  3. lib.py +50 -6
  4. make_txt_embedding.py +48 -5
  5. test_lib.py +57 -0
app.py CHANGED
@@ -37,6 +37,7 @@ open_domain_examples = [
37
  ["examples/Ursus-arctos.jpeg", "Species"],
38
  ["examples/Phoca-vitulina.png", "Species"],
39
  ["examples/Felis-catus.jpeg", "Genus"],
 
40
  ]
41
  zero_shot_examples = [
42
  [
 
37
  ["examples/Ursus-arctos.jpeg", "Species"],
38
  ["examples/Phoca-vitulina.png", "Species"],
39
  ["examples/Felis-catus.jpeg", "Genus"],
40
+ ["examples/Sarcoscypha-coccinea.jpeg", "Order"],
41
  ]
42
  zero_shot_examples = [
43
  [
examples/Sarcoscypha-coccinea.jpeg ADDED

Git LFS Details

  • SHA256: 84dfec1fe373d375cd31f129dfd961dfa9d0b400575f9dd9610a08d900fd1cf9
  • Pointer size: 131 Bytes
  • Size of remote file: 409 kB
lib.py CHANGED
@@ -1,3 +1,13 @@
 
 
 
 
 
 
 
 
 
 
1
  import itertools
2
  import json
3
 
@@ -33,12 +43,30 @@ class TaxonomicNode:
33
 
34
  return self._children[first].children(rest)
35
 
36
- def __iter__(self):
37
- yield self.name, self.index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  for child in self._children.values():
40
- for name, index in child:
41
- yield f"{self.name} {name}", index
42
 
43
  @classmethod
44
  def from_dict(cls, dct, root):
@@ -82,9 +110,25 @@ class TaxonomicTree:
82
 
83
  return self.kingdoms[first].children(rest)
84
 
85
- def __iter__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  for kingdom in self.kingdoms.values():
87
- yield from kingdom
88
 
89
  def __len__(self):
90
  return self.size
 
1
+ """
2
+ Mostly a TaxonomicTree class that implements a taxonomy and some helpers for easily
3
+ walking and looking in the tree.
4
+
5
+ A tree is an arrangement of TaxonomicNodes.
6
+
7
+
8
+ """
9
+
10
+
11
  import itertools
12
  import json
13
 
 
43
 
44
  return self._children[first].children(rest)
45
 
46
+ def descendants(self, prefix=None):
47
+ """Iterates over all values in the subtree that match prefix."""
48
+
49
+ if not prefix:
50
+ yield (self.name,), self.index
51
+ for child in self._children.values():
52
+ for name, i in child.descendants():
53
+ yield (self.name, *name), i
54
+ return
55
+
56
+ first, rest = prefix[0], prefix[1:]
57
+ if first not in self._children:
58
+ return
59
+
60
+ for name, i in self._children[first].descendants(rest):
61
+ yield (self.name, *name), i
62
+
63
+ def values(self):
64
+ """Iterates over all (name, i) pairs in the tree."""
65
+ yield (self.name,), self.index
66
 
67
  for child in self._children.values():
68
+ for name, index in child.values():
69
+ yield (self.name, *name), index
70
 
71
  @classmethod
72
  def from_dict(cls, dct, root):
 
110
 
111
  return self.kingdoms[first].children(rest)
112
 
113
+ def descendants(self, prefix=None):
114
+ """Iterates over all values in the tree that match prefix."""
115
+ if not prefix:
116
+ # Give them all the subnodes
117
+ for kingdom in self.kingdoms.values():
118
+ yield from kingdom.descendants()
119
+
120
+ return
121
+
122
+ first, rest = prefix[0], prefix[1:]
123
+ if first not in self.kingdoms:
124
+ return
125
+
126
+ yield from self.kingdoms[first].descendants(rest)
127
+
128
+ def values(self):
129
+ """Iterates over all (name, i) pairs in the tree."""
130
  for kingdom in self.kingdoms.values():
131
+ yield from kingdom.values()
132
 
133
  def __len__(self):
134
  return self.size
make_txt_embedding.py CHANGED
@@ -6,20 +6,28 @@ import argparse
6
  import csv
7
  import json
8
  import os
 
9
 
10
  import numpy as np
11
  import torch
12
  import torch.nn.functional as F
 
13
  from open_clip import create_model, get_tokenizer
14
  from tqdm import tqdm
15
 
16
  import lib
17
  from templates import openai_imagenet_template
18
 
 
 
 
 
19
  model_str = "hf-hub:imageomics/bioclip"
20
  tokenizer_str = "ViT-B-16"
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
 
 
23
 
24
  @torch.no_grad()
25
  def write_txt_features(name_lookup):
@@ -38,7 +46,7 @@ def write_txt_features(name_lookup):
38
  ):
39
  # Skip if any non-zero elements
40
  if all_features[:, indices].any():
41
- print(f"Skipping batch {batch}")
42
  continue
43
 
44
  txts = [
@@ -59,6 +67,41 @@ def write_txt_features(name_lookup):
59
  np.save(args.out_path, all_features)
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def get_name_lookup(catalog_path, cache_path):
63
  if os.path.isfile(cache_path):
64
  with open(cache_path) as fd:
@@ -106,14 +149,14 @@ if __name__ == "__main__":
106
  args = parser.parse_args()
107
 
108
  name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
109
- print("Got name lookup.")
110
 
111
  model = create_model(model_str, output_dict=True, require_pretrained=True)
112
  model = model.to(device)
113
- print("Created model.")
114
-
115
  model = torch.compile(model)
116
- print("Compiled model.")
117
 
118
  tokenizer = get_tokenizer(tokenizer_str)
119
  write_txt_features(name_lookup)
 
 
6
  import csv
7
  import json
8
  import os
9
+ import logging
10
 
11
  import numpy as np
12
  import torch
13
  import torch.nn.functional as F
14
+
15
  from open_clip import create_model, get_tokenizer
16
  from tqdm import tqdm
17
 
18
  import lib
19
  from templates import openai_imagenet_template
20
 
21
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
22
+ logging.basicConfig(level=logging.INFO, format=log_format)
23
+ logger = logging.getLogger()
24
+
25
  model_str = "hf-hub:imageomics/bioclip"
26
  tokenizer_str = "ViT-B-16"
27
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
 
29
+ ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
30
+
31
 
32
  @torch.no_grad()
33
  def write_txt_features(name_lookup):
 
46
  ):
47
  # Skip if any non-zero elements
48
  if all_features[:, indices].any():
49
+ logger.info(f"Skipping batch {batch}")
50
  continue
51
 
52
  txts = [
 
67
  np.save(args.out_path, all_features)
68
 
69
 
70
+ def convert_txt_features_to_avgs(name_lookup):
71
+ assert os.path.isfile(args.out_path)
72
+
73
+ # Put that big boy on the GPU. We're going fast.
74
+ all_features = torch.from_numpy(np.load(args.out_path)).to(device)
75
+ logger.info("Loaded text features from disk to %s.", device)
76
+
77
+ all_names = [set() for rank in ranks]
78
+ for name, index in tqdm(name_lookup.values()):
79
+ i = len(name) - 1
80
+ all_names[i].add((name, index))
81
+
82
+ zeroed = 0
83
+ for i, rank in reversed(list(enumerate(ranks))):
84
+ if rank == "Species":
85
+ continue
86
+ for name, index in tqdm(all_names[i], desc=rank):
87
+ species = tuple(zip(*((d, i) for d, i in name_lookup.descendants(prefix=name) if len(d) >= 7)))
88
+ if not species:
89
+ logger.warning("No species for %s.", " ".join(name))
90
+ all_features[:, index] = 0.0
91
+ zeroed += 1
92
+ continue
93
+
94
+
95
+ values, indices = species
96
+ mean = all_features[:, indices].mean(dim=1)
97
+ all_features[:, index] = F.normalize(mean, dim=0)
98
+
99
+ out_path, ext = os.path.splitext(args.out_path)
100
+ np.save(f"{out_path}_avgs{ext}", all_features.cpu().numpy())
101
+ if zeroed:
102
+ logger.warning("Zeroed out %d nodes because they didn't have any genus or species-level labels.", zeroed)
103
+
104
+
105
  def get_name_lookup(catalog_path, cache_path):
106
  if os.path.isfile(cache_path):
107
  with open(cache_path) as fd:
 
149
  args = parser.parse_args()
150
 
151
  name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
152
+ logger.info("Got name lookup.")
153
 
154
  model = create_model(model_str, output_dict=True, require_pretrained=True)
155
  model = model.to(device)
156
+ logger.info("Created model.")
 
157
  model = torch.compile(model)
158
+ logger.info("Compiled model.")
159
 
160
  tokenizer = get_tokenizer(tokenizer_str)
161
  write_txt_features(name_lookup)
162
+ convert_txt_features_to_avgs(name_lookup)
test_lib.py CHANGED
@@ -422,3 +422,60 @@ def test_taxonomiclookup_children_of_gorilla():
422
  )
423
  expected = set()
424
  assert actual == expected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  )
423
  expected = set()
424
  assert actual == expected
425
+
426
+
427
+ def test_taxonomictree_descendants_last():
428
+ lookup = lib.TaxonomicTree()
429
+
430
+ lookup.add(("A", "B", "C", "D", "E", "F", "G"))
431
+
432
+ actual = list(lookup.descendants(("A", "B", "C", "D", "E", "F", "G")))
433
+
434
+ expected = [
435
+ (("A", "B", "C", "D", "E", "F", "G"), 6),
436
+ ]
437
+ assert actual == expected
438
+
439
+
440
+ def test_taxonomictree_descendants_entire_tree():
441
+ lookup = lib.TaxonomicTree()
442
+
443
+ lookup.add(("A", "B"))
444
+
445
+ actual = list(lookup.descendants())
446
+
447
+ expected = [
448
+ (("A",), 0),
449
+ (("A", "B"), 1),
450
+ ]
451
+ assert actual == expected
452
+
453
+
454
+ def test_taxonomictree_descendants_entire_tree_with_prefix():
455
+ lookup = lib.TaxonomicTree()
456
+
457
+ lookup.add(("A", "B"))
458
+
459
+ actual = list(lookup.descendants(prefix=("A",)))
460
+
461
+ expected = [
462
+ (("A",), 0),
463
+ (("A", "B"), 1),
464
+ ]
465
+ assert actual == expected
466
+
467
+
468
+ def test_taxonomictree_descendants_general():
469
+ lookup = lib.TaxonomicTree()
470
+
471
+ lookup.add(("A", "B", "C", "D", "E", "F", "G"))
472
+
473
+ actual = list(lookup.descendants(("A", "B", "C", "D")))
474
+
475
+ expected = [
476
+ (("A", "B", "C", "D"), 3),
477
+ (("A", "B", "C", "D", "E"), 4),
478
+ (("A", "B", "C", "D", "E", "F"), 5),
479
+ (("A", "B", "C", "D", "E", "F", "G"), 6),
480
+ ]
481
+ assert actual == expected