File size: 3,173 Bytes
290c238
 
d1c1a86
 
 
5cfebb1
d1c1a86
 
8fa75cc
d1c1a86
290c238
 
 
216fbaf
 
 
290c238
d1c1a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290c238
d1c1a86
 
 
 
 
 
 
 
 
290c238
 
d1c1a86
 
290c238
d1c1a86
 
290c238
5cfebb1
290c238
d1c1a86
 
 
290c238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1c1a86
 
216fbaf
290c238
d1c1a86
 
 
 
 
216fbaf
d1c1a86
 
290c238
d1c1a86
 
 
290c238
 
 
 
 
d1c1a86
 
 
290c238
 
 
d1c1a86
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os

import gradio as gr
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from torchvision import transforms

from templates import openai_imagenet_template

hf_token = os.getenv("HF_TOKEN")
hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo")

model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

preprocess_img = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)


@torch.no_grad()
def get_txt_features(classnames, templates):
    all_features = []
    for classname in classnames:
        txts = [template(classname) for template in templates]
        txts = tokenizer(txts).to(device)
        txt_features = model.encode_text(txts)
        txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
        txt_features /= txt_features.norm()
        all_features.append(txt_features)
    all_features = torch.stack(all_features, dim=1)
    return all_features


@torch.no_grad()
def predict(img, classes: list[str]) -> dict[str, float]:
    classes = [cls.strip() for cls in classes if cls.strip()]
    txt_features = get_txt_features(classes, openai_imagenet_template)

    img = preprocess_img(img).to(device)
    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)

    logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
    probs = F.softmax(logits, dim=0).to("cpu").tolist()
    return {cls: prob for cls, prob in zip(classes, probs)}


def hierarchical_predict(img) -> list[str]:
    """
    Predicts from the top of the tree of life down to the species.
    """
    img = preprocess_img(img).to(device)
    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)

    breakpoint()


def run(img, cls_str: str) -> dict[str, float]:
    breakpoint()
    if cls_str:
        classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
        return predict(img, classes)
    else:
        return hierarchical_predict(img)


if __name__ == "__main__":
    print("Starting.")
    model = create_model(model_str, output_dict=True, require_pretrained=True)
    model = model.to(device)
    print("Created model.")

    model = torch.compile(model)
    print("Compiled model.")

    tokenizer = get_tokenizer(tokenizer_str)

    demo = gr.Interface(
        fn=run,
        inputs=[
            gr.Image(shape=(224, 224)),
            gr.Textbox(
                placeholder="dog\ncat\n...",
                lines=3,
                label="Classes",
                show_label=True,
                info="If empty, will predict from the entire tree of life.",
            ),
        ],
        outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
        allow_flagging="manual",
        flagging_options=["Incorrect", "Other"],
        flagging_callback=hf_writer,
    )

    demo.launch()