File size: 1,543 Bytes
b821924
 
 
 
 
 
b9e8251
b821924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torchvision.datasets as datasets
import numpy as np
import clip
import torch
def get_similiarity(prompt, model_resnet, model_vit,  top_k=3):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data_dir = 'sample/sample/data'
    image_arr = np.loadtxt("embeddings.csv", delimiter=",")
    raw_dataset = datasets.ImageFolder(data_dir)
    # получите список всех изображений
    # create transformer-readable tokens
    inputs = clip.tokenize(prompt).to(device)
    text_emb = model_resnet.encode_text(inputs)
    text_emb = text_emb.cpu().detach().numpy()
    scores = np.dot(text_emb, image_arr.T)
    # score_vit
    # get the top k indices for most similar vecs
    idx = np.argsort(-scores[0])[:top_k]
    image_files = []
    for i in idx:
        image_files.append(raw_dataset.imgs[i][0])

    image_arr_vit = np.loadtxt('embeddings_vit.csv', delimiter=",")
    inputs_vit = clip.tokenize(prompt).to(device)
    text_emb_vit = model_vit.encode_text(inputs_vit)
    text_emb_vit = text_emb_vit.cpu().detach().numpy()
    scores_vit = np.dot(text_emb_vit, image_arr_vit.T)
    idx_vit = np.argsort(-scores_vit[0])[:top_k]
    image_files_vit = []
    for i in idx_vit:
        image_files_vit.append(raw_dataset.imgs[i][0])

    return image_files, image_files_vit
# def get_text_enc(input_text: str):
#     text = clip.tokenize([input_text]).to(device)
#     text_features = model.encode_text(text).cpu()
#     text_features = text_features.cpu().detach().numpy()
#     return text_features