Spaces:
Sleeping
Sleeping
import pandas as pd | |
import torch | |
from PIL import Image | |
import gradio as gr | |
from pathlib import Path | |
import numpy as np | |
import open_clip # Importa la libreria OpenCLIP | |
from sklearn.metrics.pairwise import cosine_similarity | |
def load_openclip_model(device): | |
model, _, preprocess = open_clip.create_model_and_transforms("ViT-g-14", pretrained="laion2b_s12b_b42k") | |
model = model.to(device) | |
model.eval() | |
return model, preprocess | |
def load_embeddings(embedding_file): | |
df = pd.read_csv(embedding_file) | |
# Correggi i percorsi delle immagini per Linux | |
df['filename'] = df['filename'].str.replace("\\", "/") | |
embeddings = df.iloc[:, 1:].values # Escludi la prima colonna (filename) | |
image_paths = df['filename'].tolist() # Salva i nomi dei file | |
return embeddings, image_paths | |
def query_images(text, model, preprocess, image_embeddings, image_paths, device, num_images): | |
# Genera l'embedding per il testo con OpenCLIP | |
with torch.no_grad(): | |
text_embedding = model.encode_text(open_clip.tokenize([text]).to(device)).cpu().numpy().flatten() | |
# Calcola la similarità coseno tra l'embedding del testo e gli embeddings delle immagini | |
similarities = cosine_similarity([text_embedding], image_embeddings)[0] | |
# Ordina le immagini per similarità e prendi il numero di immagini specificato da num_images | |
top_indices = similarities.argsort()[-num_images:][::-1] | |
# Restituisci i percorsi delle immagini più simili e i loro punteggi | |
return [(Path("img") / image_paths[i], similarities[i]) for i in top_indices] | |
def predict(query_text, num_images): | |
# Ottieni il numero di immagini simili specificato e i loro punteggi | |
similar_images = query_images(query_text, model, preprocess, embeddings, image_paths, device, num_images) | |
image_outputs = [] | |
scores = [] | |
# Crea una lista di percorsi e punteggi per generare il CSV e la tabella | |
for img_path, score in similar_images: | |
img = Image.open(img_path) | |
image_outputs.append(img) | |
scores.append([img_path.name, score]) | |
# Salva il CSV con le immagini selezionate | |
df = pd.DataFrame(scores, columns=["filename", "similarity_score"]) | |
df.to_csv("filtered_images.csv", index=False) | |
# Restituisci le immagini, il link al CSV e la tabella dei punteggi | |
download_link = gr.File(label="Download CSV of filtered images", value="filtered_images.csv") | |
return image_outputs, download_link, pd.DataFrame(scores, columns=["Image Name", "Similarity Score"]) | |
if __name__ == "__main__": | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = load_openclip_model(device) | |
# Carica gli embeddings dal file CSV | |
embedding_file = "embeddings.csv" # Sostituisci con il percorso corretto | |
embeddings, image_paths = load_embeddings(embedding_file) | |
# Crea l'interfaccia Gradio | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox(label="Enter your query text"), | |
gr.Slider(label="Number of Similar Images", minimum=3, maximum=20, step=1, value=3) | |
], | |
outputs=[ | |
gr.Gallery(label="Similar Images", elem_id="image_gallery"), | |
gr.File(label="Download CSV"), | |
gr.Dataframe(label="Similarity Scores Table") | |
], | |
title="Find Similar Images", | |
description="Insert text to find similar images and choose the number of images to display." | |
) | |
interface.launch() | |