File size: 2,188 Bytes
8781c55
 
9899d06
b821924
 
8781c55
 
b821924
 
 
8781c55
b821924
8781c55
b821924
 
8781c55
b821924
8781c55
a3426b0
b821924
a3426b0
b821924
8781c55
 
 
 
 
b821924
f37d563
 
 
 
b821924
6e828f7
b821924
 
6e828f7
b821924
 
 
 
 
 
 
 
f37d563
 
d822f4c
f37d563
 
 
 
 
6e828f7
b821924
 
 
6e828f7
b821924
8e9828f
 
b821924
 
 
 
 
 
c287bd0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import zipfile
import random
from PIL import Image

import pandas as pd
import numpy as np
import streamlit as st

import clip
import torch
import torchvision.transforms as transforms

from get_similiarty import get_similiarity


device = "cuda" if torch.cuda.is_available() else "cpu"

#load model -resnet50
model_resnet, prerocess = clip.load("RN50", device=device)
#load model - ViT-B/32
model_vit, preprocess = clip.load('ViT-B/32', device)

#Распаковка ZIP-файла с фотографиями
zip_file_path = "sample.zip"
target_folder = "sample/"
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(target_folder)

df = pd.read_csv('results.csv', 
                 sep = '|', 
                 names = ['image_name', 'comment_number', 'comment'], 
                 header=0)

def find_image_disc(prompt, df, top_k):
    img_descs = []
    img_descs_vit = []
    list_images_names, list_images_names_vit = get_similiarity(prompt, model_resnet, model_vit, top_k)
    for img in list_images_names:
        img_descs.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))
    #vit
    for img in list_images_names_vit:
        img_descs_vit.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))

    return list_images_names, img_descs, list_images_names_vit, img_descs_vit

st.image('image.png')
# st.title('Find my pic!')
col3, col4 = st.columns(2)
with col3:
    st.image('3bd0e1e6-6b8a-4aa6-828a-c1756c6d38b2.jpeg')
with col4:
    txt = st.text_area("Describe the picture you'd like to see")
    
top_k = st.slider('Number of images', 1, 5, 3)

if txt is not None:
     if st.button('Find!'):
        list_images, img_desc, list_images_vit, img_descs_vit = find_image_disc(txt, df, top_k)
        col1, col2 = st.columns(2)
        col1.header('ResNet50')
        col2.header('ViT 32')
        for ind, pic in enumerate(zip(list_images, list_images_vit)):
            with col1:
                st.image(pic[0])
                st.write(img_desc[ind])
            with col2:
                st.image(pic[1])
                st.write(img_descs_vit[ind])