prodrec / main.py
Ishaan Shah
live inference api
6e7c9db
from fastapi import FastAPI, HTTPException, BackgroundTasks
import joblib
import uuid
import asyncio
from pydantic import BaseModel
from typing import Dict, List
# Load your model and vectorizer
model = joblib.load("./model.pkl")
vectorizer = joblib.load("./vectorizer.pkl")
order_centroids = model.cluster_centers_.argsort()[:, ::-1]
terms = vectorizer.get_feature_names_out()
# Simulate function to show recommendations
def show_recommendations(product):
Y = vectorizer.transform([product])
prediction = model.predict(Y)
return int(prediction[0]) # Ensure the prediction is a native Python int
# Get terms associated with a cluster
def get_cluster_terms(cluster_index):
cluster_terms = [terms[ind] for ind in order_centroids[cluster_index, :10]]
return cluster_terms
app = FastAPI()
# In-memory store for inference batches
inferences: Dict[str, Dict] = {}
class BatchRequest(BaseModel):
products: List[str]
class BatchResponse(BaseModel):
inferenceId: str
class BatchInferenceResponse(BaseModel):
inferenceId: str
status: str
results: List[Dict]
def process_batch(inferenceId, products):
results = []
for product in products:
cluster_index = show_recommendations(product)
cluster_terms = get_cluster_terms(cluster_index)
results.append({"product": product, "cluster": cluster_index, "top_terms": cluster_terms})
inferences[inferenceId]["status"] = "completed"
inferences[inferenceId]["result"] = results
@app.post("/inference/batch", response_model=BatchResponse)
async def start_get_recommendations_batch(batch_request: BatchRequest, background_tasks: BackgroundTasks):
inferenceId = str(uuid.uuid4())
inferences[inferenceId] = {"status": "in_progress", "result": []}
background_tasks.add_task(process_batch, inferenceId, batch_request.products)
return BatchResponse(inferenceId=inferenceId)
@app.get("/inference/batch/{inferenceId}", response_model=BatchInferenceResponse)
async def get_recommendations_batch(inferenceId: str):
if inferenceId not in inferences:
raise HTTPException(status_code=404, detail="Inference ID not found")
inference = inferences[inferenceId]
return BatchInferenceResponse(inferenceId=inferenceId, status=inference["status"], results=inference["result"])
class InferenceRequest(BaseModel):
product: str
class InferenceResponse(BaseModel):
cluster: int
top_terms: List[str]
# Add a new endpoint for single inferences
@app.post("/inference", response_model=InferenceResponse)
def get_recommendations(inference_request: InferenceRequest):
cluster_index = show_recommendations(inference_request.product)
cluster_terms = get_cluster_terms(cluster_index)
return InferenceResponse(cluster=cluster_index, top_terms=cluster_terms)