Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Query | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer, util | |
from datasets import load_dataset | |
from typing import List | |
import numpy as np | |
import base64 | |
from PIL import Image | |
from io import BytesIO | |
app = FastAPI() | |
def root(): | |
return {"message": "Welcome to the Product Search API!"} | |
def encode_image_to_base64(image): | |
""" | |
Converts a PIL Image or an image-like object to a Base64-encoded string. | |
""" | |
if isinstance(image, Image.Image): | |
buffer = BytesIO() | |
image.save(buffer, format="PNG") | |
return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
return None | |
# Initialize FastAPI | |
# Load Dataset | |
dataset = load_dataset("ashraq/fashion-product-images-small", split="train") | |
# Define fields for embedding | |
fields_for_embedding = [ | |
"productDisplayName", | |
"usage", | |
"season", | |
"baseColour", | |
"articleType", | |
"subCategory", | |
"masterCategory", | |
"gender", | |
] | |
# Prepare Data | |
data = [] | |
for item in dataset: | |
data.append({ | |
"productDisplayName": item["productDisplayName"], | |
"usage": item["usage"], | |
"season": item["season"], | |
"baseColour": item["baseColour"], | |
"articleType": item["articleType"], | |
"subCategory": item["subCategory"], | |
"masterCategory": item["masterCategory"], | |
"gender": item["gender"], | |
"year": item["year"], | |
"image": item["image"], | |
}) | |
# Load Sentence Transformer Model | |
model = SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1") | |
# Generate Embeddings | |
def create_combined_text(item): | |
return " ".join([str(item[field]) for field in fields_for_embedding if item[field]]) | |
texts = [create_combined_text(item) for item in data] | |
embeddings = model.encode(texts, convert_to_tensor=True) | |
# Response Model | |
class ProductResponse(BaseModel): | |
productDisplayName: str | |
usage: str | |
season: str | |
baseColour: str | |
articleType: str | |
subCategory: str | |
masterCategory: str | |
gender: str | |
year: int | |
image: str # Base64 encoded string | |
def search_products( | |
query: str = Query("", title="Search Query", description="Search term for products"), | |
page: int = Query(1, ge=1, title="Page Number"), | |
items_per_page: int = Query(10, ge=1, le=100, title="Items Per Page"), | |
): | |
# Perform Search | |
if query: | |
query_embedding = model.encode(query, convert_to_tensor=True) | |
scores = util.cos_sim(query_embedding, embeddings).squeeze().tolist() | |
ranked_indices = np.argsort(scores)[::-1] | |
else: | |
ranked_indices = np.arange(len(data)) | |
# Pagination | |
total_items = len(ranked_indices) | |
total_pages = (total_items + items_per_page - 1) // items_per_page | |
start_idx = (page - 1) * items_per_page | |
end_idx = start_idx + items_per_page | |
paginated_indices = ranked_indices[start_idx:end_idx] | |
# Prepare Response | |
results = [] | |
for idx in paginated_indices: | |
item = data[idx] | |
results.append({ | |
"productDisplayName": item["productDisplayName"], | |
"usage": item["usage"], | |
"season": item["season"], | |
"baseColour": item["baseColour"], | |
"articleType": item["articleType"], | |
"subCategory": item["subCategory"], | |
"masterCategory": item["masterCategory"], | |
"gender": item["gender"], | |
"year": item["year"], | |
"image": encode_image_to_base64(item["image"]), | |
}) | |
# Construct the API response | |
return { | |
"status": 200, | |
"data": results, | |
"totalpages": total_pages, | |
"currentpage": page | |
} |