OwenElliott commited on
Commit
b6c64a0
·
verified ·
1 Parent(s): f50de00

Upload 18 files

Browse files
amazon.json ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open_clip
3
+ from PIL import Image
4
+ import requests
5
+ import json
6
+ import gradio as gr
7
+ import pandas as pd
8
+ from io import BytesIO
9
+ import os
10
+
11
+ # Load the Amazon taxonomy from a JSON file
12
+ with open("amazon.json", "r") as f:
13
+ AMAZON_TAXONOMY = json.load(f)
14
+
15
+
16
+ base_model_name = "ViT-B-16"
17
+ model_base, _, preprocess_base = open_clip.create_model_and_transforms(base_model_name)
18
+ tokenizer_base = open_clip.get_tokenizer(base_model_name)
19
+ model_name_B = "hf-hub:Marqo/marqo-ecommerce-embeddings-B"
20
+ model_B, _, preprocess_B = open_clip.create_model_and_transforms(model_name_B)
21
+ tokenizer_B = open_clip.get_tokenizer(model_name_B)
22
+ model_name_L = "hf-hub:Marqo/marqo-ecommerce-embeddings-L"
23
+ model_L, _, preprocess_L = open_clip.create_model_and_transforms(model_name_L)
24
+ tokenizer_L = open_clip.get_tokenizer(model_name_L)
25
+
26
+ models = [base_model_name, model_name_B, model_name_L]
27
+
28
+ taxonomy_cache = {}
29
+ for model in models:
30
+ with open(f'{model.split("/")[-1]}.json', "r") as f:
31
+ taxonomy_cache[model] = json.load(f)
32
+
33
+
34
+ def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
35
+ numerator = (a * b).sum(dim=-1)
36
+ denominator = torch.linalg.norm(a, ord=2, dim=-1) * torch.linalg.norm(
37
+ b, ord=2, dim=-1
38
+ )
39
+ return 0.5 * (numerator / denominator + 1.0)
40
+
41
+
42
+ class BeamPath:
43
+ def __init__(self, path: list, cumulative_score: float, current_layer: dict | list):
44
+ self.path = path
45
+ self.cumulative_score = cumulative_score
46
+ self.current_layer = current_layer
47
+
48
+ def __repr__(self):
49
+ return f"BeamPath(path={self.path}, cumulative_score={self.cumulative_score})"
50
+
51
+
52
+ def _compute_similarities(classes: list, base_embedding: torch.Tensor, cache_key: str):
53
+ text_features = torch.tensor(
54
+ [taxonomy_cache[cache_key][class_name] for class_name in classes]
55
+ )
56
+
57
+ similarities = cosine_similarity(base_embedding, text_features)
58
+ return similarities.cpu().numpy()
59
+
60
+
61
+ def map_taxonomy(
62
+ base_image: Image.Image,
63
+ taxonomy: dict,
64
+ model,
65
+ tokenizer,
66
+ preprocess_val,
67
+ cache_key,
68
+ beam_width: int = 3,
69
+ ) -> tuple[list[tuple[str, float]], float]:
70
+ image_tensor = preprocess_val(base_image).unsqueeze(0)
71
+ with torch.no_grad(), torch.cuda.amp.autocast():
72
+ base_embedding = model.encode_image(image_tensor, normalize=True)
73
+
74
+ initial_path = BeamPath(path=[], cumulative_score=0.0, current_layer=taxonomy)
75
+ beam = [initial_path]
76
+
77
+ final_paths = []
78
+ is_first = True
79
+ while beam:
80
+ candidates = []
81
+ candidate_entries = []
82
+
83
+ for beam_path in beam:
84
+ layer = beam_path.current_layer
85
+
86
+ if isinstance(layer, dict):
87
+ classes = list(layer.keys())
88
+ elif isinstance(layer, list):
89
+ classes = layer
90
+ if classes == []:
91
+ final_paths.append(beam_path)
92
+ continue
93
+ else:
94
+ final_paths.append(beam_path)
95
+ continue
96
+
97
+ # current_path_class_names = [class_name for class_name, _ in beam_path.path]
98
+
99
+ for class_name in classes:
100
+ candidate_string = class_name
101
+ if isinstance(layer, dict):
102
+ next_layer = layer[class_name]
103
+ else:
104
+ next_layer = None
105
+ candidate_entries.append(
106
+ (candidate_string, class_name, beam_path, next_layer)
107
+ )
108
+
109
+ if not candidate_entries:
110
+ break
111
+
112
+ candidate_strings = [
113
+ candidate_string for candidate_string, _, _, _ in candidate_entries
114
+ ]
115
+
116
+ similarities = _compute_similarities(
117
+ candidate_strings, base_embedding, cache_key
118
+ )
119
+
120
+ for (candidate_string, class_name, beam_path, next_layer), similarity in zip(
121
+ candidate_entries, similarities
122
+ ):
123
+ new_path = beam_path.path + [(class_name, float(similarity))]
124
+ new_cumulative_score = beam_path.cumulative_score + similarity
125
+ candidate = BeamPath(
126
+ path=new_path,
127
+ cumulative_score=new_cumulative_score,
128
+ current_layer=next_layer,
129
+ )
130
+ candidates.append(candidate)
131
+
132
+ from collections import defaultdict
133
+
134
+ by_parents = defaultdict(list)
135
+
136
+ for candidate in candidates:
137
+ by_parents[candidate.path[0][0]].append(candidate)
138
+
139
+ beam = []
140
+ for parent in by_parents:
141
+ children = by_parents[parent]
142
+ children.sort(
143
+ key=lambda x: x.cumulative_score / len(x.path) + x.path[-1][1],
144
+ reverse=True,
145
+ )
146
+ if is_first:
147
+ beam.extend(children)
148
+ else:
149
+ beam.extend(children[:beam_width])
150
+
151
+ is_first = False
152
+
153
+ all_paths = beam + final_paths
154
+
155
+ if all_paths:
156
+ all_paths.sort(key=lambda x: x.cumulative_score / len(x.path), reverse=True)
157
+ best_path = all_paths[0]
158
+ return best_path.path, float(best_path.cumulative_score)
159
+ else:
160
+ return [], 0.0
161
+
162
+
163
+ # Function to classify image and map taxonomy
164
+ def classify_image(
165
+ image_input: Image.Image | None,
166
+ image_url: str | None,
167
+ model_size: str,
168
+ beam_width: int,
169
+ ):
170
+ if image_input is not None:
171
+ image = image_input
172
+ elif image_url:
173
+ # Try to get image from URL
174
+ try:
175
+ response = requests.get(image_url)
176
+ image = Image.open(BytesIO(response.content)).convert("RGB")
177
+ except Exception as e:
178
+ return pd.DataFrame({"Error": [str(e)]})
179
+ else:
180
+ return pd.DataFrame(
181
+ {
182
+ "Error": [
183
+ "Please provide an image, an image URL, or select an example image"
184
+ ]
185
+ }
186
+ )
187
+
188
+ # Select the model, tokenizer, and preprocess
189
+ if model_size == "marqo-ecommerce-embeddings-L":
190
+ key = "hf-hub:Marqo/marqo-ecommerce-embeddings-L"
191
+ model = model_L
192
+ preprocess_val = preprocess_L
193
+ tokenizer = tokenizer_L
194
+ elif model_size == "marqo-ecommerce-embeddings-B":
195
+ key = "hf-hub:Marqo/marqo-ecommerce-embeddings-B"
196
+ model = model_B
197
+ preprocess_val = preprocess_B
198
+ tokenizer = tokenizer_B
199
+ elif model_size == "openai-ViT-B-16":
200
+ key = "ViT-B-16"
201
+ model = model_base
202
+ preprocess_val = preprocess_base
203
+ tokenizer = tokenizer_base
204
+ else:
205
+ return pd.DataFrame({"Error": ["Invalid model size"]})
206
+
207
+ path, cumulative_score = map_taxonomy(
208
+ base_image=image,
209
+ taxonomy=AMAZON_TAXONOMY,
210
+ model=model,
211
+ tokenizer=tokenizer,
212
+ preprocess_val=preprocess_val,
213
+ cache_key=key,
214
+ beam_width=beam_width,
215
+ )
216
+
217
+ output = []
218
+ for idx, (category, score) in enumerate(path):
219
+ level = idx + 1
220
+ output.append({"Level": level, "Category": category, "Score": score})
221
+
222
+ df = pd.DataFrame(output)
223
+ return df
224
+
225
+
226
+ with gr.Blocks() as demo:
227
+ gr.Markdown("# Image Classification with Taxonomy Mapping")
228
+ gr.Markdown(
229
+ "## How to use this app\n\nThis app compares Marqo's E-commerce embeddings to OpenAI's ViT-B-16 CLIP model for E-commerce taxonomy mapping. A beam search is used to find the correct classification in the taxonomy. The original OpenAI CLIP models perform very poorly on E-commerce data."
230
+ )
231
+ gr.Markdown(
232
+ "Upload an image, provide an image URL, or select an example image, select the model size, and get the taxonomy mapping. The taxonomy is based on the Amazon product taxonomy."
233
+ )
234
+
235
+ with gr.Row():
236
+ with gr.Column():
237
+ image_input = gr.Image(type="pil", label="Upload Image", height=300)
238
+ image_url_input = gr.Textbox(
239
+ lines=1, placeholder="Image URL", label="Image URL"
240
+ )
241
+ gr.Markdown("### Or select an example image:")
242
+ # Get example images from 'images' folder
243
+ example_images_folder = "images"
244
+ example_image_paths = [
245
+ os.path.join(example_images_folder, img)
246
+ for img in os.listdir(example_images_folder)
247
+ ]
248
+ gr.Examples(
249
+ examples=[[img_path] for img_path in example_image_paths],
250
+ inputs=image_input,
251
+ label="Example Images",
252
+ examples_per_page=100,
253
+ )
254
+ with gr.Column():
255
+ model_size_input = gr.Radio(
256
+ choices=[
257
+ "marqo-ecommerce-embeddings-L",
258
+ "marqo-ecommerce-embeddings-B",
259
+ "openai-ViT-B-16",
260
+ ],
261
+ label="Model",
262
+ value="marqo-ecommerce-embeddings-L",
263
+ )
264
+ beam_width_input = gr.Number(
265
+ label="Beam Width", value=5, minimum=1, step=1
266
+ )
267
+ classify_button = gr.Button("Classify")
268
+ output_table = gr.Dataframe(headers=["Level", "Category", "Score"])
269
+
270
+ classify_button.click(
271
+ fn=classify_image,
272
+ inputs=[image_input, image_url_input, model_size_input, beam_width_input],
273
+ outputs=output_table,
274
+ )
275
+
276
+ demo.launch()
cache_taxonomy_vectors.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import open_clip
4
+ from tqdm import tqdm
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ if device == "cpu":
8
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
9
+
10
+
11
+ def generate_cache(texts: list[str], model_name: str, batch_size: int = 16) -> dict:
12
+ model, _, _ = open_clip.create_model_and_transforms(model_name, device=device)
13
+ tokenizer = open_clip.get_tokenizer(model_name)
14
+
15
+ cache = {}
16
+
17
+ for i in tqdm(range(0, len(texts), batch_size)):
18
+ batch = texts[i : i + batch_size]
19
+ tokens = tokenizer(batch).to(device)
20
+ with torch.no_grad(), torch.cuda.amp.autocast():
21
+ embeddings = model.encode_text(tokens, normalize=True).cpu().numpy()
22
+ for text, embedding in zip(batch, embeddings):
23
+ cache[text] = embedding.tolist()
24
+
25
+ return cache
26
+
27
+
28
+ def flatten_taxonomy(taxonomy: dict) -> list[str]:
29
+ classes = []
30
+ for key, value in taxonomy.items():
31
+ classes.append(key)
32
+ if isinstance(value, dict):
33
+ classes.extend(flatten_taxonomy(value))
34
+ if isinstance(value, list):
35
+ classes.extend(value)
36
+ return classes
37
+
38
+
39
+ def main():
40
+ models = [
41
+ "hf-hub:Marqo/marqo-ecommerce-embeddings-B",
42
+ "hf-hub:Marqo/marqo-ecommerce-embeddings-L",
43
+ "ViT-B-16"
44
+ ]
45
+
46
+ with open("amazon.json") as f:
47
+ taxonomy = json.load(f)
48
+ print("Loaded taxonomy")
49
+
50
+ print("Flattening taxonomy")
51
+ texts = flatten_taxonomy(taxonomy)
52
+
53
+ print("Generating cache")
54
+ for model in models:
55
+ cache = generate_cache(texts, model)
56
+ with open(f'{model.split("/")[-1]}.json', "w+") as f:
57
+ json.dump(cache, f)
58
+
59
+
60
+ if __name__ == "__main__":
61
+ main()
images/bike-helmet.png ADDED
images/coffee.png ADDED
images/cooking-book.jpg ADDED
images/cutting-board.png ADDED
images/flip-flops.jpg ADDED
images/grater.png ADDED
images/green-shirt.webp ADDED
images/hoop-earring.jpg ADDED
images/iron.png ADDED
images/laptop.png ADDED
images/notebook.png ADDED
images/red-dress.webp ADDED
images/runners.png ADDED
images/sleeping-bag.png ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow
4
+ gradio
5
+ ftfy
6
+ open_clip_torch