Spaces:
Sleeping
Sleeping
jacopoteneggi
commited on
Update
Browse files- app_lib/config.py +159 -0
- app_lib/defaults.py +15 -10
- app_lib/multimodal.py +186 -0
- app_lib/test.py +28 -56
- app_lib/user_input.py +12 -12
- app_lib/utils.py +2 -10
- assets/results/bowl_ace.npy +1 -1
- assets/results/gardener_ace.npy +1 -1
- assets/results/gentleman_ace.npy +1 -1
- assets/results/mathematician_ace.npy +1 -1
- precompute_results.py +7 -7
app_lib/config.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/utils/config.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from enum import Enum
|
6 |
+
from itertools import product
|
7 |
+
from typing import Any, Iterable, Mapping, Optional, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from ml_collections import ConfigDict
|
11 |
+
from numpy import ndarray
|
12 |
+
|
13 |
+
Array = Union[ndarray, torch.Tensor]
|
14 |
+
|
15 |
+
|
16 |
+
class TestType(Enum):
|
17 |
+
GLOBAL = "global"
|
18 |
+
GLOBAL_COND = "global_cond"
|
19 |
+
LOCAL_COND = "local_cond"
|
20 |
+
|
21 |
+
|
22 |
+
class ConceptType(Enum):
|
23 |
+
DATASET = "dataset"
|
24 |
+
CLASS = "class"
|
25 |
+
IMAGE = "image"
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class Constants:
|
30 |
+
WORKDIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
31 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
|
33 |
+
|
34 |
+
class DataConfig(ConfigDict):
|
35 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
36 |
+
super().__init__()
|
37 |
+
if config_dict is None:
|
38 |
+
config_dict = {}
|
39 |
+
|
40 |
+
self.dataset: str = config_dict.get("dataset", None)
|
41 |
+
self.backbone: str = config_dict.get("backbone", None)
|
42 |
+
self.bottleneck: str = config_dict.get("bottleneck", None)
|
43 |
+
self.classifier: str = config_dict.get("classifier", None)
|
44 |
+
self.sampler: str = config_dict.get("sampler", None)
|
45 |
+
self.num_concepts: int = config_dict.get("num_concepts", None)
|
46 |
+
|
47 |
+
|
48 |
+
class SpliceConfig(ConfigDict):
|
49 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
50 |
+
super().__init__()
|
51 |
+
if config_dict is None:
|
52 |
+
config_dict = {}
|
53 |
+
|
54 |
+
self.vocab: str = config_dict.get("vocab", None)
|
55 |
+
self.vocab_size: int = config_dict.get("vocab_size", None)
|
56 |
+
self.l1_penalty: float = config_dict.get("l1_penalty", None)
|
57 |
+
|
58 |
+
|
59 |
+
class PCBMConfig(ConfigDict):
|
60 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
61 |
+
super().__init__()
|
62 |
+
if config_dict is None:
|
63 |
+
config_dict = {}
|
64 |
+
|
65 |
+
self.alpha: float = config_dict.get("alpha", None)
|
66 |
+
self.l1_ratio: float = config_dict.get("l1_ratio", None)
|
67 |
+
|
68 |
+
|
69 |
+
class cKDEConfig(ConfigDict):
|
70 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
71 |
+
super().__init__()
|
72 |
+
if config_dict is None:
|
73 |
+
config_dict = {}
|
74 |
+
|
75 |
+
self.metric: str = config_dict.get("metric", None)
|
76 |
+
self.scale_method: str = config_dict.get("scale_method", None)
|
77 |
+
self.scale: float = config_dict.get("scale", None)
|
78 |
+
|
79 |
+
|
80 |
+
class TestingConfig(ConfigDict):
|
81 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
82 |
+
super().__init__()
|
83 |
+
if config_dict is None:
|
84 |
+
config_dict = {}
|
85 |
+
|
86 |
+
self.significance_level: float = config_dict.get("significance_level", None)
|
87 |
+
self.wealth: str = config_dict.get("wealth", None)
|
88 |
+
self.bet: str = config_dict.get("bet", None)
|
89 |
+
self.kernel: str = config_dict.get("kernel", None)
|
90 |
+
self.kernel_scale_method: str = config_dict.get("kernel_scale_method", None)
|
91 |
+
self.kernel_scale: float = config_dict.get("kernel_scale", None)
|
92 |
+
self.tau_max: int = config_dict.get("tau_max", None)
|
93 |
+
self.images_per_class: int = config_dict.get("images_per_class", None)
|
94 |
+
self.r: int = config_dict.get("r", None)
|
95 |
+
self.cardinality: Iterable[int] = config_dict.get("cardinality", None)
|
96 |
+
|
97 |
+
|
98 |
+
class Config(ConfigDict):
|
99 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
100 |
+
super().__init__()
|
101 |
+
if config_dict is None:
|
102 |
+
config_dict = {}
|
103 |
+
|
104 |
+
self.name: str = config_dict.get("name", None)
|
105 |
+
self.data = DataConfig(config_dict.get("data", None))
|
106 |
+
self.splice = SpliceConfig(config_dict.get("splice", None))
|
107 |
+
self.pcbm = PCBMConfig(config_dict.get("pcbm", None))
|
108 |
+
self.ckde = cKDEConfig(config_dict.get("ckde", None))
|
109 |
+
self.testing = TestingConfig(config_dict.get("testing", None))
|
110 |
+
|
111 |
+
def backbone_name(self):
|
112 |
+
backbone = self.data.backbone.strip().lower()
|
113 |
+
return backbone.replace("/", "_").replace(":", "_")
|
114 |
+
|
115 |
+
def sweep(self, keys: Iterable[str]):
|
116 |
+
def _get(dict, key):
|
117 |
+
keys = key.split(".")
|
118 |
+
if len(keys) == 1:
|
119 |
+
return dict[keys[0]]
|
120 |
+
else:
|
121 |
+
return _get(dict[keys[0]], ".".join(keys[1:]))
|
122 |
+
|
123 |
+
def _set(dict, key, value):
|
124 |
+
keys = key.split(".")
|
125 |
+
if len(keys) == 1:
|
126 |
+
dict[keys[0]] = value
|
127 |
+
else:
|
128 |
+
_set(dict[keys[0]], ".".join(keys[1:]), value)
|
129 |
+
|
130 |
+
to_iterable = lambda v: v if isinstance(v, list) else [v]
|
131 |
+
|
132 |
+
config_dict = self.to_dict()
|
133 |
+
sweep_values = [_get(config_dict, key) for key in keys]
|
134 |
+
sweep = list(product(*map(to_iterable, sweep_values)))
|
135 |
+
|
136 |
+
configs: Iterable[Config] = []
|
137 |
+
for _sweep in sweep:
|
138 |
+
_config_dict = config_dict.copy()
|
139 |
+
for key, value in zip(keys, _sweep):
|
140 |
+
_set(_config_dict, key, value)
|
141 |
+
|
142 |
+
configs.append(Config(_config_dict))
|
143 |
+
return configs
|
144 |
+
|
145 |
+
|
146 |
+
def register_config(name: str):
|
147 |
+
def register(cls: Config):
|
148 |
+
if name in configs:
|
149 |
+
raise ValueError(f"Config {name} is already registered")
|
150 |
+
configs[name] = cls
|
151 |
+
|
152 |
+
return register
|
153 |
+
|
154 |
+
|
155 |
+
def get_config(name: str) -> Config:
|
156 |
+
return configs[name]()
|
157 |
+
|
158 |
+
|
159 |
+
configs: Mapping[str, Config] = {}
|
app_lib/defaults.py
CHANGED
@@ -1,14 +1,19 @@
|
|
1 |
-
|
2 |
-
MODEL_NAME = "open_clip:ViT-B-32"
|
3 |
|
4 |
-
SIGNIFICANCE_LEVEL_VALUE = 0.05
|
5 |
-
SIGNIFICANCE_LEVEL_STEP = 0.01
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
|
|
2 |
|
|
|
|
|
3 |
|
4 |
+
@dataclass
|
5 |
+
class Defaults:
|
6 |
+
DATASET_NAME = "imagenette"
|
7 |
+
MODEL_NAME = "open_clip:ViT-B-32"
|
8 |
|
9 |
+
SIGNIFICANCE_LEVEL_VALUE = 0.05
|
10 |
+
SIGNIFICANCE_LEVEL_STEP = 0.01
|
11 |
|
12 |
+
TAU_MAX_VALUE = 200
|
13 |
+
TAU_MAX_STEP = 50
|
14 |
+
|
15 |
+
R_VALUE = 20
|
16 |
+
R_STEP = 5
|
17 |
+
|
18 |
+
CARDINALITY_VALUE = 1
|
19 |
+
CARDINALITY_STEP = 1
|
app_lib/multimodal.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/multimodal.py
|
2 |
+
|
3 |
+
from abc import abstractmethod
|
4 |
+
from typing import Mapping, Optional
|
5 |
+
|
6 |
+
import clip
|
7 |
+
import open_clip
|
8 |
+
from transformers import (
|
9 |
+
AlignModel,
|
10 |
+
AlignProcessor,
|
11 |
+
BlipForImageTextRetrieval,
|
12 |
+
BlipProcessor,
|
13 |
+
FlavaModel,
|
14 |
+
FlavaProcessor,
|
15 |
+
)
|
16 |
+
|
17 |
+
from app_lib.config import Config
|
18 |
+
from app_lib.config import Constants as c
|
19 |
+
|
20 |
+
|
21 |
+
class VisionLanguageModel:
|
22 |
+
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
|
23 |
+
pass
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def encode_text(self, text):
|
27 |
+
pass
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def encode_image(self, image):
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
models: Mapping[str, VisionLanguageModel] = {}
|
35 |
+
|
36 |
+
|
37 |
+
def register_model(name):
|
38 |
+
def register(cls: VisionLanguageModel):
|
39 |
+
if name in models:
|
40 |
+
raise ValueError(f"Model {name} is already registered")
|
41 |
+
models[name] = cls
|
42 |
+
|
43 |
+
return register
|
44 |
+
|
45 |
+
|
46 |
+
def get_model_name_and_backbone(config: Config):
|
47 |
+
backbone = config.data.backbone.split(":")
|
48 |
+
if len(backbone) == 1:
|
49 |
+
backbone.append(None)
|
50 |
+
return backbone
|
51 |
+
|
52 |
+
|
53 |
+
def get_model(config: Config, device=c.DEVICE) -> VisionLanguageModel:
|
54 |
+
model_name, backbone = get_model_name_and_backbone(config)
|
55 |
+
return models[model_name](backbone, device=device)
|
56 |
+
|
57 |
+
|
58 |
+
def get_text_encoder(config: Config, device=c.DEVICE):
|
59 |
+
model = get_model(config, device=device)
|
60 |
+
return model.encode_text
|
61 |
+
|
62 |
+
|
63 |
+
def get_image_encoder(config: Config, device=c.DEVICE):
|
64 |
+
model = get_model(config, device=device)
|
65 |
+
return model.encode_image
|
66 |
+
|
67 |
+
|
68 |
+
@register_model(name="clip")
|
69 |
+
class CLIPModel(VisionLanguageModel):
|
70 |
+
def __init__(self, backbone: str, device=c.DEVICE):
|
71 |
+
self.model, self.preprocess = clip.load(backbone, device=device)
|
72 |
+
self.tokenize = clip.tokenize
|
73 |
+
|
74 |
+
self.device = device
|
75 |
+
|
76 |
+
def encode_text(self, text):
|
77 |
+
text = self.tokenize(text).to(self.device)
|
78 |
+
return self.model.encode_text(text)
|
79 |
+
|
80 |
+
def encode_image(self, image):
|
81 |
+
image = self.preprocess(image).unsqueeze(0).to(self.device)
|
82 |
+
return self.model.encode_image(image)
|
83 |
+
|
84 |
+
|
85 |
+
@register_model(name="open_clip")
|
86 |
+
class OpenClipModel(VisionLanguageModel):
|
87 |
+
OPENCLIP_WEIGHTS = {
|
88 |
+
"ViT-B-32": "laion2b_s34b_b79k",
|
89 |
+
"ViT-L-14": "laion2b_s32b_b82k",
|
90 |
+
}
|
91 |
+
|
92 |
+
def __init__(self, backbone: str, device=c.DEVICE):
|
93 |
+
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
94 |
+
backbone, pretrained=self.OPENCLIP_WEIGHTS[backbone], device=device
|
95 |
+
)
|
96 |
+
self.tokenize = open_clip.get_tokenizer(backbone)
|
97 |
+
|
98 |
+
self.device = device
|
99 |
+
|
100 |
+
def encode_text(self, text):
|
101 |
+
text = self.tokenize(text).to(self.device)
|
102 |
+
return self.model.encode_text(text)
|
103 |
+
|
104 |
+
def encode_image(self, image):
|
105 |
+
image = self.preprocess(image).unsqueeze(0).to(self.device)
|
106 |
+
return self.model.encode_image(image)
|
107 |
+
|
108 |
+
|
109 |
+
@register_model(name="flava")
|
110 |
+
class FLAVAModel(VisionLanguageModel):
|
111 |
+
HF_MODEL = "facebook/flava-full"
|
112 |
+
|
113 |
+
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
|
114 |
+
if backbone is None:
|
115 |
+
backbone = self.HF_MODEL
|
116 |
+
|
117 |
+
self.model = FlavaModel.from_pretrained(backbone).to(device)
|
118 |
+
self.processor = FlavaProcessor.from_pretrained(backbone)
|
119 |
+
|
120 |
+
self.device = device
|
121 |
+
|
122 |
+
def encode_text(self, text):
|
123 |
+
text_inputs = self.processor(
|
124 |
+
text=text, return_tensors="pt", padding="max_length", max_length=77
|
125 |
+
)
|
126 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
127 |
+
return self.model.get_text_features(**text_inputs)[:, 0, :]
|
128 |
+
|
129 |
+
def encode_image(self, image):
|
130 |
+
image_inputs = self.processor(images=image, return_tensors="pt")
|
131 |
+
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
|
132 |
+
return self.model.get_image_features(**image_inputs)[:, 0, :]
|
133 |
+
|
134 |
+
|
135 |
+
@register_model(name="align")
|
136 |
+
class ALIGNModel(VisionLanguageModel):
|
137 |
+
HF_MODEL = "kakaobrain/align-base"
|
138 |
+
|
139 |
+
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
|
140 |
+
if backbone is None:
|
141 |
+
backbone = self.HF_MODEL
|
142 |
+
|
143 |
+
self.model = AlignModel.from_pretrained(backbone).to(device)
|
144 |
+
self.processor = AlignProcessor.from_pretrained(backbone)
|
145 |
+
|
146 |
+
self.device = device
|
147 |
+
|
148 |
+
def encode_text(self, text):
|
149 |
+
text_inputs = self.processor(
|
150 |
+
text=text, return_tensors="pt", padding="max_length", max_length=77
|
151 |
+
)
|
152 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
153 |
+
return self.model.get_text_features(**text_inputs)
|
154 |
+
|
155 |
+
def encode_image(self, image):
|
156 |
+
image_inputs = self.processor(images=image, return_tensors="pt")
|
157 |
+
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
|
158 |
+
return self.model.get_image_features(**image_inputs)
|
159 |
+
|
160 |
+
|
161 |
+
@register_model(name="blip")
|
162 |
+
class BLIPModel(VisionLanguageModel):
|
163 |
+
HF_MODEL = "Salesforce/blip-itm-base-coco"
|
164 |
+
|
165 |
+
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
|
166 |
+
if backbone is None:
|
167 |
+
backbone = self.HF_MODEL
|
168 |
+
|
169 |
+
self.model = BlipForImageTextRetrieval.from_pretrained(backbone).to(device)
|
170 |
+
self.processor = BlipProcessor.from_pretrained(backbone)
|
171 |
+
|
172 |
+
self.device = device
|
173 |
+
|
174 |
+
def encode_text(self, text):
|
175 |
+
text_inputs = self.processor(
|
176 |
+
text=text, return_tensors="pt", padding="max_length", max_length=77
|
177 |
+
)
|
178 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
179 |
+
question_embeds = self.model.text_encoder(**text_inputs)[0]
|
180 |
+
return self.model.text_proj(question_embeds[:, 0, :])
|
181 |
+
|
182 |
+
def encode_image(self, image):
|
183 |
+
image_inputs = self.processor(images=image, return_tensors="pt")
|
184 |
+
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
|
185 |
+
image_embeds = self.model.vision_model(**image_inputs)[0]
|
186 |
+
return self.model.vision_proj(image_embeds[:, 0, :])
|
app_lib/test.py
CHANGED
@@ -1,65 +1,33 @@
|
|
1 |
import os
|
2 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
3 |
|
4 |
-
import clip
|
5 |
-
import h5py
|
6 |
import ml_collections
|
7 |
import numpy as np
|
8 |
-
import
|
9 |
import streamlit as st
|
10 |
import torch
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
|
|
|
13 |
from app_lib.ckde import cKDE
|
14 |
-
from app_lib.
|
|
|
15 |
from ibydmt.test import xSKIT
|
16 |
|
17 |
rng = np.random.default_rng()
|
18 |
|
19 |
|
20 |
-
def _get_open_clip_model(model_name, device):
|
21 |
-
backbone = model_name.split(":")[-1]
|
22 |
-
|
23 |
-
model, _, preprocess = open_clip.create_model_and_transforms(
|
24 |
-
SUPPORTED_MODELS[model_name], device=device
|
25 |
-
)
|
26 |
-
model.eval()
|
27 |
-
tokenizer = open_clip.get_tokenizer(backbone)
|
28 |
-
return model, preprocess, tokenizer
|
29 |
-
|
30 |
-
|
31 |
-
def _get_clip_model(model_name, device):
|
32 |
-
backbone = model_name.split(":")[-1]
|
33 |
-
model, preprocess = clip.load(backbone, device=device)
|
34 |
-
tokenizer = clip.tokenize
|
35 |
-
return model, preprocess, tokenizer
|
36 |
-
|
37 |
-
|
38 |
-
def _load_model(model_name, device):
|
39 |
-
if "open_clip" in model_name:
|
40 |
-
model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
|
41 |
-
elif "clip" in model_name:
|
42 |
-
model, preprocess, tokenizer = _get_clip_model(model_name, device)
|
43 |
-
return model, preprocess, tokenizer
|
44 |
-
|
45 |
-
|
46 |
@torch.no_grad()
|
47 |
@torch.cuda.amp.autocast()
|
48 |
-
def _encode_concepts(
|
49 |
-
|
50 |
-
|
51 |
-
concept_features = model.encode_text(concepts_text)
|
52 |
concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
|
53 |
return concept_features.cpu().numpy()
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
@torch.cuda.amp.autocast()
|
58 |
-
def _encode_image(model,
|
59 |
-
image = preprocess(image)
|
60 |
-
image = image.unsqueeze(0)
|
61 |
-
image = image.to(device)
|
62 |
-
|
63 |
image_features = model.encode_image(image)
|
64 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
65 |
return image_features.cpu().numpy()
|
@@ -67,24 +35,24 @@ def _encode_image(model, preprocess, image, device):
|
|
67 |
|
68 |
@torch.no_grad()
|
69 |
@torch.cuda.amp.autocast()
|
70 |
-
def _encode_class_name(
|
71 |
-
class_text =
|
72 |
-
|
73 |
class_features = model.encode_text(class_text)
|
74 |
class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
|
75 |
return class_features.cpu().numpy()
|
76 |
|
77 |
|
78 |
-
def
|
79 |
dataset_path = hf_hub_download(
|
80 |
repo_id="jacopoteneggi/IBYDMT",
|
81 |
-
filename=
|
|
|
|
|
82 |
repo_type="dataset",
|
83 |
)
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
return embedding
|
88 |
|
89 |
|
90 |
def _sample_random_subset(concept_idx, concepts, cardinality):
|
@@ -162,26 +130,30 @@ def test(
|
|
162 |
cardinality,
|
163 |
dataset_name,
|
164 |
model_name,
|
165 |
-
device=
|
166 |
with_streamlit=True,
|
167 |
):
|
|
|
|
|
|
|
|
|
168 |
if with_streamlit:
|
169 |
with st.spinner("Loading model"):
|
170 |
-
model
|
171 |
else:
|
172 |
-
model
|
173 |
|
174 |
if with_streamlit:
|
175 |
with st.spinner("Encoding concepts"):
|
176 |
-
cbm = _encode_concepts(
|
177 |
else:
|
178 |
-
cbm = _encode_concepts(
|
179 |
|
180 |
if with_streamlit:
|
181 |
with st.spinner("Encoding image"):
|
182 |
-
h = _encode_image(model,
|
183 |
else:
|
184 |
-
h = _encode_image(model,
|
185 |
z = h @ cbm.T
|
186 |
z = z.squeeze()
|
187 |
|
@@ -201,11 +173,11 @@ def test(
|
|
201 |
),
|
202 |
)
|
203 |
|
204 |
-
embedding =
|
205 |
semantics = embedding @ cbm.T
|
206 |
sampler = cKDE(embedding, semantics)
|
207 |
|
208 |
-
classifier = _encode_class_name(
|
209 |
|
210 |
with ThreadPoolExecutor() as executor:
|
211 |
futures = [
|
|
|
1 |
import os
|
2 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
3 |
|
|
|
|
|
4 |
import ml_collections
|
5 |
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
import streamlit as st
|
8 |
import torch
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
|
11 |
+
import app_lib.multimodal as multimodal
|
12 |
from app_lib.ckde import cKDE
|
13 |
+
from app_lib.config import Config
|
14 |
+
from app_lib.config import Constants as c
|
15 |
from ibydmt.test import xSKIT
|
16 |
|
17 |
rng = np.random.default_rng()
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
@torch.no_grad()
|
21 |
@torch.cuda.amp.autocast()
|
22 |
+
def _encode_concepts(model, concepts):
|
23 |
+
concept_features = model.encode_text(concepts)
|
|
|
|
|
24 |
concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
|
25 |
return concept_features.cpu().numpy()
|
26 |
|
27 |
|
28 |
@torch.no_grad()
|
29 |
@torch.cuda.amp.autocast()
|
30 |
+
def _encode_image(model, image):
|
|
|
|
|
|
|
|
|
31 |
image_features = model.encode_image(image)
|
32 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
33 |
return image_features.cpu().numpy()
|
|
|
35 |
|
36 |
@torch.no_grad()
|
37 |
@torch.cuda.amp.autocast()
|
38 |
+
def _encode_class_name(model, class_name):
|
39 |
+
class_text = [f"A photo of a {class_name}"]
|
|
|
40 |
class_features = model.encode_text(class_text)
|
41 |
class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
|
42 |
return class_features.cpu().numpy()
|
43 |
|
44 |
|
45 |
+
def _load_embedding(config):
|
46 |
dataset_path = hf_hub_download(
|
47 |
repo_id="jacopoteneggi/IBYDMT",
|
48 |
+
filename=(
|
49 |
+
f"{config.data.dataset.lower()}_train_{config.backbone_name()}.parquet"
|
50 |
+
),
|
51 |
repo_type="dataset",
|
52 |
)
|
53 |
|
54 |
+
dataset = pd.read_parquet(dataset_path)
|
55 |
+
return np.array(dataset["embedding"].values.tolist())
|
|
|
56 |
|
57 |
|
58 |
def _sample_random_subset(concept_idx, concepts, cardinality):
|
|
|
130 |
cardinality,
|
131 |
dataset_name,
|
132 |
model_name,
|
133 |
+
device=c.DEVICE,
|
134 |
with_streamlit=True,
|
135 |
):
|
136 |
+
config = Config()
|
137 |
+
config.data.dataset = dataset_name
|
138 |
+
config.data.backbone = model_name
|
139 |
+
|
140 |
if with_streamlit:
|
141 |
with st.spinner("Loading model"):
|
142 |
+
model = multimodal.get_model(config, device=device)
|
143 |
else:
|
144 |
+
model = multimodal.get_model(config, device=device)
|
145 |
|
146 |
if with_streamlit:
|
147 |
with st.spinner("Encoding concepts"):
|
148 |
+
cbm = _encode_concepts(model, concepts)
|
149 |
else:
|
150 |
+
cbm = _encode_concepts(model, concepts)
|
151 |
|
152 |
if with_streamlit:
|
153 |
with st.spinner("Encoding image"):
|
154 |
+
h = _encode_image(model, image)
|
155 |
else:
|
156 |
+
h = _encode_image(model, image)
|
157 |
z = h @ cbm.T
|
158 |
z = z.squeeze()
|
159 |
|
|
|
173 |
),
|
174 |
)
|
175 |
|
176 |
+
embedding = _load_embedding(config)
|
177 |
semantics = embedding @ cbm.T
|
178 |
sampler = cKDE(embedding, semantics)
|
179 |
|
180 |
+
classifier = _encode_class_name(model, class_name)
|
181 |
|
182 |
with ThreadPoolExecutor() as executor:
|
183 |
futures = [
|
app_lib/user_input.py
CHANGED
@@ -5,7 +5,7 @@ import streamlit as st
|
|
5 |
from PIL import Image
|
6 |
from streamlit_image_select import image_select
|
7 |
|
8 |
-
|
9 |
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
|
10 |
|
11 |
IMAGE_DIR = os.path.join("assets", "images")
|
@@ -31,8 +31,8 @@ def _validate_concepts(concepts):
|
|
31 |
|
32 |
|
33 |
def _get_significance_level():
|
34 |
-
default =
|
35 |
-
step =
|
36 |
return st.slider(
|
37 |
"Significance level",
|
38 |
help=f"The level of significance of the tests. Defaults to {default:.2F}.",
|
@@ -45,8 +45,8 @@ def _get_significance_level():
|
|
45 |
|
46 |
|
47 |
def _get_tau_max():
|
48 |
-
default =
|
49 |
-
step =
|
50 |
return int(
|
51 |
st.slider(
|
52 |
"Length of test",
|
@@ -61,8 +61,8 @@ def _get_tau_max():
|
|
61 |
|
62 |
|
63 |
def _get_number_of_tests():
|
64 |
-
default =
|
65 |
-
step =
|
66 |
return int(
|
67 |
st.slider(
|
68 |
"Number of tests per concept",
|
@@ -80,8 +80,8 @@ def _get_number_of_tests():
|
|
80 |
|
81 |
|
82 |
def _get_cardinality(concepts, concepts_ready):
|
83 |
-
default =
|
84 |
-
step =
|
85 |
return st.slider(
|
86 |
"Size of conditioning set",
|
87 |
help=(
|
@@ -98,7 +98,7 @@ def _get_cardinality(concepts, concepts_ready):
|
|
98 |
|
99 |
def _get_dataset_name():
|
100 |
options = SUPPORTED_DATASETS
|
101 |
-
default_idx = options.index(
|
102 |
return st.selectbox(
|
103 |
"Dataset",
|
104 |
options=options,
|
@@ -112,8 +112,8 @@ def _get_dataset_name():
|
|
112 |
|
113 |
|
114 |
def get_model_name():
|
115 |
-
options = list(SUPPORTED_MODELS
|
116 |
-
default_idx = options.index(
|
117 |
return st.selectbox(
|
118 |
"Model to test",
|
119 |
options=options,
|
|
|
5 |
from PIL import Image
|
6 |
from streamlit_image_select import image_select
|
7 |
|
8 |
+
from app_lib.defaults import Defaults as d
|
9 |
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
|
10 |
|
11 |
IMAGE_DIR = os.path.join("assets", "images")
|
|
|
31 |
|
32 |
|
33 |
def _get_significance_level():
|
34 |
+
default = d.SIGNIFICANCE_LEVEL_VALUE
|
35 |
+
step = d.SIGNIFICANCE_LEVEL_STEP
|
36 |
return st.slider(
|
37 |
"Significance level",
|
38 |
help=f"The level of significance of the tests. Defaults to {default:.2F}.",
|
|
|
45 |
|
46 |
|
47 |
def _get_tau_max():
|
48 |
+
default = d.TAU_MAX_VALUE
|
49 |
+
step = d.TAU_MAX_STEP
|
50 |
return int(
|
51 |
st.slider(
|
52 |
"Length of test",
|
|
|
61 |
|
62 |
|
63 |
def _get_number_of_tests():
|
64 |
+
default = d.R_VALUE
|
65 |
+
step = d.R_STEP
|
66 |
return int(
|
67 |
st.slider(
|
68 |
"Number of tests per concept",
|
|
|
80 |
|
81 |
|
82 |
def _get_cardinality(concepts, concepts_ready):
|
83 |
+
default = d.CARDINALITY_VALUE
|
84 |
+
step = d.CARDINALITY_STEP
|
85 |
return st.slider(
|
86 |
"Size of conditioning set",
|
87 |
help=(
|
|
|
98 |
|
99 |
def _get_dataset_name():
|
100 |
options = SUPPORTED_DATASETS
|
101 |
+
default_idx = options.index(d.DATASET_NAME)
|
102 |
return st.selectbox(
|
103 |
"Dataset",
|
104 |
options=options,
|
|
|
112 |
|
113 |
|
114 |
def get_model_name():
|
115 |
+
options = list(SUPPORTED_MODELS)
|
116 |
+
default_idx = options.index(d.MODEL_NAME)
|
117 |
return st.selectbox(
|
118 |
"Model to test",
|
119 |
options=options,
|
app_lib/utils.py
CHANGED
@@ -11,16 +11,8 @@ supported_datasets_path = hf_hub_download(
|
|
11 |
repo_type="dataset",
|
12 |
)
|
13 |
|
14 |
-
SUPPORTED_MODELS = {}
|
15 |
with open(supported_models_path, "r") as f:
|
16 |
-
|
17 |
-
line = line.strip()
|
18 |
-
model_name, model_url = line.split(",")
|
19 |
-
SUPPORTED_MODELS[model_name] = model_url
|
20 |
|
21 |
-
|
22 |
-
SUPPORTED_DATASETS = []
|
23 |
with open(supported_datasets_path, "r") as f:
|
24 |
-
|
25 |
-
dataset_name = line.strip()
|
26 |
-
SUPPORTED_DATASETS.append(dataset_name)
|
|
|
11 |
repo_type="dataset",
|
12 |
)
|
13 |
|
|
|
14 |
with open(supported_models_path, "r") as f:
|
15 |
+
SUPPORTED_MODELS = f.read().splitlines()
|
|
|
|
|
|
|
16 |
|
|
|
|
|
17 |
with open(supported_datasets_path, "r") as f:
|
18 |
+
SUPPORTED_DATASETS = f.read().splitlines()
|
|
|
|
assets/results/bowl_ace.npy
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 226871
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:adaeeda897451c5548b3119fb917214b82590becfe3138158cfcf1055bcb714d
|
3 |
size 226871
|
assets/results/gardener_ace.npy
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 226873
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a7e6d291b8e7226da6af5990094501a741a99c363f04647688da6f8e71746c6
|
3 |
size 226873
|
assets/results/gentleman_ace.npy
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 226874
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb55900dbb2870ffccc56aac8de6fde7960c3d817035bd9d461419ef0bee6b3b
|
3 |
size 226874
|
assets/results/mathematician_ace.npy
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 226873
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab4807601def20a04c66a2ec993a90e929172aee047e66fb8d052d8eaee438b0
|
3 |
size 226873
|
precompute_results.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
|
7 |
-
|
8 |
from app_lib.test import get_testing_config, test
|
9 |
|
10 |
assets_dir = "assets"
|
@@ -13,9 +13,9 @@ results_dir = os.path.join(assets_dir, "results")
|
|
13 |
os.makedirs(results_dir, exist_ok=True)
|
14 |
|
15 |
testing_config = get_testing_config(
|
16 |
-
significance_level=
|
17 |
-
tau_max=
|
18 |
-
r=
|
19 |
)
|
20 |
|
21 |
image_presets = json.load(open(os.path.join(assets_dir, "image_presets.json")))
|
@@ -26,7 +26,7 @@ for _image_name, _image_presets in image_presets.items():
|
|
26 |
_image = Image.open(_image_path)
|
27 |
_class_name = _image_presets["class_name"]
|
28 |
_concepts = _image_presets["concepts"]
|
29 |
-
_cardinality =
|
30 |
|
31 |
_results = test(
|
32 |
testing_config,
|
@@ -34,8 +34,8 @@ for _image_name, _image_presets in image_presets.items():
|
|
34 |
_class_name,
|
35 |
_concepts,
|
36 |
_cardinality,
|
37 |
-
|
38 |
-
|
39 |
with_streamlit=False,
|
40 |
)
|
41 |
|
|
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
|
7 |
+
from app_lib.defaults import Defaults as d
|
8 |
from app_lib.test import get_testing_config, test
|
9 |
|
10 |
assets_dir = "assets"
|
|
|
13 |
os.makedirs(results_dir, exist_ok=True)
|
14 |
|
15 |
testing_config = get_testing_config(
|
16 |
+
significance_level=d.SIGNIFICANCE_LEVEL_VALUE,
|
17 |
+
tau_max=d.TAU_MAX_VALUE,
|
18 |
+
r=d.R_VALUE,
|
19 |
)
|
20 |
|
21 |
image_presets = json.load(open(os.path.join(assets_dir, "image_presets.json")))
|
|
|
26 |
_image = Image.open(_image_path)
|
27 |
_class_name = _image_presets["class_name"]
|
28 |
_concepts = _image_presets["concepts"]
|
29 |
+
_cardinality = d.CARDINALITY_VALUE
|
30 |
|
31 |
_results = test(
|
32 |
testing_config,
|
|
|
34 |
_class_name,
|
35 |
_concepts,
|
36 |
_cardinality,
|
37 |
+
d.DATASET_NAME,
|
38 |
+
d.MODEL_NAME,
|
39 |
with_streamlit=False,
|
40 |
)
|
41 |
|