Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	
		Samuel Stevens
		
	commited on
		
		
					Commit 
							
							·
						
						2cfb891
	
1
								Parent(s):
							
							290c238
								
v0.1
Browse files- README.md +2 -2
 - app.py +169 -31
 - lib.py +11 -7
 - make_txt_embedding.py +46 -16
 - txt_emb.npy +3 -0
 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,11 +1,11 @@ 
     | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
             
            title: Bioclip Demo
         
     | 
| 3 | 
         
            -
            emoji:  
     | 
| 4 | 
         
             
            colorFrom: indigo
         
     | 
| 5 | 
         
             
            colorTo: purple
         
     | 
| 6 | 
         
             
            sdk: gradio
         
     | 
| 7 | 
         
             
            sdk_version: 4.7.1
         
     | 
| 8 | 
         
             
            app_file: app.py
         
     | 
| 9 | 
         
            -
            pinned:  
     | 
| 10 | 
         
             
            license: mit
         
     | 
| 11 | 
         
             
            ---
         
     | 
| 
         | 
|
| 1 | 
         
             
            ---
         
     | 
| 2 | 
         
             
            title: Bioclip Demo
         
     | 
| 3 | 
         
            +
            emoji: 🐘
         
     | 
| 4 | 
         
             
            colorFrom: indigo
         
     | 
| 5 | 
         
             
            colorTo: purple
         
     | 
| 6 | 
         
             
            sdk: gradio
         
     | 
| 7 | 
         
             
            sdk_version: 4.7.1
         
     | 
| 8 | 
         
             
            app_file: app.py
         
     | 
| 9 | 
         
            +
            pinned: true
         
     | 
| 10 | 
         
             
            license: mit
         
     | 
| 11 | 
         
             
            ---
         
     | 
    	
        app.py
    CHANGED
    
    | 
         @@ -1,24 +1,29 @@ 
     | 
|
| 
         | 
|
| 1 | 
         
             
            import os
         
     | 
| 2 | 
         | 
| 3 | 
         
             
            import gradio as gr
         
     | 
| 
         | 
|
| 4 | 
         
             
            import torch
         
     | 
| 5 | 
         
             
            import torch.nn.functional as F
         
     | 
| 6 | 
         
             
            from open_clip import create_model, get_tokenizer
         
     | 
| 7 | 
         
             
            from torchvision import transforms
         
     | 
| 8 | 
         | 
| 
         | 
|
| 9 | 
         
             
            from templates import openai_imagenet_template
         
     | 
| 10 | 
         | 
| 11 | 
         
             
            hf_token = os.getenv("HF_TOKEN")
         
     | 
| 12 | 
         
            -
            hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo")
         
     | 
| 13 | 
         | 
| 14 | 
         
             
            model_str = "hf-hub:imageomics/bioclip"
         
     | 
| 15 | 
         
             
            tokenizer_str = "ViT-B-16"
         
     | 
| 
         | 
|
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            preprocess_img = transforms.Compose(
         
     | 
| 20 | 
         
             
                [
         
     | 
| 21 | 
         
             
                    transforms.ToTensor(),
         
     | 
| 
         | 
|
| 22 | 
         
             
                    transforms.Normalize(
         
     | 
| 23 | 
         
             
                        mean=(0.48145466, 0.4578275, 0.40821073),
         
     | 
| 24 | 
         
             
                        std=(0.26862954, 0.26130258, 0.27577711),
         
     | 
| 
         @@ -26,6 +31,28 @@ preprocess_img = transforms.Compose( 
     | 
|
| 26 | 
         
             
                ]
         
     | 
| 27 | 
         
             
            )
         
     | 
| 28 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 29 | 
         | 
| 30 | 
         
             
            @torch.no_grad()
         
     | 
| 31 | 
         
             
            def get_txt_features(classnames, templates):
         
     | 
| 
         @@ -42,8 +69,8 @@ def get_txt_features(classnames, templates): 
     | 
|
| 42 | 
         | 
| 43 | 
         | 
| 44 | 
         
             
            @torch.no_grad()
         
     | 
| 45 | 
         
            -
            def  
     | 
| 46 | 
         
            -
                classes = [cls.strip() for cls in  
     | 
| 47 | 
         
             
                txt_features = get_txt_features(classes, openai_imagenet_template)
         
     | 
| 48 | 
         | 
| 49 | 
         
             
                img = preprocess_img(img).to(device)
         
     | 
| 
         @@ -55,7 +82,8 @@ def predict(img, classes: list[str]) -> dict[str, float]: 
     | 
|
| 55 | 
         
             
                return {cls: prob for cls, prob in zip(classes, probs)}
         
     | 
| 56 | 
         | 
| 57 | 
         | 
| 58 | 
         
            -
             
     | 
| 
         | 
|
| 59 | 
         
             
                """
         
     | 
| 60 | 
         
             
                Predicts from the top of the tree of life down to the species.
         
     | 
| 61 | 
         
             
                """
         
     | 
| 
         @@ -63,16 +91,44 @@ def hierarchical_predict(img) -> list[str]: 
     | 
|
| 63 | 
         
             
                img_features = model.encode_image(img.unsqueeze(0))
         
     | 
| 64 | 
         
             
                img_features = F.normalize(img_features, dim=-1)
         
     | 
| 65 | 
         | 
| 66 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 67 | 
         | 
| 68 | 
         | 
| 69 | 
         
            -
            def  
     | 
| 70 | 
         
            -
                 
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
            -
                    classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
         
     | 
| 73 | 
         
            -
                    return predict(img, classes)
         
     | 
| 74 | 
         
            -
                else:
         
     | 
| 75 | 
         
            -
                    return hierarchical_predict(img)
         
     | 
| 76 | 
         | 
| 77 | 
         | 
| 78 | 
         
             
            if __name__ == "__main__":
         
     | 
| 
         @@ -86,22 +142,104 @@ if __name__ == "__main__": 
     | 
|
| 86 | 
         | 
| 87 | 
         
             
                tokenizer = get_tokenizer(tokenizer_str)
         
     | 
| 88 | 
         | 
| 89 | 
         
            -
                 
     | 
| 90 | 
         
            -
             
     | 
| 91 | 
         
            -
             
     | 
| 92 | 
         
            -
             
     | 
| 93 | 
         
            -
             
     | 
| 94 | 
         
            -
             
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
             
     | 
| 97 | 
         
            -
             
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
             
     | 
| 100 | 
         
            -
             
     | 
| 101 | 
         
            -
                     
     | 
| 102 | 
         
            -
             
     | 
| 103 | 
         
            -
             
     | 
| 104 | 
         
            -
             
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
             
            import os
         
     | 
| 3 | 
         | 
| 4 | 
         
             
            import gradio as gr
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
             
            import torch
         
     | 
| 7 | 
         
             
            import torch.nn.functional as F
         
     | 
| 8 | 
         
             
            from open_clip import create_model, get_tokenizer
         
     | 
| 9 | 
         
             
            from torchvision import transforms
         
     | 
| 10 | 
         | 
| 11 | 
         
            +
            import lib
         
     | 
| 12 | 
         
             
            from templates import openai_imagenet_template
         
     | 
| 13 | 
         | 
| 14 | 
         
             
            hf_token = os.getenv("HF_TOKEN")
         
     | 
| 
         | 
|
| 15 | 
         | 
| 16 | 
         
             
            model_str = "hf-hub:imageomics/bioclip"
         
     | 
| 17 | 
         
             
            tokenizer_str = "ViT-B-16"
         
     | 
| 18 | 
         
            +
            name_lookup_json = "name_lookup.json"
         
     | 
| 19 | 
         
            +
            txt_emb_npy = "txt_emb.npy"
         
     | 
| 20 | 
         | 
| 21 | 
         
             
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         
     | 
| 22 | 
         | 
| 23 | 
         
             
            preprocess_img = transforms.Compose(
         
     | 
| 24 | 
         
             
                [
         
     | 
| 25 | 
         
             
                    transforms.ToTensor(),
         
     | 
| 26 | 
         
            +
                    transforms.Resize((224, 224), antialias=True),
         
     | 
| 27 | 
         
             
                    transforms.Normalize(
         
     | 
| 28 | 
         
             
                        mean=(0.48145466, 0.4578275, 0.40821073),
         
     | 
| 29 | 
         
             
                        std=(0.26862954, 0.26130258, 0.27577711),
         
     | 
| 
         | 
|
| 31 | 
         
             
                ]
         
     | 
| 32 | 
         
             
            )
         
     | 
| 33 | 
         | 
| 34 | 
         
            +
            ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            open_domain_examples = [
         
     | 
| 37 | 
         
            +
                ["examples/Ursus-arctos.jpeg", "Species"],
         
     | 
| 38 | 
         
            +
                ["examples/Phoca-vitulina.png", "Species"],
         
     | 
| 39 | 
         
            +
                ["examples/Felis-catus.jpeg", "Genus"],
         
     | 
| 40 | 
         
            +
            ]
         
     | 
| 41 | 
         
            +
            zero_shot_examples = [
         
     | 
| 42 | 
         
            +
                [
         
     | 
| 43 | 
         
            +
                    "examples/Carnegiea-gigantea.png",
         
     | 
| 44 | 
         
            +
                    "Carnegiea gigantea\nSchlumbergera opuntioides\nMammillaria albicoma",
         
     | 
| 45 | 
         
            +
                ],
         
     | 
| 46 | 
         
            +
                [
         
     | 
| 47 | 
         
            +
                    "examples/Amanita-muscaria.jpeg",
         
     | 
| 48 | 
         
            +
                    "Amanita fulva\nAmanita vaginata (grisette)\nAmanita calyptrata (coccoli)\nAmanita crocea\nAmanita rubescens (blusher)\nAmanita caesarea (Caesar's mushroom)\nAmanita jacksonii (American Caesar's mushroom)\nAmanita muscaria (fly agaric)\nAmanita pantherina (panther cap)",
         
     | 
| 49 | 
         
            +
                ],
         
     | 
| 50 | 
         
            +
                [
         
     | 
| 51 | 
         
            +
                    "examples/Actinostola-abyssorum.png",
         
     | 
| 52 | 
         
            +
                    "Animalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola abyssorum\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola bulbosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola callosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola capensis\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola carlgreni",
         
     | 
| 53 | 
         
            +
                ],
         
     | 
| 54 | 
         
            +
            ]
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         | 
| 57 | 
         
             
            @torch.no_grad()
         
     | 
| 58 | 
         
             
            def get_txt_features(classnames, templates):
         
     | 
| 
         | 
|
| 69 | 
         | 
| 70 | 
         | 
| 71 | 
         
             
            @torch.no_grad()
         
     | 
| 72 | 
         
            +
            def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
         
     | 
| 73 | 
         
            +
                classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
         
     | 
| 74 | 
         
             
                txt_features = get_txt_features(classes, openai_imagenet_template)
         
     | 
| 75 | 
         | 
| 76 | 
         
             
                img = preprocess_img(img).to(device)
         
     | 
| 
         | 
|
| 82 | 
         
             
                return {cls: prob for cls, prob in zip(classes, probs)}
         
     | 
| 83 | 
         | 
| 84 | 
         | 
| 85 | 
         
            +
            @torch.no_grad()
         
     | 
| 86 | 
         
            +
            def open_domain_classification(img, rank: int) -> list[dict[str, float]]:
         
     | 
| 87 | 
         
             
                """
         
     | 
| 88 | 
         
             
                Predicts from the top of the tree of life down to the species.
         
     | 
| 89 | 
         
             
                """
         
     | 
| 
         | 
|
| 91 | 
         
             
                img_features = model.encode_image(img.unsqueeze(0))
         
     | 
| 92 | 
         
             
                img_features = F.normalize(img_features, dim=-1)
         
     | 
| 93 | 
         | 
| 94 | 
         
            +
                outputs = []
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                name = []
         
     | 
| 97 | 
         
            +
                for _ in range(rank + 1):
         
     | 
| 98 | 
         
            +
                    children = tuple(zip(*name_lookup.children(name)))
         
     | 
| 99 | 
         
            +
                    if not children:
         
     | 
| 100 | 
         
            +
                        break
         
     | 
| 101 | 
         
            +
                    values, indices = children
         
     | 
| 102 | 
         
            +
                    txt_features = txt_emb[:, indices].to(device)
         
     | 
| 103 | 
         
            +
                    logits = (model.logit_scale.exp() * img_features @ txt_features).view(-1)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    probs = F.softmax(logits, dim=0).to("cpu").tolist()
         
     | 
| 106 | 
         
            +
                    parent = " ".join(name)
         
     | 
| 107 | 
         
            +
                    outputs.append(
         
     | 
| 108 | 
         
            +
                        {f"{parent} {value}": prob for value, prob in zip(values, probs)}
         
     | 
| 109 | 
         
            +
                    )
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    top = values[logits.argmax()]
         
     | 
| 112 | 
         
            +
                    name.append(top)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                while len(outputs) < 7:
         
     | 
| 115 | 
         
            +
                    outputs.append({})
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                return list(reversed(outputs))
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            def change_output(choice):
         
     | 
| 121 | 
         
            +
                return [
         
     | 
| 122 | 
         
            +
                    gr.Label(
         
     | 
| 123 | 
         
            +
                        num_top_classes=5, label=rank, show_label=True, visible=(6 - i <= choice)
         
     | 
| 124 | 
         
            +
                    )
         
     | 
| 125 | 
         
            +
                    for i, rank in enumerate(reversed(ranks))
         
     | 
| 126 | 
         
            +
                ]
         
     | 
| 127 | 
         | 
| 128 | 
         | 
| 129 | 
         
            +
            def get_name_lookup(path):
         
     | 
| 130 | 
         
            +
                with open(path) as fd:
         
     | 
| 131 | 
         
            +
                    return lib.TaxonomicTree.from_dict(json.load(fd))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 132 | 
         | 
| 133 | 
         | 
| 134 | 
         
             
            if __name__ == "__main__":
         
     | 
| 
         | 
|
| 142 | 
         | 
| 143 | 
         
             
                tokenizer = get_tokenizer(tokenizer_str)
         
     | 
| 144 | 
         | 
| 145 | 
         
            +
                name_lookup = get_name_lookup(name_lookup_json)
         
     | 
| 146 | 
         
            +
                txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r"))
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                done = txt_emb.any(axis=0).sum().item()
         
     | 
| 149 | 
         
            +
                total = txt_emb.shape[1]
         
     | 
| 150 | 
         
            +
                status_msg = ""
         
     | 
| 151 | 
         
            +
                if done != total:
         
     | 
| 152 | 
         
            +
                    status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                with gr.Blocks() as app:
         
     | 
| 155 | 
         
            +
                    img_input = gr.Image()
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    with gr.Tab("Open-Ended"):
         
     | 
| 158 | 
         
            +
                        with gr.Row():
         
     | 
| 159 | 
         
            +
                            with gr.Column():
         
     | 
| 160 | 
         
            +
                                rank_dropdown = gr.Dropdown(
         
     | 
| 161 | 
         
            +
                                    label="Taxonomic Rank",
         
     | 
| 162 | 
         
            +
                                    info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
         
     | 
| 163 | 
         
            +
                                    choices=ranks,
         
     | 
| 164 | 
         
            +
                                    value="Species",
         
     | 
| 165 | 
         
            +
                                    type="index",
         
     | 
| 166 | 
         
            +
                                )
         
     | 
| 167 | 
         
            +
                                open_domain_btn = gr.Button("Submit", variant="primary")
         
     | 
| 168 | 
         
            +
                                gr.Examples(
         
     | 
| 169 | 
         
            +
                                    examples=open_domain_examples,
         
     | 
| 170 | 
         
            +
                                    inputs=[img_input, rank_dropdown],
         
     | 
| 171 | 
         
            +
                                )
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                            with gr.Column():
         
     | 
| 174 | 
         
            +
                                open_domain_outputs = [
         
     | 
| 175 | 
         
            +
                                    gr.Label(num_top_classes=5, label=rank, show_label=True)
         
     | 
| 176 | 
         
            +
                                    for rank in reversed(ranks)
         
     | 
| 177 | 
         
            +
                                ]
         
     | 
| 178 | 
         
            +
                                open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                        open_domain_callback = gr.HuggingFaceDatasetSaver(
         
     | 
| 181 | 
         
            +
                            hf_token, "imageomics/bioclip-demo-open-domain-mistakes", private=True
         
     | 
| 182 | 
         
            +
                        )
         
     | 
| 183 | 
         
            +
                        open_domain_callback.setup(
         
     | 
| 184 | 
         
            +
                            [img_input, *open_domain_outputs], flagging_dir="logs/flagged"
         
     | 
| 185 | 
         
            +
                        )
         
     | 
| 186 | 
         
            +
                        open_domain_flag_btn.click(
         
     | 
| 187 | 
         
            +
                            lambda *args: open_domain_callback.flag(args),
         
     | 
| 188 | 
         
            +
                            [img_input, *open_domain_outputs],
         
     | 
| 189 | 
         
            +
                            None,
         
     | 
| 190 | 
         
            +
                            preprocess=False,
         
     | 
| 191 | 
         
            +
                        )
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    with gr.Tab("Zero-Shot"):
         
     | 
| 194 | 
         
            +
                        with gr.Row():
         
     | 
| 195 | 
         
            +
                            with gr.Column():
         
     | 
| 196 | 
         
            +
                                classes_txt = gr.Textbox(
         
     | 
| 197 | 
         
            +
                                    placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
         
     | 
| 198 | 
         
            +
                                    lines=3,
         
     | 
| 199 | 
         
            +
                                    label="Classes",
         
     | 
| 200 | 
         
            +
                                    show_label=True,
         
     | 
| 201 | 
         
            +
                                    info="Use taxonomic names where possible; include common names if possible.",
         
     | 
| 202 | 
         
            +
                                )
         
     | 
| 203 | 
         
            +
                                zero_shot_btn = gr.Button("Submit", variant="primary")
         
     | 
| 204 | 
         
            +
                                gr.Examples(
         
     | 
| 205 | 
         
            +
                                    examples=zero_shot_examples,
         
     | 
| 206 | 
         
            +
                                    inputs=[img_input, classes_txt],
         
     | 
| 207 | 
         
            +
                                )
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                            with gr.Column():
         
     | 
| 210 | 
         
            +
                                zero_shot_output = gr.Label(
         
     | 
| 211 | 
         
            +
                                    num_top_classes=5, label="Prediction", show_label=True
         
     | 
| 212 | 
         
            +
                                )
         
     | 
| 213 | 
         
            +
                                zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                        zero_shot_callback = gr.HuggingFaceDatasetSaver(
         
     | 
| 216 | 
         
            +
                            hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
         
     | 
| 217 | 
         
            +
                        )
         
     | 
| 218 | 
         
            +
                        zero_shot_callback.setup(
         
     | 
| 219 | 
         
            +
                            [img_input, zero_shot_output], flagging_dir="logs/flagged"
         
     | 
| 220 | 
         
            +
                        )
         
     | 
| 221 | 
         
            +
                        zero_shot_flag_btn.click(
         
     | 
| 222 | 
         
            +
                            lambda *args: zero_shot_callback.flag(args),
         
     | 
| 223 | 
         
            +
                            [img_input, zero_shot_output],
         
     | 
| 224 | 
         
            +
                            None,
         
     | 
| 225 | 
         
            +
                            preprocess=False,
         
     | 
| 226 | 
         
            +
                        )
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    rank_dropdown.change(
         
     | 
| 229 | 
         
            +
                        fn=change_output, inputs=rank_dropdown, outputs=open_domain_outputs
         
     | 
| 230 | 
         
            +
                    )
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                    open_domain_btn.click(
         
     | 
| 233 | 
         
            +
                        fn=open_domain_classification,
         
     | 
| 234 | 
         
            +
                        inputs=[img_input, rank_dropdown],
         
     | 
| 235 | 
         
            +
                        outputs=open_domain_outputs,
         
     | 
| 236 | 
         
            +
                    )
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    zero_shot_btn.click(
         
     | 
| 239 | 
         
            +
                        fn=zero_shot_classification,
         
     | 
| 240 | 
         
            +
                        inputs=[img_input, classes_txt],
         
     | 
| 241 | 
         
            +
                        outputs=zero_shot_output,
         
     | 
| 242 | 
         
            +
                    )
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                app.queue(max_size=20)
         
     | 
| 245 | 
         
            +
                app.launch()
         
     | 
    	
        lib.py
    CHANGED
    
    | 
         @@ -1,5 +1,5 @@ 
     | 
|
| 1 | 
         
            -
            import json
         
     | 
| 2 | 
         
             
            import itertools
         
     | 
| 
         | 
|
| 3 | 
         | 
| 4 | 
         | 
| 5 | 
         
             
            class TaxonomicNode:
         
     | 
| 
         @@ -43,11 +43,12 @@ class TaxonomicNode: 
     | 
|
| 43 | 
         
             
                @classmethod
         
     | 
| 44 | 
         
             
                def from_dict(cls, dct, root):
         
     | 
| 45 | 
         
             
                    node = cls(dct["name"], dct["index"], root)
         
     | 
| 46 | 
         
            -
                    node._children = { 
     | 
| 
         | 
|
| 
         | 
|
| 47 | 
         
             
                    return node
         
     | 
| 48 | 
         | 
| 49 | 
         | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
             
            class TaxonomicTree:
         
     | 
| 52 | 
         
             
                """
         
     | 
| 53 | 
         
             
                Efficient structure for finding taxonomic names and their descendants.
         
     | 
| 
         @@ -85,11 +86,15 @@ class TaxonomicTree: 
     | 
|
| 85 | 
         
             
                    for kingdom in self.kingdoms.values():
         
     | 
| 86 | 
         
             
                        yield from kingdom
         
     | 
| 87 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 88 | 
         
             
                @classmethod
         
     | 
| 89 | 
         
             
                def from_dict(cls, dct):
         
     | 
| 90 | 
         
             
                    tree = cls()
         
     | 
| 91 | 
         
             
                    tree.kingdoms = {
         
     | 
| 92 | 
         
            -
                        kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree) 
     | 
| 
         | 
|
| 93 | 
         
             
                    }
         
     | 
| 94 | 
         
             
                    tree.size = dct["size"]
         
     | 
| 95 | 
         
             
                    return tree
         
     | 
| 
         @@ -112,11 +117,10 @@ class TaxonomicJsonEncoder(json.JSONEncoder): 
     | 
|
| 112 | 
         
             
                        super().default(self, obj)
         
     | 
| 113 | 
         | 
| 114 | 
         | 
| 115 | 
         
            -
             
     | 
| 116 | 
         
             
            def batched(iterable, n):
         
     | 
| 117 | 
         
             
                # batched('ABCDEFG', 3) --> ABC DEF G
         
     | 
| 118 | 
         
             
                if n < 1:
         
     | 
| 119 | 
         
            -
                    raise ValueError( 
     | 
| 120 | 
         
             
                it = iter(iterable)
         
     | 
| 121 | 
         
             
                while batch := tuple(itertools.islice(it, n)):
         
     | 
| 122 | 
         
            -
                    yield zip(*batch)
         
     | 
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import itertools
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         | 
| 4 | 
         | 
| 5 | 
         
             
            class TaxonomicNode:
         
     | 
| 
         | 
|
| 43 | 
         
             
                @classmethod
         
     | 
| 44 | 
         
             
                def from_dict(cls, dct, root):
         
     | 
| 45 | 
         
             
                    node = cls(dct["name"], dct["index"], root)
         
     | 
| 46 | 
         
            +
                    node._children = {
         
     | 
| 47 | 
         
            +
                        child["name"]: cls.from_dict(child, root) for child in dct["children"]
         
     | 
| 48 | 
         
            +
                    }
         
     | 
| 49 | 
         
             
                    return node
         
     | 
| 50 | 
         | 
| 51 | 
         | 
| 
         | 
|
| 52 | 
         
             
            class TaxonomicTree:
         
     | 
| 53 | 
         
             
                """
         
     | 
| 54 | 
         
             
                Efficient structure for finding taxonomic names and their descendants.
         
     | 
| 
         | 
|
| 86 | 
         
             
                    for kingdom in self.kingdoms.values():
         
     | 
| 87 | 
         
             
                        yield from kingdom
         
     | 
| 88 | 
         | 
| 89 | 
         
            +
                def __len__(self):
         
     | 
| 90 | 
         
            +
                    return self.size
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
             
                @classmethod
         
     | 
| 93 | 
         
             
                def from_dict(cls, dct):
         
     | 
| 94 | 
         
             
                    tree = cls()
         
     | 
| 95 | 
         
             
                    tree.kingdoms = {
         
     | 
| 96 | 
         
            +
                        kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree)
         
     | 
| 97 | 
         
            +
                        for kingdom in dct["kingdoms"]
         
     | 
| 98 | 
         
             
                    }
         
     | 
| 99 | 
         
             
                    tree.size = dct["size"]
         
     | 
| 100 | 
         
             
                    return tree
         
     | 
| 
         | 
|
| 117 | 
         
             
                        super().default(self, obj)
         
     | 
| 118 | 
         | 
| 119 | 
         | 
| 
         | 
|
| 120 | 
         
             
            def batched(iterable, n):
         
     | 
| 121 | 
         
             
                # batched('ABCDEFG', 3) --> ABC DEF G
         
     | 
| 122 | 
         
             
                if n < 1:
         
     | 
| 123 | 
         
            +
                    raise ValueError("n must be at least one")
         
     | 
| 124 | 
         
             
                it = iter(iterable)
         
     | 
| 125 | 
         
             
                while batch := tuple(itertools.islice(it, n)):
         
     | 
| 126 | 
         
            +
                    yield zip(*batch)
         
     | 
    	
        make_txt_embedding.py
    CHANGED
    
    | 
         @@ -5,6 +5,7 @@ Uses the catalog.csv file from TreeOfLife-10M. 
     | 
|
| 5 | 
         
             
            import argparse
         
     | 
| 6 | 
         
             
            import csv
         
     | 
| 7 | 
         
             
            import json
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         
             
            import numpy as np
         
     | 
| 10 | 
         
             
            import torch
         
     | 
| 
         @@ -22,29 +23,53 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
     | 
|
| 22 | 
         | 
| 23 | 
         
             
            @torch.no_grad()
         
     | 
| 24 | 
         
             
            def write_txt_features(name_lookup):
         
     | 
| 25 | 
         
            -
                 
     | 
| 26 | 
         
            -
                     
     | 
| 27 | 
         
            -
                 
     | 
| 
         | 
|
| 28 | 
         | 
| 29 | 
         
             
                batch_size = args.batch_size // len(openai_imagenet_template)
         
     | 
| 30 | 
         
            -
                for names, indices in  
     | 
| 31 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 32 | 
         
             
                    txts = tokenizer(txts).to(device)
         
     | 
| 33 | 
         
             
                    txt_features = model.encode_text(txts)
         
     | 
| 34 | 
         
            -
                    txt_features = torch.reshape( 
     | 
| 
         | 
|
| 
         | 
|
| 35 | 
         
             
                    txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
         
     | 
| 36 | 
         
             
                    txt_features /= txt_features.norm(dim=1, keepdim=True)
         
     | 
| 37 | 
         
            -
                    all_features[:, indices] = txt_features.cpu().numpy() 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 38 | 
         | 
| 39 | 
         
            -
                 
     | 
| 40 | 
         | 
| 41 | 
         | 
| 42 | 
         
            -
            def get_name_lookup(catalog_path):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 43 | 
         
             
                lookup = lib.TaxonomicTree()
         
     | 
| 44 | 
         | 
| 45 | 
         
             
                with open(catalog_path) as fd:
         
     | 
| 46 | 
         
             
                    reader = csv.DictReader(fd)
         
     | 
| 47 | 
         
            -
                    for row in tqdm(reader):
         
     | 
| 48 | 
         
             
                        name = [
         
     | 
| 49 | 
         
             
                            row["kingdom"],
         
     | 
| 50 | 
         
             
                            row["phylum"],
         
     | 
| 
         @@ -58,6 +83,9 @@ def get_name_lookup(catalog_path): 
     | 
|
| 58 | 
         
             
                            name = name[: name.index("")]
         
     | 
| 59 | 
         
             
                        lookup.add(name)
         
     | 
| 60 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 61 | 
         
             
                return lookup
         
     | 
| 62 | 
         | 
| 63 | 
         | 
| 
         @@ -69,15 +97,17 @@ if __name__ == "__main__": 
     | 
|
| 69 | 
         
             
                    required=True,
         
     | 
| 70 | 
         
             
                )
         
     | 
| 71 | 
         
             
                parser.add_argument("--out-path", help="Path to the output file.", required=True)
         
     | 
| 72 | 
         
            -
                parser.add_argument( 
     | 
| 73 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 74 | 
         
             
                args = parser.parse_args()
         
     | 
| 75 | 
         | 
| 76 | 
         
            -
                name_lookup = get_name_lookup(args.catalog_path)
         
     | 
| 77 | 
         
            -
                 
     | 
| 78 | 
         
            -
                    json.dump(name_lookup, fd, cls=lib.TaxonomicJsonEncoder)
         
     | 
| 79 | 
         | 
| 80 | 
         
            -
                print("Starting.")
         
     | 
| 81 | 
         
             
                model = create_model(model_str, output_dict=True, require_pretrained=True)
         
     | 
| 82 | 
         
             
                model = model.to(device)
         
     | 
| 83 | 
         
             
                print("Created model.")
         
     | 
| 
         | 
|
| 5 | 
         
             
            import argparse
         
     | 
| 6 | 
         
             
            import csv
         
     | 
| 7 | 
         
             
            import json
         
     | 
| 8 | 
         
            +
            import os
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            import numpy as np
         
     | 
| 11 | 
         
             
            import torch
         
     | 
| 
         | 
|
| 23 | 
         | 
| 24 | 
         
             
            @torch.no_grad()
         
     | 
| 25 | 
         
             
            def write_txt_features(name_lookup):
         
     | 
| 26 | 
         
            +
                if os.path.isfile(args.out_path):
         
     | 
| 27 | 
         
            +
                    all_features = np.load(args.out_path)
         
     | 
| 28 | 
         
            +
                else:
         
     | 
| 29 | 
         
            +
                    all_features = np.zeros((512, len(name_lookup)), dtype=np.float32)
         
     | 
| 30 | 
         | 
| 31 | 
         
             
                batch_size = args.batch_size // len(openai_imagenet_template)
         
     | 
| 32 | 
         
            +
                for batch, (names, indices) in enumerate(
         
     | 
| 33 | 
         
            +
                    tqdm(
         
     | 
| 34 | 
         
            +
                        lib.batched(name_lookup, batch_size),
         
     | 
| 35 | 
         
            +
                        desc="txt feats",
         
     | 
| 36 | 
         
            +
                        total=len(name_lookup) // batch_size,
         
     | 
| 37 | 
         
            +
                    )
         
     | 
| 38 | 
         
            +
                ):
         
     | 
| 39 | 
         
            +
                    # Skip if any non-zero elements
         
     | 
| 40 | 
         
            +
                    if all_features[:, indices].any():
         
     | 
| 41 | 
         
            +
                        print(f"Skipping batch {batch}")
         
     | 
| 42 | 
         
            +
                        continue
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    txts = [
         
     | 
| 45 | 
         
            +
                        template(name) for name in names for template in openai_imagenet_template
         
     | 
| 46 | 
         
            +
                    ]
         
     | 
| 47 | 
         
             
                    txts = tokenizer(txts).to(device)
         
     | 
| 48 | 
         
             
                    txt_features = model.encode_text(txts)
         
     | 
| 49 | 
         
            +
                    txt_features = torch.reshape(
         
     | 
| 50 | 
         
            +
                        txt_features, (len(names), len(openai_imagenet_template), 512)
         
     | 
| 51 | 
         
            +
                    )
         
     | 
| 52 | 
         
             
                    txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
         
     | 
| 53 | 
         
             
                    txt_features /= txt_features.norm(dim=1, keepdim=True)
         
     | 
| 54 | 
         
            +
                    all_features[:, indices] = txt_features.T.cpu().numpy()
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    if batch % 100 == 0:
         
     | 
| 57 | 
         
            +
                        np.save(args.out_path, all_features)
         
     | 
| 58 | 
         | 
| 59 | 
         
            +
                np.save(args.out_path, all_features)
         
     | 
| 60 | 
         | 
| 61 | 
         | 
| 62 | 
         
            +
            def get_name_lookup(catalog_path, cache_path):
         
     | 
| 63 | 
         
            +
                if os.path.isfile(cache_path):
         
     | 
| 64 | 
         
            +
                    with open(cache_path) as fd:
         
     | 
| 65 | 
         
            +
                        lookup = lib.TaxonomicTree.from_dict(json.load(fd))
         
     | 
| 66 | 
         
            +
                    return lookup
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
             
                lookup = lib.TaxonomicTree()
         
     | 
| 69 | 
         | 
| 70 | 
         
             
                with open(catalog_path) as fd:
         
     | 
| 71 | 
         
             
                    reader = csv.DictReader(fd)
         
     | 
| 72 | 
         
            +
                    for row in tqdm(reader, desc="catalog"):
         
     | 
| 73 | 
         
             
                        name = [
         
     | 
| 74 | 
         
             
                            row["kingdom"],
         
     | 
| 75 | 
         
             
                            row["phylum"],
         
     | 
| 
         | 
|
| 83 | 
         
             
                            name = name[: name.index("")]
         
     | 
| 84 | 
         
             
                        lookup.add(name)
         
     | 
| 85 | 
         | 
| 86 | 
         
            +
                with open(args.name_cache_path, "w") as fd:
         
     | 
| 87 | 
         
            +
                    json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
             
                return lookup
         
     | 
| 90 | 
         | 
| 91 | 
         | 
| 
         | 
|
| 97 | 
         
             
                    required=True,
         
     | 
| 98 | 
         
             
                )
         
     | 
| 99 | 
         
             
                parser.add_argument("--out-path", help="Path to the output file.", required=True)
         
     | 
| 100 | 
         
            +
                parser.add_argument(
         
     | 
| 101 | 
         
            +
                    "--name-cache-path",
         
     | 
| 102 | 
         
            +
                    help="Path to the name cache file.",
         
     | 
| 103 | 
         
            +
                    default="name_lookup.json",
         
     | 
| 104 | 
         
            +
                )
         
     | 
| 105 | 
         
            +
                parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int)
         
     | 
| 106 | 
         
             
                args = parser.parse_args()
         
     | 
| 107 | 
         | 
| 108 | 
         
            +
                name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
         
     | 
| 109 | 
         
            +
                print("Got name lookup.")
         
     | 
| 
         | 
|
| 110 | 
         | 
| 
         | 
|
| 111 | 
         
             
                model = create_model(model_str, output_dict=True, require_pretrained=True)
         
     | 
| 112 | 
         
             
                model = model.to(device)
         
     | 
| 113 | 
         
             
                print("Created model.")
         
     | 
    	
        txt_emb.npy
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:b4a3c3412c3dae49cf92cc760aba5ee84227362adf1eb08f04dd50ee2a756e43
         
     | 
| 3 | 
         
            +
            size 969818240
         
     |