import os import gradio as gr import torch import torch.nn.functional as F from open_clip import create_model, get_tokenizer from torchvision import transforms from templates import openai_imagenet_template hf_token = os.getenv("HF_TOKEN") hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo") model_str = "hf-hub:imageomics/bioclip" tokenizer_str = "ViT-B-16" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") preprocess_img = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) @torch.no_grad() def get_txt_features(classnames, templates): all_features = [] for classname in classnames: txts = [template(classname) for template in templates] txts = tokenizer(txts).to(device) txt_features = model.encode_text(txts) txt_features = F.normalize(txt_features, dim=-1).mean(dim=0) txt_features /= txt_features.norm() all_features.append(txt_features) all_features = torch.stack(all_features, dim=1) return all_features @torch.no_grad() def predict(img, classes: list[str]) -> dict[str, float]: classes = [cls.strip() for cls in classes if cls.strip()] txt_features = get_txt_features(classes, openai_imagenet_template) img = preprocess_img(img).to(device) img_features = model.encode_image(img.unsqueeze(0)) img_features = F.normalize(img_features, dim=-1) logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze() probs = F.softmax(logits, dim=0).to("cpu").tolist() return {cls: prob for cls, prob in zip(classes, probs)} def hierarchical_predict(img) -> list[str]: """ Predicts from the top of the tree of life down to the species. """ img = preprocess_img(img).to(device) img_features = model.encode_image(img.unsqueeze(0)) img_features = F.normalize(img_features, dim=-1) breakpoint() def run(img, cls_str: str) -> dict[str, float]: breakpoint() if cls_str: classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()] return predict(img, classes) else: return hierarchical_predict(img) if __name__ == "__main__": print("Starting.") model = create_model(model_str, output_dict=True, require_pretrained=True) model = model.to(device) print("Created model.") model = torch.compile(model) print("Compiled model.") tokenizer = get_tokenizer(tokenizer_str) demo = gr.Interface( fn=run, inputs=[ gr.Image(shape=(224, 224)), gr.Textbox( placeholder="dog\ncat\n...", lines=3, label="Classes", show_label=True, info="If empty, will predict from the entire tree of life.", ), ], outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True), allow_flagging="manual", flagging_options=["Incorrect", "Other"], flagging_callback=hf_writer, ) demo.launch()