Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import gradio as gr
|
5 |
+
from pathlib import Path
|
6 |
+
import numpy as np
|
7 |
+
import open_clip # Importa la libreria OpenCLIP
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
|
10 |
+
|
11 |
+
def load_openclip_model(device):
|
12 |
+
model, _, preprocess = open_clip.create_model_and_transforms("ViT-g-14", pretrained="laion2b_s12b_b42k")
|
13 |
+
model = model.to(device)
|
14 |
+
model.eval()
|
15 |
+
return model, preprocess
|
16 |
+
|
17 |
+
def load_embeddings(embedding_file):
|
18 |
+
df = pd.read_csv(embedding_file)
|
19 |
+
embeddings = df.iloc[:, 1:].values # Escludi la prima colonna (filename)
|
20 |
+
image_paths = df['filename'].tolist() # Salva i nomi dei file
|
21 |
+
return embeddings, image_paths
|
22 |
+
|
23 |
+
def query_images(text, model, preprocess, image_embeddings, image_paths, device, num_images):
|
24 |
+
# Genera l'embedding per il testo con OpenCLIP
|
25 |
+
with torch.no_grad():
|
26 |
+
text_embedding = model.encode_text(open_clip.tokenize([text]).to(device)).cpu().numpy().flatten()
|
27 |
+
|
28 |
+
# Calcola la similarità coseno tra l'embedding del testo e gli embeddings delle immagini
|
29 |
+
similarities = cosine_similarity([text_embedding], image_embeddings)[0]
|
30 |
+
|
31 |
+
# Ordina le immagini per similarità e prendi il numero di immagini specificato da num_images
|
32 |
+
top_indices = similarities.argsort()[-num_images:][::-1]
|
33 |
+
|
34 |
+
# Restituisci i percorsi delle immagini più simili e i loro punteggi
|
35 |
+
return [(Path(image_paths[i]), similarities[i]) for i in top_indices] # Rimuove Path("img") /
|
36 |
+
|
37 |
+
def predict(query_text, num_images):
|
38 |
+
# Ottieni il numero di immagini simili specificato e i loro punteggi
|
39 |
+
similar_images = query_images(query_text, model, preprocess, embeddings, image_paths, device, num_images)
|
40 |
+
image_outputs = []
|
41 |
+
scores = []
|
42 |
+
|
43 |
+
# Crea una lista di percorsi e punteggi per generare il CSV e la tabella
|
44 |
+
for img_path, score in similar_images:
|
45 |
+
img = Image.open(img_path)
|
46 |
+
image_outputs.append(img)
|
47 |
+
scores.append([img_path.name, score])
|
48 |
+
|
49 |
+
# Salva il CSV con le immagini selezionate
|
50 |
+
df = pd.DataFrame(scores, columns=["filename", "similarity_score"])
|
51 |
+
df.to_csv("filtered_images.csv", index=False)
|
52 |
+
|
53 |
+
# Restituisci le immagini, il link al CSV e la tabella dei punteggi
|
54 |
+
download_link = gr.File(label="Download CSV of filtered images", value="filtered_images.csv")
|
55 |
+
return image_outputs, download_link, pd.DataFrame(scores, columns=["Image Name", "Similarity Score"])
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
59 |
+
model, preprocess = load_openclip_model(device)
|
60 |
+
|
61 |
+
# Carica gli embeddings dal file CSV
|
62 |
+
embedding_file = "embeddings.csv" # Sostituisci con il percorso corretto
|
63 |
+
embeddings, image_paths = load_embeddings(embedding_file)
|
64 |
+
|
65 |
+
# Crea l'interfaccia Gradio
|
66 |
+
interface = gr.Interface(
|
67 |
+
fn=predict,
|
68 |
+
inputs=[
|
69 |
+
gr.Textbox(label="Enter your query text"),
|
70 |
+
gr.Slider(label="Number of Similar Images", minimum=3, maximum=20, step=1, value=3)
|
71 |
+
],
|
72 |
+
outputs=[
|
73 |
+
gr.Gallery(label="Similar Images", elem_id="image_gallery"),
|
74 |
+
gr.File(label="Download CSV"),
|
75 |
+
gr.Dataframe(label="Similarity Scores Table")
|
76 |
+
],
|
77 |
+
title="Find Similar Images",
|
78 |
+
description="Insert text to find similar images and choose the number of images to display."
|
79 |
+
)
|
80 |
+
|
81 |
+
interface.launch()
|