Spaces:
Runtime error
Runtime error
jamescalam
commited on
Commit
•
88172be
1
Parent(s):
814d271
upgrade to contrastive learning and unsplash lite dataset
Browse files- app.py +120 -53
- link-check.py +58 -0
- unsplash-25k-clip-indexer.ipynb +775 -0
app.py
CHANGED
@@ -11,19 +11,10 @@ import logging
|
|
11 |
from urllib3.exceptions import ProtocolError
|
12 |
|
13 |
PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
|
14 |
-
INDEX = "
|
15 |
MODEL_ID = "openai/clip-vit-base-patch32"
|
16 |
DIMS = 512
|
17 |
|
18 |
-
@st.experimental_singleton(show_spinner=False)
|
19 |
-
def init_dataset():
|
20 |
-
return load_dataset(
|
21 |
-
'frgfm/imagenette',
|
22 |
-
'full_size',
|
23 |
-
split='train',
|
24 |
-
ignore_verifications=False # set to True if seeing splits Error
|
25 |
-
)
|
26 |
-
|
27 |
@st.experimental_singleton(show_spinner=False)
|
28 |
def init_clip():
|
29 |
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
|
@@ -39,7 +30,12 @@ def init_db():
|
|
39 |
meta_field = datetime.now().isoformat()
|
40 |
return meta_field, pinecone.Index(INDEX)
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
43 |
logging.info(f"Query to Pinecone with '{st.session_state.meta}'")
|
44 |
attempt = 0
|
45 |
while attempt < 3:
|
@@ -48,14 +44,15 @@ def query(xq, top_k=10, include_values=True, filter=None):
|
|
48 |
xq,
|
49 |
top_k=top_k,
|
50 |
include_values=include_values,
|
|
|
51 |
filter=filter
|
52 |
)
|
53 |
-
matches =
|
54 |
break
|
55 |
except ProtocolError:
|
56 |
attempt += 1
|
57 |
-
matches =
|
58 |
-
if len(matches
|
59 |
logging.error(f"No matches found for '{st.session_state.meta}'")
|
60 |
return matches
|
61 |
|
@@ -108,44 +105,34 @@ def pil_to_bytes(img):
|
|
108 |
img_bin = base64.b64encode(img_bin).decode('utf-8')
|
109 |
return img_bin
|
110 |
|
111 |
-
def card(i):
|
112 |
-
img =
|
113 |
-
img_bin = pil_to_bytes(img)
|
114 |
-
return f'<img id="img{i}" src="data:image/jpeg;base64,{img_bin}" width="200px;">'
|
115 |
|
116 |
-
def get_top_k(xq, top_k=
|
117 |
matches = query(
|
118 |
-
xq, top_k=top_k, include_values=True,
|
|
|
119 |
)
|
120 |
return matches
|
121 |
|
122 |
-
def tune(
|
123 |
-
positive_idx = [idx for idx, val in inputs.items() if val == 1]
|
124 |
-
negatives = [match for match in matches.items() if match[0] not in positive_idx]
|
125 |
-
negative_idx = [match[0] for match in negatives]
|
126 |
-
negative_vectors = [match[1] for match in negatives]
|
127 |
-
positive_vectors = [match[1] for match in matches.items() if match[0] in positive_idx]
|
128 |
-
# prep training data
|
129 |
-
y = [1] * len(positive_idx) + [0] * len(negative_idx)
|
130 |
-
X = positive_vectors + negative_vectors
|
131 |
# train the classifier
|
|
|
132 |
st.session_state.clf.fit(X, y, iters=iters)
|
133 |
# extract new vector
|
134 |
st.session_state.xq = st.session_state.clf.get_weights()
|
135 |
-
# update one record at a time
|
136 |
-
for i in positive_idx + negative_idx:
|
137 |
-
st.session_state.index.update(str(i), set_metadata={st.session_state.meta: 1})
|
138 |
|
139 |
def refresh_index():
|
140 |
logging.info(f"Refresh for '{st.session_state.meta}'")
|
|
|
141 |
xq = st.session_state.xq
|
142 |
if type(xq) is not list:
|
143 |
xq = xq.tolist()
|
144 |
while True:
|
145 |
matches = query(xq, top_k=100, filter={st.session_state.meta: 1})
|
146 |
-
|
147 |
-
if len(
|
148 |
-
for i in
|
149 |
st.session_state.index.update(str(i), set_metadata={st.session_state.meta: 0})
|
150 |
# refresh session states
|
151 |
del st.session_state.clf, st.session_state.xq
|
@@ -156,19 +143,26 @@ def calc_dist():
|
|
156 |
return np.linalg.norm(xq - orig_xq)
|
157 |
|
158 |
def submit():
|
|
|
159 |
matches = st.session_state.matches
|
160 |
-
velocity = st.session_state.velocity
|
161 |
-
|
162 |
states = [
|
163 |
st.session_state[f"input{i}"] for i in range(len(matches))
|
164 |
]
|
165 |
-
for i,
|
166 |
-
|
167 |
states[i] = False
|
168 |
# reset states to unchecked
|
169 |
for i in range(len(matches)):
|
170 |
st.session_state[f"input{i}"] = False
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
def delete_element(element):
|
174 |
del element
|
@@ -180,18 +174,72 @@ st.markdown("""
|
|
180 |
/>
|
181 |
""", unsafe_allow_html=True)
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
with st.spinner("Initializing everything..."):
|
184 |
-
imagenet = init_dataset()
|
185 |
st.session_state.meta, st.session_state.index = init_db()
|
186 |
if 'xq' not in st.session_state:
|
187 |
tokenizer, clip = init_clip()
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
if 'xq' not in st.session_state:
|
190 |
-
start = [st.empty(), st.empty(), st.empty(), st.empty()]
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
start[3].
|
|
|
195 |
if random_xq:
|
196 |
print("r_xq")
|
197 |
xq, orig_xq = init_random_query()
|
@@ -216,17 +264,36 @@ if 'xq' in st.session_state:
|
|
216 |
refresh_index()
|
217 |
else:
|
218 |
# if we want to display images we end up here
|
219 |
-
st.
|
220 |
# first retrieve images from pinecone
|
221 |
-
st.session_state.matches = get_top_k(st.session_state.xq, top_k=
|
222 |
# once retrieved, display them alongside checkboxes in a form
|
223 |
with st.form("my_form", clear_on_submit=False):
|
224 |
-
|
|
|
225 |
# we have three columns in the form
|
226 |
cols = st.columns(3)
|
227 |
-
for i,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
# the card shows an image and a checkbox
|
229 |
-
cols[i%3].markdown(card(
|
230 |
# we access the values of the checkbox via st.session_state[f"input{i}"]
|
231 |
-
cols[i%3].
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from urllib3.exceptions import ProtocolError
|
12 |
|
13 |
PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
|
14 |
+
INDEX = "unsplash-25k-clip"
|
15 |
MODEL_ID = "openai/clip-vit-base-patch32"
|
16 |
DIMS = 512
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
@st.experimental_singleton(show_spinner=False)
|
19 |
def init_clip():
|
20 |
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
|
|
|
30 |
meta_field = datetime.now().isoformat()
|
31 |
return meta_field, pinecone.Index(INDEX)
|
32 |
|
33 |
+
@st.experimental_singleton(show_spinner=False)
|
34 |
+
def init_query_num():
|
35 |
+
print("init query_num")
|
36 |
+
return 0
|
37 |
+
|
38 |
+
def query(xq, top_k=10, include_values=True, include_metadata=True, filter=None):
|
39 |
logging.info(f"Query to Pinecone with '{st.session_state.meta}'")
|
40 |
attempt = 0
|
41 |
while attempt < 3:
|
|
|
44 |
xq,
|
45 |
top_k=top_k,
|
46 |
include_values=include_values,
|
47 |
+
include_metadata=include_metadata,
|
48 |
filter=filter
|
49 |
)
|
50 |
+
matches = xc['matches']
|
51 |
break
|
52 |
except ProtocolError:
|
53 |
attempt += 1
|
54 |
+
matches = []
|
55 |
+
if len(matches) == 0:
|
56 |
logging.error(f"No matches found for '{st.session_state.meta}'")
|
57 |
return matches
|
58 |
|
|
|
105 |
img_bin = base64.b64encode(img_bin).decode('utf-8')
|
106 |
return img_bin
|
107 |
|
108 |
+
def card(i, url):
|
109 |
+
return f'<img id="img{i}" src="{url}" width="200px;">'
|
|
|
|
|
110 |
|
111 |
+
def get_top_k(xq, top_k=9):
|
112 |
matches = query(
|
113 |
+
xq, top_k=top_k, include_values=True, include_metadata=True,
|
114 |
+
filter={st.session_state.meta: {"$ne": 1}}
|
115 |
)
|
116 |
return matches
|
117 |
|
118 |
+
def tune(X, y, iters=5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
# train the classifier
|
120 |
+
print(y)
|
121 |
st.session_state.clf.fit(X, y, iters=iters)
|
122 |
# extract new vector
|
123 |
st.session_state.xq = st.session_state.clf.get_weights()
|
|
|
|
|
|
|
124 |
|
125 |
def refresh_index():
|
126 |
logging.info(f"Refresh for '{st.session_state.meta}'")
|
127 |
+
st.session_state.query_num = 0
|
128 |
xq = st.session_state.xq
|
129 |
if type(xq) is not list:
|
130 |
xq = xq.tolist()
|
131 |
while True:
|
132 |
matches = query(xq, top_k=100, filter={st.session_state.meta: 1})
|
133 |
+
id_vals = [match['id'] for match in matches]
|
134 |
+
if len(id_vals) == 0: break
|
135 |
+
for i in id_vals:
|
136 |
st.session_state.index.update(str(i), set_metadata={st.session_state.meta: 0})
|
137 |
# refresh session states
|
138 |
del st.session_state.clf, st.session_state.xq
|
|
|
143 |
return np.linalg.norm(xq - orig_xq)
|
144 |
|
145 |
def submit():
|
146 |
+
st.session_state.query_num += 1
|
147 |
matches = st.session_state.matches
|
148 |
+
velocity = 2 #st.session_state.velocity
|
149 |
+
scores = {}
|
150 |
states = [
|
151 |
st.session_state[f"input{i}"] for i in range(len(matches))
|
152 |
]
|
153 |
+
for i, match in enumerate(matches):
|
154 |
+
scores[match['id']] = float(states[i])
|
155 |
states[i] = False
|
156 |
# reset states to unchecked
|
157 |
for i in range(len(matches)):
|
158 |
st.session_state[f"input{i}"] = False
|
159 |
+
# get training data and labels
|
160 |
+
X = list([match['values'] for match in matches])
|
161 |
+
y = list(scores.values())
|
162 |
+
tune(X, y, iters=velocity)
|
163 |
+
# update record metadata after training
|
164 |
+
for match in matches:
|
165 |
+
st.session_state.index.update(str(match['id']), set_metadata={st.session_state.meta: 1})
|
166 |
|
167 |
def delete_element(element):
|
168 |
del element
|
|
|
174 |
/>
|
175 |
""", unsafe_allow_html=True)
|
176 |
|
177 |
+
messages = [
|
178 |
+
f"""
|
179 |
+
Welcome to the semantic query trainer app! Here we will demo how to efficiently train
|
180 |
+
a classifier to *very accurately* classify images based on their semantic content.
|
181 |
+
|
182 |
+
First, we need to initialize the classifier with a simple prompt. Try and write something
|
183 |
+
similar to what you're looking for, or if you want a challenge, try something completely
|
184 |
+
different.
|
185 |
+
""",
|
186 |
+
f"""
|
187 |
+
With the first query we have initialized the classifier weights (they're a 512-d vector)
|
188 |
+
and used those weights to perform a *vector search* to find images embeddings (also 512-d
|
189 |
+
vectors) that closely match the classifier weights.
|
190 |
+
|
191 |
+
These are essentially the images that the classifier would currently classify as "positive".
|
192 |
+
|
193 |
+
Based on your *target class* for the classifier, decide how relevant each of the images
|
194 |
+
are below, rating them from -1 (completely irrelevant) to +1 (a perfect match).
|
195 |
+
""",
|
196 |
+
f"""
|
197 |
+
Each of the image embeddings is paired with the *score* that you just gave it. These are
|
198 |
+
all fed into the classifier and used to train it. The classifier learns to *move* towards
|
199 |
+
the positively scored images, and to *avoid* the negatively scored images.
|
200 |
+
""",
|
201 |
+
f"""
|
202 |
+
As we repeat the process, the classifier rapidly learns the target space of our intended
|
203 |
+
class.
|
204 |
+
|
205 |
+
Typically, we don't train classifiers like this, instead we label a huge dataset and train
|
206 |
+
the classifier across all images and their labels. This is massively inefficient. Here we
|
207 |
+
save annotation and compute time by using vector search to identify and focus on the images
|
208 |
+
that make the *biggest* difference in classifier performance.
|
209 |
+
""",
|
210 |
+
f"""
|
211 |
+
We shouldn't need to repeat this process many times before our classifier converges on our
|
212 |
+
target space. Once we begin returning only relevant images, we can stop training the classifier.
|
213 |
+
|
214 |
+
*(In this demo, you can try changing your target space and 'traversing' the vector space
|
215 |
+
to the new target space)*
|
216 |
+
""",
|
217 |
+
f"""
|
218 |
+
The app uses the [Pinecone vector database](https://pinecone.io/) to store and query images
|
219 |
+
using vector search. All images are sourced from the [Unsplash Lite dataset](https://huggingface.co/openai/clip-vit-base-patch32) and encoded
|
220 |
+
using [OpenAI's CLIP](https://huggingface.co/openai/clip-vit-base-patch32). We explain how
|
221 |
+
it all works [here](https://classifier-train-vector-search--optimistic-curran-b817a8.netlify.app/learn/classifier-train-vector-search/).
|
222 |
+
"""
|
223 |
+
]
|
224 |
+
|
225 |
with st.spinner("Initializing everything..."):
|
|
|
226 |
st.session_state.meta, st.session_state.index = init_db()
|
227 |
if 'xq' not in st.session_state:
|
228 |
tokenizer, clip = init_clip()
|
229 |
+
st.session_state.query_num = 0
|
230 |
+
|
231 |
+
if st.session_state.query_num+1 < len(messages):
|
232 |
+
msg = messages[st.session_state.query_num+1]
|
233 |
+
else:
|
234 |
+
msg = messages[-1]
|
235 |
|
236 |
if 'xq' not in st.session_state:
|
237 |
+
start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty()]
|
238 |
+
start[0].info(msg, icon="⁉️")
|
239 |
+
prompt = start[1].text_input("Prompt:", value="")
|
240 |
+
prompt_xq = start[2].button("Prompt", disabled=len(prompt) == 0)
|
241 |
+
random_xq = start[3].button("Random", disabled=len(prompt) != 0)
|
242 |
+
start[4].markdown('Not sure what to write? Try **"dogs in the snow"**, **"close up of a dog"**, **"sony radio"**, or click **Random**.')
|
243 |
if random_xq:
|
244 |
print("r_xq")
|
245 |
xq, orig_xq = init_random_query()
|
|
|
264 |
refresh_index()
|
265 |
else:
|
266 |
# if we want to display images we end up here
|
267 |
+
st.info(msg, icon="🔎")
|
268 |
# first retrieve images from pinecone
|
269 |
+
st.session_state.matches = get_top_k(st.session_state.xq, top_k=9)
|
270 |
# once retrieved, display them alongside checkboxes in a form
|
271 |
with st.form("my_form", clear_on_submit=False):
|
272 |
+
st.form_submit_button("Tune", on_click=submit)
|
273 |
+
#velocity = st.slider("Velocity", 1, 8, 2, key="velocity")
|
274 |
# we have three columns in the form
|
275 |
cols = st.columns(3)
|
276 |
+
for i, match in enumerate(st.session_state.matches):
|
277 |
+
# find good url
|
278 |
+
loc = match["metadata"].get("good_url")
|
279 |
+
if loc:
|
280 |
+
url = match["metadata"][loc]
|
281 |
+
if loc == "photo_url":
|
282 |
+
url += "/download?force=true&w=640"
|
283 |
+
disabled = False
|
284 |
+
else:
|
285 |
+
# will show no image, but not sure what else to place here
|
286 |
+
url = match["metadata"]["photo_url"]
|
287 |
+
disabled=True
|
288 |
# the card shows an image and a checkbox
|
289 |
+
cols[i%3].markdown(card(i, url), unsafe_allow_html=True)
|
290 |
# we access the values of the checkbox via st.session_state[f"input{i}"]
|
291 |
+
cols[i%3].slider(
|
292 |
+
"Relevance",
|
293 |
+
min_value=-1.0,
|
294 |
+
max_value=1.0,
|
295 |
+
value=0.0,
|
296 |
+
step=0.1,
|
297 |
+
key=f"input{i}",
|
298 |
+
disabled=disabled
|
299 |
+
)
|
link-check.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pinecone
|
2 |
+
import requests
|
3 |
+
from tqdm.auto import tqdm
|
4 |
+
import logging
|
5 |
+
|
6 |
+
# we run this to check for broken links
|
7 |
+
|
8 |
+
PINECONE_API_KEY = "<<API_KEY_HERE>>"
|
9 |
+
INDEX = "unsplash-25k-clip"
|
10 |
+
|
11 |
+
pinecone.init(
|
12 |
+
api_key=PINECONE_API_KEY,
|
13 |
+
environment="us-west1-gcp"
|
14 |
+
)
|
15 |
+
|
16 |
+
index = pinecone.Index(INDEX)
|
17 |
+
|
18 |
+
dim = index.describe_index_stats()['dimension']
|
19 |
+
total = int(index.describe_index_stats()['totalVectorCount'])
|
20 |
+
xq = [0.0] * dim
|
21 |
+
|
22 |
+
count = 0
|
23 |
+
ID_LIST = []
|
24 |
+
|
25 |
+
logging.info("Checking links...")
|
26 |
+
|
27 |
+
with tqdm(total=total) as pbar:
|
28 |
+
while True:
|
29 |
+
xc = index.query(
|
30 |
+
xq, top_k=100, include_metadata=True,
|
31 |
+
filter={"link_check": {"$ne": True}}
|
32 |
+
)
|
33 |
+
matches = xc['matches']
|
34 |
+
if len(matches) == 0:
|
35 |
+
break
|
36 |
+
for match in matches:
|
37 |
+
photo_url = match['metadata']['photo_url']+"/download?force=true&w=640"
|
38 |
+
res = requests.get(photo_url)
|
39 |
+
if res.status_code == 200:
|
40 |
+
good_url = "photo_url"
|
41 |
+
else:
|
42 |
+
res = requests.get(match['metadata']['photo_image_url'])
|
43 |
+
if res.status_code == 200:
|
44 |
+
good_url = "photo_image_url"
|
45 |
+
else:
|
46 |
+
good_url = "not_found"
|
47 |
+
index.update(match['id'], set_metadata={
|
48 |
+
'good_url': good_url,
|
49 |
+
'link_check': True
|
50 |
+
})
|
51 |
+
ID_LIST.append(match['id'])
|
52 |
+
pbar.update(1)
|
53 |
+
|
54 |
+
logging.info("Refreshing 'link_check' field...")
|
55 |
+
for _id in tqdm(ID_LIST):
|
56 |
+
index.update(_id, set_metadata={
|
57 |
+
'link_check': False
|
58 |
+
})
|
unsplash-25k-clip-indexer.ipynb
ADDED
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"!pip install transformers pinecone-client tqdm"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "markdown",
|
14 |
+
"metadata": {},
|
15 |
+
"source": [
|
16 |
+
"The dataset used is the [Unsplash Lite dataset](https://github.com/unsplash/datasets)."
|
17 |
+
]
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": 2,
|
22 |
+
"metadata": {},
|
23 |
+
"outputs": [
|
24 |
+
{
|
25 |
+
"data": {
|
26 |
+
"text/html": [
|
27 |
+
"<div>\n",
|
28 |
+
"<style scoped>\n",
|
29 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
30 |
+
" vertical-align: middle;\n",
|
31 |
+
" }\n",
|
32 |
+
"\n",
|
33 |
+
" .dataframe tbody tr th {\n",
|
34 |
+
" vertical-align: top;\n",
|
35 |
+
" }\n",
|
36 |
+
"\n",
|
37 |
+
" .dataframe thead th {\n",
|
38 |
+
" text-align: right;\n",
|
39 |
+
" }\n",
|
40 |
+
"</style>\n",
|
41 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
42 |
+
" <thead>\n",
|
43 |
+
" <tr style=\"text-align: right;\">\n",
|
44 |
+
" <th></th>\n",
|
45 |
+
" <th>photo_id</th>\n",
|
46 |
+
" <th>photo_url</th>\n",
|
47 |
+
" <th>photo_image_url</th>\n",
|
48 |
+
" <th>photo_submitted_at</th>\n",
|
49 |
+
" <th>photo_featured</th>\n",
|
50 |
+
" <th>photo_width</th>\n",
|
51 |
+
" <th>photo_height</th>\n",
|
52 |
+
" <th>photo_aspect_ratio</th>\n",
|
53 |
+
" <th>photo_description</th>\n",
|
54 |
+
" <th>photographer_username</th>\n",
|
55 |
+
" <th>...</th>\n",
|
56 |
+
" <th>photo_location_country</th>\n",
|
57 |
+
" <th>photo_location_city</th>\n",
|
58 |
+
" <th>stats_views</th>\n",
|
59 |
+
" <th>stats_downloads</th>\n",
|
60 |
+
" <th>ai_description</th>\n",
|
61 |
+
" <th>ai_primary_landmark_name</th>\n",
|
62 |
+
" <th>ai_primary_landmark_latitude</th>\n",
|
63 |
+
" <th>ai_primary_landmark_longitude</th>\n",
|
64 |
+
" <th>ai_primary_landmark_confidence</th>\n",
|
65 |
+
" <th>blur_hash</th>\n",
|
66 |
+
" </tr>\n",
|
67 |
+
" </thead>\n",
|
68 |
+
" <tbody>\n",
|
69 |
+
" <tr>\n",
|
70 |
+
" <th>0</th>\n",
|
71 |
+
" <td>XMyPniM9LF0</td>\n",
|
72 |
+
" <td>https://unsplash.com/photos/XMyPniM9LF0</td>\n",
|
73 |
+
" <td>https://images.unsplash.com/uploads/1411949294...</td>\n",
|
74 |
+
" <td>2014-09-29 00:08:38.594364</td>\n",
|
75 |
+
" <td>t</td>\n",
|
76 |
+
" <td>4272</td>\n",
|
77 |
+
" <td>2848</td>\n",
|
78 |
+
" <td>1.50</td>\n",
|
79 |
+
" <td>Woman exploring a forest</td>\n",
|
80 |
+
" <td>michellespencer77</td>\n",
|
81 |
+
" <td>...</td>\n",
|
82 |
+
" <td>NaN</td>\n",
|
83 |
+
" <td>NaN</td>\n",
|
84 |
+
" <td>2375421</td>\n",
|
85 |
+
" <td>6967</td>\n",
|
86 |
+
" <td>woman walking in the middle of forest</td>\n",
|
87 |
+
" <td>NaN</td>\n",
|
88 |
+
" <td>NaN</td>\n",
|
89 |
+
" <td>NaN</td>\n",
|
90 |
+
" <td>NaN</td>\n",
|
91 |
+
" <td>L56bVcRRIWMh.gVunlS4SMbsRRxr</td>\n",
|
92 |
+
" </tr>\n",
|
93 |
+
" <tr>\n",
|
94 |
+
" <th>1</th>\n",
|
95 |
+
" <td>rDLBArZUl1c</td>\n",
|
96 |
+
" <td>https://unsplash.com/photos/rDLBArZUl1c</td>\n",
|
97 |
+
" <td>https://images.unsplash.com/photo-141633941111...</td>\n",
|
98 |
+
" <td>2014-11-18 19:36:57.08945</td>\n",
|
99 |
+
" <td>t</td>\n",
|
100 |
+
" <td>3000</td>\n",
|
101 |
+
" <td>4000</td>\n",
|
102 |
+
" <td>0.75</td>\n",
|
103 |
+
" <td>Succulents in a terrarium</td>\n",
|
104 |
+
" <td>ugmonk</td>\n",
|
105 |
+
" <td>...</td>\n",
|
106 |
+
" <td>NaN</td>\n",
|
107 |
+
" <td>NaN</td>\n",
|
108 |
+
" <td>13784815</td>\n",
|
109 |
+
" <td>82141</td>\n",
|
110 |
+
" <td>succulent plants in clear glass terrarium</td>\n",
|
111 |
+
" <td>NaN</td>\n",
|
112 |
+
" <td>NaN</td>\n",
|
113 |
+
" <td>NaN</td>\n",
|
114 |
+
" <td>NaN</td>\n",
|
115 |
+
" <td>LvI$4txu%2s:_4t6WUj]xat7RPoe</td>\n",
|
116 |
+
" </tr>\n",
|
117 |
+
" <tr>\n",
|
118 |
+
" <th>2</th>\n",
|
119 |
+
" <td>cNDGZ2sQ3Bo</td>\n",
|
120 |
+
" <td>https://unsplash.com/photos/cNDGZ2sQ3Bo</td>\n",
|
121 |
+
" <td>https://images.unsplash.com/photo-142014251503...</td>\n",
|
122 |
+
" <td>2015-01-01 20:02:02.097036</td>\n",
|
123 |
+
" <td>t</td>\n",
|
124 |
+
" <td>2564</td>\n",
|
125 |
+
" <td>1710</td>\n",
|
126 |
+
" <td>1.50</td>\n",
|
127 |
+
" <td>Rural winter mountainside</td>\n",
|
128 |
+
" <td>johnprice</td>\n",
|
129 |
+
" <td>...</td>\n",
|
130 |
+
" <td>NaN</td>\n",
|
131 |
+
" <td>NaN</td>\n",
|
132 |
+
" <td>1302461</td>\n",
|
133 |
+
" <td>3428</td>\n",
|
134 |
+
" <td>rocky mountain under gray sky at daytime</td>\n",
|
135 |
+
" <td>NaN</td>\n",
|
136 |
+
" <td>NaN</td>\n",
|
137 |
+
" <td>NaN</td>\n",
|
138 |
+
" <td>NaN</td>\n",
|
139 |
+
" <td>LhMj%NxvM{t7_4t7aeoM%2M{ozj[</td>\n",
|
140 |
+
" </tr>\n",
|
141 |
+
" <tr>\n",
|
142 |
+
" <th>3</th>\n",
|
143 |
+
" <td>iuZ_D1eoq9k</td>\n",
|
144 |
+
" <td>https://unsplash.com/photos/iuZ_D1eoq9k</td>\n",
|
145 |
+
" <td>https://images.unsplash.com/photo-141487280988...</td>\n",
|
146 |
+
" <td>2014-11-01 20:15:13.410073</td>\n",
|
147 |
+
" <td>t</td>\n",
|
148 |
+
" <td>2912</td>\n",
|
149 |
+
" <td>4368</td>\n",
|
150 |
+
" <td>0.67</td>\n",
|
151 |
+
" <td>Poppy seeds and flowers</td>\n",
|
152 |
+
" <td>krisatomic</td>\n",
|
153 |
+
" <td>...</td>\n",
|
154 |
+
" <td>NaN</td>\n",
|
155 |
+
" <td>NaN</td>\n",
|
156 |
+
" <td>2890238</td>\n",
|
157 |
+
" <td>33704</td>\n",
|
158 |
+
" <td>red common poppy flower selective focus phography</td>\n",
|
159 |
+
" <td>NaN</td>\n",
|
160 |
+
" <td>NaN</td>\n",
|
161 |
+
" <td>NaN</td>\n",
|
162 |
+
" <td>NaN</td>\n",
|
163 |
+
" <td>LSC7DirZAsX7}Br@GEWWmnoLWCnj</td>\n",
|
164 |
+
" </tr>\n",
|
165 |
+
" <tr>\n",
|
166 |
+
" <th>4</th>\n",
|
167 |
+
" <td>BeD3vjQ8SI0</td>\n",
|
168 |
+
" <td>https://unsplash.com/photos/BeD3vjQ8SI0</td>\n",
|
169 |
+
" <td>https://images.unsplash.com/photo-141700759404...</td>\n",
|
170 |
+
" <td>2014-11-26 13:13:50.134383</td>\n",
|
171 |
+
" <td>t</td>\n",
|
172 |
+
" <td>4896</td>\n",
|
173 |
+
" <td>3264</td>\n",
|
174 |
+
" <td>1.50</td>\n",
|
175 |
+
" <td>Silhouette near dark trees</td>\n",
|
176 |
+
" <td>jonaseriksson</td>\n",
|
177 |
+
" <td>...</td>\n",
|
178 |
+
" <td>NaN</td>\n",
|
179 |
+
" <td>NaN</td>\n",
|
180 |
+
" <td>8704860</td>\n",
|
181 |
+
" <td>49662</td>\n",
|
182 |
+
" <td>trees during night time</td>\n",
|
183 |
+
" <td>NaN</td>\n",
|
184 |
+
" <td>NaN</td>\n",
|
185 |
+
" <td>NaN</td>\n",
|
186 |
+
" <td>NaN</td>\n",
|
187 |
+
" <td>L25|_:V@0hxtI=W;odae0ht6=^NG</td>\n",
|
188 |
+
" </tr>\n",
|
189 |
+
" </tbody>\n",
|
190 |
+
"</table>\n",
|
191 |
+
"<p>5 rows × 31 columns</p>\n",
|
192 |
+
"</div>"
|
193 |
+
],
|
194 |
+
"text/plain": [
|
195 |
+
" photo_id photo_url \\\n",
|
196 |
+
"0 XMyPniM9LF0 https://unsplash.com/photos/XMyPniM9LF0 \n",
|
197 |
+
"1 rDLBArZUl1c https://unsplash.com/photos/rDLBArZUl1c \n",
|
198 |
+
"2 cNDGZ2sQ3Bo https://unsplash.com/photos/cNDGZ2sQ3Bo \n",
|
199 |
+
"3 iuZ_D1eoq9k https://unsplash.com/photos/iuZ_D1eoq9k \n",
|
200 |
+
"4 BeD3vjQ8SI0 https://unsplash.com/photos/BeD3vjQ8SI0 \n",
|
201 |
+
"\n",
|
202 |
+
" photo_image_url \\\n",
|
203 |
+
"0 https://images.unsplash.com/uploads/1411949294... \n",
|
204 |
+
"1 https://images.unsplash.com/photo-141633941111... \n",
|
205 |
+
"2 https://images.unsplash.com/photo-142014251503... \n",
|
206 |
+
"3 https://images.unsplash.com/photo-141487280988... \n",
|
207 |
+
"4 https://images.unsplash.com/photo-141700759404... \n",
|
208 |
+
"\n",
|
209 |
+
" photo_submitted_at photo_featured photo_width photo_height \\\n",
|
210 |
+
"0 2014-09-29 00:08:38.594364 t 4272 2848 \n",
|
211 |
+
"1 2014-11-18 19:36:57.08945 t 3000 4000 \n",
|
212 |
+
"2 2015-01-01 20:02:02.097036 t 2564 1710 \n",
|
213 |
+
"3 2014-11-01 20:15:13.410073 t 2912 4368 \n",
|
214 |
+
"4 2014-11-26 13:13:50.134383 t 4896 3264 \n",
|
215 |
+
"\n",
|
216 |
+
" photo_aspect_ratio photo_description photographer_username ... \\\n",
|
217 |
+
"0 1.50 Woman exploring a forest michellespencer77 ... \n",
|
218 |
+
"1 0.75 Succulents in a terrarium ugmonk ... \n",
|
219 |
+
"2 1.50 Rural winter mountainside johnprice ... \n",
|
220 |
+
"3 0.67 Poppy seeds and flowers krisatomic ... \n",
|
221 |
+
"4 1.50 Silhouette near dark trees jonaseriksson ... \n",
|
222 |
+
"\n",
|
223 |
+
" photo_location_country photo_location_city stats_views stats_downloads \\\n",
|
224 |
+
"0 NaN NaN 2375421 6967 \n",
|
225 |
+
"1 NaN NaN 13784815 82141 \n",
|
226 |
+
"2 NaN NaN 1302461 3428 \n",
|
227 |
+
"3 NaN NaN 2890238 33704 \n",
|
228 |
+
"4 NaN NaN 8704860 49662 \n",
|
229 |
+
"\n",
|
230 |
+
" ai_description ai_primary_landmark_name \\\n",
|
231 |
+
"0 woman walking in the middle of forest NaN \n",
|
232 |
+
"1 succulent plants in clear glass terrarium NaN \n",
|
233 |
+
"2 rocky mountain under gray sky at daytime NaN \n",
|
234 |
+
"3 red common poppy flower selective focus phography NaN \n",
|
235 |
+
"4 trees during night time NaN \n",
|
236 |
+
"\n",
|
237 |
+
" ai_primary_landmark_latitude ai_primary_landmark_longitude \\\n",
|
238 |
+
"0 NaN NaN \n",
|
239 |
+
"1 NaN NaN \n",
|
240 |
+
"2 NaN NaN \n",
|
241 |
+
"3 NaN NaN \n",
|
242 |
+
"4 NaN NaN \n",
|
243 |
+
"\n",
|
244 |
+
" ai_primary_landmark_confidence blur_hash \n",
|
245 |
+
"0 NaN L56bVcRRIWMh.gVunlS4SMbsRRxr \n",
|
246 |
+
"1 NaN LvI$4txu%2s:_4t6WUj]xat7RPoe \n",
|
247 |
+
"2 NaN LhMj%NxvM{t7_4t7aeoM%2M{ozj[ \n",
|
248 |
+
"3 NaN LSC7DirZAsX7}Br@GEWWmnoLWCnj \n",
|
249 |
+
"4 NaN L25|_:V@0hxtI=W;odae0ht6=^NG \n",
|
250 |
+
"\n",
|
251 |
+
"[5 rows x 31 columns]"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
"execution_count": 2,
|
255 |
+
"metadata": {},
|
256 |
+
"output_type": "execute_result"
|
257 |
+
}
|
258 |
+
],
|
259 |
+
"source": [
|
260 |
+
"import pandas as pd\n",
|
261 |
+
"\n",
|
262 |
+
"images = pd.read_csv('photos.tsv000', delimiter='\\t')\n",
|
263 |
+
"images.head()"
|
264 |
+
]
|
265 |
+
},
|
266 |
+
{
|
267 |
+
"cell_type": "markdown",
|
268 |
+
"metadata": {},
|
269 |
+
"source": [
|
270 |
+
"We download using the `photo_image_url` field."
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": 4,
|
276 |
+
"metadata": {},
|
277 |
+
"outputs": [],
|
278 |
+
"source": [
|
279 |
+
"from PIL import Image\n",
|
280 |
+
"import requests\n",
|
281 |
+
"from io import BytesIO\n",
|
282 |
+
"\n",
|
283 |
+
"url = images['photo_image_url'].iloc[0]\n",
|
284 |
+
"\n",
|
285 |
+
"response = requests.get(url)\n",
|
286 |
+
"img = Image.open(BytesIO(response.content))\n",
|
287 |
+
"img"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"cell_type": "markdown",
|
292 |
+
"metadata": {},
|
293 |
+
"source": [
|
294 |
+
"We need to use these images to create vector embeddings, to do this we will use OpenAI's CLIP from the `transformers` library.\n",
|
295 |
+
"\n",
|
296 |
+
"```\n",
|
297 |
+
"!pip install transformers\n",
|
298 |
+
"```"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": 5,
|
304 |
+
"metadata": {},
|
305 |
+
"outputs": [
|
306 |
+
{
|
307 |
+
"name": "stderr",
|
308 |
+
"output_type": "stream",
|
309 |
+
"text": [
|
310 |
+
"2022-08-12 14:07:47.935784: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n",
|
311 |
+
"ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.\n"
|
312 |
+
]
|
313 |
+
}
|
314 |
+
],
|
315 |
+
"source": [
|
316 |
+
"from transformers import CLIPProcessor, CLIPModel\n",
|
317 |
+
"import torch\n",
|
318 |
+
"\n",
|
319 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
320 |
+
"model_name = \"openai/clip-vit-base-patch32\"\n",
|
321 |
+
"\n",
|
322 |
+
"model = CLIPModel.from_pretrained(model_name).to(device)\n",
|
323 |
+
"processor = CLIPProcessor.from_pretrained(model_name)"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "markdown",
|
328 |
+
"metadata": {},
|
329 |
+
"source": [
|
330 |
+
"Now we're ready to use the vision transformer (ViT) portion of CLIP to create feature vectors (embedding representations) from the image."
|
331 |
+
]
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "code",
|
335 |
+
"execution_count": 6,
|
336 |
+
"metadata": {},
|
337 |
+
"outputs": [],
|
338 |
+
"source": [
|
339 |
+
"img = processor(\n",
|
340 |
+
" text=None,\n",
|
341 |
+
" images=img,\n",
|
342 |
+
" return_tensors='pt',\n",
|
343 |
+
" padding=True\n",
|
344 |
+
")['pixel_values'].to(device)"
|
345 |
+
]
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"cell_type": "code",
|
349 |
+
"execution_count": 7,
|
350 |
+
"metadata": {},
|
351 |
+
"outputs": [
|
352 |
+
{
|
353 |
+
"data": {
|
354 |
+
"text/plain": [
|
355 |
+
"torch.Size([1, 512])"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
"execution_count": 7,
|
359 |
+
"metadata": {},
|
360 |
+
"output_type": "execute_result"
|
361 |
+
}
|
362 |
+
],
|
363 |
+
"source": [
|
364 |
+
"out = model.get_image_features(pixel_values=img)\n",
|
365 |
+
"out.shape"
|
366 |
+
]
|
367 |
+
},
|
368 |
+
{
|
369 |
+
"cell_type": "code",
|
370 |
+
"execution_count": 8,
|
371 |
+
"metadata": {},
|
372 |
+
"outputs": [
|
373 |
+
{
|
374 |
+
"data": {
|
375 |
+
"text/plain": [
|
376 |
+
"torch.Size([512])"
|
377 |
+
]
|
378 |
+
},
|
379 |
+
"execution_count": 8,
|
380 |
+
"metadata": {},
|
381 |
+
"output_type": "execute_result"
|
382 |
+
}
|
383 |
+
],
|
384 |
+
"source": [
|
385 |
+
"out = out.squeeze(0)\n",
|
386 |
+
"out.shape"
|
387 |
+
]
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"cell_type": "code",
|
391 |
+
"execution_count": 9,
|
392 |
+
"metadata": {},
|
393 |
+
"outputs": [
|
394 |
+
{
|
395 |
+
"data": {
|
396 |
+
"text/plain": [
|
397 |
+
"(512,)"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
"execution_count": 9,
|
401 |
+
"metadata": {},
|
402 |
+
"output_type": "execute_result"
|
403 |
+
}
|
404 |
+
],
|
405 |
+
"source": [
|
406 |
+
"emb = out.cpu().detach().numpy()\n",
|
407 |
+
"emb.shape"
|
408 |
+
]
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"cell_type": "code",
|
412 |
+
"execution_count": 10,
|
413 |
+
"metadata": {},
|
414 |
+
"outputs": [
|
415 |
+
{
|
416 |
+
"data": {
|
417 |
+
"text/plain": [
|
418 |
+
"(-7.985501, 2.0108054)"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
"execution_count": 10,
|
422 |
+
"metadata": {},
|
423 |
+
"output_type": "execute_result"
|
424 |
+
}
|
425 |
+
],
|
426 |
+
"source": [
|
427 |
+
"emb.min(), emb.max()"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"cell_type": "markdown",
|
432 |
+
"metadata": {},
|
433 |
+
"source": [
|
434 |
+
"Now we have a single `512` dimensional vector that represents the *meaning* behind the image. As we will be using dot product similarity we should also normalize these vectors."
|
435 |
+
]
|
436 |
+
},
|
437 |
+
{
|
438 |
+
"cell_type": "code",
|
439 |
+
"execution_count": 10,
|
440 |
+
"metadata": {},
|
441 |
+
"outputs": [],
|
442 |
+
"source": [
|
443 |
+
"import numpy as np\n",
|
444 |
+
"\n",
|
445 |
+
"emb = emb / np.linalg.norm(emb)"
|
446 |
+
]
|
447 |
+
},
|
448 |
+
{
|
449 |
+
"cell_type": "code",
|
450 |
+
"execution_count": 11,
|
451 |
+
"metadata": {},
|
452 |
+
"outputs": [
|
453 |
+
{
|
454 |
+
"data": {
|
455 |
+
"text/plain": [
|
456 |
+
"(-0.56626415, 0.13343191)"
|
457 |
+
]
|
458 |
+
},
|
459 |
+
"execution_count": 11,
|
460 |
+
"metadata": {},
|
461 |
+
"output_type": "execute_result"
|
462 |
+
}
|
463 |
+
],
|
464 |
+
"source": [
|
465 |
+
"emb.min(), emb.max()"
|
466 |
+
]
|
467 |
+
},
|
468 |
+
{
|
469 |
+
"cell_type": "markdown",
|
470 |
+
"metadata": {},
|
471 |
+
"source": [
|
472 |
+
"## Indexing\n",
|
473 |
+
"\n",
|
474 |
+
"To index this image in Pinecone we first install the Pinecone client:\n",
|
475 |
+
"\n",
|
476 |
+
"```\n",
|
477 |
+
"!pip install pinecone-client\n",
|
478 |
+
"```\n",
|
479 |
+
"\n",
|
480 |
+
"And then initialize our connection to Pinecone, this requires a [free API key](https://app.pinecone.io/)."
|
481 |
+
]
|
482 |
+
},
|
483 |
+
{
|
484 |
+
"cell_type": "code",
|
485 |
+
"execution_count": 11,
|
486 |
+
"metadata": {},
|
487 |
+
"outputs": [],
|
488 |
+
"source": [
|
489 |
+
"import pinecone\n",
|
490 |
+
"\n",
|
491 |
+
"index_name = \"unsplash-25k-clip\"\n",
|
492 |
+
"\n",
|
493 |
+
"pinecone.init(\n",
|
494 |
+
" api_key=\"<<API_KEY_HERE>>\",\n",
|
495 |
+
" environment=\"us-west1-gcp\"\n",
|
496 |
+
")\n",
|
497 |
+
"\n",
|
498 |
+
"if index_name not in pinecone.list_indexes():\n",
|
499 |
+
" pinecone.create_index(\n",
|
500 |
+
" index_name,\n",
|
501 |
+
" emb.shape[0],\n",
|
502 |
+
" metric=\"dotproduct\"\n",
|
503 |
+
" )\n",
|
504 |
+
"# connect to the index\n",
|
505 |
+
"index = pinecone.Index(index_name)"
|
506 |
+
]
|
507 |
+
},
|
508 |
+
{
|
509 |
+
"cell_type": "markdown",
|
510 |
+
"metadata": {},
|
511 |
+
"source": [
|
512 |
+
"To upsert the single feature embedding we have created, we use `upsert`. There are also some possibly relevant metadata info we might want to add."
|
513 |
+
]
|
514 |
+
},
|
515 |
+
{
|
516 |
+
"cell_type": "code",
|
517 |
+
"execution_count": 13,
|
518 |
+
"metadata": {},
|
519 |
+
"outputs": [
|
520 |
+
{
|
521 |
+
"data": {
|
522 |
+
"text/plain": [
|
523 |
+
"{'photo_url': 'https://unsplash.com/photos/XMyPniM9LF0',\n",
|
524 |
+
" 'photo_image_url': 'https://images.unsplash.com/uploads/14119492946973137ce46/f1f2ebf3',\n",
|
525 |
+
" 'photo_submitted_at': '2014-09-29 00:08:38.594364',\n",
|
526 |
+
" 'photo_description': 'Woman exploring a forest',\n",
|
527 |
+
" 'photographer_username': 'michellespencer77',\n",
|
528 |
+
" 'ai_description': 'woman walking in the middle of forest'}"
|
529 |
+
]
|
530 |
+
},
|
531 |
+
"execution_count": 13,
|
532 |
+
"metadata": {},
|
533 |
+
"output_type": "execute_result"
|
534 |
+
}
|
535 |
+
],
|
536 |
+
"source": [
|
537 |
+
"row = images.iloc[0]\n",
|
538 |
+
"\n",
|
539 |
+
"_id = row['photo_id']\n",
|
540 |
+
"meta = {\n",
|
541 |
+
" \"photo_url\": row[\"photo_url\"],\n",
|
542 |
+
" \"photo_image_url\": row[\"photo_image_url\"],\n",
|
543 |
+
" \"photo_submitted_at\": row[\"photo_submitted_at\"],\n",
|
544 |
+
" \"photo_description\": row[\"photo_description\"],\n",
|
545 |
+
" \"photographer_username\": row[\"photographer_username\"],\n",
|
546 |
+
" \"ai_description\": row[\"ai_description\"]\n",
|
547 |
+
"}\n",
|
548 |
+
"\n",
|
549 |
+
"meta"
|
550 |
+
]
|
551 |
+
},
|
552 |
+
{
|
553 |
+
"cell_type": "code",
|
554 |
+
"execution_count": 14,
|
555 |
+
"metadata": {},
|
556 |
+
"outputs": [
|
557 |
+
{
|
558 |
+
"data": {
|
559 |
+
"text/plain": [
|
560 |
+
"{'upserted_count': 1}"
|
561 |
+
]
|
562 |
+
},
|
563 |
+
"execution_count": 14,
|
564 |
+
"metadata": {},
|
565 |
+
"output_type": "execute_result"
|
566 |
+
}
|
567 |
+
],
|
568 |
+
"source": [
|
569 |
+
"to_upsert = [(_id, emb.tolist(), meta)]\n",
|
570 |
+
"\n",
|
571 |
+
"index.upsert(to_upsert)"
|
572 |
+
]
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"cell_type": "code",
|
576 |
+
"execution_count": 15,
|
577 |
+
"metadata": {},
|
578 |
+
"outputs": [
|
579 |
+
{
|
580 |
+
"data": {
|
581 |
+
"text/plain": [
|
582 |
+
"'XMyPniM9LF0'"
|
583 |
+
]
|
584 |
+
},
|
585 |
+
"execution_count": 15,
|
586 |
+
"metadata": {},
|
587 |
+
"output_type": "execute_result"
|
588 |
+
}
|
589 |
+
],
|
590 |
+
"source": [
|
591 |
+
"_id"
|
592 |
+
]
|
593 |
+
},
|
594 |
+
{
|
595 |
+
"cell_type": "markdown",
|
596 |
+
"metadata": {
|
597 |
+
"tags": []
|
598 |
+
},
|
599 |
+
"source": [
|
600 |
+
"Note that we added a string ID value `\"XMyPniM9LF0\"` and also converted the feature embedding tensor to a flat list before adding to our Pinecone index.\n",
|
601 |
+
"\n",
|
602 |
+
"## Indexing Everything\n",
|
603 |
+
"\n",
|
604 |
+
"So far we've built one feature embedding and indexed it in Pinecone, now let's repeat the process for a lot of images.\n",
|
605 |
+
"\n",
|
606 |
+
"We will do this in batches, taking `32` images at a time, embedding them with Resnet-50, and indexing them in Pinecone."
|
607 |
+
]
|
608 |
+
},
|
609 |
+
{
|
610 |
+
"cell_type": "code",
|
611 |
+
"execution_count": 23,
|
612 |
+
"metadata": {},
|
613 |
+
"outputs": [],
|
614 |
+
"source": [
|
615 |
+
"import numpy as np"
|
616 |
+
]
|
617 |
+
},
|
618 |
+
{
|
619 |
+
"cell_type": "code",
|
620 |
+
"execution_count": 79,
|
621 |
+
"metadata": {},
|
622 |
+
"outputs": [
|
623 |
+
{
|
624 |
+
"data": {
|
625 |
+
"application/vnd.jupyter.widget-view+json": {
|
626 |
+
"model_id": "6726c0eb47de4cd780f3e1096a2d743f",
|
627 |
+
"version_major": 2,
|
628 |
+
"version_minor": 0
|
629 |
+
},
|
630 |
+
"text/plain": [
|
631 |
+
" 0%| | 0/1370 [00:00<?, ?it/s]"
|
632 |
+
]
|
633 |
+
},
|
634 |
+
"metadata": {},
|
635 |
+
"output_type": "display_data"
|
636 |
+
},
|
637 |
+
{
|
638 |
+
"name": "stderr",
|
639 |
+
"output_type": "stream",
|
640 |
+
"text": [
|
641 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (99996755 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
642 |
+
" DecompressionBombWarning,\n",
|
643 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (96768910 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
644 |
+
" DecompressionBombWarning,\n",
|
645 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (99991727 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
646 |
+
" DecompressionBombWarning,\n",
|
647 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (143040000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
648 |
+
" DecompressionBombWarning,\n",
|
649 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (94212096 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
650 |
+
" DecompressionBombWarning,\n",
|
651 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (121500000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
652 |
+
" DecompressionBombWarning,\n",
|
653 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (107424768 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
654 |
+
" DecompressionBombWarning,\n",
|
655 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (147015000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
656 |
+
" DecompressionBombWarning,\n",
|
657 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (107184040 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
658 |
+
" DecompressionBombWarning,\n",
|
659 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (146784000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
660 |
+
" DecompressionBombWarning,\n",
|
661 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (90671520 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
662 |
+
" DecompressionBombWarning,\n",
|
663 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (99992815 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
664 |
+
" DecompressionBombWarning,\n",
|
665 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (95808000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
666 |
+
" DecompressionBombWarning,\n",
|
667 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (121554000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
668 |
+
" DecompressionBombWarning,\n",
|
669 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (91177320 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
670 |
+
" DecompressionBombWarning,\n",
|
671 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (99996120 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
672 |
+
" DecompressionBombWarning,\n",
|
673 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (96000000 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
674 |
+
" DecompressionBombWarning,\n",
|
675 |
+
"/opt/conda/lib/python3.7/site-packages/PIL/Image.py:2899: DecompressionBombWarning: Image size (98058240 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
|
676 |
+
" DecompressionBombWarning,\n"
|
677 |
+
]
|
678 |
+
}
|
679 |
+
],
|
680 |
+
"source": [
|
681 |
+
"from tqdm.auto import tqdm\n",
|
682 |
+
"batch_size = 16\n",
|
683 |
+
"images_len = len(images)\n",
|
684 |
+
"\n",
|
685 |
+
"exceptions = []\n",
|
686 |
+
"\n",
|
687 |
+
"for i in tqdm(range(3088, images_len, batch_size)):\n",
|
688 |
+
" # select the batch start and end\n",
|
689 |
+
" i_end = min(i + batch_size, images_len)\n",
|
690 |
+
" # get batch\n",
|
691 |
+
" batch = images.iloc[i:i_end]\n",
|
692 |
+
" # retrieve URLs\n",
|
693 |
+
" url_batch = batch['photo_image_url']\n",
|
694 |
+
" # get images\n",
|
695 |
+
" image_batch = []\n",
|
696 |
+
" for url in url_batch:\n",
|
697 |
+
" try:\n",
|
698 |
+
" response = requests.get(url)\n",
|
699 |
+
" img = Image.open(BytesIO(response.content))\n",
|
700 |
+
" if img.mode in ['L', 'CMYK', 'RGBA']:\n",
|
701 |
+
" # L is grayscale, CMYK uses alternative color channels\n",
|
702 |
+
" img = img.convert('RGB')\n",
|
703 |
+
" image_batch.append(img)\n",
|
704 |
+
" except Exception as e:\n",
|
705 |
+
" exceptions.append((\"url\", e))\n",
|
706 |
+
" # process images and extract pytorch tensor pixel values\n",
|
707 |
+
" try:\n",
|
708 |
+
" image_batch = processor(\n",
|
709 |
+
" text=None,\n",
|
710 |
+
" images=image_batch,\n",
|
711 |
+
" return_tensors='pt'\n",
|
712 |
+
" )['pixel_values'].to(device)\n",
|
713 |
+
" # feed tensors to model and extract last state\n",
|
714 |
+
" out = model.get_image_features(pixel_values=image_batch)\n",
|
715 |
+
" out = out.squeeze(0)\n",
|
716 |
+
" # take the mean across each dimension to create a single vector embedding\n",
|
717 |
+
" embeds = out.cpu().detach().numpy()\n",
|
718 |
+
" # normalize and convert to list\n",
|
719 |
+
" embeds = embeds / np.linalg.norm(embeds, axis=0)\n",
|
720 |
+
" embeds = embeds.tolist()\n",
|
721 |
+
" # get ID values\n",
|
722 |
+
" ids = batch['photo_id']\n",
|
723 |
+
" # prep metadata\n",
|
724 |
+
" metadata = batch[[\n",
|
725 |
+
" \"photo_url\", \"photo_image_url\", \"photo_submitted_at\",\n",
|
726 |
+
" \"photo_description\", \"photographer_username\", \"ai_description\"\n",
|
727 |
+
" ]].fillna(\"\").to_dict(orient=\"records\")\n",
|
728 |
+
" # zip all data together and upsert\n",
|
729 |
+
" to_upsert = zip(ids, embeds, metadata)\n",
|
730 |
+
" index.upsert(to_upsert)\n",
|
731 |
+
" except Exception as e:\n",
|
732 |
+
" exceptions.append((\"process\", e))"
|
733 |
+
]
|
734 |
+
},
|
735 |
+
{
|
736 |
+
"cell_type": "markdown",
|
737 |
+
"metadata": {},
|
738 |
+
"source": [
|
739 |
+
"---"
|
740 |
+
]
|
741 |
+
}
|
742 |
+
],
|
743 |
+
"metadata": {
|
744 |
+
"environment": {
|
745 |
+
"kernel": "python3",
|
746 |
+
"name": "common-cu110.m91",
|
747 |
+
"type": "gcloud",
|
748 |
+
"uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
|
749 |
+
},
|
750 |
+
"kernelspec": {
|
751 |
+
"display_name": "Python 3",
|
752 |
+
"language": "python",
|
753 |
+
"name": "python3"
|
754 |
+
},
|
755 |
+
"language_info": {
|
756 |
+
"codemirror_mode": {
|
757 |
+
"name": "ipython",
|
758 |
+
"version": 3
|
759 |
+
},
|
760 |
+
"file_extension": ".py",
|
761 |
+
"mimetype": "text/x-python",
|
762 |
+
"name": "python",
|
763 |
+
"nbconvert_exporter": "python",
|
764 |
+
"pygments_lexer": "ipython3",
|
765 |
+
"version": "3.8.13"
|
766 |
+
},
|
767 |
+
"vscode": {
|
768 |
+
"interpreter": {
|
769 |
+
"hash": "9ec8fc8fb845fc3e050bf8bf651a355c069bbfeddee31167baf4bc42b6050476"
|
770 |
+
}
|
771 |
+
}
|
772 |
+
},
|
773 |
+
"nbformat": 4,
|
774 |
+
"nbformat_minor": 4
|
775 |
+
}
|