Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| import joblib | |
| import json | |
| import pandas as pd | |
| import os | |
| # ========================================================= | |
| # Load model + features | |
| # ========================================================= | |
| try: | |
| pipeline = joblib.load("pipeline.pkl") | |
| except Exception as e: | |
| raise RuntimeError(f"Model loading error: {e}") | |
| # Charger feature_names depuis le fichier (avec gestion d'erreur explicite) | |
| try: | |
| with open("feature_names.json", "r", encoding="utf-8") as f: | |
| feature_names = json.load(f) | |
| except FileNotFoundError: | |
| # Si le fichier est manquant, utilise une liste par défaut (13 features) | |
| feature_names = [ | |
| "model_key", "mileage", "engine_power", "fuel", "paint_color", "car_type", | |
| "private_parking_available", "has_gps", "has_air_conditioning", | |
| "automatic_car", "has_getaround_connect", "has_speed_regulator", "winter_tires" | |
| ] | |
| except Exception as e: | |
| raise RuntimeError(f"Feature loading error: {e}") | |
| N_FEATURES = len(feature_names) | |
| # ========================================================= | |
| # FastAPI app | |
| # ========================================================= | |
| app = FastAPI( | |
| title="GetAround Pricing API", | |
| description="Predicts rental price per day", | |
| version="1.0.0", | |
| # Désactive la documentation Swagger par défaut (on utilise /documentation) | |
| docs_url=None, | |
| redoc_url=None, | |
| ) | |
| # ========================================================= | |
| # Input schema | |
| # ========================================================= | |
| class PredictInput(BaseModel): | |
| input: list[list] | |
| # ========================================================= | |
| # Root endpoint | |
| # ========================================================= | |
| def root(): | |
| return """ | |
| <html> | |
| <body style="font-family: Arial; text-align: center; padding: 50px;"> | |
| <h1>🚗 GetAround API</h1> | |
| <a href="/documentation" style="font-size: 1.2em; color: #007BFF; text-decoration: none;">📄 Documentation</a> | |
| </body> | |
| </html> | |
| """ | |
| # ========================================================= | |
| # Prediction endpoint | |
| # ========================================================= | |
| def predict(data: PredictInput): | |
| """ | |
| Predict the rental price per day for a vehicle. | |
| **Request Body Example:** | |
| ```json | |
| { | |
| "input": [ | |
| ["Citroën", 50000, 120, "diesel", "black", "sedan", 1, 1, 1, 0, 1, 1, 0] | |
| ] | |
| } | |
| ``` | |
| """ | |
| if not data.input: | |
| raise HTTPException(status_code=400, detail="Empty input") | |
| # Vérification du nombre de features | |
| for i, row in enumerate(data.input): | |
| if len(row) != N_FEATURES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Row {i} has {len(row)} features, expected {N_FEATURES}" | |
| ) | |
| try: | |
| X = pd.DataFrame(data.input, columns=feature_names) | |
| preds = pipeline.predict(X) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") | |
| return {"prediction": [round(float(p), 2) for p in preds]} | |
| # ========================================================= | |
| # CUSTOM DOCUMENTATION (/documentation) | |
| # ========================================================= | |
| def documentation(): | |
| # Génère les lignes du tableau des features | |
| feature_rows = "".join( | |
| f"<tr><td>{i+1}</td><td>{name}</td></tr>" | |
| for i, name in enumerate(feature_names) | |
| ) | |
| return f""" | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>GetAround API Documentation</title> | |
| <style> | |
| body {{ | |
| font-family: Arial; | |
| background: #f5f7fa; | |
| padding: 40px; | |
| color: #333; | |
| }} | |
| h1 {{ | |
| text-align: center; | |
| }} | |
| .box {{ | |
| background: white; | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.05); | |
| }} | |
| table {{ | |
| width: 100%; | |
| border-collapse: collapse; | |
| }} | |
| th, td {{ | |
| padding: 8px; | |
| border-bottom: 1px solid #eee; | |
| text-align: left; | |
| }} | |
| code {{ | |
| background: #eee; | |
| padding: 2px 6px; | |
| border-radius: 4px; | |
| }} | |
| pre {{ | |
| background: #f8f9fa; | |
| padding: 12px; | |
| border-radius: 8px; | |
| overflow-x: auto; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>🚗 GetAround Pricing API</h1> | |
| <div class="box"> | |
| <h2>📌 Endpoint</h2> | |
| <p><code>POST /predict</code> → returns predicted rental price per day</p> | |
| </div> | |
| <div class="box"> | |
| <h2>📊 Features</h2> | |
| <table> | |
| <tr><th>#</th><th>Feature</th></tr> | |
| {feature_rows} | |
| </table> | |
| </div> | |
| <div class="box"> | |
| <h2>📥 Example request</h2> | |
| <pre> | |
| curl -X POST "https://dreipfelt-getaround-api.hf.space/predict" \ | |
| -H "Content-Type: application/json" \ | |
| -d '{{"input": [["Citroën", 150000, 120, "diesel", "black", "sedan", 1, 1, 1, 0, 1, 1, 0]]}}' | |
| </pre> | |
| </div> | |
| <div class="box"> | |
| <h2>📤 Example response</h2> | |
| <pre>{{"prediction": [89.5]}}</pre> | |
| </div> | |
| </body> | |
| </html> | |
| """ |