2nzi commited on
Commit
15a293c
1 Parent(s): 9d069d4

update main

Browse files
Files changed (2) hide show
  1. best_model_XGBoost.pkl +2 -2
  2. main.py +43 -13
best_model_XGBoost.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6528807e265ed511558c91021f0959beca72f818c8c0328034556a2e091938fb
3
- size 1174163
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7af777eeb0d05a44845a8dcecaebdfaaf3101c93b318a1ba5f723e569e8c8b5f
3
+ size 1174186
main.py CHANGED
@@ -1,9 +1,11 @@
1
  import uvicorn
2
- import pandas as pd
3
  from pydantic import BaseModel
4
  from typing import List, Union
5
  from fastapi import FastAPI
6
  import joblib
 
 
7
 
8
  description = """
9
  Welcome to the GetAround Car Value Prediction API. This app provides an endpoint to predict car values based on various features! Try it out 🕹️
@@ -30,42 +32,70 @@ app = FastAPI(
30
  version="0.1",
31
  contact={
32
  "name": "Antoine VERDON",
33
- "email": "antoineverdon.pro@gmail.com", # Replace with actual email
34
  },
35
  openapi_tags=tags_metadata
36
  )
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  class PredictionFeatures(BaseModel):
40
- CarData: List[Union[str, int, bool]] = ["Renault", 193231, 85, "diesel", "black", "estate", False, True, False, False, False, False, True]
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- @app.get("/", tags=["Introduction Endpoints"])
43
  async def index():
44
  return (
45
- "Hello world! This `/` is the most simple and default endpoint. If you want to learn more, check out documentation of the API at https://2nzi-getaroundapi.hf.space/docs."
 
 
 
46
  )
47
- # return "Hello world! This `/` is the most simple and default endpoint. If you want to learn more, check out documentation of the API at `/docs https://2nzi-getaroundapi.hf.space/docs`"
48
 
49
  @app.post("/predict", tags=["Machine Learning"])
50
  async def predict(predictionFeatures: PredictionFeatures):
51
  columns = [
52
- 'model_key', 'mileage', 'engine_power', 'fuel', 'paint_color',
53
  'car_type', 'private_parking_available', 'has_gps',
54
  'has_air_conditioning', 'automatic_car', 'has_getaround_connect',
55
  'has_speed_regulator', 'winter_tires'
56
  ]
57
 
58
- car_data_dict = {col: [val] for col, val in zip(columns, predictionFeatures.CarData)}
59
  car_data = pd.DataFrame(car_data_dict)
60
 
61
- # model_file = hf_hub_download(repo_id="2nzi/GetAround-CarPrediction", filename="best_model_XGBoost.pkl")
62
- # with open(model_file, 'rb') as f:
63
- # model = pickle.load(f)
64
-
65
  model = joblib.load('best_model_XGBoost.pkl')
66
  prediction = model.predict(car_data)
67
  response = {"prediction": prediction.tolist()[0]}
68
  return response
69
 
70
- if __name__=="__main__":
71
  uvicorn.run(app, host="0.0.0.0", port=4000)
 
1
  import uvicorn
2
+ import pandas as pd
3
  from pydantic import BaseModel
4
  from typing import List, Union
5
  from fastapi import FastAPI
6
  import joblib
7
+ from enum import Enum
8
+ from fastapi.responses import HTMLResponse
9
 
10
  description = """
11
  Welcome to the GetAround Car Value Prediction API. This app provides an endpoint to predict car values based on various features! Try it out 🕹️
 
32
  version="0.1",
33
  contact={
34
  "name": "Antoine VERDON",
35
+ "email": "antoineverdon.pro@gmail.com",
36
  },
37
  openapi_tags=tags_metadata
38
  )
39
 
40
+ class CarBrand(str, Enum):
41
+ citroen = "Citroën"
42
+ peugeot = "Peugeot"
43
+ pgo = "PGO"
44
+ renault = "Renault"
45
+ audi = "Audi"
46
+ bmw = "BMW"
47
+ other = "other"
48
+ mercedes = "Mercedes"
49
+ opel = "Opel"
50
+ volkswagen = "Volkswagen"
51
+ ferrari = "Ferrari"
52
+ maserati = "Maserati"
53
+ mitsubishi = "Mitsubishi"
54
+ nissan = "Nissan"
55
+ seat = "SEAT"
56
+ subaru = "Subaru"
57
+ toyota = "Toyota"
58
 
59
  class PredictionFeatures(BaseModel):
60
+ brand: CarBrand
61
+ mileage: int
62
+ engine_power: int
63
+ fuel: str
64
+ paint_color: str
65
+ car_type: str
66
+ private_parking_available: bool
67
+ has_gps: bool
68
+ has_air_conditioning: bool
69
+ automatic_car: bool
70
+ has_getaround_connect: bool
71
+ has_speed_regulator: bool
72
+ winter_tires: bool
73
 
74
+ @app.get("/", response_class=HTMLResponse, tags=["Introduction Endpoints"])
75
  async def index():
76
  return (
77
+ "Hello world! This `/` is the most simple and default endpoint. "
78
+ "If you want to learn more, check out documentation of the API at "
79
+ "<a href='/docs'>/docs</a> or "
80
+ "<a href='https://2nzi-getaroundapi.hf.space/docs' target='_blank'>external docs</a>."
81
  )
 
82
 
83
  @app.post("/predict", tags=["Machine Learning"])
84
  async def predict(predictionFeatures: PredictionFeatures):
85
  columns = [
86
+ 'brand', 'mileage', 'engine_power', 'fuel', 'paint_color',
87
  'car_type', 'private_parking_available', 'has_gps',
88
  'has_air_conditioning', 'automatic_car', 'has_getaround_connect',
89
  'has_speed_regulator', 'winter_tires'
90
  ]
91
 
92
+ car_data_dict = {col: [getattr(predictionFeatures, col)] for col in columns}
93
  car_data = pd.DataFrame(car_data_dict)
94
 
 
 
 
 
95
  model = joblib.load('best_model_XGBoost.pkl')
96
  prediction = model.predict(car_data)
97
  response = {"prediction": prediction.tolist()[0]}
98
  return response
99
 
100
+ if __name__ == "__main__":
101
  uvicorn.run(app, host="0.0.0.0", port=4000)