phalanx80 commited on
Commit
1a6b747
·
verified ·
1 Parent(s): d90f7ab

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from pathlib import Path
6
+ import numpy as np
7
+ import open_clip # Importa la libreria OpenCLIP
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+
10
+
11
+ def load_openclip_model(device):
12
+ model, _, preprocess = open_clip.create_model_and_transforms("ViT-g-14", pretrained="laion2b_s12b_b42k")
13
+ model = model.to(device)
14
+ model.eval()
15
+ return model, preprocess
16
+
17
+ def load_embeddings(embedding_file):
18
+ df = pd.read_csv(embedding_file)
19
+ embeddings = df.iloc[:, 1:].values # Escludi la prima colonna (filename)
20
+ image_paths = df['filename'].tolist() # Salva i nomi dei file
21
+ return embeddings, image_paths
22
+
23
+ def query_images(text, model, preprocess, image_embeddings, image_paths, device, num_images):
24
+ # Genera l'embedding per il testo con OpenCLIP
25
+ with torch.no_grad():
26
+ text_embedding = model.encode_text(open_clip.tokenize([text]).to(device)).cpu().numpy().flatten()
27
+
28
+ # Calcola la similarità coseno tra l'embedding del testo e gli embeddings delle immagini
29
+ similarities = cosine_similarity([text_embedding], image_embeddings)[0]
30
+
31
+ # Ordina le immagini per similarità e prendi il numero di immagini specificato da num_images
32
+ top_indices = similarities.argsort()[-num_images:][::-1]
33
+
34
+ # Restituisci i percorsi delle immagini più simili e i loro punteggi
35
+ return [(Path(image_paths[i]), similarities[i]) for i in top_indices] # Rimuove Path("img") /
36
+
37
+ def predict(query_text, num_images):
38
+ # Ottieni il numero di immagini simili specificato e i loro punteggi
39
+ similar_images = query_images(query_text, model, preprocess, embeddings, image_paths, device, num_images)
40
+ image_outputs = []
41
+ scores = []
42
+
43
+ # Crea una lista di percorsi e punteggi per generare il CSV e la tabella
44
+ for img_path, score in similar_images:
45
+ img = Image.open(img_path)
46
+ image_outputs.append(img)
47
+ scores.append([img_path.name, score])
48
+
49
+ # Salva il CSV con le immagini selezionate
50
+ df = pd.DataFrame(scores, columns=["filename", "similarity_score"])
51
+ df.to_csv("filtered_images.csv", index=False)
52
+
53
+ # Restituisci le immagini, il link al CSV e la tabella dei punteggi
54
+ download_link = gr.File(label="Download CSV of filtered images", value="filtered_images.csv")
55
+ return image_outputs, download_link, pd.DataFrame(scores, columns=["Image Name", "Similarity Score"])
56
+
57
+ if __name__ == "__main__":
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ model, preprocess = load_openclip_model(device)
60
+
61
+ # Carica gli embeddings dal file CSV
62
+ embedding_file = "embeddings.csv" # Sostituisci con il percorso corretto
63
+ embeddings, image_paths = load_embeddings(embedding_file)
64
+
65
+ # Crea l'interfaccia Gradio
66
+ interface = gr.Interface(
67
+ fn=predict,
68
+ inputs=[
69
+ gr.Textbox(label="Enter your query text"),
70
+ gr.Slider(label="Number of Similar Images", minimum=3, maximum=20, step=1, value=3)
71
+ ],
72
+ outputs=[
73
+ gr.Gallery(label="Similar Images", elem_id="image_gallery"),
74
+ gr.File(label="Download CSV"),
75
+ gr.Dataframe(label="Similarity Scores Table")
76
+ ],
77
+ title="Find Similar Images",
78
+ description="Insert text to find similar images and choose the number of images to display."
79
+ )
80
+
81
+ interface.launch()