Embeddings / app.py
TomasHalmazna's picture
update app.py
ce3a255 verified
raw
history blame
2.83 kB
# to install transformers
!pip install -qq transformers
# import and precomputed clips
import pickle
precomputed_filename = 'precomputed_clips'
def load_precomputed(precomputed_filename):
with open(precomputed_filename + '.pickle', 'rb') as f:
return pickle.load(f)
precomputed_dict = load_precomputed(precomputed_filename)
# embeddings and similar pictures
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
def get_clip_embeddings(input_data, input_type='text'):
# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Prepare the input based on the type
if input_type == 'text':
inputs = processor(text=input_data, return_tensors="pt", padding=True, truncation=True)
elif input_type == 'image':
if isinstance(input_data, str):
image = Image.open(input_data)
elif isinstance(input_data, Image.Image):
image = input_data
else:
raise ValueError("For image input, provide either a file path or a PIL Image object")
inputs = processor(images=image, return_tensors="pt")
else:
raise ValueError("Invalid input_type. Choose 'text' or 'image'")
# Get the embeddings
with torch.no_grad():
if input_type == 'text':
embeddings = model.get_text_features(**inputs)
else:
embeddings = model.get_image_features(**inputs)
return embeddings.numpy()
def find_similar_images(text_input, image_embeddings, all_images, take_best = 4):
# Získání embeddingu pro text
text_embedding = get_clip_embeddings(text_input, input_type='text')
# Výpočet kosinové podobnosti mezi textem a obrázky
similarities = cosine_similarity(text_embedding, image_embeddings)
# Seřazení podle podobnosti
best_indices = np.argsort(similarities[0])[::-1][:take_best]
# Výběr nejlepších 4 obrázků
best_images = [all_images[i] for i in best_indices]
return [Image.open(img) for img in best_images]
# find the most similar pictures compared to text inserted
def find_most_similar(text_input):
return find_similar_images(text_input, precomputed_dict['image_clips'], precomputed_dict['image_paths'])
# gradio run
import gradio as gr # Importing Gradio for creating the web interface
# vytvoření Gradio rozhraní
interface = gr.Interface(
fn=find_most_similar,
inputs="text",
outputs=gr.Gallery(label="Most Similar Images"),
title="Find Similar Images with CLIP",
description="Enter a text prompt to find the most similar images."
)
# app launch
interface.launch()