jasonwuyl92 commited on
Commit
2ab45c8
0 Parent(s):

initial commit after cleanup

Browse files
Files changed (12) hide show
  1. .gitattributes +36 -0
  2. .gitignore +6 -0
  3. README.md +14 -0
  4. app.py +69 -0
  5. app_old.py +38 -0
  6. get_embeddings.ipynb +1047 -0
  7. misc.py +24 -0
  8. requirements.txt +13 -0
  9. run.py +51 -0
  10. streamlit_app.py +39 -0
  11. utils.py +170 -0
  12. vector_db.py +37 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pq filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .idea/
3
+ .python-version
4
+ .ipynb_checkpoints/
5
+ __pycache__
6
+ flagged
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Image Search Playground
3
+ emoji: 📈
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.30.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.10.0
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+
4
+ import gradio as gr
5
+ import pandas as pd
6
+
7
+ import utils
8
+ import vector_db
9
+ from utils import get_image_embedding, \
10
+ get_image_path, model_names, download_images, generate_and_save_embeddings, get_metadata_path, url_to_image
11
+
12
+ NUM_OUTPUTS = 4
13
+
14
+
15
+ def search(input_img, model_name):
16
+ query_embedding = get_image_embedding(model_name, input_img).tolist()
17
+ top_results = vector_db.query_embeddings_db(query_embedding=query_embedding,
18
+ dataset_name=utils.cur_dataset, model_name=model_name)
19
+ print (top_results)
20
+ return [utils.url_to_image(hit['metadata']['mainphotourl']) for hit in top_results['matches']]
21
+
22
+
23
+ def read_tsv_temporary_file(temp_file_wrapper):
24
+ dataset_name = os.path.splitext(os.path.basename(temp_file_wrapper.name))[0]
25
+ utils.set_cur_dataset(dataset_name)
26
+ df = pd.read_csv(temp_file_wrapper.name, sep='\t') # Read the TSV content into a pandas DataFrame
27
+ df.to_csv(os.path.join(get_metadata_path(), dataset_name + '.tsv'), sep='\t', index=False)
28
+ print('start downloading')
29
+ download_images(df, get_image_path())
30
+ generate_and_save_embeddings()
31
+ utils.refresh_all_datasets()
32
+ utils.set_cur_dataset(dataset_name)
33
+ return gr.update(choices=utils.all_datasets, value=dataset_name)
34
+
35
+
36
+ def update_dataset_dropdown():
37
+ utils.refresh_all_datasets()
38
+ utils.set_cur_dataset(utils.all_datasets[0])
39
+ return gr.update(choices=utils.all_datasets, value=utils.cur_dataset)
40
+
41
+
42
+ def gen_image_blocks(num_outputs):
43
+ with gr.Row():
44
+ row = [gr.outputs.Image(label=model_name, type='filepath') for i in range(int(num_outputs))]
45
+ return row
46
+
47
+
48
+ with gr.Blocks() as demo:
49
+ galleries = dict()
50
+ with gr.Row():
51
+ with gr.Column(scale=1):
52
+ file_upload = gr.File(label="Upload TSV File", file_types=[".tsv"])
53
+ image_input = gr.inputs.Image(type="pil", label="Input Image")
54
+ dataset_dropdown = gr.Dropdown(label='Datasets', choices=utils.all_datasets)
55
+ b1 = gr.Button("Find Similar Images")
56
+ b2 = gr.Button("Refresh Datasets")
57
+
58
+ dataset_dropdown.select(utils.set_cur_dataset, inputs=dataset_dropdown)
59
+ file_upload.upload(read_tsv_temporary_file, inputs=file_upload, outputs=dataset_dropdown)
60
+ b2.click(update_dataset_dropdown, outputs=dataset_dropdown)
61
+ with gr.Column(scale=3):
62
+ for model_name in model_names:
63
+ galleries[model_name] = gen_image_blocks(NUM_OUTPUTS)
64
+ for model_name in model_names:
65
+ b1.click(partial(search, model_name=model_name), inputs=[image_input],
66
+ outputs=galleries[model_name])
67
+ b2.click(utils.refresh_all_datasets, outputs=dataset_dropdown)
68
+
69
+ demo.launch()
app_old.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from sentence_transformers import util as st_util
4
+ import pandas as pd
5
+ import os
6
+ from utils import load_models, get_image_embedding, img_folder, model_name_to_ids, data_path, model_names
7
+
8
+
9
+ def search(input_img, num_outputs):
10
+ results = []
11
+ for model_name in model_names:
12
+ query_embedding = get_image_embedding(model_name, input_img)
13
+ top_results = st_util.semantic_search(query_embedding,
14
+ np.vstack(list(corpus_embeddings[model_name + '-embedding'])),
15
+ top_k=int(num_outputs))[0]
16
+ results.append([os.path.join(img_folder,
17
+ corpus_embeddings.iloc[hit['corpus_id']]['name']) for hit in top_results])
18
+ return results
19
+
20
+
21
+ load_models()
22
+ corpus_embeddings = pd.read_parquet(
23
+ os.path.join(data_path, 'metadata/patagonia_losGatos_embeddings.pq'))
24
+
25
+
26
+
27
+ # Create the Gradio interface
28
+ iface = gr.Interface(
29
+ fn=search,
30
+ inputs=[gr.Image(type="pil"),
31
+ gr.inputs.Number(label="Number of results", default=3)],
32
+ outputs=[gr.Gallery(label=model_name, type='filepath') for model_name in model_names],
33
+ title="Search Similar Images",
34
+ description="Upload an image and find similar images",
35
+ )
36
+
37
+ # Launch the Gradio interface
38
+ iface.launch(debug=True)
get_embeddings.ipynb ADDED
@@ -0,0 +1,1047 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 28,
6
+ "metadata": {
7
+ "tags": []
8
+ },
9
+ "outputs": [
10
+ {
11
+ "ename": "ImportError",
12
+ "evalue": "cannot import name 'data_path' from 'utils' (/Users/yonglinwu/dev/image-search-playground/utils.py)",
13
+ "output_type": "error",
14
+ "traceback": [
15
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
16
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
17
+ "Cell \u001b[0;32mIn[28], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtorch\u001b[39;00m\n\u001b[1;32m 7\u001b[0m torch\u001b[39m.\u001b[39mset_printoptions(precision\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mutils\u001b[39;00m \u001b[39mimport\u001b[39;00m get_image_embeddings, model_name_to_ids, load_models, model_dict, data_path\n\u001b[1;32m 11\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mwarnings\u001b[39;00m\n\u001b[1;32m 12\u001b[0m warnings\u001b[39m.\u001b[39msimplefilter(action\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mignore\u001b[39m\u001b[39m'\u001b[39m, category\u001b[39m=\u001b[39m\u001b[39mFutureWarning\u001b[39;00m)\n",
18
+ "\u001b[0;31mImportError\u001b[0m: cannot import name 'data_path' from 'utils' (/Users/yonglinwu/dev/image-search-playground/utils.py)"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "from sentence_transformers import SentenceTransformer, util\n",
24
+ "from PIL import Image\n",
25
+ "import pandas as pd\n",
26
+ "import os\n",
27
+ "import numpy as np\n",
28
+ "import torch\n",
29
+ "torch.set_printoptions(precision=10)\n",
30
+ "\n",
31
+ "from utils import get_image_embeddings, model_name_to_ids, load_models, model_dict, data_path\n",
32
+ "\n",
33
+ "import warnings\n",
34
+ "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
35
+ "\n",
36
+ "%load_ext autoreload\n",
37
+ "%autoreload 2\n"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": []
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 3,
50
+ "metadata": {
51
+ "tags": []
52
+ },
53
+ "outputs": [],
54
+ "source": [
55
+ "patagonia_df = pd.read_csv(data_path + 'metadata/patagonia_losGatos.tsv', sep='\\t')"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 4,
61
+ "metadata": {
62
+ "tags": []
63
+ },
64
+ "outputs": [
65
+ {
66
+ "data": {
67
+ "text/html": [
68
+ "<div>\n",
69
+ "<style scoped>\n",
70
+ " .dataframe tbody tr th:only-of-type {\n",
71
+ " vertical-align: middle;\n",
72
+ " }\n",
73
+ "\n",
74
+ " .dataframe tbody tr th {\n",
75
+ " vertical-align: top;\n",
76
+ " }\n",
77
+ "\n",
78
+ " .dataframe thead th {\n",
79
+ " text-align: right;\n",
80
+ " }\n",
81
+ "</style>\n",
82
+ "<table border=\"1\" class=\"dataframe\">\n",
83
+ " <thead>\n",
84
+ " <tr style=\"text-align: right;\">\n",
85
+ " <th></th>\n",
86
+ " <th>brand</th>\n",
87
+ " <th>title</th>\n",
88
+ " <th>product_url</th>\n",
89
+ " <th>price</th>\n",
90
+ " <th>description</th>\n",
91
+ " <th>size</th>\n",
92
+ " <th>category</th>\n",
93
+ " <th>colors</th>\n",
94
+ " <th>Poshmark</th>\n",
95
+ " <th>Unnamed: 9</th>\n",
96
+ " <th>...</th>\n",
97
+ " <th>Unnamed: 38</th>\n",
98
+ " <th>Unnamed: 39</th>\n",
99
+ " <th>Unnamed: 40</th>\n",
100
+ " <th>Unnamed: 41</th>\n",
101
+ " <th>Unnamed: 42</th>\n",
102
+ " <th>Unnamed: 43</th>\n",
103
+ " <th>Unnamed: 44</th>\n",
104
+ " <th>Unnamed: 45</th>\n",
105
+ " <th>Unnamed: 46</th>\n",
106
+ " <th>Unnamed: 47</th>\n",
107
+ " </tr>\n",
108
+ " </thead>\n",
109
+ " <tbody>\n",
110
+ " <tr>\n",
111
+ " <th>0</th>\n",
112
+ " <td>Patagonia</td>\n",
113
+ " <td>Patagonia Women's Los Gatos Fleece 1/4-Zip Smo...</td>\n",
114
+ " <td>https://poshmark.com/listing/63d4821f2fbf1afe8...</td>\n",
115
+ " <td>$36.00</td>\n",
116
+ " <td>A soft, warm and versatile quarter-zip pullove...</td>\n",
117
+ " <td>M</td>\n",
118
+ " <td>Tops</td>\n",
119
+ " <td>[{'name': 'Gray', 'rgb': '#929292', 'message_i...</td>\n",
120
+ " <td>Poshmark</td>\n",
121
+ " <td>False</td>\n",
122
+ " <td>...</td>\n",
123
+ " <td>NaN</td>\n",
124
+ " <td>NaN</td>\n",
125
+ " <td>NaN</td>\n",
126
+ " <td>NaN</td>\n",
127
+ " <td>NaN</td>\n",
128
+ " <td>NaN</td>\n",
129
+ " <td>NaN</td>\n",
130
+ " <td>NaN</td>\n",
131
+ " <td>NaN</td>\n",
132
+ " <td>NaN</td>\n",
133
+ " </tr>\n",
134
+ " <tr>\n",
135
+ " <th>1</th>\n",
136
+ " <td>Patagonia</td>\n",
137
+ " <td>Patagonia Los Gatos 1/4 Zip Pullover M Beech B...</td>\n",
138
+ " <td>https://poshmark.com/listing/63fcd7709f212bd48...</td>\n",
139
+ " <td>$59.00</td>\n",
140
+ " <td>High pile, quarter zip pulllover\\nMeasurements...</td>\n",
141
+ " <td>M</td>\n",
142
+ " <td>Tops</td>\n",
143
+ " <td>[{'name': 'Brown', 'rgb': '#663509', 'message_...</td>\n",
144
+ " <td>Poshmark</td>\n",
145
+ " <td>False</td>\n",
146
+ " <td>...</td>\n",
147
+ " <td>NaN</td>\n",
148
+ " <td>NaN</td>\n",
149
+ " <td>NaN</td>\n",
150
+ " <td>NaN</td>\n",
151
+ " <td>NaN</td>\n",
152
+ " <td>NaN</td>\n",
153
+ " <td>NaN</td>\n",
154
+ " <td>NaN</td>\n",
155
+ " <td>NaN</td>\n",
156
+ " <td>NaN</td>\n",
157
+ " </tr>\n",
158
+ " <tr>\n",
159
+ " <th>2</th>\n",
160
+ " <td>Patagonia</td>\n",
161
+ " <td>PATAGONIA Women's Los Gatos Fleece 1/4-Zip Pul...</td>\n",
162
+ " <td>https://poshmark.com/listing/642b9bbcfed51f812...</td>\n",
163
+ " <td>$59.00</td>\n",
164
+ " <td>PATAGONIA Women's Los Gatos Fleece 1/4-Zip Pul...</td>\n",
165
+ " <td>S</td>\n",
166
+ " <td>Tops</td>\n",
167
+ " <td>[{'name': 'White', 'rgb': '#FFFFFF', 'message_...</td>\n",
168
+ " <td>Poshmark</td>\n",
169
+ " <td>False</td>\n",
170
+ " <td>...</td>\n",
171
+ " <td>NaN</td>\n",
172
+ " <td>NaN</td>\n",
173
+ " <td>NaN</td>\n",
174
+ " <td>NaN</td>\n",
175
+ " <td>NaN</td>\n",
176
+ " <td>NaN</td>\n",
177
+ " <td>NaN</td>\n",
178
+ " <td>NaN</td>\n",
179
+ " <td>NaN</td>\n",
180
+ " <td>NaN</td>\n",
181
+ " </tr>\n",
182
+ " <tr>\n",
183
+ " <th>3</th>\n",
184
+ " <td>Patagonia</td>\n",
185
+ " <td>Girl’s Patagonia Los Gatos Fleece 1/4 Zip XS</td>\n",
186
+ " <td>https://poshmark.com/listing/63f4f459c5df6c7f8...</td>\n",
187
+ " <td>$30.00</td>\n",
188
+ " <td>Girl’s Patagonia Los Gatos 1/4 Zip Fleece\\n\\n-...</td>\n",
189
+ " <td>XSG</td>\n",
190
+ " <td>Other</td>\n",
191
+ " <td>[{'name': 'Tan', 'rgb': '#d1b48e', 'message_id...</td>\n",
192
+ " <td>Poshmark</td>\n",
193
+ " <td>False</td>\n",
194
+ " <td>...</td>\n",
195
+ " <td>NaN</td>\n",
196
+ " <td>NaN</td>\n",
197
+ " <td>NaN</td>\n",
198
+ " <td>NaN</td>\n",
199
+ " <td>NaN</td>\n",
200
+ " <td>NaN</td>\n",
201
+ " <td>NaN</td>\n",
202
+ " <td>NaN</td>\n",
203
+ " <td>NaN</td>\n",
204
+ " <td>NaN</td>\n",
205
+ " </tr>\n",
206
+ " <tr>\n",
207
+ " <th>4</th>\n",
208
+ " <td>Patagonia</td>\n",
209
+ " <td>Patagonia Los Gatos Quarter Zip Grey</td>\n",
210
+ " <td>https://poshmark.com/listing/622cc43d3a0db900b...</td>\n",
211
+ " <td>$59.00</td>\n",
212
+ " <td>Patagonia Los Gatos Quarter Zip Grey \\nWomen’s...</td>\n",
213
+ " <td>M</td>\n",
214
+ " <td>Tops</td>\n",
215
+ " <td>[{'name': 'Gray', 'rgb': '#929292', 'message_i...</td>\n",
216
+ " <td>Poshmark</td>\n",
217
+ " <td>False</td>\n",
218
+ " <td>...</td>\n",
219
+ " <td>NaN</td>\n",
220
+ " <td>NaN</td>\n",
221
+ " <td>NaN</td>\n",
222
+ " <td>NaN</td>\n",
223
+ " <td>NaN</td>\n",
224
+ " <td>NaN</td>\n",
225
+ " <td>NaN</td>\n",
226
+ " <td>NaN</td>\n",
227
+ " <td>NaN</td>\n",
228
+ " <td>NaN</td>\n",
229
+ " </tr>\n",
230
+ " </tbody>\n",
231
+ "</table>\n",
232
+ "<p>5 rows × 48 columns</p>\n",
233
+ "</div>"
234
+ ],
235
+ "text/plain": [
236
+ " brand title \\\n",
237
+ "0 Patagonia Patagonia Women's Los Gatos Fleece 1/4-Zip Smo... \n",
238
+ "1 Patagonia Patagonia Los Gatos 1/4 Zip Pullover M Beech B... \n",
239
+ "2 Patagonia PATAGONIA Women's Los Gatos Fleece 1/4-Zip Pul... \n",
240
+ "3 Patagonia Girl’s Patagonia Los Gatos Fleece 1/4 Zip XS \n",
241
+ "4 Patagonia Patagonia Los Gatos Quarter Zip Grey \n",
242
+ "\n",
243
+ " product_url price \\\n",
244
+ "0 https://poshmark.com/listing/63d4821f2fbf1afe8... $36.00 \n",
245
+ "1 https://poshmark.com/listing/63fcd7709f212bd48... $59.00 \n",
246
+ "2 https://poshmark.com/listing/642b9bbcfed51f812... $59.00 \n",
247
+ "3 https://poshmark.com/listing/63f4f459c5df6c7f8... $30.00 \n",
248
+ "4 https://poshmark.com/listing/622cc43d3a0db900b... $59.00 \n",
249
+ "\n",
250
+ " description size category \\\n",
251
+ "0 A soft, warm and versatile quarter-zip pullove... M Tops \n",
252
+ "1 High pile, quarter zip pulllover\\nMeasurements... M Tops \n",
253
+ "2 PATAGONIA Women's Los Gatos Fleece 1/4-Zip Pul... S Tops \n",
254
+ "3 Girl’s Patagonia Los Gatos 1/4 Zip Fleece\\n\\n-... XSG Other \n",
255
+ "4 Patagonia Los Gatos Quarter Zip Grey \\nWomen’s... M Tops \n",
256
+ "\n",
257
+ " colors Poshmark Unnamed: 9 \\\n",
258
+ "0 [{'name': 'Gray', 'rgb': '#929292', 'message_i... Poshmark False \n",
259
+ "1 [{'name': 'Brown', 'rgb': '#663509', 'message_... Poshmark False \n",
260
+ "2 [{'name': 'White', 'rgb': '#FFFFFF', 'message_... Poshmark False \n",
261
+ "3 [{'name': 'Tan', 'rgb': '#d1b48e', 'message_id... Poshmark False \n",
262
+ "4 [{'name': 'Gray', 'rgb': '#929292', 'message_i... Poshmark False \n",
263
+ "\n",
264
+ " ... Unnamed: 38 Unnamed: 39 Unnamed: 40 Unnamed: 41 Unnamed: 42 \\\n",
265
+ "0 ... NaN NaN NaN NaN NaN \n",
266
+ "1 ... NaN NaN NaN NaN NaN \n",
267
+ "2 ... NaN NaN NaN NaN NaN \n",
268
+ "3 ... NaN NaN NaN NaN NaN \n",
269
+ "4 ... NaN NaN NaN NaN NaN \n",
270
+ "\n",
271
+ " Unnamed: 43 Unnamed: 44 Unnamed: 45 Unnamed: 46 Unnamed: 47 \n",
272
+ "0 NaN NaN NaN NaN NaN \n",
273
+ "1 NaN NaN NaN NaN NaN \n",
274
+ "2 NaN NaN NaN NaN NaN \n",
275
+ "3 NaN NaN NaN NaN NaN \n",
276
+ "4 NaN NaN NaN NaN NaN \n",
277
+ "\n",
278
+ "[5 rows x 48 columns]"
279
+ ]
280
+ },
281
+ "execution_count": 4,
282
+ "metadata": {},
283
+ "output_type": "execute_result"
284
+ }
285
+ ],
286
+ "source": [
287
+ "patagonia_df.head()"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": null,
293
+ "metadata": {},
294
+ "outputs": [],
295
+ "source": [
296
+ "#download_images(patagonia_df, data_path)"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": 56,
302
+ "metadata": {
303
+ "tags": []
304
+ },
305
+ "outputs": [],
306
+ "source": [
307
+ "load_models()"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": 54,
313
+ "metadata": {
314
+ "tags": []
315
+ },
316
+ "outputs": [],
317
+ "source": [
318
+ "def generate_embeddings():\n",
319
+ " embeddings_df = pd.DataFrame()\n",
320
+ "\n",
321
+ " # Get image embeddings\n",
322
+ " with torch.no_grad():\n",
323
+ " for fp in os.listdir(data_path + 'images/'):\n",
324
+ " if fp.endswith('.jpg'):\n",
325
+ " new_row = {'name': fp}\n",
326
+ " for model_name in model_name_to_ids.keys():\n",
327
+ " new_row[f'{model_name}-embedding'] = get_image_embeddings(model_name, Image.open(data_path + 'images/' + fp))\n",
328
+ " embeddings_df = embeddings_df.append(new_row, ignore_index=True)\n",
329
+ " return embeddings_df"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": 26,
335
+ "metadata": {
336
+ "tags": []
337
+ },
338
+ "outputs": [],
339
+ "source": [
340
+ "fp = os.listdir(data_path + 'images/')[0]"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": 28,
346
+ "metadata": {
347
+ "tags": []
348
+ },
349
+ "outputs": [],
350
+ "source": [
351
+ "model_name = 'fashion'"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "execution_count": 29,
357
+ "metadata": {
358
+ "tags": []
359
+ },
360
+ "outputs": [],
361
+ "source": [
362
+ "new_row = {'name': fp, f'{model_name}-embedding': get_image_embeddings(model_name, Image.open(data_path + 'images/' + fp))}\n",
363
+ " "
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": 57,
369
+ "metadata": {
370
+ "tags": []
371
+ },
372
+ "outputs": [],
373
+ "source": [
374
+ "embeddings_df = generate_embeddings()"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": 58,
380
+ "metadata": {
381
+ "tags": []
382
+ },
383
+ "outputs": [
384
+ {
385
+ "data": {
386
+ "text/html": [
387
+ "<div>\n",
388
+ "<style scoped>\n",
389
+ " .dataframe tbody tr th:only-of-type {\n",
390
+ " vertical-align: middle;\n",
391
+ " }\n",
392
+ "\n",
393
+ " .dataframe tbody tr th {\n",
394
+ " vertical-align: top;\n",
395
+ " }\n",
396
+ "\n",
397
+ " .dataframe thead th {\n",
398
+ " text-align: right;\n",
399
+ " }\n",
400
+ "</style>\n",
401
+ "<table border=\"1\" class=\"dataframe\">\n",
402
+ " <thead>\n",
403
+ " <tr style=\"text-align: right;\">\n",
404
+ " <th></th>\n",
405
+ " <th>name</th>\n",
406
+ " <th>sentence-transformer-clip-ViT-L-14-embedding</th>\n",
407
+ " <th>fashion-embedding</th>\n",
408
+ " <th>openai-clip-embedding</th>\n",
409
+ " </tr>\n",
410
+ " </thead>\n",
411
+ " <tbody>\n",
412
+ " <tr>\n",
413
+ " <th>0</th>\n",
414
+ " <td>Women's Under Armour Hustle Fleece Hoodie pull...</td>\n",
415
+ " <td>[1.0734258, 0.99022365, 0.32032806, 0.2895219,...</td>\n",
416
+ " <td>[0.23177437, -1.9268938, 0.273342, -0.02474568...</td>\n",
417
+ " <td>[-0.32902592, -0.09434131, 0.3055967, 0.229937...</td>\n",
418
+ " </tr>\n",
419
+ " <tr>\n",
420
+ " <th>1</th>\n",
421
+ " <td>Patagonia Los Gatos Fleece Grey Pullover.jpg</td>\n",
422
+ " <td>[0.6227796, 0.026531212, 0.45240527, -0.488214...</td>\n",
423
+ " <td>[0.38133767, -1.3040155, 1.1697398, -0.3085520...</td>\n",
424
+ " <td>[-0.1695469, 0.5067289, 0.31120676, -0.0083701...</td>\n",
425
+ " </tr>\n",
426
+ " <tr>\n",
427
+ " <th>2</th>\n",
428
+ " <td>REI Women's Down With It Quilted Hooded Parka ...</td>\n",
429
+ " <td>[0.8497103, 1.2925782, -0.21685322, 0.24116844...</td>\n",
430
+ " <td>[-0.30043703, -1.3144073, -0.33848628, 0.24008...</td>\n",
431
+ " <td>[-0.24841668, 0.4876942, 0.39810008, -0.141552...</td>\n",
432
+ " </tr>\n",
433
+ " <tr>\n",
434
+ " <th>3</th>\n",
435
+ " <td>Chanel Haute Couture Navy Blue Dress Semi Shee...</td>\n",
436
+ " <td>[0.536018, 0.60787296, -0.2751825, 1.0325747, ...</td>\n",
437
+ " <td>[-0.101031125, 0.033914, -0.44531134, -0.64656...</td>\n",
438
+ " <td>[-0.08328074, 0.19443086, 0.14361368, 0.259305...</td>\n",
439
+ " </tr>\n",
440
+ " <tr>\n",
441
+ " <th>4</th>\n",
442
+ " <td>Patagonia Women’s S Los Gatos Quarter-Zip Flee...</td>\n",
443
+ " <td>[0.79398394, 1.3899276, -0.21383175, 0.0109823...</td>\n",
444
+ " <td>[0.60070944, -1.1051046, 1.0572466, 0.47092092...</td>\n",
445
+ " <td>[-0.27894062, -0.09589732, 0.5556799, -0.13458...</td>\n",
446
+ " </tr>\n",
447
+ " <tr>\n",
448
+ " <th>...</th>\n",
449
+ " <td>...</td>\n",
450
+ " <td>...</td>\n",
451
+ " <td>...</td>\n",
452
+ " <td>...</td>\n",
453
+ " </tr>\n",
454
+ " <tr>\n",
455
+ " <th>326</th>\n",
456
+ " <td>Women's REI Elements Jacket Size M.jpg</td>\n",
457
+ " <td>[0.6310029, 0.9942212, 0.009293936, 0.7862729,...</td>\n",
458
+ " <td>[0.19858713, -1.8665266, -0.3323754, 0.0465058...</td>\n",
459
+ " <td>[-0.0952643, 0.8016211, 0.08129032, 0.15187423...</td>\n",
460
+ " </tr>\n",
461
+ " <tr>\n",
462
+ " <th>327</th>\n",
463
+ " <td>CHANEL Black cotton bodycon tank dress with zi...</td>\n",
464
+ " <td>[1.0761135, 0.18927886, -0.007131472, 0.625682...</td>\n",
465
+ " <td>[0.07516122, -0.1886161, 0.1334078, -0.2829321...</td>\n",
466
+ " <td>[-0.12297699, 0.026368856, 0.04415588, 0.26031...</td>\n",
467
+ " </tr>\n",
468
+ " <tr>\n",
469
+ " <th>328</th>\n",
470
+ " <td>Reformation X Veda Women's Bad Leather Jacket ...</td>\n",
471
+ " <td>[0.79690784, 1.2895226, 0.22802149, -0.2736021...</td>\n",
472
+ " <td>[-0.12224964, -0.38734418, 0.35824925, 0.95855...</td>\n",
473
+ " <td>[0.6507246, 0.27751687, 0.36114892, -0.0831387...</td>\n",
474
+ " </tr>\n",
475
+ " <tr>\n",
476
+ " <th>329</th>\n",
477
+ " <td>DISNEY HER UNIVERSE LILO AND STICH Rainbow Qua...</td>\n",
478
+ " <td>[1.1617887, 0.19193622, 0.046035454, 0.4334900...</td>\n",
479
+ " <td>[-0.20762922, 0.1754938, -0.7334341, -0.106492...</td>\n",
480
+ " <td>[-0.31946087, 0.19534132, 0.37351555, -0.09741...</td>\n",
481
+ " </tr>\n",
482
+ " <tr>\n",
483
+ " <th>330</th>\n",
484
+ " <td>PATAGONIA Nano Puff Jacket Zip Primaloft Insul...</td>\n",
485
+ " <td>[0.2912089, 0.72192264, -0.01620815, 0.0022971...</td>\n",
486
+ " <td>[0.0026952028, -1.6660439, 0.03839147, -0.2164...</td>\n",
487
+ " <td>[0.12799336, 0.75828236, 0.10943861, -0.036647...</td>\n",
488
+ " </tr>\n",
489
+ " </tbody>\n",
490
+ "</table>\n",
491
+ "<p>331 rows × 4 columns</p>\n",
492
+ "</div>"
493
+ ],
494
+ "text/plain": [
495
+ " name \\\n",
496
+ "0 Women's Under Armour Hustle Fleece Hoodie pull... \n",
497
+ "1 Patagonia Los Gatos Fleece Grey Pullover.jpg \n",
498
+ "2 REI Women's Down With It Quilted Hooded Parka ... \n",
499
+ "3 Chanel Haute Couture Navy Blue Dress Semi Shee... \n",
500
+ "4 Patagonia Women’s S Los Gatos Quarter-Zip Flee... \n",
501
+ ".. ... \n",
502
+ "326 Women's REI Elements Jacket Size M.jpg \n",
503
+ "327 CHANEL Black cotton bodycon tank dress with zi... \n",
504
+ "328 Reformation X Veda Women's Bad Leather Jacket ... \n",
505
+ "329 DISNEY HER UNIVERSE LILO AND STICH Rainbow Qua... \n",
506
+ "330 PATAGONIA Nano Puff Jacket Zip Primaloft Insul... \n",
507
+ "\n",
508
+ " sentence-transformer-clip-ViT-L-14-embedding \\\n",
509
+ "0 [1.0734258, 0.99022365, 0.32032806, 0.2895219,... \n",
510
+ "1 [0.6227796, 0.026531212, 0.45240527, -0.488214... \n",
511
+ "2 [0.8497103, 1.2925782, -0.21685322, 0.24116844... \n",
512
+ "3 [0.536018, 0.60787296, -0.2751825, 1.0325747, ... \n",
513
+ "4 [0.79398394, 1.3899276, -0.21383175, 0.0109823... \n",
514
+ ".. ... \n",
515
+ "326 [0.6310029, 0.9942212, 0.009293936, 0.7862729,... \n",
516
+ "327 [1.0761135, 0.18927886, -0.007131472, 0.625682... \n",
517
+ "328 [0.79690784, 1.2895226, 0.22802149, -0.2736021... \n",
518
+ "329 [1.1617887, 0.19193622, 0.046035454, 0.4334900... \n",
519
+ "330 [0.2912089, 0.72192264, -0.01620815, 0.0022971... \n",
520
+ "\n",
521
+ " fashion-embedding \\\n",
522
+ "0 [0.23177437, -1.9268938, 0.273342, -0.02474568... \n",
523
+ "1 [0.38133767, -1.3040155, 1.1697398, -0.3085520... \n",
524
+ "2 [-0.30043703, -1.3144073, -0.33848628, 0.24008... \n",
525
+ "3 [-0.101031125, 0.033914, -0.44531134, -0.64656... \n",
526
+ "4 [0.60070944, -1.1051046, 1.0572466, 0.47092092... \n",
527
+ ".. ... \n",
528
+ "326 [0.19858713, -1.8665266, -0.3323754, 0.0465058... \n",
529
+ "327 [0.07516122, -0.1886161, 0.1334078, -0.2829321... \n",
530
+ "328 [-0.12224964, -0.38734418, 0.35824925, 0.95855... \n",
531
+ "329 [-0.20762922, 0.1754938, -0.7334341, -0.106492... \n",
532
+ "330 [0.0026952028, -1.6660439, 0.03839147, -0.2164... \n",
533
+ "\n",
534
+ " openai-clip-embedding \n",
535
+ "0 [-0.32902592, -0.09434131, 0.3055967, 0.229937... \n",
536
+ "1 [-0.1695469, 0.5067289, 0.31120676, -0.0083701... \n",
537
+ "2 [-0.24841668, 0.4876942, 0.39810008, -0.141552... \n",
538
+ "3 [-0.08328074, 0.19443086, 0.14361368, 0.259305... \n",
539
+ "4 [-0.27894062, -0.09589732, 0.5556799, -0.13458... \n",
540
+ ".. ... \n",
541
+ "326 [-0.0952643, 0.8016211, 0.08129032, 0.15187423... \n",
542
+ "327 [-0.12297699, 0.026368856, 0.04415588, 0.26031... \n",
543
+ "328 [0.6507246, 0.27751687, 0.36114892, -0.0831387... \n",
544
+ "329 [-0.31946087, 0.19534132, 0.37351555, -0.09741... \n",
545
+ "330 [0.12799336, 0.75828236, 0.10943861, -0.036647... \n",
546
+ "\n",
547
+ "[331 rows x 4 columns]"
548
+ ]
549
+ },
550
+ "execution_count": 58,
551
+ "metadata": {},
552
+ "output_type": "execute_result"
553
+ }
554
+ ],
555
+ "source": [
556
+ "embeddings_df"
557
+ ]
558
+ },
559
+ {
560
+ "cell_type": "code",
561
+ "execution_count": 65,
562
+ "metadata": {
563
+ "tags": []
564
+ },
565
+ "outputs": [],
566
+ "source": [
567
+ "embeddings_path = os.path.join(data_path, 'metadata/patagonia_losGatos_embeddings.pq')\n",
568
+ "embeddings_df.to_parquet(embeddings_path)"
569
+ ]
570
+ },
571
+ {
572
+ "cell_type": "code",
573
+ "execution_count": 66,
574
+ "metadata": {
575
+ "tags": []
576
+ },
577
+ "outputs": [],
578
+ "source": [
579
+ "embeddings_df = pd.read_parquet(embeddings_path)"
580
+ ]
581
+ },
582
+ {
583
+ "cell_type": "code",
584
+ "execution_count": 67,
585
+ "metadata": {
586
+ "tags": []
587
+ },
588
+ "outputs": [],
589
+ "source": [
590
+ "for i, row in embeddings_df.iterrows():\n",
591
+ " if '\\n' in row['name']:\n",
592
+ " print(row['name'])\n",
593
+ " embeddings_df = embeddings_df.drop(i)"
594
+ ]
595
+ },
596
+ {
597
+ "cell_type": "code",
598
+ "execution_count": 68,
599
+ "metadata": {
600
+ "tags": []
601
+ },
602
+ "outputs": [
603
+ {
604
+ "data": {
605
+ "text/html": [
606
+ "<div>\n",
607
+ "<style scoped>\n",
608
+ " .dataframe tbody tr th:only-of-type {\n",
609
+ " vertical-align: middle;\n",
610
+ " }\n",
611
+ "\n",
612
+ " .dataframe tbody tr th {\n",
613
+ " vertical-align: top;\n",
614
+ " }\n",
615
+ "\n",
616
+ " .dataframe thead th {\n",
617
+ " text-align: right;\n",
618
+ " }\n",
619
+ "</style>\n",
620
+ "<table border=\"1\" class=\"dataframe\">\n",
621
+ " <thead>\n",
622
+ " <tr style=\"text-align: right;\">\n",
623
+ " <th></th>\n",
624
+ " <th>name</th>\n",
625
+ " <th>sentence-transformer-clip-ViT-L-14-embedding</th>\n",
626
+ " <th>fashion-embedding</th>\n",
627
+ " <th>openai-clip-embedding</th>\n",
628
+ " </tr>\n",
629
+ " </thead>\n",
630
+ " <tbody>\n",
631
+ " <tr>\n",
632
+ " <th>0</th>\n",
633
+ " <td>Women's Under Armour Hustle Fleece Hoodie pull...</td>\n",
634
+ " <td>[1.0734258, 0.99022365, 0.32032806, 0.2895219,...</td>\n",
635
+ " <td>[0.23177437, -1.9268938, 0.273342, -0.02474568...</td>\n",
636
+ " <td>[-0.32902592, -0.09434131, 0.3055967, 0.229937...</td>\n",
637
+ " </tr>\n",
638
+ " <tr>\n",
639
+ " <th>1</th>\n",
640
+ " <td>Patagonia Los Gatos Fleece Grey Pullover.jpg</td>\n",
641
+ " <td>[0.6227796, 0.026531212, 0.45240527, -0.488214...</td>\n",
642
+ " <td>[0.38133767, -1.3040155, 1.1697398, -0.3085520...</td>\n",
643
+ " <td>[-0.1695469, 0.5067289, 0.31120676, -0.0083701...</td>\n",
644
+ " </tr>\n",
645
+ " <tr>\n",
646
+ " <th>2</th>\n",
647
+ " <td>REI Women's Down With It Quilted Hooded Parka ...</td>\n",
648
+ " <td>[0.8497103, 1.2925782, -0.21685322, 0.24116844...</td>\n",
649
+ " <td>[-0.30043703, -1.3144073, -0.33848628, 0.24008...</td>\n",
650
+ " <td>[-0.24841668, 0.4876942, 0.39810008, -0.141552...</td>\n",
651
+ " </tr>\n",
652
+ " <tr>\n",
653
+ " <th>3</th>\n",
654
+ " <td>Chanel Haute Couture Navy Blue Dress Semi Shee...</td>\n",
655
+ " <td>[0.536018, 0.60787296, -0.2751825, 1.0325747, ...</td>\n",
656
+ " <td>[-0.101031125, 0.033914, -0.44531134, -0.64656...</td>\n",
657
+ " <td>[-0.08328074, 0.19443086, 0.14361368, 0.259305...</td>\n",
658
+ " </tr>\n",
659
+ " <tr>\n",
660
+ " <th>4</th>\n",
661
+ " <td>Patagonia Women’s S Los Gatos Quarter-Zip Flee...</td>\n",
662
+ " <td>[0.79398394, 1.3899276, -0.21383175, 0.0109823...</td>\n",
663
+ " <td>[0.60070944, -1.1051046, 1.0572466, 0.47092092...</td>\n",
664
+ " <td>[-0.27894062, -0.09589732, 0.5556799, -0.13458...</td>\n",
665
+ " </tr>\n",
666
+ " <tr>\n",
667
+ " <th>...</th>\n",
668
+ " <td>...</td>\n",
669
+ " <td>...</td>\n",
670
+ " <td>...</td>\n",
671
+ " <td>...</td>\n",
672
+ " </tr>\n",
673
+ " <tr>\n",
674
+ " <th>326</th>\n",
675
+ " <td>Women's REI Elements Jacket Size M.jpg</td>\n",
676
+ " <td>[0.6310029, 0.9942212, 0.009293936, 0.7862729,...</td>\n",
677
+ " <td>[0.19858713, -1.8665266, -0.3323754, 0.0465058...</td>\n",
678
+ " <td>[-0.0952643, 0.8016211, 0.08129032, 0.15187423...</td>\n",
679
+ " </tr>\n",
680
+ " <tr>\n",
681
+ " <th>327</th>\n",
682
+ " <td>CHANEL Black cotton bodycon tank dress with zi...</td>\n",
683
+ " <td>[1.0761135, 0.18927886, -0.007131472, 0.625682...</td>\n",
684
+ " <td>[0.07516122, -0.1886161, 0.1334078, -0.2829321...</td>\n",
685
+ " <td>[-0.12297699, 0.026368856, 0.04415588, 0.26031...</td>\n",
686
+ " </tr>\n",
687
+ " <tr>\n",
688
+ " <th>328</th>\n",
689
+ " <td>Reformation X Veda Women's Bad Leather Jacket ...</td>\n",
690
+ " <td>[0.79690784, 1.2895226, 0.22802149, -0.2736021...</td>\n",
691
+ " <td>[-0.12224964, -0.38734418, 0.35824925, 0.95855...</td>\n",
692
+ " <td>[0.6507246, 0.27751687, 0.36114892, -0.0831387...</td>\n",
693
+ " </tr>\n",
694
+ " <tr>\n",
695
+ " <th>329</th>\n",
696
+ " <td>DISNEY HER UNIVERSE LILO AND STICH Rainbow Qua...</td>\n",
697
+ " <td>[1.1617887, 0.19193622, 0.046035454, 0.4334900...</td>\n",
698
+ " <td>[-0.20762922, 0.1754938, -0.7334341, -0.106492...</td>\n",
699
+ " <td>[-0.31946087, 0.19534132, 0.37351555, -0.09741...</td>\n",
700
+ " </tr>\n",
701
+ " <tr>\n",
702
+ " <th>330</th>\n",
703
+ " <td>PATAGONIA Nano Puff Jacket Zip Primaloft Insul...</td>\n",
704
+ " <td>[0.2912089, 0.72192264, -0.01620815, 0.0022971...</td>\n",
705
+ " <td>[0.0026952028, -1.6660439, 0.03839147, -0.2164...</td>\n",
706
+ " <td>[0.12799336, 0.75828236, 0.10943861, -0.036647...</td>\n",
707
+ " </tr>\n",
708
+ " </tbody>\n",
709
+ "</table>\n",
710
+ "<p>331 rows × 4 columns</p>\n",
711
+ "</div>"
712
+ ],
713
+ "text/plain": [
714
+ " name \\\n",
715
+ "0 Women's Under Armour Hustle Fleece Hoodie pull... \n",
716
+ "1 Patagonia Los Gatos Fleece Grey Pullover.jpg \n",
717
+ "2 REI Women's Down With It Quilted Hooded Parka ... \n",
718
+ "3 Chanel Haute Couture Navy Blue Dress Semi Shee... \n",
719
+ "4 Patagonia Women’s S Los Gatos Quarter-Zip Flee... \n",
720
+ ".. ... \n",
721
+ "326 Women's REI Elements Jacket Size M.jpg \n",
722
+ "327 CHANEL Black cotton bodycon tank dress with zi... \n",
723
+ "328 Reformation X Veda Women's Bad Leather Jacket ... \n",
724
+ "329 DISNEY HER UNIVERSE LILO AND STICH Rainbow Qua... \n",
725
+ "330 PATAGONIA Nano Puff Jacket Zip Primaloft Insul... \n",
726
+ "\n",
727
+ " sentence-transformer-clip-ViT-L-14-embedding \\\n",
728
+ "0 [1.0734258, 0.99022365, 0.32032806, 0.2895219,... \n",
729
+ "1 [0.6227796, 0.026531212, 0.45240527, -0.488214... \n",
730
+ "2 [0.8497103, 1.2925782, -0.21685322, 0.24116844... \n",
731
+ "3 [0.536018, 0.60787296, -0.2751825, 1.0325747, ... \n",
732
+ "4 [0.79398394, 1.3899276, -0.21383175, 0.0109823... \n",
733
+ ".. ... \n",
734
+ "326 [0.6310029, 0.9942212, 0.009293936, 0.7862729,... \n",
735
+ "327 [1.0761135, 0.18927886, -0.007131472, 0.625682... \n",
736
+ "328 [0.79690784, 1.2895226, 0.22802149, -0.2736021... \n",
737
+ "329 [1.1617887, 0.19193622, 0.046035454, 0.4334900... \n",
738
+ "330 [0.2912089, 0.72192264, -0.01620815, 0.0022971... \n",
739
+ "\n",
740
+ " fashion-embedding \\\n",
741
+ "0 [0.23177437, -1.9268938, 0.273342, -0.02474568... \n",
742
+ "1 [0.38133767, -1.3040155, 1.1697398, -0.3085520... \n",
743
+ "2 [-0.30043703, -1.3144073, -0.33848628, 0.24008... \n",
744
+ "3 [-0.101031125, 0.033914, -0.44531134, -0.64656... \n",
745
+ "4 [0.60070944, -1.1051046, 1.0572466, 0.47092092... \n",
746
+ ".. ... \n",
747
+ "326 [0.19858713, -1.8665266, -0.3323754, 0.0465058... \n",
748
+ "327 [0.07516122, -0.1886161, 0.1334078, -0.2829321... \n",
749
+ "328 [-0.12224964, -0.38734418, 0.35824925, 0.95855... \n",
750
+ "329 [-0.20762922, 0.1754938, -0.7334341, -0.106492... \n",
751
+ "330 [0.0026952028, -1.6660439, 0.03839147, -0.2164... \n",
752
+ "\n",
753
+ " openai-clip-embedding \n",
754
+ "0 [-0.32902592, -0.09434131, 0.3055967, 0.229937... \n",
755
+ "1 [-0.1695469, 0.5067289, 0.31120676, -0.0083701... \n",
756
+ "2 [-0.24841668, 0.4876942, 0.39810008, -0.141552... \n",
757
+ "3 [-0.08328074, 0.19443086, 0.14361368, 0.259305... \n",
758
+ "4 [-0.27894062, -0.09589732, 0.5556799, -0.13458... \n",
759
+ ".. ... \n",
760
+ "326 [-0.0952643, 0.8016211, 0.08129032, 0.15187423... \n",
761
+ "327 [-0.12297699, 0.026368856, 0.04415588, 0.26031... \n",
762
+ "328 [0.6507246, 0.27751687, 0.36114892, -0.0831387... \n",
763
+ "329 [-0.31946087, 0.19534132, 0.37351555, -0.09741... \n",
764
+ "330 [0.12799336, 0.75828236, 0.10943861, -0.036647... \n",
765
+ "\n",
766
+ "[331 rows x 4 columns]"
767
+ ]
768
+ },
769
+ "execution_count": 68,
770
+ "metadata": {},
771
+ "output_type": "execute_result"
772
+ }
773
+ ],
774
+ "source": [
775
+ "embeddings_df"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "code",
780
+ "execution_count": 8,
781
+ "metadata": {},
782
+ "outputs": [],
783
+ "source": [
784
+ "import os\n",
785
+ "\n",
786
+ "for fp in os.listdir(data_path + 'images/'):\n",
787
+ " if '?' in fp:\n",
788
+ " print(fp)"
789
+ ]
790
+ },
791
+ {
792
+ "cell_type": "code",
793
+ "execution_count": 7,
794
+ "metadata": {
795
+ "tags": []
796
+ },
797
+ "outputs": [
798
+ {
799
+ "data": {
800
+ "text/plain": [
801
+ "2"
802
+ ]
803
+ },
804
+ "execution_count": 7,
805
+ "metadata": {},
806
+ "output_type": "execute_result"
807
+ }
808
+ ],
809
+ "source": [
810
+ "1+1"
811
+ ]
812
+ },
813
+ {
814
+ "cell_type": "code",
815
+ "execution_count": 2,
816
+ "metadata": {
817
+ "tags": []
818
+ },
819
+ "outputs": [],
820
+ "source": [
821
+ "%reload_ext autoreload\n",
822
+ "%autoreload 2"
823
+ ]
824
+ },
825
+ {
826
+ "cell_type": "code",
827
+ "execution_count": 7,
828
+ "metadata": {
829
+ "tags": []
830
+ },
831
+ "outputs": [],
832
+ "source": [
833
+ "df.to_csv('random.tsv', sep='\\t')"
834
+ ]
835
+ },
836
+ {
837
+ "cell_type": "code",
838
+ "execution_count": 1,
839
+ "metadata": {
840
+ "tags": []
841
+ },
842
+ "outputs": [
843
+ {
844
+ "name": "stdout",
845
+ "output_type": "stream",
846
+ "text": [
847
+ "disco-io/data\n"
848
+ ]
849
+ }
850
+ ],
851
+ "source": [
852
+ "import utils\n"
853
+ ]
854
+ },
855
+ {
856
+ "cell_type": "code",
857
+ "execution_count": 4,
858
+ "metadata": {
859
+ "tags": []
860
+ },
861
+ "outputs": [],
862
+ "source": [
863
+ "from utils import get_immediate_subdirectories"
864
+ ]
865
+ },
866
+ {
867
+ "cell_type": "code",
868
+ "execution_count": 10,
869
+ "metadata": {
870
+ "tags": []
871
+ },
872
+ "outputs": [
873
+ {
874
+ "name": "stdout",
875
+ "output_type": "stream",
876
+ "text": [
877
+ "disco-io/data\n",
878
+ "Refreshing all datasets: ['test']\n"
879
+ ]
880
+ }
881
+ ],
882
+ "source": [
883
+ "utils.refresh_all_datasets()"
884
+ ]
885
+ },
886
+ {
887
+ "cell_type": "code",
888
+ "execution_count": 3,
889
+ "metadata": {
890
+ "tags": []
891
+ },
892
+ "outputs": [
893
+ {
894
+ "data": {
895
+ "text/plain": [
896
+ "'test'"
897
+ ]
898
+ },
899
+ "execution_count": 3,
900
+ "metadata": {},
901
+ "output_type": "execute_result"
902
+ }
903
+ ],
904
+ "source": [
905
+ "utils.cur_dataset"
906
+ ]
907
+ },
908
+ {
909
+ "cell_type": "code",
910
+ "execution_count": 2,
911
+ "metadata": {
912
+ "tags": []
913
+ },
914
+ "outputs": [
915
+ {
916
+ "name": "stdout",
917
+ "output_type": "stream",
918
+ "text": [
919
+ "disco-io/data\n"
920
+ ]
921
+ },
922
+ {
923
+ "data": {
924
+ "text/plain": [
925
+ "['test']"
926
+ ]
927
+ },
928
+ "execution_count": 2,
929
+ "metadata": {},
930
+ "output_type": "execute_result"
931
+ }
932
+ ],
933
+ "source": [
934
+ "get_immediate_subdirectories('data')\n"
935
+ ]
936
+ },
937
+ {
938
+ "cell_type": "code",
939
+ "execution_count": 20,
940
+ "metadata": {},
941
+ "outputs": [],
942
+ "source": [
943
+ "import utils"
944
+ ]
945
+ },
946
+ {
947
+ "cell_type": "code",
948
+ "execution_count": 21,
949
+ "metadata": {},
950
+ "outputs": [],
951
+ "source": [
952
+ "from utils import fs"
953
+ ]
954
+ },
955
+ {
956
+ "cell_type": "code",
957
+ "execution_count": 22,
958
+ "metadata": {},
959
+ "outputs": [],
960
+ "source": [
961
+ "s3_path = 'data'"
962
+ ]
963
+ },
964
+ {
965
+ "cell_type": "code",
966
+ "execution_count": 23,
967
+ "metadata": {},
968
+ "outputs": [],
969
+ "source": [
970
+ "s3_full_path = f\"{utils.S3_BUCKET}/{s3_path}\""
971
+ ]
972
+ },
973
+ {
974
+ "cell_type": "code",
975
+ "execution_count": 24,
976
+ "metadata": {},
977
+ "outputs": [
978
+ {
979
+ "data": {
980
+ "text/plain": [
981
+ "['disco-io/data/Cvlsntdjgrnuyrlf.jpg', 'disco-io/data/test']"
982
+ ]
983
+ },
984
+ "execution_count": 24,
985
+ "metadata": {},
986
+ "output_type": "execute_result"
987
+ }
988
+ ],
989
+ "source": [
990
+ "fs.glob(f\"{s3_full_path}/*\")"
991
+ ]
992
+ },
993
+ {
994
+ "cell_type": "code",
995
+ "execution_count": 25,
996
+ "metadata": {},
997
+ "outputs": [
998
+ {
999
+ "data": {
1000
+ "text/plain": [
1001
+ "True"
1002
+ ]
1003
+ },
1004
+ "execution_count": 25,
1005
+ "metadata": {},
1006
+ "output_type": "execute_result"
1007
+ }
1008
+ ],
1009
+ "source": [
1010
+ "fs.isdir('disco-io/data/test')"
1011
+ ]
1012
+ },
1013
+ {
1014
+ "cell_type": "code",
1015
+ "execution_count": null,
1016
+ "metadata": {},
1017
+ "outputs": [],
1018
+ "source": []
1019
+ }
1020
+ ],
1021
+ "metadata": {
1022
+ "kernelspec": {
1023
+ "display_name": "Python 3 (ipykernel)",
1024
+ "language": "python",
1025
+ "name": "python3"
1026
+ },
1027
+ "language_info": {
1028
+ "codemirror_mode": {
1029
+ "name": "ipython",
1030
+ "version": 3
1031
+ },
1032
+ "file_extension": ".py",
1033
+ "mimetype": "text/x-python",
1034
+ "name": "python",
1035
+ "nbconvert_exporter": "python",
1036
+ "pygments_lexer": "ipython3",
1037
+ "version": "3.10.0"
1038
+ },
1039
+ "vscode": {
1040
+ "interpreter": {
1041
+ "hash": "e85fcd8d0dbb45c39d3e544566c77318961c8114425a16ff4cb5c14067743b34"
1042
+ }
1043
+ }
1044
+ },
1045
+ "nbformat": 4,
1046
+ "nbformat_minor": 4
1047
+ }
misc.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import random
3
+
4
+ # Function to generate random text for titles
5
+
6
+ def generate_random_images_df(filename):
7
+ def generate_title():
8
+ title_length = random.randint(5, 20)
9
+ title = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz', k=title_length))
10
+ return title.capitalize()
11
+
12
+ # Function to generate random image URLs
13
+ def generate_image_url():
14
+ url = "https://picsum.photos/200/300" # Change the size of the image as per your requirement
15
+ return url
16
+
17
+ # Create a list of dictionaries with random titles and image URLs
18
+ data = []
19
+ for i in range(10):
20
+ data.append({'title': generate_title(), 'IMG_URL': generate_image_url()})
21
+
22
+ # Convert the list of dictionaries to a Pandas DataFrame
23
+ df = pd.DataFrame(data)
24
+ df.to_csv(filename, sep='\t', index=False)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.30.0
2
+ numpy==1.23.5
3
+ pandas==1.5.3
4
+ pandas_stubs==1.2.0.35
5
+ Pillow==9.5.0
6
+ sentence_transformers==2.2.2
7
+ pyarrow
8
+ transformers~=4.26.1
9
+ tqdm
10
+ streamlit
11
+ s3fs
12
+ requests
13
+ pinecone-client
run.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ with gr.Blocks() as demo:
4
+ gr.Markdown(
5
+ """
6
+ # Animal Generator
7
+ Once you select a species, the detail panel should be visible.
8
+ """
9
+ )
10
+
11
+ species = gr.Radio(label="Animal Class", choices=["Mammal", "Fish", "Bird"])
12
+ animal = gr.Dropdown(label="Animal", choices=[])
13
+
14
+ with gr.Column(visible=False) as details_col:
15
+ weight = gr.Slider(0, 20)
16
+ details = gr.Textbox(label="Extra Details")
17
+ generate_btn = gr.Button("Generate")
18
+ output = gr.Textbox(label="Output")
19
+
20
+ species_map = {
21
+ "Mammal": ["Elephant", "Giraffe", "Hamster"],
22
+ "Fish": ["Shark", "Salmon", "Tuna"],
23
+ "Bird": ["Chicken", "Eagle", "Hawk"],
24
+ }
25
+
26
+ def filter_species(species):
27
+ return gr.Dropdown.update(
28
+ choices=species_map[species], value=species_map[species][1]
29
+ ), gr.update(visible=True)
30
+
31
+ species.change(filter_species, species, [animal, details_col])
32
+
33
+ def filter_weight(animal):
34
+ if animal in ("Elephant", "Shark", "Giraffe"):
35
+ return gr.update(maximum=100)
36
+ else:
37
+ return gr.update(maximum=20)
38
+
39
+ animal.change(filter_weight, animal, weight)
40
+ weight.change(lambda w: gr.update(lines=int(w / 10) + 1), weight, details)
41
+
42
+ generate_btn.click(lambda x: x, details, output)
43
+
44
+
45
+ if __name__ == "__main__":
46
+
47
+ from tqdm import tqdm
48
+
49
+ for i in tqdm(range(int(9e6))):
50
+ pass
51
+ #demo.launch()
streamlit_app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+ def process_image(input_image):
6
+ # Your image processing function goes here
7
+ output_image = input_image.copy()
8
+ return output_image
9
+
10
+ # Set the title of the web application
11
+ st.title('Multiple Input and Output Images Interface')
12
+
13
+ # Create a sidebar for image inputs
14
+ st.sidebar.title('Input Images')
15
+
16
+ # Set up a file uploader in the sidebar for each input image
17
+ uploaded_images = []
18
+ num_images = 3 # The number of input images
19
+ for i in range(num_images):
20
+ uploaded_image = st.sidebar.file_uploader(f'Upload Image {i+1}', type=['png', 'jpg', 'jpeg'])
21
+ if uploaded_image is not None:
22
+ uploaded_images.append(uploaded_image)
23
+
24
+ # Display input images and process them
25
+ if uploaded_images:
26
+ st.header('Input Images')
27
+ input_images = []
28
+ for img in uploaded_images:
29
+ input_img = Image.open(img)
30
+ input_images.append(input_img)
31
+ st.image(input_img, width=200, caption='Uploaded Image')
32
+
33
+ # Process input images and display output images
34
+ st.header('Output Images')
35
+ for input_img in input_images:
36
+ output_img = process_image(input_img)
37
+ st.image(output_img, width=200, caption='Processed Image')
38
+ else:
39
+ st.warning('Please upload images in the sidebar.')
utils.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer, util as st_util
2
+ from transformers import CLIPModel, CLIPProcessor
3
+
4
+ from PIL import Image
5
+ import requests
6
+ import os
7
+ import torch
8
+ torch.set_printoptions(precision=10)
9
+ from tqdm import tqdm
10
+ import s3fs
11
+ from io import BytesIO
12
+ import vector_db
13
+
14
+ "sentence-transformer-clip-ViT-L-14"
15
+ "openai-clip"
16
+ model_names = ["fashion"]
17
+
18
+ model_name_to_ids = {
19
+ "sentence-transformer-clip-ViT-L-14": "clip-ViT-L-14",
20
+ "fashion": "patrickjohncyh/fashion-clip",
21
+ "openai-clip": "openai/clip-vit-base-patch32",
22
+ }
23
+
24
+ AWS_ACCESS_KEY_ID = os.environ["AWS_ACCESS_KEY_ID"]
25
+ AWS_SECRET_ACCESS_KEY = os.environ["AWS_SECRET_ACCESS_KEY"]
26
+
27
+ # Define your bucket and dataset name.
28
+ S3_BUCKET = "s3://disco-io"
29
+
30
+ fs = s3fs.S3FileSystem(
31
+ key=AWS_ACCESS_KEY_ID,
32
+ secret=AWS_SECRET_ACCESS_KEY,
33
+ )
34
+
35
+ ROOT_DATA_PATH = os.path.join(S3_BUCKET, 'data')
36
+
37
+ def get_data_path():
38
+ return os.path.join(ROOT_DATA_PATH, cur_dataset)
39
+
40
+ def get_image_path():
41
+ return os.path.join(get_data_path(), 'images')
42
+
43
+ def get_metadata_path():
44
+ return os.path.join(get_data_path(), 'metadata')
45
+
46
+ def get_embeddings_path():
47
+ return os.path.join(get_metadata_path(), cur_dataset + '_embeddings.pq')
48
+
49
+ model_dict = dict()
50
+
51
+
52
+ def download_to_s3(url, s3_path):
53
+ # Download the file from the URL
54
+ response = requests.get(url, stream=True)
55
+ response.raise_for_status()
56
+
57
+ # Upload the file to the S3 path
58
+ with fs.open(s3_path, "wb") as s3_file:
59
+ for chunk in response.iter_content(chunk_size=8192):
60
+ s3_file.write(chunk)
61
+
62
+
63
+ def remove_all_files_from_s3_directory(s3_directory):
64
+ # List all objects in the S3 directory
65
+ objects = fs.ls(s3_directory)
66
+
67
+ # Remove each object
68
+ for obj in objects:
69
+ try:
70
+ fs.rm(obj)
71
+ except:
72
+ print('Error removing file: ' + obj)
73
+
74
+ def download_images(df, img_folder):
75
+ remove_all_files_from_s3_directory(img_folder)
76
+ for index, row in df.iterrows():
77
+ try:
78
+ download_to_s3(row['IMG_URL'], os.path.join(img_folder,
79
+ row['title'].replace('/', '_').replace('\n', '') + '.jpg'))
80
+ except:
81
+ print('Error downloading image: ' + str(index) + row['title'])
82
+
83
+
84
+ def load_models():
85
+ for model_name in model_name_to_ids:
86
+ if model_name not in model_dict:
87
+ model_dict[model_name] = dict()
88
+ if model_name.startswith('sentence-transformer'):
89
+ model_dict[model_name]['model'] = SentenceTransformer(model_name_to_ids[model_name])
90
+ else:
91
+ model_dict[model_name]['hf_dir'] = model_name_to_ids[model_name]
92
+ model_dict[model_name]['model'] = CLIPModel.from_pretrained(model_name_to_ids[model_name])
93
+ model_dict[model_name]['processor'] = CLIPProcessor.from_pretrained(model_name_to_ids[model_name])
94
+
95
+
96
+ if len(model_dict) == 0:
97
+ print('Loading models...')
98
+ load_models()
99
+
100
+
101
+ def get_image_embedding(model_name, image):
102
+ """
103
+ Takes an image as input and returns an embedding vector.
104
+ """
105
+ model = model_dict[model_name]['model']
106
+ if model_name.startswith('sentence-transformer'):
107
+ return model.encode(image)
108
+ else:
109
+ inputs = model_dict[model_name]['processor'](images=image, return_tensors="pt")
110
+ image_features = model.get_image_features(**inputs).detach().numpy()[0]
111
+ return image_features
112
+
113
+ def s3_path_to_image(fs, s3_path):
114
+ """
115
+ Takes an S3 path as input and returns a PIL Image object.
116
+
117
+ Args:
118
+ s3_path (str): The path to the image in the S3 bucket, including the bucket name (e.g., "bucket_name/path/to/image.jpg").
119
+
120
+ Returns:
121
+ Image: A PIL Image object.
122
+ """
123
+ with fs.open(s3_path, "rb") as f:
124
+ image_data = BytesIO(f.read())
125
+ img = Image.open(image_data)
126
+ return img
127
+
128
+ def generate_and_save_embeddings():
129
+ # Get image embeddings
130
+ with torch.no_grad():
131
+ for fp in tqdm(fs.ls(get_image_path()), desc="Generate embeddings for Images"):
132
+ if fp.endswith('.jpg'):
133
+ name = fp.split('/')[-1]
134
+ for model_name in model_name_to_ids.keys():
135
+ s3_path = 's3://' + fp
136
+ vector_db.add_image_embedding_to_db(
137
+ embedding=get_image_embedding(model_name, s3_path_to_image(fs, s3_path)),
138
+ model_name=model_name,
139
+ dataset_name=cur_dataset,
140
+ path_to_image=s3_path,
141
+ image_name=name,
142
+ )
143
+
144
+
145
+ def get_immediate_subdirectories(s3_path):
146
+ return [obj.split('/')[-1] for obj in fs.glob(f"{s3_path}/*") if fs.isdir(obj)]
147
+
148
+ all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH)
149
+ cur_dataset = all_datasets[0]
150
+
151
+ def set_cur_dataset(dataset):
152
+ refresh_all_datasets()
153
+ print(f"Setting current dataset to {dataset}")
154
+ global cur_dataset
155
+ cur_dataset = dataset
156
+
157
+ def refresh_all_datasets():
158
+ global all_datasets
159
+ all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH)
160
+ print(f"Refreshing all datasets: {all_datasets}")
161
+
162
+ def url_to_image(url):
163
+ try:
164
+ response = requests.get(url)
165
+ response.raise_for_status()
166
+ img = Image.open(BytesIO(response.content))
167
+ return img
168
+ except requests.exceptions.RequestException as e:
169
+ print(f"Error fetching image from URL: {url}")
170
+ return None
vector_db.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pinecone
2
+ import os
3
+ import uuid
4
+
5
+ pinecone.init(api_key=os.environ["PINECONE_API_KEY"], environment="us-west1-gcp")
6
+
7
+ INDEX_512_NAME = "images-512"
8
+ INDEX_768_NAME = "images-768"
9
+
10
+ index_512 = pinecone.Index(INDEX_512_NAME)
11
+ index_768 = pinecone.Index(INDEX_768_NAME)
12
+
13
+ DEV_NAMESPACE = 'disco-web-app-search-dev'
14
+ PROD_NAMESPACE = 'disco-web-app-search-prod'
15
+
16
+
17
+ def add_image_embedding_to_db(embedding, model_name, dataset_name, path_to_image, image_name):
18
+ index = {
19
+ 512: index_512,
20
+ 768: index_768
21
+ }[embedding.shape[0]]
22
+ print (embedding.shape)
23
+ index.upsert([(str(uuid.uuid4()), embedding.tolist(), {'model': model_name,
24
+ 'dataset': dataset_name,
25
+ 'path': path_to_image,
26
+ 'image_name': image_name})])
27
+
28
+
29
+ def query_embeddings_db(query_embedding, dataset_name, model_name, top_k=4):
30
+ index = {
31
+ 512: index_512,
32
+ 768: index_768
33
+ }[len(query_embedding)]
34
+ return index.query(vector=query_embedding,
35
+ top_k=top_k,
36
+ namespace=DEV_NAMESPACE,
37
+ include_metadata=True)