|
import mlflow.sklearn |
|
from flask import Flask, request, jsonify |
|
from joblib import load |
|
import requests |
|
import sys |
|
|
|
def download_model(model_url, local_path): |
|
"""Downloads a model from the specified URL and saves it locally.""" |
|
try: |
|
response = requests.get(model_url) |
|
with open(local_path, 'wb') as f: |
|
f.write(response.content) |
|
except Exception as e: |
|
print(f"Error downloading model: {e}") |
|
raise |
|
|
|
def load_model_and_log(model_path): |
|
"""Loads a pre-trained Scikit-learn model from the specified path and logs it with MLflow.""" |
|
try: |
|
|
|
model = load(model_path) |
|
|
|
|
|
mlflow.sklearn.log_model(model, "loaded_model") |
|
|
|
|
|
active_run = mlflow.active_run() |
|
|
|
|
|
run_id = active_run.info.run_id |
|
|
|
print("Run ID:", run_id) |
|
|
|
return model, run_id |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise |
|
|
|
def load_model_from_mlflow(run_id): |
|
"""Loads a model from MLflow based on the provided run ID.""" |
|
try: |
|
|
|
loaded_model = mlflow.sklearn.load_model(f"runs:/{run_id}/loaded_model") |
|
|
|
return loaded_model |
|
except Exception as e: |
|
print(f"Error loading model from MLflow: {e}") |
|
raise |
|
|
|
def predict(model): |
|
"""Makes predictions using the provided model.""" |
|
try: |
|
input_data = request.get_json() |
|
prediction = model.predict(input_data) |
|
return jsonify({'prediction': prediction.tolist()}) |
|
except Exception as e: |
|
print(f"Error making prediction: {e}") |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
model_url = sys.argv[1] |
|
|
|
|
|
local_model_path = "/tmp/model.joblib" |
|
|
|
|
|
download_model(model_url, local_model_path) |
|
|
|
|
|
model, run_id = load_model_and_log(local_model_path) |
|
|
|
|
|
loaded_model = load_model_from_mlflow(run_id) |
|
|
|
@app.route('/predict', methods=['POST']) |
|
def inference(): |
|
return predict(loaded_model) |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=5000) |