search-demo / utils /embedding_generation.py
rfmantoan
Fix image path
2bd9f7e
raw
history blame
6.71 kB
import torch
from PIL import Image
from utils.load_models import fclip_model, fclip_processor
from utils.load_models import siglip_model, siglip_preprocess_train, siglip_preprocess_val, siglip_tokenizer
def get_info(catalog, column):
image_paths = []
text_descriptions = []
for index, row in catalog.iterrows():
path = "/home/user/app/images" + str(row["Id"]) + ".jpg"
image_paths.append(path)
text_descriptions.append(row[column])
return image_paths, text_descriptions
def normalize_embedding(embedding):
norm = torch.norm(embedding, p=2, dim=-1, keepdim=True).item() # Get the norm before normalization
embedding = embedding / norm
return embedding.detach().cpu().numpy()
def normalize_embeddings(embeddings):
norm = torch.norm(embeddings, p=2, dim=-1, keepdim=True)
normalized_embeddings = embeddings / norm
return normalized_embeddings
def generate_fclip_embeddings(image_paths, texts, batch_size, alpha):
image_embeds_list = []
text_embeds_list = []
# Batch processing loop
for i in range(0, len(image_paths), batch_size):
batch_image_paths = image_paths[i:i + batch_size]
batch_texts = texts[i:i + batch_size]
# Load and preprocess batch of images and texts
images = [Image.open(path).convert("RGB") for path in batch_image_paths]
# Set the maximum sequence length to 77 to match the position embeddings
inputs = fclip_processor(text=batch_texts, images=images, return_tensors="pt", padding=True, truncation=True, max_length=77)
# Move inputs to the GPU
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()} # Move inputs to GPU
# Generate embeddings
with torch.no_grad():
outputs = fclip_model(**inputs)
image_embeds_list.append(outputs.image_embeds)
text_embeds_list.append(outputs.text_embeds)
# Concatenate all embeddings
image_embeds = torch.cat(image_embeds_list, dim=0)
text_embeds = torch.cat(text_embeds_list, dim=0)
# Normalize embeddings
image_embeds = normalize_embeddings(image_embeds)
text_embeds = normalize_embeddings(text_embeds)
# Average embeddings
avg_embeds = (image_embeds + text_embeds) / 2
weighted_avg_embeds = alpha * image_embeds + (1 - alpha) * text_embeds
avg_embeds = normalize_embeddings(avg_embeds)
weighted_avg_embeds = normalize_embeddings(weighted_avg_embeds)
return image_embeds.cpu().numpy(), text_embeds.cpu().numpy(), avg_embeds.cpu().numpy(), weighted_avg_embeds.cpu().numpy()
def generate_siglip_embeddings(image_paths, texts, batch_size, alpha):
image_embeds_list = []
text_embeds_list = []
# Batch processing loop
for i in range(0, len(image_paths), batch_size):
batch_image_paths = image_paths[i:i + batch_size]
batch_texts = texts[i:i + batch_size]
# Load and preprocess batch of images and texts
images = [siglip_preprocess_val(Image.open(image_path).convert('RGB')).unsqueeze(0) for image_path in batch_image_paths]
images = torch.cat(images)
tokens = siglip_tokenizer(batch_texts)
# Move images to the same device as the model weights (GPU if available)
if torch.cuda.is_available():
images = images.cuda()
tokens = tokens.cuda()
# Generate embeddings
with torch.no_grad():
image_embeddings_batch = siglip_model.encode_image(images)
text_embeddings_batch = siglip_model.encode_text(tokens)
# Store embeddings
image_embeds_list.append(image_embeddings_batch)
text_embeds_list.append(text_embeddings_batch)
# Concatenate all embeddings
image_embeds = torch.cat(image_embeds_list, dim=0)
text_embeds = torch.cat(text_embeds_list, dim=0)
# Normalize embeddings
image_embeds = normalize_embeddings(image_embeds)
text_embeds = normalize_embeddings(text_embeds)
# Average embeddings
avg_embeds = (image_embeds + text_embeds) / 2
weighted_avg_embeds = alpha * image_embeds + (1 - alpha) * text_embeds
avg_embeds = normalize_embeddings(avg_embeds)
weighted_avg_embeds = normalize_embeddings(weighted_avg_embeds)
return image_embeds.cpu().numpy(), text_embeds.cpu().numpy(), avg_embeds.cpu().numpy(), weighted_avg_embeds.cpu().numpy()
# Function to process text embedding for any model
def generate_text_embedding(model, tokenizer, query, model_type):
if model_type == "fashionCLIP":
# Process the text with the tokenizer and move to GPU
inputs = tokenizer(text=query, return_tensors="pt", padding=True, truncation=True, max_length=77)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Get text embedding from the model
text_embed = model.get_text_features(**inputs)
elif model_type == "fashionSigLIP":
tokens = tokenizer(query)
# Tokenize text and move to GPU
if torch.cuda.is_available():
tokens = tokens.to("cuda")
# Get text embedding from the model
text_embed = model.encode_text(tokens)
return normalize_embedding(text_embed)
# Function to process image embedding for any model
def generate_image_embedding(model, processor, image_path, model_type):
image = Image.open(image_path).convert("RGB")
if model_type == "fashionCLIP":
# Preprocess image for FashionCLIP and move to GPU
inputs = processor(images=image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Get image embedding from the model
image_embed = model.get_image_features(**inputs)
elif model_type == "fashionSigLIP":
# Preprocess image for SigLip and move to GPU
image_tensor = processor(image).unsqueeze(0)
if torch.cuda.is_available():
image_tensor = image_tensor.to("cuda")
# Get image embedding from the model
image_embed = model.encode_image(image_tensor)
return normalize_embedding(image_embed)
# Unified function to generate embeddings for both models and query types
def generate_query_embedding(query, query_type, model, processor, tokenizer, model_type):
if query_type == "text":
return generate_text_embedding(model, tokenizer, query, model_type)
elif query_type == "image":
return generate_image_embedding(model, processor, query, model_type)
else:
raise ValueError("Invalid query type. Choose 'text' or 'image'.")