find_my_pic / get_similiarty.py
IvaElen's picture
Update get_similiarty.py
b9e8251
raw
history blame
1.54 kB
import torchvision.datasets as datasets
import numpy as np
import clip
import torch
def get_similiarity(prompt, model_resnet, model_vit, top_k=3):
device = "cuda" if torch.cuda.is_available() else "cpu"
data_dir = 'sample/sample/data'
image_arr = np.loadtxt("embeddings.csv", delimiter=",")
raw_dataset = datasets.ImageFolder(data_dir)
# получите список всех изображений
# create transformer-readable tokens
inputs = clip.tokenize(prompt).to(device)
text_emb = model_resnet.encode_text(inputs)
text_emb = text_emb.cpu().detach().numpy()
scores = np.dot(text_emb, image_arr.T)
# score_vit
# get the top k indices for most similar vecs
idx = np.argsort(-scores[0])[:top_k]
image_files = []
for i in idx:
image_files.append(raw_dataset.imgs[i][0])
image_arr_vit = np.loadtxt('embeddings_vit.csv', delimiter=",")
inputs_vit = clip.tokenize(prompt).to(device)
text_emb_vit = model_vit.encode_text(inputs_vit)
text_emb_vit = text_emb_vit.cpu().detach().numpy()
scores_vit = np.dot(text_emb_vit, image_arr_vit.T)
idx_vit = np.argsort(-scores_vit[0])[:top_k]
image_files_vit = []
for i in idx_vit:
image_files_vit.append(raw_dataset.imgs[i][0])
return image_files, image_files_vit
# def get_text_enc(input_text: str):
# text = clip.tokenize([input_text]).to(device)
# text_features = model.encode_text(text).cpu()
# text_features = text_features.cpu().detach().numpy()
# return text_features