File size: 1,847 Bytes
3eaabcf
fac2e05
 
3eaabcf
fac2e05
 
3eaabcf
fac2e05
 
 
 
 
 
 
3eaabcf
 
fac2e05
 
3eaabcf
 
fac2e05
3eaabcf
 
fac2e05
3eaabcf
 
fac2e05
 
3eaabcf
 
fac2e05
 
3eaabcf
fac2e05
 
 
3eaabcf
 
fac2e05
 
 
3eaabcf
 
 
fac2e05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import json
import faiss
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
from PIL import Image
import io

# Fix caching permissions for Hugging Face
os.environ["HF_HOME"] = "./cache"
os.environ["TRANSFORMERS_CACHE"] = "./cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "./cache"

app = FastAPI()

# Enable CORS (for frontend HTML to connect)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load product metadata
with open("id_mapping.json", "r", encoding="utf-8") as f:
    products = json.load(f)

# Load FAISS index
index = faiss.read_index("products.index")

# Load CLIP model
print("🧠 Loading CLIP model...")
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32", cache_folder="./cache")


@app.get("/")
def root():
    return {"message": "🚀 Visual Product Matcher API is running!"}


@app.post("/search_text")
def search_text(query: str = Form(...), top_k: int = 5):
    """
    Search products using text query.
    """
    query_emb = model.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_emb, top_k)
    results = [products[i] for i in indices[0]]
    return {"query": query, "results": results}


@app.post("/search_image")
async def search_image(file: UploadFile = File(...), top_k: int = 5):
    """
    Search products using image query.
    """
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    image_emb = model.encode([image], convert_to_numpy=True)
    distances, indices = index.search(image_emb, top_k)
    results = [products[i] for i in indices[0]]
    return {"results": results}