Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| from pathlib import Path | |
| import numpy as np | |
| import open_clip # Importa la libreria OpenCLIP | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| def load_openclip_model(device): | |
| model, _, preprocess = open_clip.create_model_and_transforms("ViT-g-14", pretrained="laion2b_s12b_b42k") | |
| model = model.to(device) | |
| model.eval() | |
| return model, preprocess | |
| def load_embeddings(embedding_file): | |
| df = pd.read_csv(embedding_file) | |
| # Correggi i percorsi delle immagini per Linux | |
| df['filename'] = df['filename'].str.replace("\\", "/") | |
| embeddings = df.iloc[:, 1:].values # Escludi la prima colonna (filename) | |
| image_paths = df['filename'].tolist() # Salva i nomi dei file | |
| return embeddings, image_paths | |
| def query_images(text, model, preprocess, image_embeddings, image_paths, device, num_images): | |
| # Genera l'embedding per il testo con OpenCLIP | |
| with torch.no_grad(): | |
| text_embedding = model.encode_text(open_clip.tokenize([text]).to(device)).cpu().numpy().flatten() | |
| # Calcola la similarità coseno tra l'embedding del testo e gli embeddings delle immagini | |
| similarities = cosine_similarity([text_embedding], image_embeddings)[0] | |
| # Ordina le immagini per similarità e prendi il numero di immagini specificato da num_images | |
| top_indices = similarities.argsort()[-num_images:][::-1] | |
| # Restituisci i percorsi delle immagini più simili e i loro punteggi | |
| return [(Path("img") / image_paths[i], similarities[i]) for i in top_indices] | |
| def predict(query_text, num_images): | |
| # Ottieni il numero di immagini simili specificato e i loro punteggi | |
| similar_images = query_images(query_text, model, preprocess, embeddings, image_paths, device, num_images) | |
| image_outputs = [] | |
| scores = [] | |
| # Crea una lista di percorsi e punteggi per generare il CSV e la tabella | |
| for img_path, score in similar_images: | |
| img = Image.open(img_path) | |
| image_outputs.append(img) | |
| scores.append([img_path.name, score]) | |
| # Salva il CSV con le immagini selezionate | |
| df = pd.DataFrame(scores, columns=["filename", "similarity_score"]) | |
| df.to_csv("filtered_images.csv", index=False) | |
| # Restituisci le immagini, il link al CSV e la tabella dei punteggi | |
| download_link = gr.File(label="Download CSV of filtered images", value="filtered_images.csv") | |
| return image_outputs, download_link, pd.DataFrame(scores, columns=["Image Name", "Similarity Score"]) | |
| if __name__ == "__main__": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = load_openclip_model(device) | |
| # Carica gli embeddings dal file CSV | |
| embedding_file = "embeddings.csv" # Sostituisci con il percorso corretto | |
| embeddings, image_paths = load_embeddings(embedding_file) | |
| # Crea l'interfaccia Gradio | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Textbox(label="Enter your query text"), | |
| gr.Slider(label="Number of Similar Images", minimum=3, maximum=20, step=1, value=3) | |
| ], | |
| outputs=[ | |
| gr.Gallery(label="Similar Images", elem_id="image_gallery"), | |
| gr.File(label="Download CSV"), | |
| gr.Dataframe(label="Similarity Scores Table") | |
| ], | |
| title="Find Similar Images", | |
| description="Insert text to find similar images and choose the number of images to display." | |
| ) | |
| interface.launch() | |