|
import os |
|
|
|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
from open_clip import create_model, get_tokenizer |
|
from torchvision import transforms |
|
|
|
from templates import openai_imagenet_template |
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo") |
|
|
|
model_str = "hf-hub:imageomics/bioclip" |
|
tokenizer_str = "ViT-B-16" |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
preprocess_img = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=(0.48145466, 0.4578275, 0.40821073), |
|
std=(0.26862954, 0.26130258, 0.27577711), |
|
), |
|
] |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def get_txt_features(classnames, templates): |
|
all_features = [] |
|
for classname in classnames: |
|
txts = [template(classname) for template in templates] |
|
txts = tokenizer(txts).to(device) |
|
txt_features = model.encode_text(txts) |
|
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0) |
|
txt_features /= txt_features.norm() |
|
all_features.append(txt_features) |
|
all_features = torch.stack(all_features, dim=1) |
|
return all_features |
|
|
|
|
|
@torch.no_grad() |
|
def predict(img, classes: list[str]) -> dict[str, float]: |
|
classes = [cls.strip() for cls in classes if cls.strip()] |
|
txt_features = get_txt_features(classes, openai_imagenet_template) |
|
|
|
img = preprocess_img(img).to(device) |
|
img_features = model.encode_image(img.unsqueeze(0)) |
|
img_features = F.normalize(img_features, dim=-1) |
|
|
|
logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze() |
|
probs = F.softmax(logits, dim=0).to("cpu").tolist() |
|
return {cls: prob for cls, prob in zip(classes, probs)} |
|
|
|
|
|
def hierarchical_predict(img) -> list[str]: |
|
""" |
|
Predicts from the top of the tree of life down to the species. |
|
""" |
|
img = preprocess_img(img).to(device) |
|
img_features = model.encode_image(img.unsqueeze(0)) |
|
img_features = F.normalize(img_features, dim=-1) |
|
|
|
breakpoint() |
|
|
|
|
|
def run(img, cls_str: str) -> dict[str, float]: |
|
breakpoint() |
|
if cls_str: |
|
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()] |
|
return predict(img, classes) |
|
else: |
|
return hierarchical_predict(img) |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Starting.") |
|
model = create_model(model_str, output_dict=True, require_pretrained=True) |
|
model = model.to(device) |
|
print("Created model.") |
|
|
|
model = torch.compile(model) |
|
print("Compiled model.") |
|
|
|
tokenizer = get_tokenizer(tokenizer_str) |
|
|
|
demo = gr.Interface( |
|
fn=run, |
|
inputs=[ |
|
gr.Image(shape=(224, 224)), |
|
gr.Textbox( |
|
placeholder="dog\ncat\n...", |
|
lines=3, |
|
label="Classes", |
|
show_label=True, |
|
info="If empty, will predict from the entire tree of life.", |
|
), |
|
], |
|
outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True), |
|
allow_flagging="manual", |
|
flagging_options=["Incorrect", "Other"], |
|
flagging_callback=hf_writer, |
|
) |
|
|
|
demo.launch() |
|
|