Spaces:
Sleeping
Sleeping
# to install transformers | |
!pip install -qq transformers | |
# import and precomputed clips | |
import pickle | |
precomputed_filename = 'precomputed_clips' | |
def load_precomputed(precomputed_filename): | |
with open(precomputed_filename + '.pickle', 'rb') as f: | |
return pickle.load(f) | |
precomputed_dict = load_precomputed(precomputed_filename) | |
# embeddings and similar pictures | |
import torch | |
from PIL import Image | |
from transformers import CLIPProcessor, CLIPModel | |
import os | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
def get_clip_embeddings(input_data, input_type='text'): | |
# Load the CLIP model and processor | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# Prepare the input based on the type | |
if input_type == 'text': | |
inputs = processor(text=input_data, return_tensors="pt", padding=True, truncation=True) | |
elif input_type == 'image': | |
if isinstance(input_data, str): | |
image = Image.open(input_data) | |
elif isinstance(input_data, Image.Image): | |
image = input_data | |
else: | |
raise ValueError("For image input, provide either a file path or a PIL Image object") | |
inputs = processor(images=image, return_tensors="pt") | |
else: | |
raise ValueError("Invalid input_type. Choose 'text' or 'image'") | |
# Get the embeddings | |
with torch.no_grad(): | |
if input_type == 'text': | |
embeddings = model.get_text_features(**inputs) | |
else: | |
embeddings = model.get_image_features(**inputs) | |
return embeddings.numpy() | |
def find_similar_images(text_input, image_embeddings, all_images, take_best = 4): | |
# Získání embeddingu pro text | |
text_embedding = get_clip_embeddings(text_input, input_type='text') | |
# Výpočet kosinové podobnosti mezi textem a obrázky | |
similarities = cosine_similarity(text_embedding, image_embeddings) | |
# Seřazení podle podobnosti | |
best_indices = np.argsort(similarities[0])[::-1][:take_best] | |
# Výběr nejlepších 4 obrázků | |
best_images = [all_images[i] for i in best_indices] | |
return [Image.open(img) for img in best_images] | |
# find the most similar pictures compared to text inserted | |
def find_most_similar(text_input): | |
return find_similar_images(text_input, precomputed_dict['image_clips'], precomputed_dict['image_paths']) | |
# gradio run | |
import gradio as gr # Importing Gradio for creating the web interface | |
# vytvoření Gradio rozhraní | |
interface = gr.Interface( | |
fn=find_most_similar, | |
inputs="text", | |
outputs=gr.Gallery(label="Most Similar Images"), | |
title="Find Similar Images with CLIP", | |
description="Enter a text prompt to find the most similar images." | |
) | |
# app launch | |
interface.launch() | |