Samuel Stevens commited on
Commit
6e5adf0
1 Parent(s): d4005aa

add open-domain classification back

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -1
  2. app.py +115 -112
  3. make_txt_embedding.py +21 -0
  4. txt_emb_species.json +3 -0
  5. txt_emb_species.npy +3 -0
.gitattributes CHANGED
@@ -34,6 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
 
37
- *lookup.json filter=lfs diff=lfs merge=lfs -text
38
  *.jpeg filter=lfs diff=lfs merge=lfs -text
39
  *.png filter=lfs diff=lfs merge=lfs -text
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
 
37
+ *.json filter=lfs diff=lfs merge=lfs -text
38
  *.jpeg filter=lfs diff=lfs merge=lfs -text
39
  *.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import json
2
  import os
3
 
@@ -8,15 +10,18 @@ import torch.nn.functional as F
8
  from open_clip import create_model, get_tokenizer
9
  from torchvision import transforms
10
 
11
- import lib
12
  from templates import openai_imagenet_template
13
 
14
  hf_token = os.getenv("HF_TOKEN")
15
 
16
  model_str = "hf-hub:imageomics/bioclip"
17
  tokenizer_str = "ViT-B-16"
18
- name_lookup_json = "name_lookup.json"
19
- txt_emb_npy = "txt_emb.npy"
 
 
 
 
20
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
@@ -33,12 +38,12 @@ preprocess_img = transforms.Compose(
33
 
34
  ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
35
 
36
- # open_domain_examples = [
37
- # ["examples/Ursus-arctos.jpeg", "Species"],
38
- # ["examples/Phoca-vitulina.png", "Species"],
39
- # ["examples/Felis-catus.jpeg", "Genus"],
40
- # ["examples/Sarcoscypha-coccinea.jpeg", "Order"],
41
- # ]
42
  zero_shot_examples = [
43
  [
44
  "examples/Ursus-arctos.jpeg",
@@ -73,6 +78,10 @@ zero_shot_examples = [
73
  ]
74
 
75
 
 
 
 
 
76
  @torch.no_grad()
77
  def get_txt_features(classnames, templates):
78
  all_features = []
@@ -102,52 +111,38 @@ def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
102
 
103
 
104
  @torch.no_grad()
105
- def open_domain_classification(img, rank: int) -> list[dict[str, float]]:
106
  """
107
- Predicts from the top of the tree of life down to the species.
 
 
108
  """
109
  img = preprocess_img(img).to(device)
110
  img_features = model.encode_image(img.unsqueeze(0))
111
  img_features = F.normalize(img_features, dim=-1)
112
 
113
- outputs = []
114
-
115
- name = []
116
- for _ in range(rank + 1):
117
- children = tuple(zip(*name_lookup.children(name)))
118
- if not children:
119
- break
120
- values, indices = children
121
- txt_features = txt_emb[:, indices].to(device)
122
- logits = (model.logit_scale.exp() * img_features @ txt_features).view(-1)
123
-
124
- probs = F.softmax(logits, dim=0).to("cpu").tolist()
125
- parent = " ".join(name)
126
- outputs.append(
127
- {f"{parent} {value}": prob for value, prob in zip(values, probs)}
128
- )
129
-
130
- top = values[logits.argmax()]
131
- name.append(top)
132
 
133
- while len(outputs) < 7:
134
- outputs.append({})
 
 
 
 
135
 
136
- return list(reversed(outputs))
 
 
 
137
 
 
138
 
139
- def change_output(choice):
140
- return [
141
- gr.Label(
142
- num_top_classes=5, label=rank, show_label=True, visible=(6 - i <= choice)
143
- )
144
- for i, rank in enumerate(reversed(ranks))
145
- ]
146
 
147
 
148
- def get_name_lookup(path):
149
- with open(path) as fd:
150
- return lib.TaxonomicTree.from_dict(json.load(fd))
151
 
152
 
153
  if __name__ == "__main__":
@@ -161,8 +156,9 @@ if __name__ == "__main__":
161
 
162
  tokenizer = get_tokenizer(tokenizer_str)
163
 
164
- name_lookup = get_name_lookup(name_lookup_json)
165
- txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r"))
 
166
 
167
  done = txt_emb.any(axis=0).sum().item()
168
  total = txt_emb.shape[1]
@@ -173,69 +169,76 @@ if __name__ == "__main__":
173
  with gr.Blocks() as app:
174
  img_input = gr.Image(height=512)
175
 
176
- # with gr.Tab("Open-Ended"):
177
- # with gr.Row():
178
- # with gr.Column():
179
- # rank_dropdown = gr.Dropdown(
180
- # label="Taxonomic Rank",
181
- # info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
182
- # choices=ranks,
183
- # value="Species",
184
- # type="index",
185
- # )
186
- # open_domain_btn = gr.Button("Submit", variant="primary")
187
- # gr.Examples(
188
- # examples=open_domain_examples,
189
- # inputs=[img_input, rank_dropdown],
190
- # )
191
-
192
- # with gr.Column():
193
- # open_domain_outputs = [
194
- # gr.Label(num_top_classes=5, label=rank, show_label=True)
195
- # for rank in reversed(ranks)
196
- # ]
197
- # open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
198
-
199
- # open_domain_callback = gr.HuggingFaceDatasetSaver(
200
- # hf_token, "imageomics/bioclip-demo-open-domain-mistakes", private=True
201
- # )
202
- # open_domain_callback.setup(
203
- # [img_input, *open_domain_outputs], flagging_dir="logs/flagged"
204
- # )
205
- # open_domain_flag_btn.click(
206
- # lambda *args: open_domain_callback.flag(args),
207
- # [img_input, *open_domain_outputs],
208
- # None,
209
- # preprocess=False,
210
- # )
211
-
212
- # with gr.Tab("Zero-Shot"):
213
- with gr.Row():
214
- with gr.Column():
215
- classes_txt = gr.Textbox(
216
- placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
217
- lines=3,
218
- label="Classes",
219
- show_label=True,
220
- info="Use taxonomic names where possible; include common names if possible.",
221
  )
222
- zero_shot_btn = gr.Button("Submit", variant="primary")
223
 
224
- with gr.Column():
225
- zero_shot_output = gr.Label(
226
- num_top_classes=5, label="Prediction", show_label=True
227
- )
228
- zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
229
-
230
- with gr.Row():
231
- gr.Examples(
232
- examples=zero_shot_examples,
233
- inputs=[img_input, classes_txt],
234
- cache_examples=True,
235
- fn=zero_shot_classification,
236
- outputs=[zero_shot_output],
237
  )
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  zero_shot_callback = gr.HuggingFaceDatasetSaver(
240
  hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
241
  )
@@ -249,15 +252,15 @@ if __name__ == "__main__":
249
  preprocess=False,
250
  )
251
 
252
- # rank_dropdown.change(
253
- # fn=change_output, inputs=rank_dropdown, outputs=open_domain_outputs
254
- # )
255
 
256
- # open_domain_btn.click(
257
- # fn=open_domain_classification,
258
- # inputs=[img_input, rank_dropdown],
259
- # outputs=open_domain_outputs,
260
- # )
261
 
262
  zero_shot_btn.click(
263
  fn=zero_shot_classification,
 
1
+ import collections
2
+ import heapq
3
  import json
4
  import os
5
 
 
10
  from open_clip import create_model, get_tokenizer
11
  from torchvision import transforms
12
 
 
13
  from templates import openai_imagenet_template
14
 
15
  hf_token = os.getenv("HF_TOKEN")
16
 
17
  model_str = "hf-hub:imageomics/bioclip"
18
  tokenizer_str = "ViT-B-16"
19
+
20
+ txt_emb_npy = "txt_emb_species.npy"
21
+ txt_names_json = "txt_emb_species.json"
22
+
23
+ min_prob = 1e-9
24
+ k = 5
25
 
26
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
 
 
38
 
39
  ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
40
 
41
+ open_domain_examples = [
42
+ ["examples/Ursus-arctos.jpeg", "Species"],
43
+ ["examples/Phoca-vitulina.png", "Species"],
44
+ ["examples/Felis-catus.jpeg", "Genus"],
45
+ ["examples/Sarcoscypha-coccinea.jpeg", "Order"],
46
+ ]
47
  zero_shot_examples = [
48
  [
49
  "examples/Ursus-arctos.jpeg",
 
78
  ]
79
 
80
 
81
+ def indexed(lst, indices):
82
+ return [lst[i] for i in indices]
83
+
84
+
85
  @torch.no_grad()
86
  def get_txt_features(classnames, templates):
87
  all_features = []
 
111
 
112
 
113
  @torch.no_grad()
114
+ def open_domain_classification(img, rank: int) -> dict[str, float]:
115
  """
116
+ Predicts from the entire tree of life.
117
+ If targeting a higher rank than species, then this function predicts among all
118
+ species, then sums up species-level probabilities for the given rank.
119
  """
120
  img = preprocess_img(img).to(device)
121
  img_features = model.encode_image(img.unsqueeze(0))
122
  img_features = F.normalize(img_features, dim=-1)
123
 
124
+ logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
125
+ probs = F.softmax(logits, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # If predicting species, no need to sum probabilities.
128
+ if rank + 1 == len(ranks):
129
+ topk = probs.topk(k)
130
+ return {
131
+ " ".join(txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
132
+ }
133
 
134
+ # Sum up by the rank
135
+ output = collections.defaultdict(float)
136
+ for i in torch.nonzero(probs > min_prob).squeeze():
137
+ output[" ".join(txt_names[i][: rank + 1])] += probs[i]
138
 
139
+ topk_names = heapq.nlargest(k, output, key=output.get)
140
 
141
+ return {name: output[name] for name in topk_names}
 
 
 
 
 
 
142
 
143
 
144
+ def change_output(choice):
145
+ return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
 
146
 
147
 
148
  if __name__ == "__main__":
 
156
 
157
  tokenizer = get_tokenizer(tokenizer_str)
158
 
159
+ txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device)
160
+ with open(txt_names_json) as fd:
161
+ txt_names = json.load(fd)
162
 
163
  done = txt_emb.any(axis=0).sum().item()
164
  total = txt_emb.shape[1]
 
169
  with gr.Blocks() as app:
170
  img_input = gr.Image(height=512)
171
 
172
+ with gr.Tab("Open-Ended"):
173
+ with gr.Row():
174
+ with gr.Column():
175
+ rank_dropdown = gr.Dropdown(
176
+ label="Taxonomic Rank",
177
+ info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
178
+ choices=ranks,
179
+ value="Species",
180
+ type="index",
181
+ )
182
+ open_domain_btn = gr.Button("Submit", variant="primary")
183
+ with gr.Column():
184
+ open_domain_output = gr.Label(
185
+ num_top_classes=k,
186
+ label="Prediction",
187
+ show_label=True,
188
+ value=None,
189
+ )
190
+ open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
191
+
192
+ with gr.Row():
193
+ gr.Examples(
194
+ examples=open_domain_examples,
195
+ inputs=[img_input, rank_dropdown],
196
+ cache_examples=True,
197
+ fn=open_domain_classification,
198
+ outputs=[open_domain_output],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  )
 
200
 
201
+ open_domain_callback = gr.HuggingFaceDatasetSaver(
202
+ hf_token, "imageomics/bioclip-demo-open-domain-mistakes", private=True
203
+ )
204
+ open_domain_callback.setup(
205
+ [img_input, rank_dropdown, open_domain_output],
206
+ flagging_dir="logs/flagged",
207
+ )
208
+ open_domain_flag_btn.click(
209
+ lambda *args: open_domain_callback.flag(args),
210
+ [img_input, rank_dropdown, open_domain_output],
211
+ None,
212
+ preprocess=False,
 
213
  )
214
 
215
+ with gr.Tab("Zero-Shot"):
216
+ with gr.Row():
217
+ with gr.Column():
218
+ classes_txt = gr.Textbox(
219
+ placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
220
+ lines=3,
221
+ label="Classes",
222
+ show_label=True,
223
+ info="Use taxonomic names where possible; include common names if possible.",
224
+ )
225
+ zero_shot_btn = gr.Button("Submit", variant="primary")
226
+
227
+ with gr.Column():
228
+ zero_shot_output = gr.Label(
229
+ num_top_classes=k, label="Prediction", show_label=True
230
+ )
231
+ zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
232
+
233
+ with gr.Row():
234
+ gr.Examples(
235
+ examples=zero_shot_examples,
236
+ inputs=[img_input, classes_txt],
237
+ cache_examples=True,
238
+ fn=zero_shot_classification,
239
+ outputs=[zero_shot_output],
240
+ )
241
+
242
  zero_shot_callback = gr.HuggingFaceDatasetSaver(
243
  hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
244
  )
 
252
  preprocess=False,
253
  )
254
 
255
+ rank_dropdown.change(
256
+ fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
257
+ )
258
 
259
+ open_domain_btn.click(
260
+ fn=open_domain_classification,
261
+ inputs=[img_input, rank_dropdown],
262
+ outputs=[open_domain_output],
263
+ )
264
 
265
  zero_shot_btn.click(
266
  fn=zero_shot_classification,
make_txt_embedding.py CHANGED
@@ -112,6 +112,26 @@ def convert_txt_features_to_avgs(name_lookup):
112
  )
113
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def get_name_lookup(catalog_path, cache_path):
116
  if os.path.isfile(cache_path):
117
  with open(cache_path) as fd:
@@ -170,3 +190,4 @@ if __name__ == "__main__":
170
  tokenizer = get_tokenizer(tokenizer_str)
171
  write_txt_features(name_lookup)
172
  convert_txt_features_to_avgs(name_lookup)
 
 
112
  )
113
 
114
 
115
+ def convert_txt_features_to_species_only(name_lookup):
116
+ assert os.path.isfile(args.out_path)
117
+
118
+ all_features = np.load(args.out_path)
119
+ logger.info("Loaded text features from disk.")
120
+
121
+ species = [(d, i) for d, i in name_lookup.descendants() if len(d) == 7]
122
+ species_features = np.zeros((512, len(species)), dtype=np.float32)
123
+ species_names = [""] * len(species)
124
+
125
+ for new_i, (name, old_i) in enumerate(tqdm(species)):
126
+ species_features[:, new_i] = all_features[:, old_i]
127
+ species_names[new_i] = name
128
+
129
+ out_path, ext = os.path.splitext(args.out_path)
130
+ np.save(f"{out_path}_species{ext}", species_features)
131
+ with open(f"{out_path}_species.json", "w") as fd:
132
+ json.dump(species_names, fd, indent=2)
133
+
134
+
135
  def get_name_lookup(catalog_path, cache_path):
136
  if os.path.isfile(cache_path):
137
  with open(cache_path) as fd:
 
190
  tokenizer = get_tokenizer(tokenizer_str)
191
  write_txt_features(name_lookup)
192
  convert_txt_features_to_avgs(name_lookup)
193
+ convert_txt_features_to_species_only(name_lookup)
txt_emb_species.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c71babd1b7bc275a1dbb12fd36e6329bcc2487784c0b7be10c2f4d0031d34211
3
+ size 50445969
txt_emb_species.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91ce02dff2433222e3138b8bf7eefa1dd74b30f4d406c16cd3301f66d65ab4ed
3
+ size 787435648