import zipfile import random from PIL import Image import pandas as pd import numpy as np import streamlit as st import clip import torch import torchvision.transforms as transforms from get_similiarty import get_similiarity device = "cuda" if torch.cuda.is_available() else "cpu" #load model -resnet50 model_resnet, prerocess = clip.load("RN50", device=device) #load model - ViT-B/32 model_vit, preprocess = clip.load('ViT-B/32', device) #Распаковка ZIP-файла с фотографиями zip_file_path = "sample.zip" target_folder = "sample/" with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: zip_ref.extractall(target_folder) df = pd.read_csv('results.csv', sep = '|', names = ['image_name', 'comment_number', 'comment'], header=0) def find_image_disc(prompt, df, top_k): img_descs = [] img_descs_vit = [] list_images_names, list_images_names_vit = get_similiarity(prompt, model_resnet, model_vit, top_k) for img in list_images_names: img_descs.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', '')) #vit for img in list_images_names_vit: img_descs_vit.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', '')) return list_images_names, img_descs, list_images_names_vit, img_descs_vit st.image('image.png') # st.title('Find my pic!') col3, col4 = st.columns(2) with col3: st.image('3bd0e1e6-6b8a-4aa6-828a-c1756c6d38b2.jpeg') with col4: txt = st.text_area("Describe the picture you'd like to see") top_k = st.slider('Number of images', 1, 5, 3) if txt is not None: if st.button('Find!'): list_images, img_desc, list_images_vit, img_descs_vit = find_image_disc(txt, df, top_k) col1, col2 = st.columns(2) col1.header('ResNet50') col2.header('ViT 32') for ind, pic in enumerate(zip(list_images, list_images_vit)): with col1: st.image(pic[0]) st.write(img_desc[ind]) with col2: st.image(pic[1]) st.write(img_descs_vit[ind])