Spaces:
Runtime error
Runtime error
File size: 2,564 Bytes
8781c55 9899d06 b821924 8781c55 b821924 8781c55 b821924 8781c55 b821924 8781c55 b821924 8781c55 a3426b0 b821924 a3426b0 b821924 8781c55 b821924 6e828f7 b821924 6e828f7 b821924 6e828f7 b821924 6e828f7 b821924 8e9828f b821924 5238f71 4502646 b99373a 4502646 b99373a 4502646 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import zipfile
import random
from PIL import Image
import pandas as pd
import numpy as np
import streamlit as st
import clip
import torch
import torchvision.transforms as transforms
from get_similiarty import get_similiarity
device = "cuda" if torch.cuda.is_available() else "cpu"
#load model -resnet50
model_resnet, prerocess = clip.load("RN50", device=device)
#load model - ViT-B/32
model_vit, preprocess = clip.load('ViT-B/32', device)
#Распаковка ZIP-файла с фотографиями
zip_file_path = "sample.zip"
target_folder = "sample/"
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(target_folder)
st.title('Find my pic!')
def find_image_disc(prompt, df, top_k):
img_descs = []
img_descs_vit = []
list_images_names, list_images_names_vit = get_similiarity(prompt, model_resnet, model_vit, top_k)
for img in list_images_names:
img_descs.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))
#vit
for img in list_images_names_vit:
img_descs_vit.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))
return list_images_names, img_descs, list_images_names_vit, img_descs_vit
txt = st.text_area("Describe the picture you'd like to see")
top_k = st.slider('Number of images', 1, 5, 3)
df = pd.read_csv('results.csv',
sep = '|',
names = ['image_name', 'comment_number', 'comment'],
header=0)
if txt is not None:
if st.button('Find!'):
list_images, img_desc, list_images_vit, img_descs_vit = find_image_disc(txt, df, top_k)
col1, col2 = st.columns(2)
col1.header('ResNet50')
col2.header('ViT 32')
for ind, pic in enumerate(zip(list_images, list_images_vit)):
with col1:
st.image(pic[0])
st.write(img_desc[ind])
with col2:
st.image(pic[1])
st.write(img_descs_vit[ind])
with col1:
options = st.multiselect(
'Which image you want download?',
[i in range(top_k)])
img_for_download = [list_images[int(i)-1] for i in options]
st.write('You selected:', img_for_download)
# btn = st.download_button(
# label="Download all images by ResNet50",
# data= [Image.open(i) for i in list_images],
# file_name=list_images,
# mime="image/png"
# ) |