import mlflow.sklearn from flask import Flask, request, jsonify from joblib import load import requests import sys def download_model(model_url, local_path): """Downloads a model from the specified URL and saves it locally.""" try: response = requests.get(model_url) with open(local_path, 'wb') as f: f.write(response.content) except Exception as e: print(f"Error downloading model: {e}") raise # Re-raise the exception for proper handling def load_model_and_log(model_path): """Loads a pre-trained Scikit-learn model from the specified path and logs it with MLflow.""" try: # Load the model using joblib model = load(model_path) # Log the loaded model with MLflow mlflow.sklearn.log_model(model, "loaded_model") # Get the active MLflow run (the one you just logged) active_run = mlflow.active_run() # Get the run ID run_id = active_run.info.run_id print("Run ID:", run_id) return model, run_id except Exception as e: print(f"Error loading model: {e}") raise # Re-raise the exception for proper handling def load_model_from_mlflow(run_id): """Loads a model from MLflow based on the provided run ID.""" try: # Load the model from MLflow loaded_model = mlflow.sklearn.load_model(f"runs:/{run_id}/loaded_model") return loaded_model except Exception as e: print(f"Error loading model from MLflow: {e}") raise # Re-raise the exception for proper handling def predict(model): """Makes predictions using the provided model.""" try: input_data = request.get_json() prediction = model.predict(input_data) return jsonify({'prediction': prediction.tolist()}) except Exception as e: print(f"Error making prediction: {e}") return jsonify({'error': str(e)}), 500 # Internal Server Error app = Flask(__name__) # Get model URL from command line argument model_url = sys.argv[1] # Path to save the downloaded model locally local_model_path = "/tmp/model.joblib" # Download the model from the specified URL download_model(model_url, local_model_path) # Load the model on app startup and log it with MLflow model, run_id = load_model_and_log(local_model_path) # Load the model from MLflow based on the run ID loaded_model = load_model_from_mlflow(run_id) @app.route('/predict', methods=['POST']) def inference(): return predict(loaded_model) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000)