bioclip-demo / make_txt_embedding.py
Samuel Stevens
wip: hierarchical prediction
705d528
"""
Makes the entire set of text emebeddings for all possible names in the tree of life.
Uses the catalog.csv file from TreeOfLife-10M.
"""
import argparse
import csv
import json
import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from tqdm import tqdm
import lib
from templates import openai_imagenet_template
model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def write_txt_features(name_lookup):
all_features = np.memmap(
args.out_path, dtype=np.float32, mode="w+", shape=(512, name_lookup.size)
)
batch_size = args.batch_size // len(openai_imagenet_template)
for names, indices in tqdm(lib.batched(name_lookup, batch_size)):
txts = [template(name) for name in names for template in openai_imagenet_template]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = torch.reshape(txt_features, (batch_size, len(openai_imagenet_template), 512))
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
txt_features /= txt_features.norm(dim=1, keepdim=True)
all_features[:, indices] = txt_features.cpu().numpy().T
all_features.flush()
def get_name_lookup(catalog_path):
lookup = lib.TaxonomicTree()
with open(catalog_path) as fd:
reader = csv.DictReader(fd)
for row in tqdm(reader):
name = [
row["kingdom"],
row["phylum"],
row["class"],
row["order"],
row["family"],
row["genus"],
row["species"],
]
if any(not value for value in name):
name = name[: name.index("")]
lookup.add(name)
return lookup
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--catalog-path",
help="Path to the catalog.csv file from TreeOfLife-10M.",
required=True,
)
parser.add_argument("--out-path", help="Path to the output file.", required=True)
parser.add_argument("--name-cache-path", help="Path to the name cache file.", default=".name_lookup_cache.json")
parser.add_argument("--batch-size", help="Batch size.", default=2 ** 15, type=int)
args = parser.parse_args()
name_lookup = get_name_lookup(args.catalog_path)
with open(args.name_cache_path, "w") as fd:
json.dump(name_lookup, fd, cls=lib.TaxonomicJsonEncoder)
print("Starting.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
print("Created model.")
model = torch.compile(model)
print("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
write_txt_features(name_lookup)