File size: 3,484 Bytes
1a6b747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import pandas as pd
import torch
from PIL import Image
import gradio as gr
from pathlib import Path
import numpy as np
import open_clip  # Importa la libreria OpenCLIP
from sklearn.metrics.pairwise import cosine_similarity


def load_openclip_model(device):
    model, _, preprocess = open_clip.create_model_and_transforms("ViT-g-14", pretrained="laion2b_s12b_b42k")
    model = model.to(device)
    model.eval()
    return model, preprocess

def load_embeddings(embedding_file):
    df = pd.read_csv(embedding_file)
    embeddings = df.iloc[:, 1:].values  # Escludi la prima colonna (filename)
    image_paths = df['filename'].tolist()  # Salva i nomi dei file
    return embeddings, image_paths

def query_images(text, model, preprocess, image_embeddings, image_paths, device, num_images):
    # Genera l'embedding per il testo con OpenCLIP
    with torch.no_grad():
        text_embedding = model.encode_text(open_clip.tokenize([text]).to(device)).cpu().numpy().flatten()
    
    # Calcola la similarità coseno tra l'embedding del testo e gli embeddings delle immagini
    similarities = cosine_similarity([text_embedding], image_embeddings)[0]
    
    # Ordina le immagini per similarità e prendi il numero di immagini specificato da num_images
    top_indices = similarities.argsort()[-num_images:][::-1]
    
    # Restituisci i percorsi delle immagini più simili e i loro punteggi
    return [(Path(image_paths[i]), similarities[i]) for i in top_indices]  # Rimuove Path("img") /

def predict(query_text, num_images):
    # Ottieni il numero di immagini simili specificato e i loro punteggi
    similar_images = query_images(query_text, model, preprocess, embeddings, image_paths, device, num_images)
    image_outputs = []
    scores = []
    
    # Crea una lista di percorsi e punteggi per generare il CSV e la tabella
    for img_path, score in similar_images:
        img = Image.open(img_path)
        image_outputs.append(img)
        scores.append([img_path.name, score])
    
    # Salva il CSV con le immagini selezionate
    df = pd.DataFrame(scores, columns=["filename", "similarity_score"])
    df.to_csv("filtered_images.csv", index=False)
    
    # Restituisci le immagini, il link al CSV e la tabella dei punteggi
    download_link = gr.File(label="Download CSV of filtered images", value="filtered_images.csv")
    return image_outputs, download_link, pd.DataFrame(scores, columns=["Image Name", "Similarity Score"])

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = load_openclip_model(device)

    # Carica gli embeddings dal file CSV
    embedding_file = "embeddings.csv"  # Sostituisci con il percorso corretto
    embeddings, image_paths = load_embeddings(embedding_file)

    # Crea l'interfaccia Gradio
    interface = gr.Interface(
        fn=predict,
        inputs=[
            gr.Textbox(label="Enter your query text"),
            gr.Slider(label="Number of Similar Images", minimum=3, maximum=20, step=1, value=3)
        ],
        outputs=[
            gr.Gallery(label="Similar Images", elem_id="image_gallery"), 
            gr.File(label="Download CSV"),
            gr.Dataframe(label="Similarity Scores Table")
        ],
        title="Find Similar Images",
        description="Insert text to find similar images and choose the number of images to display."
    )

    interface.launch()