Sensei / predict_impulse.py
obx0x3's picture
Update predict_impulse.py
aeda93a verified
import joblib
import pandas as pd
import os
from huggingface_hub import hf_hub_download
from opik import Opik, track
# -------------------------------------------------
# Opik API key (HF Spaces / Expo safe)
# -------------------------------------------------
if "OPIK_API_KEY" not in os.environ:
os.environ["OPIK_API_KEY"] = os.environ.get(
"EXPO_PUBLIC_OPIK_API_KEY", ""
)
# -------------------------------------------------
# Load model from HF Model Hub
# -------------------------------------------------
MODEL_REPO = "obx0x3/sensei-model"
MODEL_FILE = "impulse_model.pkl"
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE
)
impulse_model = joblib.load(model_path)
# -------------------------------------------------
# Opik client (event logger)
# -------------------------------------------------
try:
opik_client = Opik(project_name="budgetbuddy-hackathon")
except Exception as e:
print("Opik disabled:", e)
opik_client = None
# -------------------------------------------------
# TRACKED FUNCTION (this creates a TRACK)
# -------------------------------------------------
@track(project_name="budgetbuddy-hackathon")
def predict_impulse(category, amount, payment_method, day):
input_data = {
"category": category,
"amount": float(amount),
"payment_method": payment_method,
"day": day
}
df = pd.DataFrame([input_data])
pred = impulse_model.predict(df)[0]
prob = impulse_model.predict_proba(df)[0].max()
result = {
"impulsive": bool(pred),
"confidence": round(float(prob), 3),
"label": "Impulsive" if pred else "Normal Spend"
}
# -------------------------------------------------
# EVENT inside the TRACK
# -------------------------------------------------
if opik_client:
try:
opik_client.log_event(
name="Impulse Analysis Result",
input=input_data,
output=result,
model="sensei-impulse-model",
metadata={
"ui": "hf-space",
"feature": "impulse-detection",
"confidence_band": (
"high" if prob > 0.75 else
"medium" if prob > 0.5 else
"low"
)
}
)
except Exception as e:
print("Opik logging failed:", e)
return result