Samuel Stevens commited on
Commit
7b4abf1
1 Parent(s): 484209d

Add zero-shot example.

Browse files
Files changed (3) hide show
  1. README.md +2 -0
  2. examples/README.md +18 -0
  3. examples/zero_shot.py +298 -0
README.md CHANGED
@@ -36,6 +36,8 @@ It is trained on [TreeOfLife-10M](https://huggingface.co/datasets/imageomics/Tre
36
  Through rigorous benchmarking on a diverse set of fine-grained biological classification tasks, BioCLIP consistently outperformed existing baselines by 17% to 20% absolute.
37
  Through intrinsic evaluation, we found that BioCLIP learned a hierarchical representation aligned to the tree of life, which demonstrates its potential for robust generalizability.
38
 
 
 
39
  ## Model Details
40
 
41
  ### Model Description
 
36
  Through rigorous benchmarking on a diverse set of fine-grained biological classification tasks, BioCLIP consistently outperformed existing baselines by 17% to 20% absolute.
37
  Through intrinsic evaluation, we found that BioCLIP learned a hierarchical representation aligned to the tree of life, which demonstrates its potential for robust generalizability.
38
 
39
+ **See the `examples/` directory for examples of how to use BioCLIP in zero-shot and few-shot settings.**
40
+
41
  ## Model Details
42
 
43
  ### Model Description
examples/README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Examples
2
+
3
+ ## Zero-Shot Classification
4
+
5
+ ```sh
6
+ pip install torch # whatever version you want
7
+ pip install open_clip_torch numpy tqdm torchvision
8
+ ```
9
+
10
+ Suppose you want to evaluate BioCLIP on zero-shot classification on two tasks, `<DATASET-NAME>` and `<DATASET2-NAME>`.
11
+ You can use `examples/zero_shot.py` to get top1 and top5 accuracy assuming your tasks are arranged as `torchvision`'s [`ImageFolder`](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html) wants.
12
+
13
+ ```sh
14
+ python examples/zero_shot.py \
15
+ --datasets <DATASET-NAME>=<DATASET-FOLDER> <DATASET2-NAME>=<DATASET2-FOLDER>
16
+ ```
17
+
18
+ This will write to `logs/bioclip-zero-shot/results.json` with your results.
examples/zero_shot.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Do zero-shot image classification.
3
+
4
+ Writes the output to a plaintext and JSON format in the logs directory.
5
+ """
6
+ import argparse
7
+ import ast
8
+ import contextlib
9
+ import json
10
+ import logging
11
+ import os
12
+ import random
13
+ import sys
14
+
15
+ import numpy as np
16
+ import open_clip
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torchvision import datasets
20
+ from tqdm import tqdm
21
+
22
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
23
+ logging.basicConfig(level=logging.INFO, format=log_format)
24
+ logger = logging.getLogger("main")
25
+
26
+ openai_templates = [
27
+ lambda c: f"a bad photo of a {c}.",
28
+ lambda c: f"a photo of many {c}.",
29
+ lambda c: f"a sculpture of a {c}.",
30
+ lambda c: f"a photo of the hard to see {c}.",
31
+ lambda c: f"a low resolution photo of the {c}.",
32
+ lambda c: f"a rendering of a {c}.",
33
+ lambda c: f"graffiti of a {c}.",
34
+ lambda c: f"a bad photo of the {c}.",
35
+ lambda c: f"a cropped photo of the {c}.",
36
+ lambda c: f"a tattoo of a {c}.",
37
+ lambda c: f"the embroidered {c}.",
38
+ lambda c: f"a photo of a hard to see {c}.",
39
+ lambda c: f"a bright photo of a {c}.",
40
+ lambda c: f"a photo of a clean {c}.",
41
+ lambda c: f"a photo of a dirty {c}.",
42
+ lambda c: f"a dark photo of the {c}.",
43
+ lambda c: f"a drawing of a {c}.",
44
+ lambda c: f"a photo of my {c}.",
45
+ lambda c: f"the plastic {c}.",
46
+ lambda c: f"a photo of the cool {c}.",
47
+ lambda c: f"a close-up photo of a {c}.",
48
+ lambda c: f"a black and white photo of the {c}.",
49
+ lambda c: f"a painting of the {c}.",
50
+ lambda c: f"a painting of a {c}.",
51
+ lambda c: f"a pixelated photo of the {c}.",
52
+ lambda c: f"a sculpture of the {c}.",
53
+ lambda c: f"a bright photo of the {c}.",
54
+ lambda c: f"a cropped photo of a {c}.",
55
+ lambda c: f"a plastic {c}.",
56
+ lambda c: f"a photo of the dirty {c}.",
57
+ lambda c: f"a jpeg corrupted photo of a {c}.",
58
+ lambda c: f"a blurry photo of the {c}.",
59
+ lambda c: f"a photo of the {c}.",
60
+ lambda c: f"a good photo of the {c}.",
61
+ lambda c: f"a rendering of the {c}.",
62
+ lambda c: f"a {c} in a video game.",
63
+ lambda c: f"a photo of one {c}.",
64
+ lambda c: f"a doodle of a {c}.",
65
+ lambda c: f"a close-up photo of the {c}.",
66
+ lambda c: f"a photo of a {c}.",
67
+ lambda c: f"the origami {c}.",
68
+ lambda c: f"the {c} in a video game.",
69
+ lambda c: f"a sketch of a {c}.",
70
+ lambda c: f"a doodle of the {c}.",
71
+ lambda c: f"a origami {c}.",
72
+ lambda c: f"a low resolution photo of a {c}.",
73
+ lambda c: f"the toy {c}.",
74
+ lambda c: f"a rendition of the {c}.",
75
+ lambda c: f"a photo of the clean {c}.",
76
+ lambda c: f"a photo of a large {c}.",
77
+ lambda c: f"a rendition of a {c}.",
78
+ lambda c: f"a photo of a nice {c}.",
79
+ lambda c: f"a photo of a weird {c}.",
80
+ lambda c: f"a blurry photo of a {c}.",
81
+ lambda c: f"a cartoon {c}.",
82
+ lambda c: f"art of a {c}.",
83
+ lambda c: f"a sketch of the {c}.",
84
+ lambda c: f"a embroidered {c}.",
85
+ lambda c: f"a pixelated photo of a {c}.",
86
+ lambda c: f"itap of the {c}.",
87
+ lambda c: f"a jpeg corrupted photo of the {c}.",
88
+ lambda c: f"a good photo of a {c}.",
89
+ lambda c: f"a plushie {c}.",
90
+ lambda c: f"a photo of the nice {c}.",
91
+ lambda c: f"a photo of the small {c}.",
92
+ lambda c: f"a photo of the weird {c}.",
93
+ lambda c: f"the cartoon {c}.",
94
+ lambda c: f"art of the {c}.",
95
+ lambda c: f"a drawing of the {c}.",
96
+ lambda c: f"a photo of the large {c}.",
97
+ lambda c: f"a black and white photo of a {c}.",
98
+ lambda c: f"the plushie {c}.",
99
+ lambda c: f"a dark photo of a {c}.",
100
+ lambda c: f"itap of a {c}.",
101
+ lambda c: f"graffiti of the {c}.",
102
+ lambda c: f"a toy {c}.",
103
+ lambda c: f"itap of my {c}.",
104
+ lambda c: f"a photo of a cool {c}.",
105
+ lambda c: f"a photo of a small {c}.",
106
+ lambda c: f"a tattoo of the {c}.",
107
+ ]
108
+
109
+
110
+ def parse_args(args):
111
+ class ParseKwargs(argparse.Action):
112
+ def __call__(self, parser, namespace, values, option_string=None):
113
+ kw = {}
114
+ for value in values:
115
+ key, value = value.split("=")
116
+ try:
117
+ kw[key] = ast.literal_eval(value)
118
+ except (ValueError, SyntaxError):
119
+ # fallback to string (avoid need to escape on command line)
120
+ kw[key] = str(value)
121
+ setattr(namespace, self.dest, kw)
122
+
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument(
125
+ "--datasets",
126
+ type=str,
127
+ default=None,
128
+ nargs="+",
129
+ help="Path to dirs(s) with validation data. In the format NAME=PATH.",
130
+ action=ParseKwargs,
131
+ )
132
+ parser.add_argument(
133
+ "--logs", type=str, default="./logs", help="Where to write logs"
134
+ )
135
+ parser.add_argument(
136
+ "--exp", type=str, default="bioclip-zero-shot", help="Experiment name."
137
+ )
138
+ parser.add_argument(
139
+ "--workers", type=int, default=8, help="Number of dataloader workers per GPU."
140
+ )
141
+ parser.add_argument(
142
+ "--batch-size", type=int, default=64, help="Batch size per GPU."
143
+ )
144
+ parser.add_argument(
145
+ "--precision",
146
+ choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp32"],
147
+ default="amp",
148
+ help="Floating point precision.",
149
+ )
150
+ parser.add_argument("--seed", type=int, default=0, help="Default random seed.")
151
+ args = parser.parse_args(args)
152
+ os.makedirs(os.path.join(args.logs, args.exp), exist_ok=True)
153
+
154
+ return args
155
+
156
+
157
+ def make_txt_features(model, classnames, templates, args):
158
+ tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
159
+ with torch.no_grad():
160
+ txt_features = []
161
+ for classname in tqdm(classnames):
162
+ classname = " ".join(word for word in classname.split("_") if word)
163
+ texts = [template(classname) for template in templates] # format with class
164
+ texts = tokenizer(texts).to(args.device) # tokenize
165
+ class_embeddings = model.encode_text(texts)
166
+ class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
167
+ class_embedding /= class_embedding.norm()
168
+ txt_features.append(class_embedding)
169
+ txt_features = torch.stack(txt_features, dim=1).to(args.device)
170
+ return txt_features
171
+
172
+
173
+ def accuracy(output, target, topk=(1,)):
174
+ pred = output.topk(max(topk), 1, True, True)[1].t()
175
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
176
+ return [correct[:k].reshape(-1).float().sum(0, keepdim=True).item() for k in topk]
177
+
178
+
179
+ def get_autocast(precision):
180
+ if precision == "amp":
181
+ return torch.cuda.amp.autocast
182
+ elif precision == "amp_bfloat16" or precision == "amp_bf16":
183
+ # amp_bfloat16 is more stable than amp float16 for clip training
184
+ return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
185
+ else:
186
+ return contextlib.suppress
187
+
188
+
189
+ def run(model, txt_features, dataloader, args):
190
+ autocast = get_autocast(args.precision)
191
+ cast_dtype = open_clip.get_cast_dtype(args.precision)
192
+
193
+ top1, top5, n = 0.0, 0.0, 0.0
194
+
195
+ with torch.no_grad():
196
+ for images, targets in tqdm(dataloader, unit_scale=args.batch_size):
197
+ images = images.to(args.device)
198
+ if cast_dtype is not None:
199
+ images = images.to(dtype=cast_dtype)
200
+ targets = targets.to(args.device)
201
+
202
+ with autocast():
203
+ image_features = model.encode_image(images)
204
+ image_features = F.normalize(image_features, dim=-1)
205
+ logits = model.logit_scale.exp() * image_features @ txt_features
206
+
207
+ # Measure accuracy
208
+ acc1, acc5 = accuracy(logits, targets, topk=(1, 5))
209
+ top1 += acc1
210
+ top5 += acc5
211
+ n += images.size(0)
212
+
213
+ top1 = top1 / n
214
+ top5 = top5 / n
215
+ return top1, top5
216
+
217
+
218
+ def evaluate(model, data, args):
219
+ results = {}
220
+
221
+ logger.info("Starting zero-shot classification.")
222
+
223
+ for split in data:
224
+ logger.info("Building zero-shot %s classifier.", split)
225
+
226
+ classnames = data[split].dataset.classes
227
+ classnames = [name.replace("_", " ") for name in classnames]
228
+
229
+ txt_features = make_txt_features(model, classnames, openai_templates, args)
230
+
231
+ logger.info("Got text features.")
232
+ top1, top5 = run(model, txt_features, data[split], args)
233
+
234
+ logger.info("%s-top1: %.3f", split, top1 * 100)
235
+ logger.info("%s-top5: %.3f", split, top5 * 100)
236
+
237
+ results[f"{split}-top1"] = top1 * 100
238
+ results[f"{split}-top5"] = top5 * 100
239
+
240
+ logger.info("Finished zero-shot %s.", split)
241
+
242
+ logger.info("Finished zero-shot classification.")
243
+
244
+ return results
245
+
246
+
247
+ if __name__ == "__main__":
248
+ args = parse_args(sys.argv[1:])
249
+
250
+ if torch.cuda.is_available():
251
+ # This enables tf32 on Ampere GPUs which is only 8% slower than
252
+ # float16 and almost as accurate as float32
253
+ # This was a default in pytorch until 1.12
254
+ torch.backends.cuda.matmul.allow_tf32 = True
255
+ torch.backends.cudnn.benchmark = True
256
+ torch.backends.cudnn.deterministic = False
257
+
258
+ # Init torch device
259
+ if torch.cuda.is_available():
260
+ device = "cuda:0"
261
+ torch.cuda.set_device(device)
262
+ else:
263
+ device = "cpu"
264
+ args.device = device
265
+
266
+ # Random seeding
267
+ torch.manual_seed(args.seed)
268
+ np.random.seed(args.seed)
269
+ random.seed(args.seed)
270
+
271
+ # Load model.
272
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
273
+ "hf-hub:imageomics/bioclip"
274
+ )
275
+
276
+ # Write datasets
277
+ params_file = os.path.join(args.logs, args.exp, "params.json")
278
+ with open(params_file, "w") as fd:
279
+ params = {name: getattr(args, name) for name in vars(args)}
280
+ json.dump(params, fd, sort_keys=True, indent=4)
281
+
282
+ # Initialize datasets.
283
+ data = {}
284
+ for split, path in args.datasets.items():
285
+ data[split] = torch.utils.data.DataLoader(
286
+ datasets.ImageFolder(path, transform=preprocess_val),
287
+ batch_size=args.batch_size,
288
+ num_workers=args.workers,
289
+ sampler=None,
290
+ shuffle=False,
291
+ )
292
+
293
+ model.eval()
294
+ results = evaluate(model, data, args)
295
+
296
+ results_file = os.path.join(args.logs, args.exp, "results.json")
297
+ with open(results_file, "w") as fd:
298
+ json.dump(results, fd, indent=4, sort_keys=True)