Ask-FashionDB / search_fashionDB.py
traopia
queries and other fizes
e9b2c9e
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from transformers import pipeline
from PIL import Image
import numpy as np
import os
import torch
from transformers import CLIPProcessor, CLIPModel
import pandas as pd
#set device: Use GPU if availanle, otherwise mps if available otherwise CPU
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Load Fashion-CLIP model and processor
model_name = "patrickjohncyh/fashion-clip"
#model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)
# Initialize segmentation pipeline
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes", device = device)
def segment_clothing_white(img, clothes=["Background"]):
segments = segmenter(img)
# Create list of masks
mask_list = []
for s in segments:
if s['label'] in clothes:
mask_list.append(s['mask'])
if not mask_list:
print("No clothing segments found in image.")
return img # Return the original image if no segments are found
# Combine all masks into a single mask
final_mask = np.array(mask_list[0])
for mask in mask_list[1:]:
final_mask = np.maximum(final_mask, np.array(mask)) # Combine masks using max
# Apply the mask to the image
img_array = np.array(img) # Convert image to numpy array
final_mask = final_mask.astype(bool) # Convert mask to boolean
img_array[final_mask] = [255,255,255] # Set unmasked regions to black
# Convert back to PIL image
segmented_img = Image.fromarray(img_array)
return segmented_img
def encode_image(image):
"""Encode image into an embedding."""
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
embedding = model.get_image_features(**inputs).cpu().numpy() # Move to CPU for stability
embedding = embedding / torch.linalg.norm(torch.tensor(embedding), ord=2, dim=-1, keepdim=True)
embedding = embedding.numpy().astype(np.float32).flatten()
return embedding
from PIL import Image
import torchvision.transforms as T
def search_images_by_image(uploaded_image, df, embeddings,embeddings_urls, top_k=30):
# Convert to PIL
if isinstance(uploaded_image, str):
uploaded_image = Image.open(uploaded_image).convert("RGB")
elif isinstance(uploaded_image, np.ndarray):
uploaded_image = Image.fromarray(uploaded_image).convert("RGB")
# Encode with CLIP
image_emb = encode_image(uploaded_image)
# Similarity against ALL embeddings
sims = cosine_similarity([image_emb], embeddings)[0]
top_indices = np.argsort(sims)[::-1][:top_k]
top_urls = [embeddings_urls[i] for i in top_indices]
# Build metadata in the same order as top_urls
df_subset = df[df["url"].isin(top_urls)].copy()
records = df_subset.to_dict(orient="records")
by_url = {}
for r in records:
u = r.get("url")
if u is not None and u not in by_url:
by_url[u] = r
ordered_metadata = [by_url[u] for u in top_urls if u in by_url]
return top_urls, ordered_metadata
def search_images_by_text(text, df, embeddings, embeddings_urls, top_k=30):
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)
with torch.no_grad():
text_emb = model.get_text_features(**inputs).cpu().numpy()
# Build URL -> index map once per call
url_to_index = {str(url): idx for idx, url in enumerate(embeddings_urls)}
# Collect indices of embeddings corresponding to filtered df URLs
filtered_urls = df["url"].astype(str).tolist()
filtered_indices = [url_to_index[u] for u in filtered_urls if u in url_to_index]
if not filtered_indices:
return [], []
embeddings_filtered = embeddings[filtered_indices]
sims = cosine_similarity(text_emb, embeddings_filtered)[0]
sims = np.asarray(sims).flatten()
# Rank within the filtered set
top_indices_local = np.argsort(sims)[::-1][:top_k]
# Map local ranks back to URLs in the same order, dedupe while preserving order
ranked_urls = [embeddings_urls[filtered_indices[i]] for i in top_indices_local]
seen = set()
top_urls = []
for u in ranked_urls:
if u not in seen:
seen.add(u)
top_urls.append(u)
# Build metadata in the same order as top_urls
df_subset = df[df["url"].isin(top_urls)].copy()
records = df_subset.to_dict(orient="records")
by_url = {}
for r in records:
u = r.get("url")
if u is not None and u not in by_url:
by_url[u] = r
ordered_metadata = [by_url[u] for u in top_urls if u in by_url]
return top_urls, ordered_metadata
def get_similar_images(df, image_key, embeddings, embedding_map, embeddings_urls, top_k=5):
if image_key not in embedding_map:
return pd.DataFrame() # fallback: no match found
index = embedding_map[image_key]
query_emb = embeddings[index]
sims = cosine_similarity([query_emb], embeddings)[0]
top_indices = np.argsort(sims)[::-1][1:top_k+1] # skip itself
top_urls = [embeddings_urls[i] for i in top_indices]
metadata = df[df["url"].isin(top_urls)].copy().to_dict(orient="records")
return top_urls, metadata