File size: 1,877 Bytes
8781c55
 
b821924
 
8781c55
 
b821924
 
 
8781c55
b821924
8781c55
b821924
 
8781c55
b821924
8781c55
 
b821924
8781c55
b821924
8781c55
 
 
 
 
b821924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8781c55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import zipfile
import random

import pandas as pd
import numpy as np
import streamlit as st

import clip
import torch
import torchvision.transforms as transforms

from get_similiarty import get_similiarity


device = "cuda" if torch.cuda.is_available() else "cpu"

#load model -resnet50
model_resnet = torch.load("model.pt", device)
#load model - ViT-B/32
model_vit = torch.load("model_vit.pt", device)

#Распаковка ZIP-файла с фотографиями
zip_file_path = "sample.zip"
target_folder = "sample/"
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(target_folder)

st.title('Find my pic!')

def find_image_disc(prompt, df):
    img_descs = []
    img_descs_vit = []
    list_images_names, list_images_names_vit = get_similiarity(prompt, model_resnet, model_vit, 3)
    for img in list_images_names:
        img_descs.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))
    #vit
    for img in list_images_names_vit:
        img_descs_vit.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))

    return list_images_names, img_descs, list_images_names_vit, img_descs_vit

txt = st.text_area("Describe the picture you'd like to see")

df = pd.read_csv('results.csv', 
                 sep = '|', 
                 names = ['image_name', 'comment_number', 'comment'], 
                 header=0)


if txt is not None:
     if st.button('Find!'):

        list_images, img_desc, list_images_vit, img_descs_vit = find_image_disc(txt, df)
        col1, col2 = st.columns(2)
        for ind, pic in enumerate(zip(list_images, list_images_vit)):
            with col1:
                st.image(pic[0])
                st.write(img_desc[ind])
            with col2:
                st.image(pic[1])
                st.write(img_desc[ind])