jacopoteneggi commited on
Commit
dffe47c
·
verified ·
1 Parent(s): 67bd8e2
app_lib/config.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/utils/config.py
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ from itertools import product
7
+ from typing import Any, Iterable, Mapping, Optional, Union
8
+
9
+ import torch
10
+ from ml_collections import ConfigDict
11
+ from numpy import ndarray
12
+
13
+ Array = Union[ndarray, torch.Tensor]
14
+
15
+
16
+ class TestType(Enum):
17
+ GLOBAL = "global"
18
+ GLOBAL_COND = "global_cond"
19
+ LOCAL_COND = "local_cond"
20
+
21
+
22
+ class ConceptType(Enum):
23
+ DATASET = "dataset"
24
+ CLASS = "class"
25
+ IMAGE = "image"
26
+
27
+
28
+ @dataclass
29
+ class Constants:
30
+ WORKDIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
31
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+
34
+ class DataConfig(ConfigDict):
35
+ def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
36
+ super().__init__()
37
+ if config_dict is None:
38
+ config_dict = {}
39
+
40
+ self.dataset: str = config_dict.get("dataset", None)
41
+ self.backbone: str = config_dict.get("backbone", None)
42
+ self.bottleneck: str = config_dict.get("bottleneck", None)
43
+ self.classifier: str = config_dict.get("classifier", None)
44
+ self.sampler: str = config_dict.get("sampler", None)
45
+ self.num_concepts: int = config_dict.get("num_concepts", None)
46
+
47
+
48
+ class SpliceConfig(ConfigDict):
49
+ def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
50
+ super().__init__()
51
+ if config_dict is None:
52
+ config_dict = {}
53
+
54
+ self.vocab: str = config_dict.get("vocab", None)
55
+ self.vocab_size: int = config_dict.get("vocab_size", None)
56
+ self.l1_penalty: float = config_dict.get("l1_penalty", None)
57
+
58
+
59
+ class PCBMConfig(ConfigDict):
60
+ def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
61
+ super().__init__()
62
+ if config_dict is None:
63
+ config_dict = {}
64
+
65
+ self.alpha: float = config_dict.get("alpha", None)
66
+ self.l1_ratio: float = config_dict.get("l1_ratio", None)
67
+
68
+
69
+ class cKDEConfig(ConfigDict):
70
+ def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
71
+ super().__init__()
72
+ if config_dict is None:
73
+ config_dict = {}
74
+
75
+ self.metric: str = config_dict.get("metric", None)
76
+ self.scale_method: str = config_dict.get("scale_method", None)
77
+ self.scale: float = config_dict.get("scale", None)
78
+
79
+
80
+ class TestingConfig(ConfigDict):
81
+ def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
82
+ super().__init__()
83
+ if config_dict is None:
84
+ config_dict = {}
85
+
86
+ self.significance_level: float = config_dict.get("significance_level", None)
87
+ self.wealth: str = config_dict.get("wealth", None)
88
+ self.bet: str = config_dict.get("bet", None)
89
+ self.kernel: str = config_dict.get("kernel", None)
90
+ self.kernel_scale_method: str = config_dict.get("kernel_scale_method", None)
91
+ self.kernel_scale: float = config_dict.get("kernel_scale", None)
92
+ self.tau_max: int = config_dict.get("tau_max", None)
93
+ self.images_per_class: int = config_dict.get("images_per_class", None)
94
+ self.r: int = config_dict.get("r", None)
95
+ self.cardinality: Iterable[int] = config_dict.get("cardinality", None)
96
+
97
+
98
+ class Config(ConfigDict):
99
+ def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
100
+ super().__init__()
101
+ if config_dict is None:
102
+ config_dict = {}
103
+
104
+ self.name: str = config_dict.get("name", None)
105
+ self.data = DataConfig(config_dict.get("data", None))
106
+ self.splice = SpliceConfig(config_dict.get("splice", None))
107
+ self.pcbm = PCBMConfig(config_dict.get("pcbm", None))
108
+ self.ckde = cKDEConfig(config_dict.get("ckde", None))
109
+ self.testing = TestingConfig(config_dict.get("testing", None))
110
+
111
+ def backbone_name(self):
112
+ backbone = self.data.backbone.strip().lower()
113
+ return backbone.replace("/", "_").replace(":", "_")
114
+
115
+ def sweep(self, keys: Iterable[str]):
116
+ def _get(dict, key):
117
+ keys = key.split(".")
118
+ if len(keys) == 1:
119
+ return dict[keys[0]]
120
+ else:
121
+ return _get(dict[keys[0]], ".".join(keys[1:]))
122
+
123
+ def _set(dict, key, value):
124
+ keys = key.split(".")
125
+ if len(keys) == 1:
126
+ dict[keys[0]] = value
127
+ else:
128
+ _set(dict[keys[0]], ".".join(keys[1:]), value)
129
+
130
+ to_iterable = lambda v: v if isinstance(v, list) else [v]
131
+
132
+ config_dict = self.to_dict()
133
+ sweep_values = [_get(config_dict, key) for key in keys]
134
+ sweep = list(product(*map(to_iterable, sweep_values)))
135
+
136
+ configs: Iterable[Config] = []
137
+ for _sweep in sweep:
138
+ _config_dict = config_dict.copy()
139
+ for key, value in zip(keys, _sweep):
140
+ _set(_config_dict, key, value)
141
+
142
+ configs.append(Config(_config_dict))
143
+ return configs
144
+
145
+
146
+ def register_config(name: str):
147
+ def register(cls: Config):
148
+ if name in configs:
149
+ raise ValueError(f"Config {name} is already registered")
150
+ configs[name] = cls
151
+
152
+ return register
153
+
154
+
155
+ def get_config(name: str) -> Config:
156
+ return configs[name]()
157
+
158
+
159
+ configs: Mapping[str, Config] = {}
app_lib/defaults.py CHANGED
@@ -1,14 +1,19 @@
1
- DATASET_NAME = "imagenette"
2
- MODEL_NAME = "open_clip:ViT-B-32"
3
 
4
- SIGNIFICANCE_LEVEL_VALUE = 0.05
5
- SIGNIFICANCE_LEVEL_STEP = 0.01
6
 
7
- TAU_MAX_VALUE = 200
8
- TAU_MAX_STEP = 50
 
 
9
 
10
- R_VALUE = 20
11
- R_STEP = 5
12
 
13
- CARDINALITY_VALUE = 1
14
- CARDINALITY_STEP = 1
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
 
2
 
 
 
3
 
4
+ @dataclass
5
+ class Defaults:
6
+ DATASET_NAME = "imagenette"
7
+ MODEL_NAME = "open_clip:ViT-B-32"
8
 
9
+ SIGNIFICANCE_LEVEL_VALUE = 0.05
10
+ SIGNIFICANCE_LEVEL_STEP = 0.01
11
 
12
+ TAU_MAX_VALUE = 200
13
+ TAU_MAX_STEP = 50
14
+
15
+ R_VALUE = 20
16
+ R_STEP = 5
17
+
18
+ CARDINALITY_VALUE = 1
19
+ CARDINALITY_STEP = 1
app_lib/multimodal.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/multimodal.py
2
+
3
+ from abc import abstractmethod
4
+ from typing import Mapping, Optional
5
+
6
+ import clip
7
+ import open_clip
8
+ from transformers import (
9
+ AlignModel,
10
+ AlignProcessor,
11
+ BlipForImageTextRetrieval,
12
+ BlipProcessor,
13
+ FlavaModel,
14
+ FlavaProcessor,
15
+ )
16
+
17
+ from app_lib.config import Config
18
+ from app_lib.config import Constants as c
19
+
20
+
21
+ class VisionLanguageModel:
22
+ def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
23
+ pass
24
+
25
+ @abstractmethod
26
+ def encode_text(self, text):
27
+ pass
28
+
29
+ @abstractmethod
30
+ def encode_image(self, image):
31
+ pass
32
+
33
+
34
+ models: Mapping[str, VisionLanguageModel] = {}
35
+
36
+
37
+ def register_model(name):
38
+ def register(cls: VisionLanguageModel):
39
+ if name in models:
40
+ raise ValueError(f"Model {name} is already registered")
41
+ models[name] = cls
42
+
43
+ return register
44
+
45
+
46
+ def get_model_name_and_backbone(config: Config):
47
+ backbone = config.data.backbone.split(":")
48
+ if len(backbone) == 1:
49
+ backbone.append(None)
50
+ return backbone
51
+
52
+
53
+ def get_model(config: Config, device=c.DEVICE) -> VisionLanguageModel:
54
+ model_name, backbone = get_model_name_and_backbone(config)
55
+ return models[model_name](backbone, device=device)
56
+
57
+
58
+ def get_text_encoder(config: Config, device=c.DEVICE):
59
+ model = get_model(config, device=device)
60
+ return model.encode_text
61
+
62
+
63
+ def get_image_encoder(config: Config, device=c.DEVICE):
64
+ model = get_model(config, device=device)
65
+ return model.encode_image
66
+
67
+
68
+ @register_model(name="clip")
69
+ class CLIPModel(VisionLanguageModel):
70
+ def __init__(self, backbone: str, device=c.DEVICE):
71
+ self.model, self.preprocess = clip.load(backbone, device=device)
72
+ self.tokenize = clip.tokenize
73
+
74
+ self.device = device
75
+
76
+ def encode_text(self, text):
77
+ text = self.tokenize(text).to(self.device)
78
+ return self.model.encode_text(text)
79
+
80
+ def encode_image(self, image):
81
+ image = self.preprocess(image).unsqueeze(0).to(self.device)
82
+ return self.model.encode_image(image)
83
+
84
+
85
+ @register_model(name="open_clip")
86
+ class OpenClipModel(VisionLanguageModel):
87
+ OPENCLIP_WEIGHTS = {
88
+ "ViT-B-32": "laion2b_s34b_b79k",
89
+ "ViT-L-14": "laion2b_s32b_b82k",
90
+ }
91
+
92
+ def __init__(self, backbone: str, device=c.DEVICE):
93
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
94
+ backbone, pretrained=self.OPENCLIP_WEIGHTS[backbone], device=device
95
+ )
96
+ self.tokenize = open_clip.get_tokenizer(backbone)
97
+
98
+ self.device = device
99
+
100
+ def encode_text(self, text):
101
+ text = self.tokenize(text).to(self.device)
102
+ return self.model.encode_text(text)
103
+
104
+ def encode_image(self, image):
105
+ image = self.preprocess(image).unsqueeze(0).to(self.device)
106
+ return self.model.encode_image(image)
107
+
108
+
109
+ @register_model(name="flava")
110
+ class FLAVAModel(VisionLanguageModel):
111
+ HF_MODEL = "facebook/flava-full"
112
+
113
+ def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
114
+ if backbone is None:
115
+ backbone = self.HF_MODEL
116
+
117
+ self.model = FlavaModel.from_pretrained(backbone).to(device)
118
+ self.processor = FlavaProcessor.from_pretrained(backbone)
119
+
120
+ self.device = device
121
+
122
+ def encode_text(self, text):
123
+ text_inputs = self.processor(
124
+ text=text, return_tensors="pt", padding="max_length", max_length=77
125
+ )
126
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
127
+ return self.model.get_text_features(**text_inputs)[:, 0, :]
128
+
129
+ def encode_image(self, image):
130
+ image_inputs = self.processor(images=image, return_tensors="pt")
131
+ image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
132
+ return self.model.get_image_features(**image_inputs)[:, 0, :]
133
+
134
+
135
+ @register_model(name="align")
136
+ class ALIGNModel(VisionLanguageModel):
137
+ HF_MODEL = "kakaobrain/align-base"
138
+
139
+ def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
140
+ if backbone is None:
141
+ backbone = self.HF_MODEL
142
+
143
+ self.model = AlignModel.from_pretrained(backbone).to(device)
144
+ self.processor = AlignProcessor.from_pretrained(backbone)
145
+
146
+ self.device = device
147
+
148
+ def encode_text(self, text):
149
+ text_inputs = self.processor(
150
+ text=text, return_tensors="pt", padding="max_length", max_length=77
151
+ )
152
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
153
+ return self.model.get_text_features(**text_inputs)
154
+
155
+ def encode_image(self, image):
156
+ image_inputs = self.processor(images=image, return_tensors="pt")
157
+ image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
158
+ return self.model.get_image_features(**image_inputs)
159
+
160
+
161
+ @register_model(name="blip")
162
+ class BLIPModel(VisionLanguageModel):
163
+ HF_MODEL = "Salesforce/blip-itm-base-coco"
164
+
165
+ def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
166
+ if backbone is None:
167
+ backbone = self.HF_MODEL
168
+
169
+ self.model = BlipForImageTextRetrieval.from_pretrained(backbone).to(device)
170
+ self.processor = BlipProcessor.from_pretrained(backbone)
171
+
172
+ self.device = device
173
+
174
+ def encode_text(self, text):
175
+ text_inputs = self.processor(
176
+ text=text, return_tensors="pt", padding="max_length", max_length=77
177
+ )
178
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
179
+ question_embeds = self.model.text_encoder(**text_inputs)[0]
180
+ return self.model.text_proj(question_embeds[:, 0, :])
181
+
182
+ def encode_image(self, image):
183
+ image_inputs = self.processor(images=image, return_tensors="pt")
184
+ image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
185
+ image_embeds = self.model.vision_model(**image_inputs)[0]
186
+ return self.model.vision_proj(image_embeds[:, 0, :])
app_lib/test.py CHANGED
@@ -1,65 +1,33 @@
1
  import os
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
 
4
- import clip
5
- import h5py
6
  import ml_collections
7
  import numpy as np
8
- import open_clip
9
  import streamlit as st
10
  import torch
11
  from huggingface_hub import hf_hub_download
12
 
 
13
  from app_lib.ckde import cKDE
14
- from app_lib.utils import SUPPORTED_MODELS
 
15
  from ibydmt.test import xSKIT
16
 
17
  rng = np.random.default_rng()
18
 
19
 
20
- def _get_open_clip_model(model_name, device):
21
- backbone = model_name.split(":")[-1]
22
-
23
- model, _, preprocess = open_clip.create_model_and_transforms(
24
- SUPPORTED_MODELS[model_name], device=device
25
- )
26
- model.eval()
27
- tokenizer = open_clip.get_tokenizer(backbone)
28
- return model, preprocess, tokenizer
29
-
30
-
31
- def _get_clip_model(model_name, device):
32
- backbone = model_name.split(":")[-1]
33
- model, preprocess = clip.load(backbone, device=device)
34
- tokenizer = clip.tokenize
35
- return model, preprocess, tokenizer
36
-
37
-
38
- def _load_model(model_name, device):
39
- if "open_clip" in model_name:
40
- model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
41
- elif "clip" in model_name:
42
- model, preprocess, tokenizer = _get_clip_model(model_name, device)
43
- return model, preprocess, tokenizer
44
-
45
-
46
  @torch.no_grad()
47
  @torch.cuda.amp.autocast()
48
- def _encode_concepts(tokenizer, model, concepts, device):
49
- concepts_text = tokenizer(concepts).to(device)
50
-
51
- concept_features = model.encode_text(concepts_text)
52
  concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
53
  return concept_features.cpu().numpy()
54
 
55
 
56
  @torch.no_grad()
57
  @torch.cuda.amp.autocast()
58
- def _encode_image(model, preprocess, image, device):
59
- image = preprocess(image)
60
- image = image.unsqueeze(0)
61
- image = image.to(device)
62
-
63
  image_features = model.encode_image(image)
64
  image_features /= image_features.norm(dim=-1, keepdim=True)
65
  return image_features.cpu().numpy()
@@ -67,24 +35,24 @@ def _encode_image(model, preprocess, image, device):
67
 
68
  @torch.no_grad()
69
  @torch.cuda.amp.autocast()
70
- def _encode_class_name(tokenizer, model, class_name, device):
71
- class_text = tokenizer([f"A photo of a {class_name}"]).to(device)
72
-
73
  class_features = model.encode_text(class_text)
74
  class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
75
  return class_features.cpu().numpy()
76
 
77
 
78
- def _load_dataset(dataset_name, model_name):
79
  dataset_path = hf_hub_download(
80
  repo_id="jacopoteneggi/IBYDMT",
81
- filename=f"{dataset_name}_{model_name}_train.h5",
 
 
82
  repo_type="dataset",
83
  )
84
 
85
- with h5py.File(dataset_path, "r") as dataset:
86
- embedding = dataset["embedding"][:]
87
- return embedding
88
 
89
 
90
  def _sample_random_subset(concept_idx, concepts, cardinality):
@@ -162,26 +130,30 @@ def test(
162
  cardinality,
163
  dataset_name,
164
  model_name,
165
- device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
166
  with_streamlit=True,
167
  ):
 
 
 
 
168
  if with_streamlit:
169
  with st.spinner("Loading model"):
170
- model, preprocess, tokenizer = _load_model(model_name, device)
171
  else:
172
- model, preprocess, tokenizer = _load_model(model_name, device)
173
 
174
  if with_streamlit:
175
  with st.spinner("Encoding concepts"):
176
- cbm = _encode_concepts(tokenizer, model, concepts, device)
177
  else:
178
- cbm = _encode_concepts(tokenizer, model, concepts, device)
179
 
180
  if with_streamlit:
181
  with st.spinner("Encoding image"):
182
- h = _encode_image(model, preprocess, image, device)
183
  else:
184
- h = _encode_image(model, preprocess, image, device)
185
  z = h @ cbm.T
186
  z = z.squeeze()
187
 
@@ -201,11 +173,11 @@ def test(
201
  ),
202
  )
203
 
204
- embedding = _load_dataset(dataset_name, model_name)
205
  semantics = embedding @ cbm.T
206
  sampler = cKDE(embedding, semantics)
207
 
208
- classifier = _encode_class_name(tokenizer, model, class_name, device)
209
 
210
  with ThreadPoolExecutor() as executor:
211
  futures = [
 
1
  import os
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
 
 
 
4
  import ml_collections
5
  import numpy as np
6
+ import pandas as pd
7
  import streamlit as st
8
  import torch
9
  from huggingface_hub import hf_hub_download
10
 
11
+ import app_lib.multimodal as multimodal
12
  from app_lib.ckde import cKDE
13
+ from app_lib.config import Config
14
+ from app_lib.config import Constants as c
15
  from ibydmt.test import xSKIT
16
 
17
  rng = np.random.default_rng()
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @torch.no_grad()
21
  @torch.cuda.amp.autocast()
22
+ def _encode_concepts(model, concepts):
23
+ concept_features = model.encode_text(concepts)
 
 
24
  concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
25
  return concept_features.cpu().numpy()
26
 
27
 
28
  @torch.no_grad()
29
  @torch.cuda.amp.autocast()
30
+ def _encode_image(model, image):
 
 
 
 
31
  image_features = model.encode_image(image)
32
  image_features /= image_features.norm(dim=-1, keepdim=True)
33
  return image_features.cpu().numpy()
 
35
 
36
  @torch.no_grad()
37
  @torch.cuda.amp.autocast()
38
+ def _encode_class_name(model, class_name):
39
+ class_text = [f"A photo of a {class_name}"]
 
40
  class_features = model.encode_text(class_text)
41
  class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
42
  return class_features.cpu().numpy()
43
 
44
 
45
+ def _load_embedding(config):
46
  dataset_path = hf_hub_download(
47
  repo_id="jacopoteneggi/IBYDMT",
48
+ filename=(
49
+ f"{config.data.dataset.lower()}_train_{config.backbone_name()}.parquet"
50
+ ),
51
  repo_type="dataset",
52
  )
53
 
54
+ dataset = pd.read_parquet(dataset_path)
55
+ return np.array(dataset["embedding"].values.tolist())
 
56
 
57
 
58
  def _sample_random_subset(concept_idx, concepts, cardinality):
 
130
  cardinality,
131
  dataset_name,
132
  model_name,
133
+ device=c.DEVICE,
134
  with_streamlit=True,
135
  ):
136
+ config = Config()
137
+ config.data.dataset = dataset_name
138
+ config.data.backbone = model_name
139
+
140
  if with_streamlit:
141
  with st.spinner("Loading model"):
142
+ model = multimodal.get_model(config, device=device)
143
  else:
144
+ model = multimodal.get_model(config, device=device)
145
 
146
  if with_streamlit:
147
  with st.spinner("Encoding concepts"):
148
+ cbm = _encode_concepts(model, concepts)
149
  else:
150
+ cbm = _encode_concepts(model, concepts)
151
 
152
  if with_streamlit:
153
  with st.spinner("Encoding image"):
154
+ h = _encode_image(model, image)
155
  else:
156
+ h = _encode_image(model, image)
157
  z = h @ cbm.T
158
  z = z.squeeze()
159
 
 
173
  ),
174
  )
175
 
176
+ embedding = _load_embedding(config)
177
  semantics = embedding @ cbm.T
178
  sampler = cKDE(embedding, semantics)
179
 
180
+ classifier = _encode_class_name(model, class_name)
181
 
182
  with ThreadPoolExecutor() as executor:
183
  futures = [
app_lib/user_input.py CHANGED
@@ -5,7 +5,7 @@ import streamlit as st
5
  from PIL import Image
6
  from streamlit_image_select import image_select
7
 
8
- import app_lib.defaults as defaults
9
  from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
10
 
11
  IMAGE_DIR = os.path.join("assets", "images")
@@ -31,8 +31,8 @@ def _validate_concepts(concepts):
31
 
32
 
33
  def _get_significance_level():
34
- default = defaults.SIGNIFICANCE_LEVEL_VALUE
35
- step = defaults.SIGNIFICANCE_LEVEL_STEP
36
  return st.slider(
37
  "Significance level",
38
  help=f"The level of significance of the tests. Defaults to {default:.2F}.",
@@ -45,8 +45,8 @@ def _get_significance_level():
45
 
46
 
47
  def _get_tau_max():
48
- default = defaults.TAU_MAX_VALUE
49
- step = defaults.TAU_MAX_STEP
50
  return int(
51
  st.slider(
52
  "Length of test",
@@ -61,8 +61,8 @@ def _get_tau_max():
61
 
62
 
63
  def _get_number_of_tests():
64
- default = defaults.R_VALUE
65
- step = defaults.R_STEP
66
  return int(
67
  st.slider(
68
  "Number of tests per concept",
@@ -80,8 +80,8 @@ def _get_number_of_tests():
80
 
81
 
82
  def _get_cardinality(concepts, concepts_ready):
83
- default = defaults.CARDINALITY_VALUE
84
- step = defaults.CARDINALITY_STEP
85
  return st.slider(
86
  "Size of conditioning set",
87
  help=(
@@ -98,7 +98,7 @@ def _get_cardinality(concepts, concepts_ready):
98
 
99
  def _get_dataset_name():
100
  options = SUPPORTED_DATASETS
101
- default_idx = options.index(defaults.DATASET_NAME)
102
  return st.selectbox(
103
  "Dataset",
104
  options=options,
@@ -112,8 +112,8 @@ def _get_dataset_name():
112
 
113
 
114
  def get_model_name():
115
- options = list(SUPPORTED_MODELS.keys())
116
- default_idx = options.index(defaults.MODEL_NAME)
117
  return st.selectbox(
118
  "Model to test",
119
  options=options,
 
5
  from PIL import Image
6
  from streamlit_image_select import image_select
7
 
8
+ from app_lib.defaults import Defaults as d
9
  from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
10
 
11
  IMAGE_DIR = os.path.join("assets", "images")
 
31
 
32
 
33
  def _get_significance_level():
34
+ default = d.SIGNIFICANCE_LEVEL_VALUE
35
+ step = d.SIGNIFICANCE_LEVEL_STEP
36
  return st.slider(
37
  "Significance level",
38
  help=f"The level of significance of the tests. Defaults to {default:.2F}.",
 
45
 
46
 
47
  def _get_tau_max():
48
+ default = d.TAU_MAX_VALUE
49
+ step = d.TAU_MAX_STEP
50
  return int(
51
  st.slider(
52
  "Length of test",
 
61
 
62
 
63
  def _get_number_of_tests():
64
+ default = d.R_VALUE
65
+ step = d.R_STEP
66
  return int(
67
  st.slider(
68
  "Number of tests per concept",
 
80
 
81
 
82
  def _get_cardinality(concepts, concepts_ready):
83
+ default = d.CARDINALITY_VALUE
84
+ step = d.CARDINALITY_STEP
85
  return st.slider(
86
  "Size of conditioning set",
87
  help=(
 
98
 
99
  def _get_dataset_name():
100
  options = SUPPORTED_DATASETS
101
+ default_idx = options.index(d.DATASET_NAME)
102
  return st.selectbox(
103
  "Dataset",
104
  options=options,
 
112
 
113
 
114
  def get_model_name():
115
+ options = list(SUPPORTED_MODELS)
116
+ default_idx = options.index(d.MODEL_NAME)
117
  return st.selectbox(
118
  "Model to test",
119
  options=options,
app_lib/utils.py CHANGED
@@ -11,16 +11,8 @@ supported_datasets_path = hf_hub_download(
11
  repo_type="dataset",
12
  )
13
 
14
- SUPPORTED_MODELS = {}
15
  with open(supported_models_path, "r") as f:
16
- for line in f:
17
- line = line.strip()
18
- model_name, model_url = line.split(",")
19
- SUPPORTED_MODELS[model_name] = model_url
20
 
21
-
22
- SUPPORTED_DATASETS = []
23
  with open(supported_datasets_path, "r") as f:
24
- for line in f:
25
- dataset_name = line.strip()
26
- SUPPORTED_DATASETS.append(dataset_name)
 
11
  repo_type="dataset",
12
  )
13
 
 
14
  with open(supported_models_path, "r") as f:
15
+ SUPPORTED_MODELS = f.read().splitlines()
 
 
 
16
 
 
 
17
  with open(supported_datasets_path, "r") as f:
18
+ SUPPORTED_DATASETS = f.read().splitlines()
 
 
assets/results/bowl_ace.npy CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7f7746293f59199f6872a9570757b2d9d827c2733935cedba2c161181a1cc19c
3
  size 226871
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adaeeda897451c5548b3119fb917214b82590becfe3138158cfcf1055bcb714d
3
  size 226871
assets/results/gardener_ace.npy CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9e3e4a87d960a3591c6c97e10cc40a9e0727048118cc1d5670770bdd74e457ce
3
  size 226873
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a7e6d291b8e7226da6af5990094501a741a99c363f04647688da6f8e71746c6
3
  size 226873
assets/results/gentleman_ace.npy CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a319ecde9c2299323692edf385274b6939c15a9d6c296aa70629898d8798934f
3
  size 226874
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb55900dbb2870ffccc56aac8de6fde7960c3d817035bd9d461419ef0bee6b3b
3
  size 226874
assets/results/mathematician_ace.npy CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bf0f277e9f06b957f03c39620a11514a373754b72cf1b5e97088964d07ff7b4a
3
  size 226873
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab4807601def20a04c66a2ec993a90e929172aee047e66fb8d052d8eaee438b0
3
  size 226873
precompute_results.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import numpy as np
5
  from PIL import Image
6
 
7
- import app_lib.defaults as defaults
8
  from app_lib.test import get_testing_config, test
9
 
10
  assets_dir = "assets"
@@ -13,9 +13,9 @@ results_dir = os.path.join(assets_dir, "results")
13
  os.makedirs(results_dir, exist_ok=True)
14
 
15
  testing_config = get_testing_config(
16
- significance_level=defaults.SIGNIFICANCE_LEVEL_VALUE,
17
- tau_max=defaults.TAU_MAX_VALUE,
18
- r=defaults.R_VALUE,
19
  )
20
 
21
  image_presets = json.load(open(os.path.join(assets_dir, "image_presets.json")))
@@ -26,7 +26,7 @@ for _image_name, _image_presets in image_presets.items():
26
  _image = Image.open(_image_path)
27
  _class_name = _image_presets["class_name"]
28
  _concepts = _image_presets["concepts"]
29
- _cardinality = defaults.CARDINALITY_VALUE
30
 
31
  _results = test(
32
  testing_config,
@@ -34,8 +34,8 @@ for _image_name, _image_presets in image_presets.items():
34
  _class_name,
35
  _concepts,
36
  _cardinality,
37
- defaults.DATASET_NAME,
38
- defaults.MODEL_NAME,
39
  with_streamlit=False,
40
  )
41
 
 
4
  import numpy as np
5
  from PIL import Image
6
 
7
+ from app_lib.defaults import Defaults as d
8
  from app_lib.test import get_testing_config, test
9
 
10
  assets_dir = "assets"
 
13
  os.makedirs(results_dir, exist_ok=True)
14
 
15
  testing_config = get_testing_config(
16
+ significance_level=d.SIGNIFICANCE_LEVEL_VALUE,
17
+ tau_max=d.TAU_MAX_VALUE,
18
+ r=d.R_VALUE,
19
  )
20
 
21
  image_presets = json.load(open(os.path.join(assets_dir, "image_presets.json")))
 
26
  _image = Image.open(_image_path)
27
  _class_name = _image_presets["class_name"]
28
  _concepts = _image_presets["concepts"]
29
+ _cardinality = d.CARDINALITY_VALUE
30
 
31
  _results = test(
32
  testing_config,
 
34
  _class_name,
35
  _concepts,
36
  _cardinality,
37
+ d.DATASET_NAME,
38
+ d.MODEL_NAME,
39
  with_streamlit=False,
40
  )
41