|
""" |
|
Makes the entire set of text emebeddings for all possible names in the tree of life. |
|
Uses the catalog.csv file from TreeOfLife-10M. |
|
""" |
|
import argparse |
|
import csv |
|
import json |
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from open_clip import create_model, get_tokenizer |
|
from tqdm import tqdm |
|
|
|
import lib |
|
from templates import openai_imagenet_template |
|
|
|
model_str = "hf-hub:imageomics/bioclip" |
|
tokenizer_str = "ViT-B-16" |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
@torch.no_grad() |
|
def write_txt_features(name_lookup): |
|
if os.path.isfile(args.out_path): |
|
all_features = np.load(args.out_path) |
|
else: |
|
all_features = np.zeros((512, len(name_lookup)), dtype=np.float32) |
|
|
|
batch_size = args.batch_size // len(openai_imagenet_template) |
|
for batch, (names, indices) in enumerate( |
|
tqdm( |
|
lib.batched(name_lookup, batch_size), |
|
desc="txt feats", |
|
total=len(name_lookup) // batch_size, |
|
) |
|
): |
|
|
|
if all_features[:, indices].any(): |
|
print(f"Skipping batch {batch}") |
|
continue |
|
|
|
txts = [ |
|
template(name) for name in names for template in openai_imagenet_template |
|
] |
|
txts = tokenizer(txts).to(device) |
|
txt_features = model.encode_text(txts) |
|
txt_features = torch.reshape( |
|
txt_features, (len(names), len(openai_imagenet_template), 512) |
|
) |
|
txt_features = F.normalize(txt_features, dim=2).mean(dim=1) |
|
txt_features /= txt_features.norm(dim=1, keepdim=True) |
|
all_features[:, indices] = txt_features.T.cpu().numpy() |
|
|
|
if batch % 100 == 0: |
|
np.save(args.out_path, all_features) |
|
|
|
np.save(args.out_path, all_features) |
|
|
|
|
|
def get_name_lookup(catalog_path, cache_path): |
|
if os.path.isfile(cache_path): |
|
with open(cache_path) as fd: |
|
lookup = lib.TaxonomicTree.from_dict(json.load(fd)) |
|
return lookup |
|
|
|
lookup = lib.TaxonomicTree() |
|
|
|
with open(catalog_path) as fd: |
|
reader = csv.DictReader(fd) |
|
for row in tqdm(reader, desc="catalog"): |
|
name = [ |
|
row["kingdom"], |
|
row["phylum"], |
|
row["class"], |
|
row["order"], |
|
row["family"], |
|
row["genus"], |
|
row["species"], |
|
] |
|
if any(not value for value in name): |
|
name = name[: name.index("")] |
|
lookup.add(name) |
|
|
|
with open(args.name_cache_path, "w") as fd: |
|
json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder) |
|
|
|
return lookup |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--catalog-path", |
|
help="Path to the catalog.csv file from TreeOfLife-10M.", |
|
required=True, |
|
) |
|
parser.add_argument("--out-path", help="Path to the output file.", required=True) |
|
parser.add_argument( |
|
"--name-cache-path", |
|
help="Path to the name cache file.", |
|
default="name_lookup.json", |
|
) |
|
parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int) |
|
args = parser.parse_args() |
|
|
|
name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path) |
|
print("Got name lookup.") |
|
|
|
model = create_model(model_str, output_dict=True, require_pretrained=True) |
|
model = model.to(device) |
|
print("Created model.") |
|
|
|
model = torch.compile(model) |
|
print("Compiled model.") |
|
|
|
tokenizer = get_tokenizer(tokenizer_str) |
|
write_txt_features(name_lookup) |
|
|