sklearntest / mlserve.py
mayurchoubey123's picture
Create mlserve.py
b4f480a verified
raw
history blame contribute delete
No virus
2.58 kB
import mlflow.sklearn
from flask import Flask, request, jsonify
from joblib import load
import requests
import sys
def download_model(model_url, local_path):
"""Downloads a model from the specified URL and saves it locally."""
try:
response = requests.get(model_url)
with open(local_path, 'wb') as f:
f.write(response.content)
except Exception as e:
print(f"Error downloading model: {e}")
raise # Re-raise the exception for proper handling
def load_model_and_log(model_path):
"""Loads a pre-trained Scikit-learn model from the specified path and logs it with MLflow."""
try:
# Load the model using joblib
model = load(model_path)
# Log the loaded model with MLflow
mlflow.sklearn.log_model(model, "loaded_model")
# Get the active MLflow run (the one you just logged)
active_run = mlflow.active_run()
# Get the run ID
run_id = active_run.info.run_id
print("Run ID:", run_id)
return model, run_id
except Exception as e:
print(f"Error loading model: {e}")
raise # Re-raise the exception for proper handling
def load_model_from_mlflow(run_id):
"""Loads a model from MLflow based on the provided run ID."""
try:
# Load the model from MLflow
loaded_model = mlflow.sklearn.load_model(f"runs:/{run_id}/loaded_model")
return loaded_model
except Exception as e:
print(f"Error loading model from MLflow: {e}")
raise # Re-raise the exception for proper handling
def predict(model):
"""Makes predictions using the provided model."""
try:
input_data = request.get_json()
prediction = model.predict(input_data)
return jsonify({'prediction': prediction.tolist()})
except Exception as e:
print(f"Error making prediction: {e}")
return jsonify({'error': str(e)}), 500 # Internal Server Error
app = Flask(__name__)
# Get model URL from command line argument
model_url = sys.argv[1]
# Path to save the downloaded model locally
local_model_path = "/tmp/model.joblib"
# Download the model from the specified URL
download_model(model_url, local_model_path)
# Load the model on app startup and log it with MLflow
model, run_id = load_model_and_log(local_model_path)
# Load the model from MLflow based on the run ID
loaded_model = load_model_from_mlflow(run_id)
@app.route('/predict', methods=['POST'])
def inference():
return predict(loaded_model)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)