import pandas as pd import torch from PIL import Image import gradio as gr from pathlib import Path import numpy as np import open_clip # Importa la libreria OpenCLIP from sklearn.metrics.pairwise import cosine_similarity def load_openclip_model(device): model, _, preprocess = open_clip.create_model_and_transforms("ViT-g-14", pretrained="laion2b_s12b_b42k") model = model.to(device) model.eval() return model, preprocess def load_embeddings(embedding_file): df = pd.read_csv(embedding_file) # Correggi i percorsi delle immagini per Linux df['filename'] = df['filename'].str.replace("\\", "/") embeddings = df.iloc[:, 1:].values # Escludi la prima colonna (filename) image_paths = df['filename'].tolist() # Salva i nomi dei file return embeddings, image_paths def query_images(text, model, preprocess, image_embeddings, image_paths, device, num_images): # Genera l'embedding per il testo con OpenCLIP with torch.no_grad(): text_embedding = model.encode_text(open_clip.tokenize([text]).to(device)).cpu().numpy().flatten() # Calcola la similarità coseno tra l'embedding del testo e gli embeddings delle immagini similarities = cosine_similarity([text_embedding], image_embeddings)[0] # Ordina le immagini per similarità e prendi il numero di immagini specificato da num_images top_indices = similarities.argsort()[-num_images:][::-1] # Restituisci i percorsi delle immagini più simili e i loro punteggi return [(Path("img") / image_paths[i], similarities[i]) for i in top_indices] def predict(query_text, num_images): # Ottieni il numero di immagini simili specificato e i loro punteggi similar_images = query_images(query_text, model, preprocess, embeddings, image_paths, device, num_images) image_outputs = [] scores = [] # Crea una lista di percorsi e punteggi per generare il CSV e la tabella for img_path, score in similar_images: img = Image.open(img_path) image_outputs.append(img) scores.append([img_path.name, score]) # Salva il CSV con le immagini selezionate df = pd.DataFrame(scores, columns=["filename", "similarity_score"]) df.to_csv("filtered_images.csv", index=False) # Restituisci le immagini, il link al CSV e la tabella dei punteggi download_link = gr.File(label="Download CSV of filtered images", value="filtered_images.csv") return image_outputs, download_link, pd.DataFrame(scores, columns=["Image Name", "Similarity Score"]) if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = load_openclip_model(device) # Carica gli embeddings dal file CSV embedding_file = "embeddings.csv" # Sostituisci con il percorso corretto embeddings, image_paths = load_embeddings(embedding_file) # Crea l'interfaccia Gradio interface = gr.Interface( fn=predict, inputs=[ gr.Textbox(label="Enter your query text"), gr.Slider(label="Number of Similar Images", minimum=3, maximum=20, step=1, value=3) ], outputs=[ gr.Gallery(label="Similar Images", elem_id="image_gallery"), gr.File(label="Download CSV"), gr.Dataframe(label="Similarity Scores Table") ], title="Find Similar Images", description="Insert text to find similar images and choose the number of images to display." ) interface.launch()