getaround-api / app.py
Dreipfelt's picture
Update app.py
5c0f28e verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import joblib
import json
import pandas as pd
import os
# =========================================================
# Load model + features
# =========================================================
try:
pipeline = joblib.load("pipeline.pkl")
except Exception as e:
raise RuntimeError(f"Model loading error: {e}")
# Charger feature_names depuis le fichier (avec gestion d'erreur explicite)
try:
with open("feature_names.json", "r", encoding="utf-8") as f:
feature_names = json.load(f)
except FileNotFoundError:
# Si le fichier est manquant, utilise une liste par défaut (13 features)
feature_names = [
"model_key", "mileage", "engine_power", "fuel", "paint_color", "car_type",
"private_parking_available", "has_gps", "has_air_conditioning",
"automatic_car", "has_getaround_connect", "has_speed_regulator", "winter_tires"
]
except Exception as e:
raise RuntimeError(f"Feature loading error: {e}")
N_FEATURES = len(feature_names)
# =========================================================
# FastAPI app
# =========================================================
app = FastAPI(
title="GetAround Pricing API",
description="Predicts rental price per day",
version="1.0.0",
# Désactive la documentation Swagger par défaut (on utilise /documentation)
docs_url=None,
redoc_url=None,
)
# =========================================================
# Input schema
# =========================================================
class PredictInput(BaseModel):
input: list[list]
# =========================================================
# Root endpoint
# =========================================================
@app.get("/", response_class=HTMLResponse)
def root():
return """
<html>
<body style="font-family: Arial; text-align: center; padding: 50px;">
<h1>🚗 GetAround API</h1>
<a href="/documentation" style="font-size: 1.2em; color: #007BFF; text-decoration: none;">📄 Documentation</a>
</body>
</html>
"""
# =========================================================
# Prediction endpoint
# =========================================================
@app.post("/predict")
def predict(data: PredictInput):
"""
Predict the rental price per day for a vehicle.
**Request Body Example:**
```json
{
"input": [
["Citroën", 50000, 120, "diesel", "black", "sedan", 1, 1, 1, 0, 1, 1, 0]
]
}
```
"""
if not data.input:
raise HTTPException(status_code=400, detail="Empty input")
# Vérification du nombre de features
for i, row in enumerate(data.input):
if len(row) != N_FEATURES:
raise HTTPException(
status_code=400,
detail=f"Row {i} has {len(row)} features, expected {N_FEATURES}"
)
try:
X = pd.DataFrame(data.input, columns=feature_names)
preds = pipeline.predict(X)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
return {"prediction": [round(float(p), 2) for p in preds]}
# =========================================================
# CUSTOM DOCUMENTATION (/documentation)
# =========================================================
@app.get("/documentation", response_class=HTMLResponse)
def documentation():
# Génère les lignes du tableau des features
feature_rows = "".join(
f"<tr><td>{i+1}</td><td>{name}</td></tr>"
for i, name in enumerate(feature_names)
)
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>GetAround API Documentation</title>
<style>
body {{
font-family: Arial;
background: #f5f7fa;
padding: 40px;
color: #333;
}}
h1 {{
text-align: center;
}}
.box {{
background: white;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
box-shadow: 0 2px 10px rgba(0,0,0,0.05);
}}
table {{
width: 100%;
border-collapse: collapse;
}}
th, td {{
padding: 8px;
border-bottom: 1px solid #eee;
text-align: left;
}}
code {{
background: #eee;
padding: 2px 6px;
border-radius: 4px;
}}
pre {{
background: #f8f9fa;
padding: 12px;
border-radius: 8px;
overflow-x: auto;
}}
</style>
</head>
<body>
<h1>🚗 GetAround Pricing API</h1>
<div class="box">
<h2>📌 Endpoint</h2>
<p><code>POST /predict</code> → returns predicted rental price per day</p>
</div>
<div class="box">
<h2>📊 Features</h2>
<table>
<tr><th>#</th><th>Feature</th></tr>
{feature_rows}
</table>
</div>
<div class="box">
<h2>📥 Example request</h2>
<pre>
curl -X POST "https://dreipfelt-getaround-api.hf.space/predict" \
-H "Content-Type: application/json" \
-d '{{"input": [["Citroën", 150000, 120, "diesel", "black", "sedan", 1, 1, 1, 0, 1, 1, 0]]}}'
</pre>
</div>
<div class="box">
<h2>📤 Example response</h2>
<pre>{{"prediction": [89.5]}}</pre>
</div>
</body>
</html>
"""