Olivier-52 commited on
Commit
fe10113
·
1 Parent(s): 7d09baa

Fix app.py

Browse files
Files changed (2) hide show
  1. app.py +97 -73
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,90 +1,114 @@
1
  import os
2
- import uvicorn
3
- import pandas as pd
4
- from pydantic import BaseModel
5
- from fastapi import FastAPI, File, UploadFile
6
  import mlflow
7
- from xgboost import XGBClassifier
 
 
8
  from dotenv import load_dotenv
 
9
 
10
- description = """
11
-
12
- # Climate Fake News Detector(https://github.com/Olivier-52/Fake_news_detector.git)
13
-
14
- This API allows you to use a Machine Learning model to detect fake news related to climate change.
15
-
16
- ## Machine-Learning
17
-
18
- Where you can:
19
- * `/predict` : prediction for a single value
20
-
21
- Check out documentation for more information on each endpoint.
22
- """
23
-
24
- tags_metadata = [
25
- {
26
- "name": "Predictions",
27
- "description": "Endpoints that uses our Machine Learning model",
28
- },
29
- ]
30
-
31
  load_dotenv()
32
 
 
 
 
 
 
 
33
  os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
34
  os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
35
 
36
- MLFLOW_TRACKING_URI = os.environ["MLFLOW_TRACKING_URI"]
37
- mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
38
-
39
- mlflow.set_tracking_uri("https://olivier-52-ml-flow.hf.space")
40
-
41
- model = mlflow.sklearn.load_model("models:/climate-fake-news-detector-model-XGBoost-v1@production")
42
-
43
  app = FastAPI(
44
- title="API for Climate Fake News Detector",
45
- description=description,
46
- version="1.0",
47
- contact={
48
- "name": "Olivier",
49
- "url": "https://github.com/Olivier-52/Fake_news_detector",
50
- },
51
- openapi_tags=tags_metadata,)
52
-
53
- @app.get("/")
54
- def index():
55
- """Return a message to the user.
56
-
57
- This endpoint does not take any parameters and returns a message
58
- to the user. It is used to test the API.
59
 
60
- Returns:
61
- str: A message to the user.
62
- """
63
- return "Hello world! Go to /docs to try the API."
64
-
65
-
66
- class PredictionFeatures(BaseModel):
67
  text: str
68
 
69
- @app.post("/predict", tags=["Predictions"])
70
- def predict(features: PredictionFeatures):
71
- """Predict Climate Fake News.
72
-
73
- This endpoint takes a text as input and returns the predicted class : fake, real, or biased.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- Args:
76
- features (PredictionFeatures): A PredictionFeatures object
77
- containing the text to analyze.
78
-
79
- Returns:
80
- dict: A dictionary containing the predicted class.
81
- """
82
- df = pd.DataFrame({
83
- "text": [features.text],
84
- })
 
 
 
 
 
 
 
 
 
 
85
 
86
- prediction = model.predict(df)[0]
87
- return {"prediction": float(prediction)}
 
 
 
88
 
89
  if __name__ == "__main__":
90
- uvicorn.run(app, host="localhost", port=8001)
 
 
1
  import os
 
 
 
 
2
  import mlflow
3
+ import pickle
4
+ from fastapi import FastAPI, HTTPException, status
5
+ from pydantic import BaseModel
6
  from dotenv import load_dotenv
7
+ from typing import Optional
8
 
9
+ # Charge les variables d'environnement
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  load_dotenv()
11
 
12
+ # Configuration des variables d'environnement
13
+ MLFLOW_TRACKING_APP_URI = os.getenv("MLFLOW_TRACKING_APP_URI", "https://olivier-52-ml-flow.hf.space")
14
+ MODEL_NAME = os.getenv("MODEL_NAME", "climate-fake-news-detector-model-XGBoost-v1")
15
+ STAGE = os.getenv("STAGE", "production")
16
+
17
+ # Configure les identifiants AWS pour accéder au bucket S3
18
  os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
19
  os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
20
 
21
+ # Initialise FastAPI
 
 
 
 
 
 
22
  app = FastAPI(
23
+ title="Climate Fake News Detector API",
24
+ description="API pour détecter les fake news sur le climat avec un modèle XGBoost.",
25
+ version="1.0.0"
26
+ )
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Modèle pour les données d'entrée
29
+ class TextInput(BaseModel):
 
 
 
 
 
30
  text: str
31
 
32
+ # Variables globales pour stocker le modèle et le vectorizer
33
+ model = None
34
+ vectorizer = None
35
+
36
+ # Fonction pour charger le modèle depuis MLflow
37
+ def load_model():
38
+ global model
39
+ try:
40
+ # Configure l'URI de tracking MLflow
41
+ mlflow.set_tracking_uri(MLFLOW_TRACKING_APP_URI)
42
+
43
+ # Charge le modèle depuis MLflow
44
+ model_uri = f"models:/{MODEL_NAME}@{STAGE}"
45
+ model = mlflow.sklearn.load_model(model_uri)
46
+ print("Modèle chargé avec succès depuis MLflow.")
47
+ except Exception as e:
48
+ print(f"Erreur lors du chargement du modèle depuis MLflow : {e}")
49
+ raise HTTPException(
50
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
51
+ detail=f"Impossible de charger le modèle depuis MLflow : {e}"
52
+ )
53
+
54
+ # Fonction pour charger le vectorizer depuis MLflow
55
+ def load_vectorizer():
56
+ try:
57
+ # Initialise le client MLflow
58
+ client = mlflow.MlflowClient(MLFLOW_TRACKING_APP_URI)
59
+
60
+ # Récupère les informations sur le modèle
61
+ model_info = client.get_model_version_by_alias(MODEL_NAME, STAGE)
62
+ run_id = model_info.run_id
63
+
64
+ # Télécharge le fichier vectorizer.pkl depuis MLflow
65
+ local_path = mlflow.artifacts.download_artifacts(
66
+ artifact_path="vectorizer.pkl",
67
+ run_id=run_id
68
+ )
69
+
70
+ # Charge le vectorizer depuis le fichier
71
+ with open(local_path, "rb") as f:
72
+ vectorizer = pickle.load(f)
73
+
74
+ return vectorizer
75
+ except Exception as e:
76
+ print(f"Erreur lors du chargement du vectorizer : {e}")
77
+ raise HTTPException(
78
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
79
+ detail=f"Impossible de charger le vectorizer : {e}"
80
+ )
81
+
82
+ load_model()
83
+ vectorizer = load_vectorizer()
84
 
85
+ @app.get("/")
86
+ async def read_root():
87
+ return {
88
+ "message": "Bienvenue sur l'API Climate Fake News Detector !",
89
+ "documentation": "Consultez la documentation de l'API à l'adresse /docs."
90
+ }
91
+
92
+ @app.post("/predict")
93
+ async def predict(input_data: TextInput):
94
+ global model, vectorizer
95
+ if model is None or vectorizer is None:
96
+ raise HTTPException(
97
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
98
+ detail="Le modèle ou le vectorizer n'est pas chargé."
99
+ )
100
+
101
+ try:
102
+ X_vectorized = vectorizer.transform([input_data.text]).toarray()
103
+ prediction = model.predict(X_vectorized)
104
+ return {"prediction": int(prediction[0])}
105
 
106
+ except Exception as e:
107
+ raise HTTPException(
108
+ status_code=status.HTTP_400_BAD_REQUEST,
109
+ detail=f"Erreur lors de la prédiction : {e}"
110
+ )
111
 
112
  if __name__ == "__main__":
113
+ import uvicorn
114
+ uvicorn.run(app, host="localhost", port=8000)
requirements.txt CHANGED
@@ -11,4 +11,6 @@ openpyxl
11
  boto3
12
  python-multipart
13
  dotenv
14
- xgboost
 
 
 
11
  boto3
12
  python-multipart
13
  dotenv
14
+ xgboost
15
+ pickle
16
+ os