jamescalam commited on
Commit
88172be
1 Parent(s): 814d271

upgrade to contrastive learning and unsplash lite dataset

Browse files
Files changed (3) hide show
  1. app.py +120 -53
  2. link-check.py +58 -0
  3. unsplash-25k-clip-indexer.ipynb +775 -0
app.py CHANGED
@@ -11,19 +11,10 @@ import logging
11
  from urllib3.exceptions import ProtocolError
12
 
13
  PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
14
- INDEX = "imagenet-query-trainer-clip"
15
  MODEL_ID = "openai/clip-vit-base-patch32"
16
  DIMS = 512
17
 
18
- @st.experimental_singleton(show_spinner=False)
19
- def init_dataset():
20
- return load_dataset(
21
- 'frgfm/imagenette',
22
- 'full_size',
23
- split='train',
24
- ignore_verifications=False # set to True if seeing splits Error
25
- )
26
-
27
  @st.experimental_singleton(show_spinner=False)
28
  def init_clip():
29
  tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
@@ -39,7 +30,12 @@ def init_db():
39
  meta_field = datetime.now().isoformat()
40
  return meta_field, pinecone.Index(INDEX)
41
 
42
- def query(xq, top_k=10, include_values=True, filter=None):
 
 
 
 
 
43
  logging.info(f"Query to Pinecone with '{st.session_state.meta}'")
44
  attempt = 0
45
  while attempt < 3:
@@ -48,14 +44,15 @@ def query(xq, top_k=10, include_values=True, filter=None):
48
  xq,
49
  top_k=top_k,
50
  include_values=include_values,
 
51
  filter=filter
52
  )
53
- matches = {match['id']: match['values'] for match in xc['matches']}
54
  break
55
  except ProtocolError:
56
  attempt += 1
57
- matches = {}
58
- if len(matches.keys()) == 0:
59
  logging.error(f"No matches found for '{st.session_state.meta}'")
60
  return matches
61
 
@@ -108,44 +105,34 @@ def pil_to_bytes(img):
108
  img_bin = base64.b64encode(img_bin).decode('utf-8')
109
  return img_bin
110
 
111
- def card(i):
112
- img = imagenet[int(i)]['image']
113
- img_bin = pil_to_bytes(img)
114
- return f'<img id="img{i}" src="data:image/jpeg;base64,{img_bin}" width="200px;">'
115
 
116
- def get_top_k(xq, top_k=10):
117
  matches = query(
118
- xq, top_k=top_k, include_values=True, filter={st.session_state.meta: {"$ne": 1}}
 
119
  )
120
  return matches
121
 
122
- def tune(matches, inputs, iters=5):
123
- positive_idx = [idx for idx, val in inputs.items() if val == 1]
124
- negatives = [match for match in matches.items() if match[0] not in positive_idx]
125
- negative_idx = [match[0] for match in negatives]
126
- negative_vectors = [match[1] for match in negatives]
127
- positive_vectors = [match[1] for match in matches.items() if match[0] in positive_idx]
128
- # prep training data
129
- y = [1] * len(positive_idx) + [0] * len(negative_idx)
130
- X = positive_vectors + negative_vectors
131
  # train the classifier
 
132
  st.session_state.clf.fit(X, y, iters=iters)
133
  # extract new vector
134
  st.session_state.xq = st.session_state.clf.get_weights()
135
- # update one record at a time
136
- for i in positive_idx + negative_idx:
137
- st.session_state.index.update(str(i), set_metadata={st.session_state.meta: 1})
138
 
139
  def refresh_index():
140
  logging.info(f"Refresh for '{st.session_state.meta}'")
 
141
  xq = st.session_state.xq
142
  if type(xq) is not list:
143
  xq = xq.tolist()
144
  while True:
145
  matches = query(xq, top_k=100, filter={st.session_state.meta: 1})
146
- idx = list(matches.keys())
147
- if len(idx) == 0: break
148
- for i in idx:
149
  st.session_state.index.update(str(i), set_metadata={st.session_state.meta: 0})
150
  # refresh session states
151
  del st.session_state.clf, st.session_state.xq
@@ -156,19 +143,26 @@ def calc_dist():
156
  return np.linalg.norm(xq - orig_xq)
157
 
158
  def submit():
 
159
  matches = st.session_state.matches
160
- velocity = st.session_state.velocity
161
- inputs = {}
162
  states = [
163
  st.session_state[f"input{i}"] for i in range(len(matches))
164
  ]
165
- for i, idx in enumerate(matches.keys()):
166
- inputs[idx] = int(states[i])
167
  states[i] = False
168
  # reset states to unchecked
169
  for i in range(len(matches)):
170
  st.session_state[f"input{i}"] = False
171
- tune(matches, inputs, iters=velocity)
 
 
 
 
 
 
172
 
173
  def delete_element(element):
174
  del element
@@ -180,18 +174,72 @@ st.markdown("""
180
  />
181
  """, unsafe_allow_html=True)
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  with st.spinner("Initializing everything..."):
184
- imagenet = init_dataset()
185
  st.session_state.meta, st.session_state.index = init_db()
186
  if 'xq' not in st.session_state:
187
  tokenizer, clip = init_clip()
 
 
 
 
 
 
188
 
189
  if 'xq' not in st.session_state:
190
- start = [st.empty(), st.empty(), st.empty(), st.empty()]
191
- prompt = start[0].text_input("Prompt:", value="")
192
- prompt_xq = start[1].button("Prompt", disabled=len(prompt) == 0)
193
- random_xq = start[2].button("Random", disabled=len(prompt) != 0)
194
- start[3].markdown('Not sure what to write? Try **"dogs in the snow"**, **"close up of a dog"**, **"sony radio"**, or click **Random**.')
 
195
  if random_xq:
196
  print("r_xq")
197
  xq, orig_xq = init_random_query()
@@ -216,17 +264,36 @@ if 'xq' in st.session_state:
216
  refresh_index()
217
  else:
218
  # if we want to display images we end up here
219
- st.markdown(f"Distance travelled: *{round(calc_dist(), 4)}*")
220
  # first retrieve images from pinecone
221
- st.session_state.matches = get_top_k(st.session_state.xq, top_k=10)
222
  # once retrieved, display them alongside checkboxes in a form
223
  with st.form("my_form", clear_on_submit=False):
224
- velocity = st.slider("Velocity", 0, 20, 5, key="velocity")
 
225
  # we have three columns in the form
226
  cols = st.columns(3)
227
- for i, idx in enumerate(st.session_state.matches.keys()):
 
 
 
 
 
 
 
 
 
 
 
228
  # the card shows an image and a checkbox
229
- cols[i%3].markdown(card(idx), unsafe_allow_html=True)
230
  # we access the values of the checkbox via st.session_state[f"input{i}"]
231
- cols[i%3].checkbox("Relevant", key=f"input{i}")
232
- st.form_submit_button("Tune", on_click=submit)
 
 
 
 
 
 
 
 
11
  from urllib3.exceptions import ProtocolError
12
 
13
  PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
14
+ INDEX = "unsplash-25k-clip"
15
  MODEL_ID = "openai/clip-vit-base-patch32"
16
  DIMS = 512
17
 
 
 
 
 
 
 
 
 
 
18
  @st.experimental_singleton(show_spinner=False)
19
  def init_clip():
20
  tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
 
30
  meta_field = datetime.now().isoformat()
31
  return meta_field, pinecone.Index(INDEX)
32
 
33
+ @st.experimental_singleton(show_spinner=False)
34
+ def init_query_num():
35
+ print("init query_num")
36
+ return 0
37
+
38
+ def query(xq, top_k=10, include_values=True, include_metadata=True, filter=None):
39
  logging.info(f"Query to Pinecone with '{st.session_state.meta}'")
40
  attempt = 0
41
  while attempt < 3:
 
44
  xq,
45
  top_k=top_k,
46
  include_values=include_values,
47
+ include_metadata=include_metadata,
48
  filter=filter
49
  )
50
+ matches = xc['matches']
51
  break
52
  except ProtocolError:
53
  attempt += 1
54
+ matches = []
55
+ if len(matches) == 0:
56
  logging.error(f"No matches found for '{st.session_state.meta}'")
57
  return matches
58
 
 
105
  img_bin = base64.b64encode(img_bin).decode('utf-8')
106
  return img_bin
107
 
108
+ def card(i, url):
109
+ return f'<img id="img{i}" src="{url}" width="200px;">'
 
 
110
 
111
+ def get_top_k(xq, top_k=9):
112
  matches = query(
113
+ xq, top_k=top_k, include_values=True, include_metadata=True,
114
+ filter={st.session_state.meta: {"$ne": 1}}
115
  )
116
  return matches
117
 
118
+ def tune(X, y, iters=5):
 
 
 
 
 
 
 
 
119
  # train the classifier
120
+ print(y)
121
  st.session_state.clf.fit(X, y, iters=iters)
122
  # extract new vector
123
  st.session_state.xq = st.session_state.clf.get_weights()
 
 
 
124
 
125
  def refresh_index():
126
  logging.info(f"Refresh for '{st.session_state.meta}'")
127
+ st.session_state.query_num = 0
128
  xq = st.session_state.xq
129
  if type(xq) is not list:
130
  xq = xq.tolist()
131
  while True:
132
  matches = query(xq, top_k=100, filter={st.session_state.meta: 1})
133
+ id_vals = [match['id'] for match in matches]
134
+ if len(id_vals) == 0: break
135
+ for i in id_vals:
136
  st.session_state.index.update(str(i), set_metadata={st.session_state.meta: 0})
137
  # refresh session states
138
  del st.session_state.clf, st.session_state.xq
 
143
  return np.linalg.norm(xq - orig_xq)
144
 
145
  def submit():
146
+ st.session_state.query_num += 1
147
  matches = st.session_state.matches
148
+ velocity = 2 #st.session_state.velocity
149
+ scores = {}
150
  states = [
151
  st.session_state[f"input{i}"] for i in range(len(matches))
152
  ]
153
+ for i, match in enumerate(matches):
154
+ scores[match['id']] = float(states[i])
155
  states[i] = False
156
  # reset states to unchecked
157
  for i in range(len(matches)):
158
  st.session_state[f"input{i}"] = False
159
+ # get training data and labels
160
+ X = list([match['values'] for match in matches])
161
+ y = list(scores.values())
162
+ tune(X, y, iters=velocity)
163
+ # update record metadata after training
164
+ for match in matches:
165
+ st.session_state.index.update(str(match['id']), set_metadata={st.session_state.meta: 1})
166
 
167
  def delete_element(element):
168
  del element
 
174
  />
175
  """, unsafe_allow_html=True)
176
 
177
+ messages = [
178
+ f"""
179
+ Welcome to the semantic query trainer app! Here we will demo how to efficiently train
180
+ a classifier to *very accurately* classify images based on their semantic content.
181
+
182
+ First, we need to initialize the classifier with a simple prompt. Try and write something
183
+ similar to what you're looking for, or if you want a challenge, try something completely
184
+ different.
185
+ """,
186
+ f"""
187
+ With the first query we have initialized the classifier weights (they're a 512-d vector)
188
+ and used those weights to perform a *vector search* to find images embeddings (also 512-d
189
+ vectors) that closely match the classifier weights.
190
+
191
+ These are essentially the images that the classifier would currently classify as "positive".
192
+
193
+ Based on your *target class* for the classifier, decide how relevant each of the images
194
+ are below, rating them from -1 (completely irrelevant) to +1 (a perfect match).
195
+ """,
196
+ f"""
197
+ Each of the image embeddings is paired with the *score* that you just gave it. These are
198
+ all fed into the classifier and used to train it. The classifier learns to *move* towards
199
+ the positively scored images, and to *avoid* the negatively scored images.
200
+ """,
201
+ f"""
202
+ As we repeat the process, the classifier rapidly learns the target space of our intended
203
+ class.
204
+
205
+ Typically, we don't train classifiers like this, instead we label a huge dataset and train
206
+ the classifier across all images and their labels. This is massively inefficient. Here we
207
+ save annotation and compute time by using vector search to identify and focus on the images
208
+ that make the *biggest* difference in classifier performance.
209
+ """,
210
+ f"""
211
+ We shouldn't need to repeat this process many times before our classifier converges on our
212
+ target space. Once we begin returning only relevant images, we can stop training the classifier.
213
+
214
+ *(In this demo, you can try changing your target space and 'traversing' the vector space
215
+ to the new target space)*
216
+ """,
217
+ f"""
218
+ The app uses the [Pinecone vector database](https://pinecone.io/) to store and query images
219
+ using vector search. All images are sourced from the [Unsplash Lite dataset](https://huggingface.co/openai/clip-vit-base-patch32) and encoded
220
+ using [OpenAI's CLIP](https://huggingface.co/openai/clip-vit-base-patch32). We explain how
221
+ it all works [here](https://classifier-train-vector-search--optimistic-curran-b817a8.netlify.app/learn/classifier-train-vector-search/).
222
+ """
223
+ ]
224
+
225
  with st.spinner("Initializing everything..."):
 
226
  st.session_state.meta, st.session_state.index = init_db()
227
  if 'xq' not in st.session_state:
228
  tokenizer, clip = init_clip()
229
+ st.session_state.query_num = 0
230
+
231
+ if st.session_state.query_num+1 < len(messages):
232
+ msg = messages[st.session_state.query_num+1]
233
+ else:
234
+ msg = messages[-1]
235
 
236
  if 'xq' not in st.session_state:
237
+ start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty()]
238
+ start[0].info(msg, icon="⁉️")
239
+ prompt = start[1].text_input("Prompt:", value="")
240
+ prompt_xq = start[2].button("Prompt", disabled=len(prompt) == 0)
241
+ random_xq = start[3].button("Random", disabled=len(prompt) != 0)
242
+ start[4].markdown('Not sure what to write? Try **"dogs in the snow"**, **"close up of a dog"**, **"sony radio"**, or click **Random**.')
243
  if random_xq:
244
  print("r_xq")
245
  xq, orig_xq = init_random_query()
 
264
  refresh_index()
265
  else:
266
  # if we want to display images we end up here
267
+ st.info(msg, icon="🔎")
268
  # first retrieve images from pinecone
269
+ st.session_state.matches = get_top_k(st.session_state.xq, top_k=9)
270
  # once retrieved, display them alongside checkboxes in a form
271
  with st.form("my_form", clear_on_submit=False):
272
+ st.form_submit_button("Tune", on_click=submit)
273
+ #velocity = st.slider("Velocity", 1, 8, 2, key="velocity")
274
  # we have three columns in the form
275
  cols = st.columns(3)
276
+ for i, match in enumerate(st.session_state.matches):
277
+ # find good url
278
+ loc = match["metadata"].get("good_url")
279
+ if loc:
280
+ url = match["metadata"][loc]
281
+ if loc == "photo_url":
282
+ url += "/download?force=true&w=640"
283
+ disabled = False
284
+ else:
285
+ # will show no image, but not sure what else to place here
286
+ url = match["metadata"]["photo_url"]
287
+ disabled=True
288
  # the card shows an image and a checkbox
289
+ cols[i%3].markdown(card(i, url), unsafe_allow_html=True)
290
  # we access the values of the checkbox via st.session_state[f"input{i}"]
291
+ cols[i%3].slider(
292
+ "Relevance",
293
+ min_value=-1.0,
294
+ max_value=1.0,
295
+ value=0.0,
296
+ step=0.1,
297
+ key=f"input{i}",
298
+ disabled=disabled
299
+ )
link-check.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pinecone
2
+ import requests
3
+ from tqdm.auto import tqdm
4
+ import logging
5
+
6
+ # we run this to check for broken links
7
+
8
+ PINECONE_API_KEY = "<<API_KEY_HERE>>"
9
+ INDEX = "unsplash-25k-clip"
10
+
11
+ pinecone.init(
12
+ api_key=PINECONE_API_KEY,
13
+ environment="us-west1-gcp"
14
+ )
15
+
16
+ index = pinecone.Index(INDEX)
17
+
18
+ dim = index.describe_index_stats()['dimension']
19
+ total = int(index.describe_index_stats()['totalVectorCount'])
20
+ xq = [0.0] * dim
21
+
22
+ count = 0
23
+ ID_LIST = []
24
+
25
+ logging.info("Checking links...")
26
+
27
+ with tqdm(total=total) as pbar:
28
+ while True:
29
+ xc = index.query(
30
+ xq, top_k=100, include_metadata=True,
31
+ filter={"link_check": {"$ne": True}}
32
+ )
33
+ matches = xc['matches']
34
+ if len(matches) == 0:
35
+ break
36
+ for match in matches:
37
+ photo_url = match['metadata']['photo_url']+"/download?force=true&w=640"
38
+ res = requests.get(photo_url)
39
+ if res.status_code == 200:
40
+ good_url = "photo_url"
41
+ else:
42
+ res = requests.get(match['metadata']['photo_image_url'])
43
+ if res.status_code == 200:
44
+ good_url = "photo_image_url"
45
+ else:
46
+ good_url = "not_found"
47
+ index.update(match['id'], set_metadata={
48
+ 'good_url': good_url,
49
+ 'link_check': True
50
+ })
51
+ ID_LIST.append(match['id'])
52
+ pbar.update(1)
53
+
54
+ logging.info("Refreshing 'link_check' field...")
55
+ for _id in tqdm(ID_LIST):
56
+ index.update(_id, set_metadata={
57
+ 'link_check': False
58
+ })
unsplash-25k-clip-indexer.ipynb ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "!pip install transformers pinecone-client tqdm"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "metadata": {},
15
+ "source": [
16
+ "The dataset used is the [Unsplash Lite dataset](https://github.com/unsplash/datasets)."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "metadata": {},
23
+ "outputs": [
24
+ {
25
+ "data": {
26
+ "text/html": [
27
+ "<div>\n",
28
+ "<style scoped>\n",
29
+ " .dataframe tbody tr th:only-of-type {\n",
30
+ " vertical-align: middle;\n",
31
+ " }\n",
32
+ "\n",
33
+ " .dataframe tbody tr th {\n",
34
+ " vertical-align: top;\n",
35
+ " }\n",
36
+ "\n",
37
+ " .dataframe thead th {\n",
38
+ " text-align: right;\n",
39
+ " }\n",
40
+ "</style>\n",
41
+ "<table border=\"1\" class=\"dataframe\">\n",
42
+ " <thead>\n",
43
+ " <tr style=\"text-align: right;\">\n",
44
+ " <th></th>\n",
45
+ " <th>photo_id</th>\n",
46
+ " <th>photo_url</th>\n",
47
+ " <th>photo_image_url</th>\n",
48
+ " <th>photo_submitted_at</th>\n",
49
+ " <th>photo_featured</th>\n",
50
+ " <th>photo_width</th>\n",
51
+ " <th>photo_height</th>\n",
52
+ " <th>photo_aspect_ratio</th>\n",
53
+ " <th>photo_description</th>\n",
54
+ " <th>photographer_username</th>\n",
55
+ " <th>...</th>\n",
56
+ " <th>photo_location_country</th>\n",
57
+ " <th>photo_location_city</th>\n",
58
+ " <th>stats_views</th>\n",
59
+ " <th>stats_downloads</th>\n",
60
+ " <th>ai_description</th>\n",
61
+ " <th>ai_primary_landmark_name</th>\n",
62
+ " <th>ai_primary_landmark_latitude</th>\n",
63
+ " <th>ai_primary_landmark_longitude</th>\n",
64
+ " <th>ai_primary_landmark_confidence</th>\n",
65
+ " <th>blur_hash</th>\n",
66
+ " </tr>\n",
67
+ " </thead>\n",
68
+ " <tbody>\n",
69
+ " <tr>\n",
70
+ " <th>0</th>\n",
71
+ " <td>XMyPniM9LF0</td>\n",
72
+ " <td>https://unsplash.com/photos/XMyPniM9LF0</td>\n",
73
+ " <td>https://images.unsplash.com/uploads/1411949294...</td>\n",
74
+ " <td>2014-09-29 00:08:38.594364</td>\n",
75
+ " <td>t</td>\n",
76
+ " <td>4272</td>\n",
77
+ " <td>2848</td>\n",
78
+ " <td>1.50</td>\n",
79
+ " <td>Woman exploring a forest</td>\n",
80
+ " <td>michellespencer77</td>\n",
81
+ " <td>...</td>\n",
82
+ " <td>NaN</td>\n",
83
+ " <td>NaN</td>\n",
84
+ " <td>2375421</td>\n",
85
+ " <td>6967</td>\n",
86
+ " <td>woman walking in the middle of forest</td>\n",
87
+ " <td>NaN</td>\n",
88
+ " <td>NaN</td>\n",
89
+ " <td>NaN</td>\n",
90
+ " <td>NaN</td>\n",
91
+ " <td>L56bVcRRIWMh.gVunlS4SMbsRRxr</td>\n",
92
+ " </tr>\n",
93
+ " <tr>\n",
94
+ " <th>1</th>\n",
95
+ " <td>rDLBArZUl1c</td>\n",
96
+ " <td>https://unsplash.com/photos/rDLBArZUl1c</td>\n",
97
+ " <td>https://images.unsplash.com/photo-141633941111...</td>\n",
98
+ " <td>2014-11-18 19:36:57.08945</td>\n",
99
+ " <td>t</td>\n",
100
+ " <td>3000</td>\n",
101
+ " <td>4000</td>\n",
102
+ " <td>0.75</td>\n",
103
+ " <td>Succulents in a terrarium</td>\n",
104
+ " <td>ugmonk</td>\n",
105
+ " <td>...</td>\n",
106
+ " <td>NaN</td>\n",
107
+ " <td>NaN</td>\n",
108
+ " <td>13784815</td>\n",
109
+ " <td>82141</td>\n",
110
+ " <td>succulent plants in clear glass terrarium</td>\n",
111
+ " <td>NaN</td>\n",
112
+ " <td>NaN</td>\n",
113
+ " <td>NaN</td>\n",
114
+ " <td>NaN</td>\n",
115
+ " <td>LvI$4txu%2s:_4t6WUj]xat7RPoe</td>\n",
116
+ " </tr>\n",
117
+ " <tr>\n",
118
+ " <th>2</th>\n",
119
+ " <td>cNDGZ2sQ3Bo</td>\n",
120
+ " <td>https://unsplash.com/photos/cNDGZ2sQ3Bo</td>\n",
121
+ " <td>https://images.unsplash.com/photo-142014251503...</td>\n",
122
+ " <td>2015-01-01 20:02:02.097036</td>\n",
123
+ " <td>t</td>\n",
124
+ " <td>2564</td>\n",
125
+ " <td>1710</td>\n",
126
+ " <td>1.50</td>\n",
127
+ " <td>Rural winter mountainside</td>\n",
128
+ " <td>johnprice</td>\n",
129
+ " <td>...</td>\n",
130
+ " <td>NaN</td>\n",
131
+ " <td>NaN</td>\n",
132
+ " <td>1302461</td>\n",
133
+ " <td>3428</td>\n",
134
+ " <td>rocky mountain under gray sky at daytime</td>\n",
135
+ " <td>NaN</td>\n",
136
+ " <td>NaN</td>\n",
137
+ " <td>NaN</td>\n",
138
+ " <td>NaN</td>\n",
139
+ " <td>LhMj%NxvM{t7_4t7aeoM%2M{ozj[</td>\n",
140
+ " </tr>\n",
141
+ " <tr>\n",
142
+ " <th>3</th>\n",
143
+ " <td>iuZ_D1eoq9k</td>\n",
144
+ " <td>https://unsplash.com/photos/iuZ_D1eoq9k</td>\n",
145
+ " <td>https://images.unsplash.com/photo-141487280988...</td>\n",
146
+ " <td>2014-11-01 20:15:13.410073</td>\n",
147
+ " <td>t</td>\n",
148
+ " <td>2912</td>\n",
149
+ " <td>4368</td>\n",
150
+ " <td>0.67</td>\n",
151
+ " <td>Poppy seeds and flowers</td>\n",
152
+ " <td>krisatomic</td>\n",
153
+ " <td>...</td>\n",
154
+ " <td>NaN</td>\n",
155
+ " <td>NaN</td>\n",
156
+ " <td>2890238</td>\n",
157
+ " <td>33704</td>\n",
158
+ " <td>red common poppy flower selective focus phography</td>\n",
159
+ " <td>NaN</td>\n",
160
+ " <td>NaN</td>\n",
161
+ " <td>NaN</td>\n",
162
+ " <td>NaN</td>\n",
163
+ " <td>LSC7DirZAsX7}Br@GEWWmnoLWCnj</td>\n",
164
+ " </tr>\n",
165
+ " <tr>\n",
166
+ " <th>4</th>\n",
167
+ " <td>BeD3vjQ8SI0</td>\n",
168
+ " <td>https://unsplash.com/photos/BeD3vjQ8SI0</td>\n",
169
+ " <td>https://images.unsplash.com/photo-141700759404...</td>\n",
170
+ " <td>2014-11-26 13:13:50.134383</td>\n",
171
+ " <td>t</td>\n",
172
+ " <td>4896</td>\n",
173
+ " <td>3264</td>\n",
174
+ " <td>1.50</td>\n",
175
+ " <td>Silhouette near dark trees</td>\n",
176
+ " <td>jonaseriksson</td>\n",
177
+ " <td>...</td>\n",
178
+ " <td>NaN</td>\n",
179
+ " <td>NaN</td>\n",
180
+ " <td>8704860</td>\n",
181
+ " <td>49662</td>\n",
182
+ " <td>trees during night time</td>\n",
183
+ " <td>NaN</td>\n",
184
+ " <td>NaN</td>\n",
185
+ " <td>NaN</td>\n",
186
+ " <td>NaN</td>\n",
187
+ " <td>L25|_:V@0hxtI=W;odae0ht6=^NG</td>\n",
188
+ " </tr>\n",
189
+ " </tbody>\n",
190
+ "</table>\n",
191
+ "<p>5 rows × 31 columns</p>\n",
192
+ "</div>"
193
+ ],
194
+ "text/plain": [
195
+ " photo_id photo_url \\\n",
196
+ "0 XMyPniM9LF0 https://unsplash.com/photos/XMyPniM9LF0 \n",
197
+ "1 rDLBArZUl1c https://unsplash.com/photos/rDLBArZUl1c \n",
198
+ "2 cNDGZ2sQ3Bo https://unsplash.com/photos/cNDGZ2sQ3Bo \n",
199
+ "3 iuZ_D1eoq9k https://unsplash.com/photos/iuZ_D1eoq9k \n",
200
+ "4 BeD3vjQ8SI0 https://unsplash.com/photos/BeD3vjQ8SI0 \n",
201
+ "\n",
202
+ " photo_image_url \\\n",
203
+ "0 https://images.unsplash.com/uploads/1411949294... \n",
204
+ "1 https://images.unsplash.com/photo-141633941111... \n",
205
+ "2 https://images.unsplash.com/photo-142014251503... \n",
206
+ "3 https://images.unsplash.com/photo-141487280988... \n",
207
+ "4 https://images.unsplash.com/photo-141700759404... \n",
208
+ "\n",
209
+ " photo_submitted_at photo_featured photo_width photo_height \\\n",
210
+ "0 2014-09-29 00:08:38.594364 t 4272 2848 \n",
211
+ "1 2014-11-18 19:36:57.08945 t 3000 4000 \n",
212
+ "2 2015-01-01 20:02:02.097036 t 2564 1710 \n",
213
+ "3 2014-11-01 20:15:13.410073 t 2912 4368 \n",
214
+ "4 2014-11-26 13:13:50.134383 t 4896 3264 \n",
215
+ "\n",
216
+ " photo_aspect_ratio photo_description photographer_username ... \\\n",
217
+ "0 1.50 Woman exploring a forest michellespencer77 ... \n",
218
+ "1 0.75 Succulents in a terrarium ugmonk ... \n",
219
+ "2 1.50 Rural winter mountainside johnprice ... \n",
220
+ "3 0.67 Poppy seeds and flowers krisatomic ... \n",
221
+ "4 1.50 Silhouette near dark trees jonaseriksson ... \n",
222
+ "\n",
223
+ " photo_location_country photo_location_city stats_views stats_downloads \\\n",
224
+ "0 NaN NaN 2375421 6967 \n",
225
+ "1 NaN NaN 13784815 82141 \n",
226
+ "2 NaN NaN 1302461 3428 \n",
227
+ "3 NaN NaN 2890238 33704 \n",
228
+ "4 NaN NaN 8704860 49662 \n",
229
+ "\n",
230
+ " ai_description ai_primary_landmark_name \\\n",
231
+ "0 woman walking in the middle of forest NaN \n",
232
+ "1 succulent plants in clear glass terrarium NaN \n",
233
+ "2 rocky mountain under gray sky at daytime NaN \n",
234
+ "3 red common poppy flower selective focus phography NaN \n",
235
+ "4 trees during night time NaN \n",
236
+ "\n",
237
+ " ai_primary_landmark_latitude ai_primary_landmark_longitude \\\n",
238
+ "0 NaN NaN \n",
239
+ "1 NaN NaN \n",
240
+ "2 NaN NaN \n",
241
+ "3 NaN NaN \n",
242
+ "4 NaN NaN \n",
243
+ "\n",
244
+ " ai_primary_landmark_confidence blur_hash \n",
245
+ "0 NaN L56bVcRRIWMh.gVunlS4SMbsRRxr \n",
246
+ "1 NaN LvI$4txu%2s:_4t6WUj]xat7RPoe \n",
247
+ "2 NaN LhMj%NxvM{t7_4t7aeoM%2M{ozj[ \n",
248
+ "3 NaN LSC7DirZAsX7}Br@GEWWmnoLWCnj \n",
249
+ "4 NaN L25|_:V@0hxtI=W;odae0ht6=^NG \n",
250
+ "\n",
251
+ "[5 rows x 31 columns]"
252
+ ]
253
+ },
254
+ "execution_count": 2,
255
+ "metadata": {},
256
+ "output_type": "execute_result"
257
+ }
258
+ ],
259
+ "source": [
260
+ "import pandas as pd\n",
261
+ "\n",
262
+ "images = pd.read_csv('photos.tsv000', delimiter='\\t')\n",
263
+ "images.head()"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "markdown",
268
+ "metadata": {},
269
+ "source": [
270
+ "We download using the `photo_image_url` field."
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 4,
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "from PIL import Image\n",
280
+ "import requests\n",
281
+ "from io import BytesIO\n",
282
+ "\n",
283
+ "url = images['photo_image_url'].iloc[0]\n",
284
+ "\n",
285
+ "response = requests.get(url)\n",
286
+ "img = Image.open(BytesIO(response.content))\n",
287
+ "img"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "markdown",
292
+ "metadata": {},
293
+ "source": [
294
+ "We need to use these images to create vector embeddings, to do this we will use OpenAI's CLIP from the `transformers` library.\n",
295
+ "\n",
296
+ "```\n",
297
+ "!pip install transformers\n",
298
+ "```"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": 5,
304
+ "metadata": {},
305
+ "outputs": [
306
+ {
307
+ "name": "stderr",
308
+ "output_type": "stream",
309
+ "text": [
310
+ "2022-08-12 14:07:47.935784: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n",
311
+ "ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.\n"
312
+ ]
313
+ }
314
+ ],
315
+ "source": [
316
+ "from transformers import CLIPProcessor, CLIPModel\n",
317
+ "import torch\n",
318
+ "\n",
319
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
320
+ "model_name = \"openai/clip-vit-base-patch32\"\n",
321
+ "\n",
322
+ "model = CLIPModel.from_pretrained(model_name).to(device)\n",
323
+ "processor = CLIPProcessor.from_pretrained(model_name)"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "metadata": {},
329
+ "source": [
330
+ "Now we're ready to use the vision transformer (ViT) portion of CLIP to create feature vectors (embedding representations) from the image."
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 6,
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "img = processor(\n",
340
+ " text=None,\n",
341
+ " images=img,\n",
342
+ " return_tensors='pt',\n",
343
+ " padding=True\n",
344
+ ")['pixel_values'].to(device)"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": 7,
350
+ "metadata": {},
351
+ "outputs": [
352
+ {
353
+ "data": {
354
+ "text/plain": [
355
+ "torch.Size([1, 512])"
356
+ ]
357
+ },
358
+ "execution_count": 7,
359
+ "metadata": {},
360
+ "output_type": "execute_result"
361
+ }
362
+ ],
363
+ "source": [
364
+ "out = model.get_image_features(pixel_values=img)\n",
365
+ "out.shape"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": 8,
371
+ "metadata": {},
372
+ "outputs": [
373
+ {
374
+ "data": {
375
+ "text/plain": [
376
+ "torch.Size([512])"
377
+ ]
378
+ },
379
+ "execution_count": 8,
380
+ "metadata": {},
381
+ "output_type": "execute_result"
382
+ }
383
+ ],
384
+ "source": [
385
+ "out = out.squeeze(0)\n",
386
+ "out.shape"
387
+ ]
388
+ },
389
+ {
390
+ "cell_type": "code",
391
+ "execution_count": 9,
392
+ "metadata": {},
393
+ "outputs": [
394
+ {
395
+ "data": {
396
+ "text/plain": [
397
+ "(512,)"
398
+ ]
399
+ },
400
+ "execution_count": 9,
401
+ "metadata": {},
402
+ "output_type": "execute_result"
403
+ }
404
+ ],
405
+ "source": [
406
+ "emb = out.cpu().detach().numpy()\n",
407
+ "emb.shape"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": 10,
413
+ "metadata": {},
414
+ "outputs": [
415
+ {
416
+ "data": {
417
+ "text/plain": [
418
+ "(-7.985501, 2.0108054)"
419
+ ]
420
+ },
421
+ "execution_count": 10,
422
+ "metadata": {},
423
+ "output_type": "execute_result"
424
+ }
425
+ ],
426
+ "source": [
427
+ "emb.min(), emb.max()"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "markdown",
432
+ "metadata": {},
433
+ "source": [
434
+ "Now we have a single `512` dimensional vector that represents the *meaning* behind the image. As we will be using dot product similarity we should also normalize these vectors."
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "execution_count": 10,
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "import numpy as np\n",
444
+ "\n",
445
+ "emb = emb / np.linalg.norm(emb)"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "execution_count": 11,
451
+ "metadata": {},
452
+ "outputs": [
453
+ {
454
+ "data": {
455
+ "text/plain": [
456
+ "(-0.56626415, 0.13343191)"
457
+ ]
458
+ },
459
+ "execution_count": 11,
460
+ "metadata": {},
461
+ "output_type": "execute_result"
462
+ }
463
+ ],
464
+ "source": [
465
+ "emb.min(), emb.max()"
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "markdown",
470
+ "metadata": {},
471
+ "source": [
472
+ "## Indexing\n",
473
+ "\n",
474
+ "To index this image in Pinecone we first install the Pinecone client:\n",
475
+ "\n",
476
+ "```\n",
477
+ "!pip install pinecone-client\n",
478
+ "```\n",
479
+ "\n",
480
+ "And then initialize our connection to Pinecone, this requires a [free API key](https://app.pinecone.io/)."
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "execution_count": 11,
486
+ "metadata": {},
487
+ "outputs": [],
488
+ "source": [
489
+ "import pinecone\n",
490
+ "\n",
491
+ "index_name = \"unsplash-25k-clip\"\n",
492
+ "\n",
493
+ "pinecone.init(\n",
494
+ " api_key=\"<<API_KEY_HERE>>\",\n",
495
+ " environment=\"us-west1-gcp\"\n",
496
+ ")\n",
497
+ "\n",
498
+ "if index_name not in pinecone.list_indexes():\n",
499
+ " pinecone.create_index(\n",
500
+ " index_name,\n",
501
+ " emb.shape[0],\n",
502
+ " metric=\"dotproduct\"\n",
503
+ " )\n",
504
+ "# connect to the index\n",
505
+ "index = pinecone.Index(index_name)"
506
+ ]
507
+ },
508
+ {
509
+ "cell_type": "markdown",
510
+ "metadata": {},
511
+ "source": [
512
+ "To upsert the single feature embedding we have created, we use `upsert`. There are also some possibly relevant metadata info we might want to add."
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": 13,
518
+ "metadata": {},
519
+ "outputs": [
520
+ {
521
+ "data": {
522
+ "text/plain": [
523
+ "{'photo_url': 'https://unsplash.com/photos/XMyPniM9LF0',\n",
524
+ " 'photo_image_url': 'https://images.unsplash.com/uploads/14119492946973137ce46/f1f2ebf3',\n",
525
+ " 'photo_submitted_at': '2014-09-29 00:08:38.594364',\n",
526
+ " 'photo_description': 'Woman exploring a forest',\n",
527
+ " 'photographer_username': 'michellespencer77',\n",
528
+ " 'ai_description': 'woman walking in the middle of forest'}"
529
+ ]
530
+ },
531
+ "execution_count": 13,
532
+ "metadata": {},
533
+ "output_type": "execute_result"
534
+ }
535
+ ],
536
+ "source": [
537
+ "row = images.iloc[0]\n",
538
+ "\n",
539
+ "_id = row['photo_id']\n",
540
+ "meta = {\n",
541
+ " \"photo_url\": row[\"photo_url\"],\n",
542
+ " \"photo_image_url\": row[\"photo_image_url\"],\n",
543
+ " \"photo_submitted_at\": row[\"photo_submitted_at\"],\n",
544
+ " \"photo_description\": row[\"photo_description\"],\n",
545
+ " \"photographer_username\": row[\"photographer_username\"],\n",
546
+ " \"ai_description\": row[\"ai_description\"]\n",
547
+ "}\n",
548
+ "\n",
549
+ "meta"
550
+ ]
551
+ },
552
+ {
553
+ "cell_type": "code",
554
+ "execution_count": 14,
555
+ "metadata": {},
556
+ "outputs": [
557
+ {
558
+ "data": {
559
+ "text/plain": [
560
+ "{'upserted_count': 1}"
561
+ ]
562
+ },
563
+ "execution_count": 14,
564
+ "metadata": {},
565
+ "output_type": "execute_result"
566
+ }
567
+ ],
568
+ "source": [
569
+ "to_upsert = [(_id, emb.tolist(), meta)]\n",
570
+ "\n",
571
+ "index.upsert(to_upsert)"
572
+ ]
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "execution_count": 15,
577
+ "metadata": {},
578
+ "outputs": [
579
+ {
580
+ "data": {
581
+ "text/plain": [
582
+ "'XMyPniM9LF0'"
583
+ ]
584
+ },
585
+ "execution_count": 15,
586
+ "metadata": {},
587
+ "output_type": "execute_result"
588
+ }
589
+ ],
590
+ "source": [
591
+ "_id"
592
+ ]
593
+ },
594
+ {
595
+ "cell_type": "markdown",
596
+ "metadata": {
597
+ "tags": []
598
+ },
599
+ "source": [
600
+ "Note that we added a string ID value `\"XMyPniM9LF0\"` and also converted the feature embedding tensor to a flat list before adding to our Pinecone index.\n",
601
+ "\n",
602
+ "## Indexing Everything\n",
603
+ "\n",
604
+ "So far we've built one feature embedding and indexed it in Pinecone, now let's repeat the process for a lot of images.\n",
605
+ "\n",
606
+ "We will do this in batches, taking `32` images at a time, embedding them with Resnet-50, and indexing them in Pinecone."
607
+ ]
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "execution_count": 23,
612
+ "metadata": {},
613
+ "outputs": [],
614
+ "source": [
615
+ "import numpy as np"
616
+ ]
617
+ },
618
+ {
619
+ "cell_type": "code",
620
+ "execution_count": 79,
621
+ "metadata": {},
622
+ "outputs": [
623
+ {
624
+ "data": {
625
+ "application/vnd.jupyter.widget-view+json": {
626
+ "model_id": "6726c0eb47de4cd780f3e1096a2d743f",
627
+ "version_major": 2,
628
+ "version_minor": 0
629
+ },
630
+ "text/plain": [
631
+ " 0%| | 0/1370 [00:00<?, ?it/s]"
632
+ ]
633
+ },
634
+ "metadata": {},
635
+ "output_type": "display_data"
636
+ },
637
+ {
638
+ "name": "stderr",
639
+ "output_type": "stream",
640
+ "text": [
641
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (99996755 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
642
+ " DecompressionBombWarning,\n",
643
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (96768910 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
644
+ " DecompressionBombWarning,\n",
645
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (99991727 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
646
+ " DecompressionBombWarning,\n",
647
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (143040000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
648
+ " DecompressionBombWarning,\n",
649
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (94212096 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
650
+ " DecompressionBombWarning,\n",
651
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (121500000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
652
+ " DecompressionBombWarning,\n",
653
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (107424768 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
654
+ " DecompressionBombWarning,\n",
655
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (147015000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
656
+ " DecompressionBombWarning,\n",
657
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (107184040 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
658
+ " DecompressionBombWarning,\n",
659
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (146784000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
660
+ " DecompressionBombWarning,\n",
661
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (90671520 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
662
+ " DecompressionBombWarning,\n",
663
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (99992815 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
664
+ " DecompressionBombWarning,\n",
665
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (95808000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
666
+ " DecompressionBombWarning,\n",
667
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (121554000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
668
+ " DecompressionBombWarning,\n",
669
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (91177320 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
670
+ " DecompressionBombWarning,\n",
671
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (99996120 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
672
+ " DecompressionBombWarning,\n",
673
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (96000000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
674
+ " DecompressionBombWarning,\n",
675
+ "/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (98058240 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
676
+ " DecompressionBombWarning,\n"
677
+ ]
678
+ }
679
+ ],
680
+ "source": [
681
+ "from tqdm.auto import tqdm\n",
682
+ "batch_size = 16\n",
683
+ "images_len = len(images)\n",
684
+ "\n",
685
+ "exceptions = []\n",
686
+ "\n",
687
+ "for i in tqdm(range(3088, images_len, batch_size)):\n",
688
+ " # select the batch start and end\n",
689
+ " i_end = min(i + batch_size, images_len)\n",
690
+ " # get batch\n",
691
+ " batch = images.iloc[i:i_end]\n",
692
+ " # retrieve URLs\n",
693
+ " url_batch = batch['photo_image_url']\n",
694
+ " # get images\n",
695
+ " image_batch = []\n",
696
+ " for url in url_batch:\n",
697
+ " try:\n",
698
+ " response = requests.get(url)\n",
699
+ " img = Image.open(BytesIO(response.content))\n",
700
+ " if img.mode in ['L', 'CMYK', 'RGBA']:\n",
701
+ " # L is grayscale, CMYK uses alternative color channels\n",
702
+ " img = img.convert('RGB')\n",
703
+ " image_batch.append(img)\n",
704
+ " except Exception as e:\n",
705
+ " exceptions.append((\"url\", e))\n",
706
+ " # process images and extract pytorch tensor pixel values\n",
707
+ " try:\n",
708
+ " image_batch = processor(\n",
709
+ " text=None,\n",
710
+ " images=image_batch,\n",
711
+ " return_tensors='pt'\n",
712
+ " )['pixel_values'].to(device)\n",
713
+ " # feed tensors to model and extract last state\n",
714
+ " out = model.get_image_features(pixel_values=image_batch)\n",
715
+ " out = out.squeeze(0)\n",
716
+ " # take the mean across each dimension to create a single vector embedding\n",
717
+ " embeds = out.cpu().detach().numpy()\n",
718
+ " # normalize and convert to list\n",
719
+ " embeds = embeds / np.linalg.norm(embeds, axis=0)\n",
720
+ " embeds = embeds.tolist()\n",
721
+ " # get ID values\n",
722
+ " ids = batch['photo_id']\n",
723
+ " # prep metadata\n",
724
+ " metadata = batch[[\n",
725
+ " \"photo_url\", \"photo_image_url\", \"photo_submitted_at\",\n",
726
+ " \"photo_description\", \"photographer_username\", \"ai_description\"\n",
727
+ " ]].fillna(\"\").to_dict(orient=\"records\")\n",
728
+ " # zip all data together and upsert\n",
729
+ " to_upsert = zip(ids, embeds, metadata)\n",
730
+ " index.upsert(to_upsert)\n",
731
+ " except Exception as e:\n",
732
+ " exceptions.append((\"process\", e))"
733
+ ]
734
+ },
735
+ {
736
+ "cell_type": "markdown",
737
+ "metadata": {},
738
+ "source": [
739
+ "---"
740
+ ]
741
+ }
742
+ ],
743
+ "metadata": {
744
+ "environment": {
745
+ "kernel": "python3",
746
+ "name": "common-cu110.m91",
747
+ "type": "gcloud",
748
+ "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
749
+ },
750
+ "kernelspec": {
751
+ "display_name": "Python 3",
752
+ "language": "python",
753
+ "name": "python3"
754
+ },
755
+ "language_info": {
756
+ "codemirror_mode": {
757
+ "name": "ipython",
758
+ "version": 3
759
+ },
760
+ "file_extension": ".py",
761
+ "mimetype": "text/x-python",
762
+ "name": "python",
763
+ "nbconvert_exporter": "python",
764
+ "pygments_lexer": "ipython3",
765
+ "version": "3.8.13"
766
+ },
767
+ "vscode": {
768
+ "interpreter": {
769
+ "hash": "9ec8fc8fb845fc3e050bf8bf651a355c069bbfeddee31167baf4bc42b6050476"
770
+ }
771
+ }
772
+ },
773
+ "nbformat": 4,
774
+ "nbformat_minor": 4
775
+ }