File size: 3,637 Bytes
290c238
 
 
 
 
 
 
2cfb891
290c238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cfb891
 
 
 
290c238
 
2cfb891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290c238
 
2cfb891
 
 
290c238
 
2cfb891
 
 
 
290c238
2cfb891
290c238
 
2cfb891
 
 
 
 
 
290c238
 
 
 
2cfb891
290c238
 
 
 
 
 
 
 
 
 
 
 
 
2cfb891
 
 
290c238
 
 
 
 
 
 
 
 
 
 
2cfb891
 
 
 
 
 
290c238
 
2cfb891
 
290c238
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Makes the entire set of text emebeddings for all possible names in the tree of life. 
Uses the catalog.csv file from TreeOfLife-10M.
"""
import argparse
import csv
import json
import os

import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from tqdm import tqdm

import lib
from templates import openai_imagenet_template

model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def write_txt_features(name_lookup):
    if os.path.isfile(args.out_path):
        all_features = np.load(args.out_path)
    else:
        all_features = np.zeros((512, len(name_lookup)), dtype=np.float32)

    batch_size = args.batch_size // len(openai_imagenet_template)
    for batch, (names, indices) in enumerate(
        tqdm(
            lib.batched(name_lookup, batch_size),
            desc="txt feats",
            total=len(name_lookup) // batch_size,
        )
    ):
        # Skip if any non-zero elements
        if all_features[:, indices].any():
            print(f"Skipping batch {batch}")
            continue

        txts = [
            template(name) for name in names for template in openai_imagenet_template
        ]
        txts = tokenizer(txts).to(device)
        txt_features = model.encode_text(txts)
        txt_features = torch.reshape(
            txt_features, (len(names), len(openai_imagenet_template), 512)
        )
        txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
        txt_features /= txt_features.norm(dim=1, keepdim=True)
        all_features[:, indices] = txt_features.T.cpu().numpy()

        if batch % 100 == 0:
            np.save(args.out_path, all_features)

    np.save(args.out_path, all_features)


def get_name_lookup(catalog_path, cache_path):
    if os.path.isfile(cache_path):
        with open(cache_path) as fd:
            lookup = lib.TaxonomicTree.from_dict(json.load(fd))
        return lookup

    lookup = lib.TaxonomicTree()

    with open(catalog_path) as fd:
        reader = csv.DictReader(fd)
        for row in tqdm(reader, desc="catalog"):
            name = [
                row["kingdom"],
                row["phylum"],
                row["class"],
                row["order"],
                row["family"],
                row["genus"],
                row["species"],
            ]
            if any(not value for value in name):
                name = name[: name.index("")]
            lookup.add(name)

    with open(args.name_cache_path, "w") as fd:
        json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder)

    return lookup


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--catalog-path",
        help="Path to the catalog.csv file from TreeOfLife-10M.",
        required=True,
    )
    parser.add_argument("--out-path", help="Path to the output file.", required=True)
    parser.add_argument(
        "--name-cache-path",
        help="Path to the name cache file.",
        default="name_lookup.json",
    )
    parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int)
    args = parser.parse_args()

    name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
    print("Got name lookup.")

    model = create_model(model_str, output_dict=True, require_pretrained=True)
    model = model.to(device)
    print("Created model.")

    model = torch.compile(model)
    print("Compiled model.")

    tokenizer = get_tokenizer(tokenizer_str)
    write_txt_features(name_lookup)