""" Makes the entire set of text emebeddings for all possible names in the tree of life. Uses the catalog.csv file from TreeOfLife-10M. """ import argparse import csv import json import numpy as np import torch import torch.nn.functional as F from open_clip import create_model, get_tokenizer from tqdm import tqdm import lib from templates import openai_imagenet_template model_str = "hf-hub:imageomics/bioclip" tokenizer_str = "ViT-B-16" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @torch.no_grad() def write_txt_features(name_lookup): all_features = np.memmap( args.out_path, dtype=np.float32, mode="w+", shape=(512, name_lookup.size) ) batch_size = args.batch_size // len(openai_imagenet_template) for names, indices in tqdm(lib.batched(name_lookup, batch_size)): txts = [template(name) for name in names for template in openai_imagenet_template] txts = tokenizer(txts).to(device) txt_features = model.encode_text(txts) txt_features = torch.reshape(txt_features, (batch_size, len(openai_imagenet_template), 512)) txt_features = F.normalize(txt_features, dim=2).mean(dim=1) txt_features /= txt_features.norm(dim=1, keepdim=True) all_features[:, indices] = txt_features.cpu().numpy().T all_features.flush() def get_name_lookup(catalog_path): lookup = lib.TaxonomicTree() with open(catalog_path) as fd: reader = csv.DictReader(fd) for row in tqdm(reader): name = [ row["kingdom"], row["phylum"], row["class"], row["order"], row["family"], row["genus"], row["species"], ] if any(not value for value in name): name = name[: name.index("")] lookup.add(name) return lookup if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--catalog-path", help="Path to the catalog.csv file from TreeOfLife-10M.", required=True, ) parser.add_argument("--out-path", help="Path to the output file.", required=True) parser.add_argument("--name-cache-path", help="Path to the name cache file.", default=".name_lookup_cache.json") parser.add_argument("--batch-size", help="Batch size.", default=2 ** 15, type=int) args = parser.parse_args() name_lookup = get_name_lookup(args.catalog_path) with open(args.name_cache_path, "w") as fd: json.dump(name_lookup, fd, cls=lib.TaxonomicJsonEncoder) print("Starting.") model = create_model(model_str, output_dict=True, require_pretrained=True) model = model.to(device) print("Created model.") model = torch.compile(model) print("Compiled model.") tokenizer = get_tokenizer(tokenizer_str) write_txt_features(name_lookup)