File size: 3,173 Bytes
705d528
 
4cc1f09
 
 
20747d8
4cc1f09
 
6a994e0
4cc1f09
705d528
 
 
2e74323
 
 
705d528
4cc1f09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705d528
4cc1f09
 
 
 
 
 
 
 
 
705d528
 
4cc1f09
 
705d528
4cc1f09
 
705d528
20747d8
705d528
4cc1f09
 
 
705d528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cc1f09
 
2e74323
705d528
4cc1f09
 
 
 
 
2e74323
4cc1f09
 
705d528
4cc1f09
 
 
705d528
 
 
 
 
4cc1f09
 
 
705d528
 
 
4cc1f09
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os

import gradio as gr
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from torchvision import transforms

from templates import openai_imagenet_template

hf_token = os.getenv("HF_TOKEN")
hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo")

model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

preprocess_img = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)


@torch.no_grad()
def get_txt_features(classnames, templates):
    all_features = []
    for classname in classnames:
        txts = [template(classname) for template in templates]
        txts = tokenizer(txts).to(device)
        txt_features = model.encode_text(txts)
        txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
        txt_features /= txt_features.norm()
        all_features.append(txt_features)
    all_features = torch.stack(all_features, dim=1)
    return all_features


@torch.no_grad()
def predict(img, classes: list[str]) -> dict[str, float]:
    classes = [cls.strip() for cls in classes if cls.strip()]
    txt_features = get_txt_features(classes, openai_imagenet_template)

    img = preprocess_img(img).to(device)
    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)

    logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
    probs = F.softmax(logits, dim=0).to("cpu").tolist()
    return {cls: prob for cls, prob in zip(classes, probs)}


def hierarchical_predict(img) -> list[str]:
    """
    Predicts from the top of the tree of life down to the species.
    """
    img = preprocess_img(img).to(device)
    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)

    breakpoint()


def run(img, cls_str: str) -> dict[str, float]:
    breakpoint()
    if cls_str:
        classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
        return predict(img, classes)
    else:
        return hierarchical_predict(img)


if __name__ == "__main__":
    print("Starting.")
    model = create_model(model_str, output_dict=True, require_pretrained=True)
    model = model.to(device)
    print("Created model.")

    model = torch.compile(model)
    print("Compiled model.")

    tokenizer = get_tokenizer(tokenizer_str)

    demo = gr.Interface(
        fn=run,
        inputs=[
            gr.Image(shape=(224, 224)),
            gr.Textbox(
                placeholder="dog\ncat\n...",
                lines=3,
                label="Classes",
                show_label=True,
                info="If empty, will predict from the entire tree of life.",
            ),
        ],
        outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
        allow_flagging="manual",
        flagging_options=["Incorrect", "Other"],
        flagging_callback=hf_writer,
    )

    demo.launch()