sohojoe commited on
Commit
0441b41
·
1 Parent(s): 7bef5db

vision experiments

Browse files
experimental/images/plant-001.jpeg ADDED
experimental/images/plant-002.jpeg ADDED
experimental/vision001.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor, as_completed
2
+ import json
3
+ import os
4
+ import time
5
+
6
+ import numpy as np
7
+ import requests
8
+ import torch
9
+
10
+ from clip_app_client import ClipAppClient
11
+ from clip_retrieval.clip_client import ClipClient, Modality
12
+ clip_retrieval_service_url = "https://knn.laion.ai/knn-service"
13
+ map_clip_to_clip_retreval = {
14
+ "ViT-L/14": "laion5B-L-14",
15
+ }
16
+
17
+
18
+ def safe_url(url):
19
+ import urllib.parse
20
+ url = urllib.parse.quote(url, safe=':/')
21
+ # if url has two .jpg filenames, take the first one
22
+ if url.count('.jpg') > 0:
23
+ url = url.split('.jpg')[0] + '.jpg'
24
+ return url
25
+
26
+ # test_image_path = os.path.join(os.getcwd(), "images", "plant-001.png")
27
+ test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-001.jpeg")
28
+ # test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg")
29
+ # test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg")
30
+ # test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "car-002.jpeg")
31
+
32
+ app_client = ClipAppClient()
33
+ clip_retrieval_client = ClipClient(
34
+ url=clip_retrieval_service_url,
35
+ indice_name=map_clip_to_clip_retreval[app_client.clip_model],
36
+ # use_safety_model = False,
37
+ # use_violence_detector = False,
38
+ # use_mclip = False,
39
+ num_images = 300,
40
+ # modality = Modality.TEXT,
41
+ # modality = Modality.TEXT,
42
+ )
43
+ preprocessed_image = app_client.preprocess_image(test_image_path)
44
+ preprocessed_image_embeddings = app_client.preprocessed_image_to_embedding(preprocessed_image)
45
+ print (f"embeddings: {preprocessed_image_embeddings.shape}")
46
+
47
+ embedding_as_list = preprocessed_image_embeddings[0].tolist()
48
+ results = clip_retrieval_client.query(embedding_input=embedding_as_list)
49
+
50
+ # hints = ""
51
+ # for result in results:
52
+ # url = safe_url(result["url"])
53
+ # similarty = float("{:.4f}".format(result["similarity"]))
54
+ # title = result["caption"]
55
+ # print (f"{similarty} \"{title}\" {url}")
56
+ # if len(hints) > 0:
57
+ # hints += f", \"{title}\""
58
+ # else:
59
+ # hints += f"\"{title}\""
60
+ # print("---")
61
+ # print(hints)
62
+
63
+ image_labels = [r['caption'] for r in results]
64
+ image_label_vectors = [app_client.text_to_embedding(label) for label in image_labels]
65
+ image_label_vectors = torch.cat(image_label_vectors, dim=0)
66
+ dot_product = torch.mm(image_label_vectors, preprocessed_image_embeddings.T)
67
+ similarity_image_label = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))]
68
+ similarity_image_label.sort(reverse=True)
69
+ for similarity, image_label in similarity_image_label:
70
+ print (f"{similarity} {image_label}")
71
+
72
+ print (f"----\n")
73
+
74
+ # now do the same for images
75
+ def _safe_image_url_to_embedding(url, safe_return):
76
+ try:
77
+ return app_client.image_url_to_embedding(url)
78
+ except:
79
+ return safe_return
80
+ image_urls = [safe_url(r['url']) for r in results]
81
+ image_vectors = [_safe_image_url_to_embedding(url, preprocessed_image_embeddings * 0) for url in image_urls]
82
+ image_vectors = torch.cat(image_vectors, dim=0)
83
+ dot_product = torch.mm(image_vectors, preprocessed_image_embeddings.T)
84
+ similarity_image = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))]
85
+ similarity_image.sort(reverse=True)
86
+ for similarity, image_label in similarity_image:
87
+ print (f"{similarity} {image_label}")
88
+
89
+ def mean_template(embeddings):
90
+ template = torch.mean(embeddings, dim=0, keepdim=True)
91
+ return template
92
+
93
+ def principal_component_analysis_template(embeddings):
94
+ mean = torch.mean(embeddings, dim=0)
95
+ embeddings_centered = embeddings - mean # Subtract the mean
96
+ u, s, v = torch.svd(embeddings_centered) # Perform SVD
97
+ template = u[:, 0] # The first column of u gives the first principal component
98
+ return template
99
+
100
+ def clustering_templates(embeddings, n_clusters=5):
101
+ from sklearn.cluster import KMeans
102
+ import numpy as np
103
+
104
+ kmeans = KMeans(n_clusters=n_clusters)
105
+ embeddings_np = embeddings.numpy() # Convert to numpy
106
+ clusters = kmeans.fit_predict(embeddings_np)
107
+
108
+ templates = []
109
+ for cluster in np.unique(clusters):
110
+ cluster_mean = np.mean(embeddings_np[clusters == cluster], axis=0)
111
+ templates.append(torch.from_numpy(cluster_mean)) # Convert back to tensor
112
+ return templates
113
+
114
+ # create a templates using clustering
115
+ print(f"create a templates using clustering")
116
+ merged_embeddings = torch.cat([image_label_vectors, image_vectors], dim=0)
117
+ clusters = clustering_templates(merged_embeddings, n_clusters=5)
118
+ # convert from list to 2d matrix
119
+ clusters = torch.stack(clusters, dim=0)
120
+ dot_product = torch.mm(clusters, preprocessed_image_embeddings.T)
121
+ cluster_similarity = [(float("{:.4f}".format(dot_product[i][0])), i) for i in range(len(clusters))]
122
+ cluster_similarity.sort(reverse=True)
123
+ for similarity, idx in cluster_similarity:
124
+ print (f"{similarity} {idx}")
125
+ # template = highest scoring cluster
126
+ # template = clusters[cluster_similarity[0][1]]
127
+ template = preprocessed_image_embeddings * (len(clusters)-1)
128
+ for i in range(1, len(clusters)):
129
+ template -= clusters[cluster_similarity[i][1]]
130
+ print("---")
131
+ print(f"seaching based on template")
132
+ results = clip_retrieval_client.query(embedding_input=template[0].tolist())
133
+ hints = ""
134
+ for result in results:
135
+ url = safe_url(result["url"])
136
+ similarty = float("{:.4f}".format(result["similarity"]))
137
+ title = result["caption"]
138
+ print (f"{similarty} \"{title}\" {url}")
139
+ if len(hints) > 0:
140
+ hints += f", \"{title}\""
141
+ else:
142
+ hints += f"\"{title}\""
143
+ print(hints)
144
+
145
+
146
+ # cluster_num = 1
147
+ # for template in clusters:
148
+ # print("---")
149
+ # print(f"cluster {cluster_num} of {len(clusters)}")
150
+ # results = clip_retrieval_client.query(embedding_input=template.tolist())
151
+ # hints = ""
152
+ # for result in results:
153
+ # url = safe_url(result["url"])
154
+ # similarty = float("{:.4f}".format(result["similarity"]))
155
+ # title = result["caption"]
156
+ # print (f"{similarty} \"{title}\" {url}")
157
+ # if len(hints) > 0:
158
+ # hints += f", \"{title}\""
159
+ # else:
160
+ # hints += f"\"{title}\""
161
+ # print(hints)
162
+ # cluster_num += 1
163
+
164
+
165
+ # create a template
166
+ # mean
167
+ # image_label_template = mean_template(image_label_vectors)
168
+ # image_template = mean_template(image_vectors)
169
+ # pca
170
+ # image_label_template = principal_component_analysis_template(image_label_vectors)
171
+ # image_template = principal_component_analysis_template(image_vectors)
172
+ # clustering
173
+ # image_label_template = clustering_template(image_label_vectors)
174
+ # image_template = clustering_template(image_vectors)
175
+
176
+ # take the embedding and subtract the template
177
+ # image_label_template = preprocessed_image_embeddings - image_label_template
178
+ # image_template = preprocessed_image_embeddings - image_template
179
+ # image_label_template = image_label_template - preprocessed_image_embeddings
180
+ # image_template = image_template - preprocessed_image_embeddings
181
+ # normalize
182
+ # image_label_template = image_label_template / image_label_template.norm()
183
+ # image_template = image_template / image_template.norm()
184
+
185
+ # results = clip_retrieval_client.query(embedding_input=image_label_template[0].tolist())
186
+ # hints = ""
187
+ # print("---")
188
+ # print("average of image labels")
189
+ # for result in results:
190
+ # url = safe_url(result["url"])
191
+ # similarty = float("{:.4f}".format(result["similarity"]))
192
+ # title = result["caption"]
193
+ # print (f"{similarty} \"{title}\" {url}")
194
+ # if len(hints) > 0:
195
+ # hints += f", \"{title}\""
196
+ # else:
197
+ # hints += f"\"{title}\""
198
+ # print(hints)
199
+
200
+ # print("---")
201
+ # print("average of images")
202
+ # results = clip_retrieval_client.query(embedding_input=image_template[0].tolist())
203
+ # hints = ""
204
+ # for result in results:
205
+ # url = safe_url(result["url"])
206
+ # similarty = float("{:.4f}".format(result["similarity"]))
207
+ # title = result["caption"]
208
+ # print (f"{similarty} \"{title}\" {url}")
209
+ # if len(hints) > 0:
210
+ # hints += f", \"{title}\""
211
+ # else:
212
+ # hints += f"\"{title}\""
213
+ # print(hints)
experimental/vision002.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor, as_completed
2
+ import json
3
+ import os
4
+ import time
5
+
6
+ import numpy as np
7
+ import requests
8
+ import torch
9
+
10
+ from clip_app_client import ClipAppClient
11
+ from clip_retrieval.clip_client import ClipClient, Modality
12
+ clip_retrieval_service_url = "https://knn.laion.ai/knn-service"
13
+ map_clip_to_clip_retreval = {
14
+ "ViT-L/14": "laion5B-L-14",
15
+ }
16
+
17
+
18
+ def safe_url(url):
19
+ import urllib.parse
20
+ url = urllib.parse.quote(url, safe=':/')
21
+ # if url has two .jpg filenames, take the first one
22
+ if url.count('.jpg') > 0:
23
+ url = url.split('.jpg')[0] + '.jpg'
24
+ return url
25
+
26
+ def _safe_image_url_to_embedding(url, safe_return):
27
+ try:
28
+ return app_client.image_url_to_embedding(url)
29
+ except:
30
+ return safe_return
31
+
32
+ def mean_template(embeddings):
33
+ template = torch.mean(embeddings, dim=0, keepdim=True)
34
+ return template
35
+
36
+ def principal_component_analysis_template(embeddings):
37
+ mean = torch.mean(embeddings, dim=0)
38
+ embeddings_centered = embeddings - mean # Subtract the mean
39
+ u, s, v = torch.svd(embeddings_centered) # Perform SVD
40
+ template = u[:, 0] # The first column of u gives the first principal component
41
+ return template
42
+
43
+ def clustering_templates(embeddings, n_clusters=5):
44
+ from sklearn.cluster import KMeans
45
+ import numpy as np
46
+
47
+ kmeans = KMeans(n_clusters=n_clusters)
48
+ embeddings_np = embeddings.numpy() # Convert to numpy
49
+ clusters = kmeans.fit_predict(embeddings_np)
50
+
51
+ templates = []
52
+ for cluster in np.unique(clusters):
53
+ cluster_mean = np.mean(embeddings_np[clusters == cluster], axis=0)
54
+ templates.append(torch.from_numpy(cluster_mean)) # Convert back to tensor
55
+ return templates
56
+
57
+ # test_image_path = os.path.join(os.getcwd(), "images", "plant-001.png")
58
+ test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-001.jpeg")
59
+ # test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg")
60
+ # test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg")
61
+ # test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "car-002.jpeg")
62
+
63
+ app_client = ClipAppClient()
64
+ clip_retrieval_client = ClipClient(
65
+ url=clip_retrieval_service_url,
66
+ indice_name=map_clip_to_clip_retreval[app_client.clip_model],
67
+ # use_safety_model = False,
68
+ # use_violence_detector = False,
69
+ # use_mclip = False,
70
+ # num_images = 300,
71
+ # modality = Modality.TEXT,
72
+ # modality = Modality.TEXT,
73
+ )
74
+ preprocessed_image = app_client.preprocess_image(test_image_path)
75
+ preprocessed_image_embeddings = app_client.preprocessed_image_to_embedding(preprocessed_image)
76
+
77
+ print (f"embeddings: {preprocessed_image_embeddings.shape}")
78
+
79
+
80
+ template = preprocessed_image_embeddings
81
+ for step_num in range(3):
82
+ print (f"\n\n---- Step {step_num} ----")
83
+
84
+ embedding_as_list = template[0].tolist()
85
+ results = clip_retrieval_client.query(embedding_input=embedding_as_list)
86
+
87
+ # get best matching labels
88
+ image_labels = [r['caption'] for r in results]
89
+ image_label_vectors = [app_client.text_to_embedding(label) for label in image_labels]
90
+ image_label_vectors = torch.cat(image_label_vectors, dim=0)
91
+ dot_product = torch.mm(image_label_vectors, preprocessed_image_embeddings.T)
92
+ similarity_image_label = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))]
93
+ similarity_image_label.sort(reverse=True)
94
+ for similarity, image_label in similarity_image_label:
95
+ print (f"{similarity} {image_label}")
96
+
97
+ # now do the same for images
98
+ image_urls = [safe_url(r['url']) for r in results]
99
+ image_vectors = [_safe_image_url_to_embedding(url, preprocessed_image_embeddings * 0) for url in image_urls]
100
+ image_vectors = torch.cat(image_vectors, dim=0)
101
+ dot_product = torch.mm(image_vectors, preprocessed_image_embeddings.T)
102
+ similarity_image = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))]
103
+ similarity_image.sort(reverse=True)
104
+ for similarity, image_label in similarity_image:
105
+ print (f"{similarity} {image_label}")
106
+ # remove images with low similarity as these will be images that did not load
107
+ image_vectors = torch.stack([image_vectors[i] for i in range(len(image_vectors)) if similarity_image[i][0] > 0.001], dim=0)
108
+
109
+ # create a templates using clustering
110
+ print(f"create a templates using clustering")
111
+ merged_embeddings = torch.cat([image_label_vectors, image_vectors], dim=0)
112
+ # merged_embeddings = image_label_vectors # only use labels
113
+ # merged_embeddings = image_vectors # only use images
114
+ clusters = clustering_templates(merged_embeddings, n_clusters=5)
115
+ # convert from list to 2d matrix
116
+ clusters = torch.stack(clusters, dim=0)
117
+ dot_product = torch.mm(clusters, preprocessed_image_embeddings.T)
118
+ cluster_similarity = [(float("{:.4f}".format(dot_product[i][0])), i) for i in range(len(clusters))]
119
+ cluster_similarity.sort(reverse=True)
120
+ for similarity, idx in cluster_similarity:
121
+ print (f"{similarity} {idx}")
122
+ # template = highest scoring cluster
123
+ # template = clusters[cluster_similarity[0][1]]
124
+ template = preprocessed_image_embeddings * (len(clusters)-1)
125
+ for i in range(1, len(clusters)):
126
+ template -= clusters[cluster_similarity[i][1]]
127
+ print("---")
128
+ print(f"seaching based on template")
129
+ results = clip_retrieval_client.query(embedding_input=template[0].tolist())
130
+ hints = ""
131
+ for result in results:
132
+ url = safe_url(result["url"])
133
+ similarty = float("{:.4f}".format(result["similarity"]))
134
+ title = result["caption"]
135
+ print (f"{similarty} \"{title}\" {url}")
136
+ if len(hints) > 0:
137
+ hints += f", \"{title}\""
138
+ else:
139
+ hints += f"\"{title}\""
140
+ print(hints)
141
+
142
+
143
+ # cluster_num = 1
144
+ # for template in clusters:
145
+ # print("---")
146
+ # print(f"cluster {cluster_num} of {len(clusters)}")
147
+ # results = clip_retrieval_client.query(embedding_input=template.tolist())
148
+ # hints = ""
149
+ # for result in results:
150
+ # url = safe_url(result["url"])
151
+ # similarty = float("{:.4f}".format(result["similarity"]))
152
+ # title = result["caption"]
153
+ # print (f"{similarty} \"{title}\" {url}")
154
+ # if len(hints) > 0:
155
+ # hints += f", \"{title}\""
156
+ # else:
157
+ # hints += f"\"{title}\""
158
+ # print(hints)
159
+ # cluster_num += 1
160
+
161
+
162
+ # create a template
163
+ # mean
164
+ # image_label_template = mean_template(image_label_vectors)
165
+ # image_template = mean_template(image_vectors)
166
+ # pca
167
+ # image_label_template = principal_component_analysis_template(image_label_vectors)
168
+ # image_template = principal_component_analysis_template(image_vectors)
169
+ # clustering
170
+ # image_label_template = clustering_template(image_label_vectors)
171
+ # image_template = clustering_template(image_vectors)
172
+
173
+ # take the embedding and subtract the template
174
+ # image_label_template = preprocessed_image_embeddings - image_label_template
175
+ # image_template = preprocessed_image_embeddings - image_template
176
+ # image_label_template = image_label_template - preprocessed_image_embeddings
177
+ # image_template = image_template - preprocessed_image_embeddings
178
+ # normalize
179
+ # image_label_template = image_label_template / image_label_template.norm()
180
+ # image_template = image_template / image_template.norm()
181
+
182
+ # results = clip_retrieval_client.query(embedding_input=image_label_template[0].tolist())
183
+ # hints = ""
184
+ # print("---")
185
+ # print("average of image labels")
186
+ # for result in results:
187
+ # url = safe_url(result["url"])
188
+ # similarty = float("{:.4f}".format(result["similarity"]))
189
+ # title = result["caption"]
190
+ # print (f"{similarty} \"{title}\" {url}")
191
+ # if len(hints) > 0:
192
+ # hints += f", \"{title}\""
193
+ # else:
194
+ # hints += f"\"{title}\""
195
+ # print(hints)
196
+
197
+ # print("---")
198
+ # print("average of images")
199
+ # results = clip_retrieval_client.query(embedding_input=image_template[0].tolist())
200
+ # hints = ""
201
+ # for result in results:
202
+ # url = safe_url(result["url"])
203
+ # similarty = float("{:.4f}".format(result["similarity"]))
204
+ # title = result["caption"]
205
+ # print (f"{similarty} \"{title}\" {url}")
206
+ # if len(hints) > 0:
207
+ # hints += f", \"{title}\""
208
+ # else:
209
+ # hints += f"\"{title}\""
210
+ # print(hints)