sohojoe's picture
experiments with open_clip, templates, clustering, recursion
b2b5d5f
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import os
import time
import numpy as np
import requests
import torch
from clip_app_client import ClipAppClient
from clip_retrieval.clip_client import ClipClient, Modality
clip_retrieval_service_url = "https://knn.laion.ai/knn-service"
map_clip_to_clip_retreval = {
"ViT-L/14": "laion5B-L-14",
"open_clip:ViT-H-14": "laion5B-H-14",
"open_clip:ViT-L-14": "laion5B-L-14",
}
def safe_url(url):
import urllib.parse
url = urllib.parse.quote(url, safe=':/')
# if url has two .jpg filenames, take the first one
if url.count('.jpg') > 0:
url = url.split('.jpg')[0] + '.jpg'
return url
def _safe_image_url_to_embedding(url, safe_return):
try:
return app_client.image_url_to_embedding(url)
except:
return safe_return
def mean_template(embeddings):
template = torch.mean(embeddings, dim=0, keepdim=True)
return template
def principal_component_analysis_template(embeddings):
mean = torch.mean(embeddings, dim=0)
embeddings_centered = embeddings - mean # Subtract the mean
u, s, v = torch.svd(embeddings_centered) # Perform SVD
template = u[:, 0] # The first column of u gives the first principal component
return template
def clustering_templates(embeddings, n_clusters=5):
from sklearn.cluster import KMeans
import numpy as np
kmeans = KMeans(n_clusters=n_clusters)
embeddings_np = embeddings.numpy() # Convert to numpy
clusters = kmeans.fit_predict(embeddings_np)
templates = []
for cluster in np.unique(clusters):
cluster_mean = np.mean(embeddings_np[clusters == cluster], axis=0)
templates.append(torch.from_numpy(cluster_mean)) # Convert back to tensor
return templates
# test_image_path = os.path.join(os.getcwd(), "images", "plant-001.png")
# test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-001.jpeg")
test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg")
# test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "plant-002.jpeg")
# test_image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "images", "car-002.jpeg")
app_client = ClipAppClient()
clip_retrieval_client = ClipClient(
url=clip_retrieval_service_url,
indice_name=map_clip_to_clip_retreval[app_client.clip_model],
# use_safety_model = False,
# use_violence_detector = False,
# use_mclip = False,
# num_images = 300,
# modality = Modality.TEXT,
# modality = Modality.TEXT,
)
preprocessed_image = app_client.preprocess_image(test_image_path)
preprocessed_image_embeddings = app_client.preprocessed_image_to_embedding(preprocessed_image)
print (f"embeddings: {preprocessed_image_embeddings.shape}")
template = preprocessed_image_embeddings
template = template / template.norm()
for step_num in range(3):
print (f"\n\n---- Step {step_num} ----")
embedding_as_list = template[0].tolist()
results = clip_retrieval_client.query(embedding_input=embedding_as_list)
# get best matching labels
image_labels = [r['caption'] for r in results]
image_label_vectors = [app_client.text_to_embedding(label) for label in image_labels]
image_label_vectors = torch.cat(image_label_vectors, dim=0)
dot_product = torch.mm(image_label_vectors, preprocessed_image_embeddings.T)
similarity_image_label = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))]
similarity_image_label.sort(reverse=True)
for similarity, image_label in similarity_image_label:
print (f"{similarity} {image_label}")
# now do the same for images
image_urls = [safe_url(r['url']) for r in results]
image_vectors = [_safe_image_url_to_embedding(url, preprocessed_image_embeddings * 0) for url in image_urls]
image_vectors = torch.cat(image_vectors, dim=0)
dot_product = torch.mm(image_vectors, preprocessed_image_embeddings.T)
similarity_image = [(float("{:.4f}".format(dot_product[i][0])), image_labels[i]) for i in range(len(image_labels))]
similarity_image.sort(reverse=True)
for similarity, image_label in similarity_image:
print (f"{similarity} {image_label}")
# remove images with low similarity as these will be images that did not load
image_vectors = torch.stack([image_vectors[i] for i in range(len(image_vectors)) if similarity_image[i][0] > 0.001], dim=0)
# create a templates using clustering
print(f"create a templates using clustering")
merged_embeddings = torch.cat([image_label_vectors, image_vectors], dim=0)
# merged_embeddings = image_label_vectors # only use labels
# merged_embeddings = image_vectors # only use images
clusters = clustering_templates(merged_embeddings, n_clusters=5)
# convert from list to 2d matrix
clusters = torch.stack(clusters, dim=0)
dot_product = torch.mm(clusters, preprocessed_image_embeddings.T)
cluster_similarity = [(float("{:.4f}".format(dot_product[i][0])), i) for i in range(len(clusters))]
cluster_similarity.sort(reverse=True)
for similarity, idx in cluster_similarity:
print (f"{similarity} {idx}")
# template = highest scoring cluster
# template = clusters[cluster_similarity[0][1]]
template = preprocessed_image_embeddings * (len(clusters)-1)
for i in range(1, len(clusters)):
cluster = clusters[cluster_similarity[i][1]]
normalized_cluster = cluster / cluster.norm()
template -= normalized_cluster
template = template / template.norm()
print("---")
print(f"seaching based on template")
results = clip_retrieval_client.query(embedding_input=template[0].tolist())
hints = ""
for result in results:
url = safe_url(result["url"])
similarty = float("{:.4f}".format(result["similarity"]))
title = result["caption"]
print (f"{similarty} \"{title}\" {url}")
if len(hints) > 0:
hints += f", \"{title}\""
else:
hints += f"\"{title}\""
print(hints)
# cluster_num = 1
# for template in clusters:
# print("---")
# print(f"cluster {cluster_num} of {len(clusters)}")
# results = clip_retrieval_client.query(embedding_input=template.tolist())
# hints = ""
# for result in results:
# url = safe_url(result["url"])
# similarty = float("{:.4f}".format(result["similarity"]))
# title = result["caption"]
# print (f"{similarty} \"{title}\" {url}")
# if len(hints) > 0:
# hints += f", \"{title}\""
# else:
# hints += f"\"{title}\""
# print(hints)
# cluster_num += 1
# create a template
# mean
# image_label_template = mean_template(image_label_vectors)
# image_template = mean_template(image_vectors)
# pca
# image_label_template = principal_component_analysis_template(image_label_vectors)
# image_template = principal_component_analysis_template(image_vectors)
# clustering
# image_label_template = clustering_template(image_label_vectors)
# image_template = clustering_template(image_vectors)
# take the embedding and subtract the template
# image_label_template = preprocessed_image_embeddings - image_label_template
# image_template = preprocessed_image_embeddings - image_template
# image_label_template = image_label_template - preprocessed_image_embeddings
# image_template = image_template - preprocessed_image_embeddings
# normalize
# image_label_template = image_label_template / image_label_template.norm()
# image_template = image_template / image_template.norm()
# results = clip_retrieval_client.query(embedding_input=image_label_template[0].tolist())
# hints = ""
# print("---")
# print("average of image labels")
# for result in results:
# url = safe_url(result["url"])
# similarty = float("{:.4f}".format(result["similarity"]))
# title = result["caption"]
# print (f"{similarty} \"{title}\" {url}")
# if len(hints) > 0:
# hints += f", \"{title}\""
# else:
# hints += f"\"{title}\""
# print(hints)
# print("---")
# print("average of images")
# results = clip_retrieval_client.query(embedding_input=image_template[0].tolist())
# hints = ""
# for result in results:
# url = safe_url(result["url"])
# similarty = float("{:.4f}".format(result["similarity"]))
# title = result["caption"]
# print (f"{similarty} \"{title}\" {url}")
# if len(hints) > 0:
# hints += f", \"{title}\""
# else:
# hints += f"\"{title}\""
# print(hints)