from concurrent.futures import ThreadPoolExecutor, as_completed import json import os import time import numpy as np import requests import torch from clip_app_client import ClipAppClient from clip_retrieval.clip_client import ClipClient, Modality clip_retrieval_service_url = "https://knn.laion.ai/knn-service" map_clip_to_clip_retreval = { "ViT-L/14": "laion5B-L-14", "open_clip:ViT-H-14": "laion5B-H-14", "open_clip:ViT-L-14": "laion5B-L-14", } def safe_url(url): import urllib.parse url = urllib.parse.quote(url, safe=':/') # if url has two .jpg filenames, take the first one if url.count('.jpg') > 0: url = url.split('.jpg')[0] + '.jpg' return url def _safe_image_url_to_embedding(url, safe_return): try: return app_client.image_url_to_embedding(url) except: return safe_return def mean_template(embeddings): template = torch.mean(embeddings, dim=0, keepdim=True) return template def principal_component_analysis_template(embeddings): mean = torch.mean(embeddings, dim=0) embeddings_centered = embeddings - mean # Subtract the mean u, s, v = torch.svd(embeddings_centered) # Perform SVD template = u[:, 0] # The first column of u gives the first principal component return template def clustering_templates(embeddings, n_clusters=5): from sklearn.cluster import KMeans import numpy as np kmeans = KMeans(n_clusters=n_clusters) embeddings_np = embeddings.numpy() # Convert to numpy clusters = kmeans.fit_predict(embeddings_np) templates = [] for cluster in np.unique(clusters): cluster_mean = np.mean(embeddings_np[clusters == cluster], axis=0) templates.append(torch.from_numpy(cluster_mean)) # Convert back to tensor return templates # test_image_path = os.path.join(os.getcwd(), "images", "plant-001.png") # test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-001.jpeg") test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg") # test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg") # test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "car-002.jpeg") app_client = ClipAppClient() clip_retrieval_client = ClipClient( url=clip_retrieval_service_url, indice_name=map_clip_to_clip_retreval[app_client.clip_model], # use_safety_model = False, # use_violence_detector = False, # use_mclip = False, # num_images = 300, # modality = Modality.TEXT, # modality = Modality.TEXT, ) preprocessed_image = app_client.preprocess_image(test_image_path) preprocessed_image_embeddings = app_client.preprocessed_image_to_embedding(preprocessed_image) print (f"embeddings: {preprocessed_image_embeddings.shape}") template = preprocessed_image_embeddings template = template / template.norm() for step_num in range(3): print (f"\n\n---- Step {step_num} ----") embedding_as_list = template[0].tolist() results = clip_retrieval_client.query(embedding_input=embedding_as_list) # get best matching labels image_labels = [r['caption'] for r in results] image_label_vectors = [app_client.text_to_embedding(label) for label in image_labels] image_label_vectors = torch.cat(image_label_vectors, dim=0) dot_product = torch.mm(image_label_vectors, preprocessed_image_embeddings.T) similarity_image_label = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))] similarity_image_label.sort(reverse=True) for similarity, image_label in similarity_image_label: print (f"{similarity} {image_label}") # now do the same for images image_urls = [safe_url(r['url']) for r in results] image_vectors = [_safe_image_url_to_embedding(url, preprocessed_image_embeddings * 0) for url in image_urls] image_vectors = torch.cat(image_vectors, dim=0) dot_product = torch.mm(image_vectors, preprocessed_image_embeddings.T) similarity_image = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))] similarity_image.sort(reverse=True) for similarity, image_label in similarity_image: print (f"{similarity} {image_label}") # remove images with low similarity as these will be images that did not load image_vectors = torch.stack([image_vectors[i] for i in range(len(image_vectors)) if similarity_image[i][0] > 0.001], dim=0) # create a templates using clustering print(f"create a templates using clustering") merged_embeddings = torch.cat([image_label_vectors, image_vectors], dim=0) # merged_embeddings = image_label_vectors # only use labels # merged_embeddings = image_vectors # only use images clusters = clustering_templates(merged_embeddings, n_clusters=5) # convert from list to 2d matrix clusters = torch.stack(clusters, dim=0) dot_product = torch.mm(clusters, preprocessed_image_embeddings.T) cluster_similarity = [(float("{:.4f}".format(dot_product[i][0])), i) for i in range(len(clusters))] cluster_similarity.sort(reverse=True) for similarity, idx in cluster_similarity: print (f"{similarity} {idx}") # template = highest scoring cluster # template = clusters[cluster_similarity[0][1]] template = preprocessed_image_embeddings * (len(clusters)-1) for i in range(1, len(clusters)): cluster = clusters[cluster_similarity[i][1]] normalized_cluster = cluster / cluster.norm() template -= normalized_cluster template = template / template.norm() print("---") print(f"seaching based on template") results = clip_retrieval_client.query(embedding_input=template[0].tolist()) hints = "" for result in results: url = safe_url(result["url"]) similarty = float("{:.4f}".format(result["similarity"])) title = result["caption"] print (f"{similarty} \"{title}\" {url}") if len(hints) > 0: hints += f", \"{title}\"" else: hints += f"\"{title}\"" print(hints) # cluster_num = 1 # for template in clusters: # print("---") # print(f"cluster {cluster_num} of {len(clusters)}") # results = clip_retrieval_client.query(embedding_input=template.tolist()) # hints = "" # for result in results: # url = safe_url(result["url"]) # similarty = float("{:.4f}".format(result["similarity"])) # title = result["caption"] # print (f"{similarty} \"{title}\" {url}") # if len(hints) > 0: # hints += f", \"{title}\"" # else: # hints += f"\"{title}\"" # print(hints) # cluster_num += 1 # create a template # mean # image_label_template = mean_template(image_label_vectors) # image_template = mean_template(image_vectors) # pca # image_label_template = principal_component_analysis_template(image_label_vectors) # image_template = principal_component_analysis_template(image_vectors) # clustering # image_label_template = clustering_template(image_label_vectors) # image_template = clustering_template(image_vectors) # take the embedding and subtract the template # image_label_template = preprocessed_image_embeddings - image_label_template # image_template = preprocessed_image_embeddings - image_template # image_label_template = image_label_template - preprocessed_image_embeddings # image_template = image_template - preprocessed_image_embeddings # normalize # image_label_template = image_label_template / image_label_template.norm() # image_template = image_template / image_template.norm() # results = clip_retrieval_client.query(embedding_input=image_label_template[0].tolist()) # hints = "" # print("---") # print("average of image labels") # for result in results: # url = safe_url(result["url"]) # similarty = float("{:.4f}".format(result["similarity"])) # title = result["caption"] # print (f"{similarty} \"{title}\" {url}") # if len(hints) > 0: # hints += f", \"{title}\"" # else: # hints += f"\"{title}\"" # print(hints) # print("---") # print("average of images") # results = clip_retrieval_client.query(embedding_input=image_template[0].tolist()) # hints = "" # for result in results: # url = safe_url(result["url"]) # similarty = float("{:.4f}".format(result["similarity"])) # title = result["caption"] # print (f"{similarty} \"{title}\" {url}") # if len(hints) > 0: # hints += f", \"{title}\"" # else: # hints += f"\"{title}\"" # print(hints)