clfegg's picture
Update handler.py
97be00a verified
raw
history blame
No virus
1.07 kB
from typing import Dict, List, Any
import pickle
import numpy as np
import pandas as pd
import os
class EndpointHandler:
def __init__(self, path=""):
model_path = os.path.join(path, "model.pkl")
with open(model_path, 'rb') as f:
self.model = pickle.load(f)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
user_id (:obj: `str` or `int`)
k (:obj: `int`, optional)
Return:
A :obj:`list` of :obj:`dict`: will be serialized and returned
"""
user_id = data.pop("user_id", None)
k = data.pop("k", 10) # Default to 10 if not provided
if user_id is None:
return [{"error": "user_id is required"}]
try:
recommended_books = self.model.predict(user_id, k=k)
return [{"recommended_books": recommended_books.tolist()}]
except Exception as e:
return [{"error": str(e)}]
def load_model(model_path):
handler = EndpointHandler(model_path)
return handler