File size: 2,905 Bytes
705d528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Makes the entire set of text emebeddings for all possible names in the tree of life. 
Uses the catalog.csv file from TreeOfLife-10M.
"""
import argparse
import csv
import json

import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from tqdm import tqdm

import lib
from templates import openai_imagenet_template

model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def write_txt_features(name_lookup):
    all_features = np.memmap(
        args.out_path, dtype=np.float32, mode="w+", shape=(512, name_lookup.size)
    )

    batch_size = args.batch_size // len(openai_imagenet_template)
    for names, indices in tqdm(lib.batched(name_lookup, batch_size)):
        txts = [template(name) for name in names for template in openai_imagenet_template]
        txts = tokenizer(txts).to(device)
        txt_features = model.encode_text(txts)
        txt_features = torch.reshape(txt_features, (batch_size, len(openai_imagenet_template), 512))
        txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
        txt_features /= txt_features.norm(dim=1, keepdim=True)
        all_features[:, indices] = txt_features.cpu().numpy().T

    all_features.flush()


def get_name_lookup(catalog_path):
    lookup = lib.TaxonomicTree()

    with open(catalog_path) as fd:
        reader = csv.DictReader(fd)
        for row in tqdm(reader):
            name = [
                row["kingdom"],
                row["phylum"],
                row["class"],
                row["order"],
                row["family"],
                row["genus"],
                row["species"],
            ]
            if any(not value for value in name):
                name = name[: name.index("")]
            lookup.add(name)

    return lookup


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--catalog-path",
        help="Path to the catalog.csv file from TreeOfLife-10M.",
        required=True,
    )
    parser.add_argument("--out-path", help="Path to the output file.", required=True)
    parser.add_argument("--name-cache-path", help="Path to the name cache file.", default=".name_lookup_cache.json")
    parser.add_argument("--batch-size", help="Batch size.", default=2 ** 15, type=int)
    args = parser.parse_args()

    name_lookup = get_name_lookup(args.catalog_path)
    with open(args.name_cache_path, "w") as fd:
        json.dump(name_lookup, fd, cls=lib.TaxonomicJsonEncoder)

    print("Starting.")
    model = create_model(model_str, output_dict=True, require_pretrained=True)
    model = model.to(device)
    print("Created model.")

    model = torch.compile(model)
    print("Compiled model.")

    tokenizer = get_tokenizer(tokenizer_str)
    write_txt_features(name_lookup)