BIOtest / make_txt_embedding.py
Samuel Stevens
wip: hierarchical prediction
290c238
raw history blame
No virus
2.91 kB
"""
Makes the entire set of text emebeddings for all possible names in the tree of life.
Uses the catalog.csv file from TreeOfLife-10M.
"""
import argparse
import csv
import json
import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from tqdm import tqdm
import lib
from templates import openai_imagenet_template
model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def write_txt_features(name_lookup):
all_features = np.memmap(
args.out_path, dtype=np.float32, mode="w+", shape=(512, name_lookup.size)
)
batch_size = args.batch_size // len(openai_imagenet_template)
for names, indices in tqdm(lib.batched(name_lookup, batch_size)):
txts = [template(name) for name in names for template in openai_imagenet_template]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = torch.reshape(txt_features, (batch_size, len(openai_imagenet_template), 512))
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
txt_features /= txt_features.norm(dim=1, keepdim=True)
all_features[:, indices] = txt_features.cpu().numpy().T
all_features.flush()
def get_name_lookup(catalog_path):
lookup = lib.TaxonomicTree()
with open(catalog_path) as fd:
reader = csv.DictReader(fd)
for row in tqdm(reader):
name = [
row["kingdom"],
row["phylum"],
row["class"],
row["order"],
row["family"],
row["genus"],
row["species"],
]
if any(not value for value in name):
name = name[: name.index("")]
lookup.add(name)
return lookup
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--catalog-path",
help="Path to the catalog.csv file from TreeOfLife-10M.",
required=True,
)
parser.add_argument("--out-path", help="Path to the output file.", required=True)
parser.add_argument("--name-cache-path", help="Path to the name cache file.", default=".name_lookup_cache.json")
parser.add_argument("--batch-size", help="Batch size.", default=2 ** 15, type=int)
args = parser.parse_args()
name_lookup = get_name_lookup(args.catalog_path)
with open(args.name_cache_path, "w") as fd:
json.dump(name_lookup, fd, cls=lib.TaxonomicJsonEncoder)
print("Starting.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
print("Created model.")
model = torch.compile(model)
print("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
write_txt_features(name_lookup)