|
from os import path |
|
from IPython.display import display |
|
from umap import UMAP |
|
from sklearn.preprocessing import MinMaxScaler |
|
import pandas as pd |
|
from tqdm import tqdm |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from s3_data_to_vector_embedding import bt_embeddings_from_local |
|
import random |
|
import numpy as np |
|
import torch |
|
from sklearn.model_selection import train_test_split |
|
from datasets import load_dataset |
|
|
|
|
|
templates = [ |
|
'a picture of {}', |
|
'an image of {}', |
|
'a nice {}', |
|
'a beautiful {}', |
|
] |
|
|
|
def data_prep(hf_dataset_name, templates=templates, test_size=1000): |
|
|
|
|
|
|
|
dataset = load_dataset(hf_dataset_name) |
|
|
|
train_test_dataset = dataset['train'].train_test_split(test_size=test_size) |
|
test_dataset = train_test_dataset['test'] |
|
print(test_dataset) |
|
|
|
img_txt_pairs = [] |
|
for i in range(len(test_dataset)): |
|
img_txt_pairs.append({ |
|
'caption' : templates[random.randint(0, len(templates)-1)], |
|
'pil_img' : test_dataset[i]['image'] |
|
}) |
|
return img_txt_pairs |
|
|
|
|
|
|
|
def load_all_dataset(): |
|
|
|
car_img_txt_pairs = data_prep("tanganke/stanford_cars", test_size=50) |
|
cat_img_txt_pairs = data_prep("yashikota/cat-image-dataset", test_size=50) |
|
|
|
return cat_img_txt_pairs, car_img_txt_pairs |
|
|
|
def load_cat_and_car_embeddings(): |
|
|
|
cat_img_txt_pairs, car_img_txt_pairs = load_all_dataset() |
|
def save_embeddings(embedding, path): |
|
torch.save(embedding, path) |
|
|
|
def load_embeddings(img_txt_pair): |
|
pil_img = img_txt_pair['pil_img'] |
|
caption = img_txt_pair['caption'] |
|
return bt_embeddings_from_local(caption, pil_img) |
|
|
|
def load_all_embeddings_from_image_text_pairs(img_txt_pairs, file_name): |
|
embeddings = [] |
|
for img_txt_pair in tqdm( |
|
img_txt_pairs, |
|
total=len(img_txt_pairs) |
|
): |
|
|
|
embedding = load_embeddings(img_txt_pair) |
|
print(embedding) |
|
cross_modal_embeddings = embedding['cross_modal_embeddings'][0].detach().numpy() |
|
|
|
|
|
embeddings.append(cross_modal_embeddings) |
|
return cross_modal_embeddings |
|
|
|
|
|
cat_embeddings = load_all_embeddings_from_image_text_pairs(cat_img_txt_pairs, './shared_data/cat_embeddings.pt') |
|
car_embeddings = load_all_embeddings_from_image_text_pairs(car_img_txt_pairs, './shared_data/car_embeddings.pt') |
|
|
|
return cat_embeddings, car_embeddings |
|
|
|
|
|
|
|
def dimensionality_reduction(embeddings, labels): |
|
|
|
|
|
print(embeddings) |
|
X_scaled = MinMaxScaler().fit_transform(embeddings.reshape(-1, 1)) |
|
mapper = UMAP(n_components=2, metric="cosine").fit(X_scaled) |
|
df_emb = pd.DataFrame(mapper.embedding_, columns=["X", "Y"]) |
|
df_emb["label"] = labels |
|
print(df_emb) |
|
return df_emb |
|
|
|
def show_umap_visualization(): |
|
def reduce_dimensions(): |
|
cat_embeddings, car_embeddings = load_cat_and_car_embeddings() |
|
|
|
all_embeddings = np.concatenate([cat_embeddings, car_embeddings]) |
|
|
|
|
|
labels = ['cat'] * len(cat_embeddings) + ['car'] * len(car_embeddings) |
|
|
|
|
|
reduced_dim_emb = dimensionality_reduction(all_embeddings, labels) |
|
return reduced_dim_emb |
|
|
|
reduced_dim_emb = reduce_dimensions() |
|
|
|
fig, ax = plt.subplots(figsize=(8,6)) |
|
|
|
sns.set_style("whitegrid", {'axes.grid' : False}) |
|
sns.scatterplot(data=reduced_dim_emb, |
|
x=reduced_dim_emb['X'], |
|
y=reduced_dim_emb['Y'], |
|
hue='label', |
|
palette='bright') |
|
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1)) |
|
plt.title('Scatter plot of images of cats and cars using UMAP') |
|
plt.xlabel('X') |
|
plt.ylabel('Y') |
|
plt.show() |
|
|
|
def an_example_of_cat_and_car_pair_data(): |
|
cat_img_txt_pairs, car_img_txt_pairs = load_all_dataset() |
|
|
|
display(cat_img_txt_pairs[0]['caption']) |
|
display(cat_img_txt_pairs[0]['pil_img']) |
|
|
|
|
|
display(car_img_txt_pairs[0]['caption']) |
|
display(car_img_txt_pairs[0]['pil_img']) |
|
|
|
|
|
if __name__ == '__main__': |
|
show_umap_visualization() |
|
|