GloedenWilhelm / app.py
phalanx80's picture
Upload app.py
1a6b747 verified
raw
history blame
3.48 kB
import pandas as pd
import torch
from PIL import Image
import gradio as gr
from pathlib import Path
import numpy as np
import open_clip # Importa la libreria OpenCLIP
from sklearn.metrics.pairwise import cosine_similarity
def load_openclip_model(device):
model, _, preprocess = open_clip.create_model_and_transforms("ViT-g-14", pretrained="laion2b_s12b_b42k")
model = model.to(device)
model.eval()
return model, preprocess
def load_embeddings(embedding_file):
df = pd.read_csv(embedding_file)
embeddings = df.iloc[:, 1:].values # Escludi la prima colonna (filename)
image_paths = df['filename'].tolist() # Salva i nomi dei file
return embeddings, image_paths
def query_images(text, model, preprocess, image_embeddings, image_paths, device, num_images):
# Genera l'embedding per il testo con OpenCLIP
with torch.no_grad():
text_embedding = model.encode_text(open_clip.tokenize([text]).to(device)).cpu().numpy().flatten()
# Calcola la similarità coseno tra l'embedding del testo e gli embeddings delle immagini
similarities = cosine_similarity([text_embedding], image_embeddings)[0]
# Ordina le immagini per similarità e prendi il numero di immagini specificato da num_images
top_indices = similarities.argsort()[-num_images:][::-1]
# Restituisci i percorsi delle immagini più simili e i loro punteggi
return [(Path(image_paths[i]), similarities[i]) for i in top_indices] # Rimuove Path("img") /
def predict(query_text, num_images):
# Ottieni il numero di immagini simili specificato e i loro punteggi
similar_images = query_images(query_text, model, preprocess, embeddings, image_paths, device, num_images)
image_outputs = []
scores = []
# Crea una lista di percorsi e punteggi per generare il CSV e la tabella
for img_path, score in similar_images:
img = Image.open(img_path)
image_outputs.append(img)
scores.append([img_path.name, score])
# Salva il CSV con le immagini selezionate
df = pd.DataFrame(scores, columns=["filename", "similarity_score"])
df.to_csv("filtered_images.csv", index=False)
# Restituisci le immagini, il link al CSV e la tabella dei punteggi
download_link = gr.File(label="Download CSV of filtered images", value="filtered_images.csv")
return image_outputs, download_link, pd.DataFrame(scores, columns=["Image Name", "Similarity Score"])
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = load_openclip_model(device)
# Carica gli embeddings dal file CSV
embedding_file = "embeddings.csv" # Sostituisci con il percorso corretto
embeddings, image_paths = load_embeddings(embedding_file)
# Crea l'interfaccia Gradio
interface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Enter your query text"),
gr.Slider(label="Number of Similar Images", minimum=3, maximum=20, step=1, value=3)
],
outputs=[
gr.Gallery(label="Similar Images", elem_id="image_gallery"),
gr.File(label="Download CSV"),
gr.Dataframe(label="Similarity Scores Table")
],
title="Find Similar Images",
description="Insert text to find similar images and choose the number of images to display."
)
interface.launch()