Spaces:
Configuration error
Configuration error
File size: 3,173 Bytes
705d528 4cc1f09 20747d8 4cc1f09 6a994e0 4cc1f09 705d528 2e74323 705d528 4cc1f09 705d528 4cc1f09 705d528 4cc1f09 705d528 4cc1f09 705d528 20747d8 705d528 4cc1f09 705d528 4cc1f09 2e74323 705d528 4cc1f09 2e74323 4cc1f09 705d528 4cc1f09 705d528 4cc1f09 705d528 4cc1f09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import gradio as gr
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from torchvision import transforms
from templates import openai_imagenet_template
hf_token = os.getenv("HF_TOKEN")
hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo")
model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
preprocess_img = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
@torch.no_grad()
def get_txt_features(classnames, templates):
all_features = []
for classname in classnames:
txts = [template(classname) for template in templates]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
txt_features /= txt_features.norm()
all_features.append(txt_features)
all_features = torch.stack(all_features, dim=1)
return all_features
@torch.no_grad()
def predict(img, classes: list[str]) -> dict[str, float]:
classes = [cls.strip() for cls in classes if cls.strip()]
txt_features = get_txt_features(classes, openai_imagenet_template)
img = preprocess_img(img).to(device)
img_features = model.encode_image(img.unsqueeze(0))
img_features = F.normalize(img_features, dim=-1)
logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
probs = F.softmax(logits, dim=0).to("cpu").tolist()
return {cls: prob for cls, prob in zip(classes, probs)}
def hierarchical_predict(img) -> list[str]:
"""
Predicts from the top of the tree of life down to the species.
"""
img = preprocess_img(img).to(device)
img_features = model.encode_image(img.unsqueeze(0))
img_features = F.normalize(img_features, dim=-1)
breakpoint()
def run(img, cls_str: str) -> dict[str, float]:
breakpoint()
if cls_str:
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
return predict(img, classes)
else:
return hierarchical_predict(img)
if __name__ == "__main__":
print("Starting.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
print("Created model.")
model = torch.compile(model)
print("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
demo = gr.Interface(
fn=run,
inputs=[
gr.Image(shape=(224, 224)),
gr.Textbox(
placeholder="dog\ncat\n...",
lines=3,
label="Classes",
show_label=True,
info="If empty, will predict from the entire tree of life.",
),
],
outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
allow_flagging="manual",
flagging_options=["Incorrect", "Other"],
flagging_callback=hf_writer,
)
demo.launch()
|