IvaElen commited on
Commit
b821924
1 Parent(s): defff1c

Upload 2 files

Browse files
Files changed (2) hide show
  1. get_similiarty.py +38 -0
  2. main.py +56 -0
get_similiarty.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.datasets as datasets
2
+ import numpy as np
3
+ import clip
4
+ import torch
5
+ def get_similiarity(prompt, model_resnet, model_vit, top_k=3):
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ data_dir = 'sample/data'
8
+ image_arr = np.loadtxt("embeddings.csv", delimiter=",")
9
+ raw_dataset = datasets.ImageFolder(data_dir)
10
+ # получите список всех изображений
11
+ # create transformer-readable tokens
12
+ inputs = clip.tokenize(prompt).to(device)
13
+ text_emb = model_resnet.encode_text(inputs)
14
+ text_emb = text_emb.cpu().detach().numpy()
15
+ scores = np.dot(text_emb, image_arr.T)
16
+ # score_vit
17
+ # get the top k indices for most similar vecs
18
+ idx = np.argsort(-scores[0])[:top_k]
19
+ image_files = []
20
+ for i in idx:
21
+ image_files.append(raw_dataset.imgs[i][0])
22
+
23
+ image_arr_vit = np.loadtxt('embeddings_vit.csv', delimiter=",")
24
+ inputs_vit = clip.tokenize(prompt).to(device)
25
+ text_emb_vit = model_vit.encode_text(inputs_vit)
26
+ text_emb_vit = text_emb_vit.cpu().detach().numpy()
27
+ scores_vit = np.dot(text_emb_vit, image_arr_vit.T)
28
+ idx_vit = np.argsort(-scores_vit[0])[:top_k]
29
+ image_files_vit = []
30
+ for i in idx_vit:
31
+ image_files_vit.append(raw_dataset.imgs[i][0])
32
+
33
+ return image_files, image_files_vit
34
+ # def get_text_enc(input_text: str):
35
+ # text = clip.tokenize([input_text]).to(device)
36
+ # text_features = model.encode_text(text).cpu()
37
+ # text_features = text_features.cpu().detach().numpy()
38
+ # return text_features
main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import pandas as pd
4
+
5
+ import clip
6
+ import torchvision.transforms as transforms
7
+ import torchvision.datasets as datasets
8
+ import torch
9
+ import numpy as np
10
+ import random
11
+ from get_similiarty import get_similiarity
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ #load model -resnet50
15
+
16
+
17
+ model_resnet = torch.load("model.pt", device )
18
+
19
+ #load model - ViT-B/32
20
+ model_vit = torch.load("model_vit.pt", device )
21
+
22
+
23
+ st.title('Find my pic!')
24
+
25
+ def find_image_disc(prompt, df):
26
+ img_descs = []
27
+ img_descs_vit = []
28
+ list_images_names, list_images_names_vit = get_similiarity(prompt, model_resnet, model_vit, 3)
29
+ for img in list_images_names:
30
+ img_descs.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))
31
+ #vit
32
+ for img in list_images_names_vit:
33
+ img_descs_vit.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))
34
+
35
+ return list_images_names, img_descs, list_images_names_vit, img_descs_vit
36
+
37
+ txt = st.text_area("Describe the picture you'd like to see")
38
+
39
+ df = pd.read_csv('results.csv',
40
+ sep = '|',
41
+ names = ['image_name', 'comment_number', 'comment'],
42
+ header=0)
43
+
44
+
45
+ if txt is not None:
46
+ if st.button('Find!'):
47
+
48
+ list_images, img_desc, list_images_vit, img_descs_vit = find_image_disc(txt, df)
49
+ col1, col2 = st.columns(2)
50
+ for ind, pic in enumerate(zip(list_images, list_images_vit)):
51
+ with col1:
52
+ st.image(pic[0])
53
+ st.write(img_desc[ind])
54
+ with col2:
55
+ st.image(pic[1])
56
+ st.write(img_desc[ind])