1vash's picture
upd: refactored
531d00a verified
import os
import time
import numpy as np
from PIL import Image
from pathlib import Path
# Disable tensorflow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow import keras
from flask import Flask, jsonify, request, render_template
load_type = 'remote_hub_from_pretrained'
"""
local;
remote_hub_download;
remote_hub_from_pretrained;
remote_hub_pipeline; - needs config.json and this is not easy to grasp how to do it with custom models
https://discuss.huggingface.co/t/how-to-create-a-config-json-after-saving-a-model/10459/4
"""
REPO_ID = "1vash/mnist_demo_model"
MODEL_DIR = "./artifacts/models"
# Load the saved model into memory
if load_type == 'local':
model = keras.models.load_model(f'{MODEL_DIR}/mnist_model.h5')
elif load_type == 'remote_hub_download':
from huggingface_hub import hf_hub_download
model = keras.models.load_model(hf_hub_download(repo_id=REPO_ID, filename="saved_model.pb"))
elif load_type == 'remote_hub_from_pretrained':
# https://huggingface.co/docs/hub/keras
os.environ['TRANSFORMERS_CACHE'] = str(Path(MODEL_DIR).absolute())
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras(REPO_ID, cache_dir=MODEL_DIR)
elif load_type == 'remote_hub_pipeline':
from transformers import pipeline
model = pipeline("image-classification", model=REPO_ID)
else:
raise AssertionError('No load type is specified!')
# Initialize the Flask application
app = Flask(__name__)
# API route for prediction
@app.route('/predict', methods=['POST'])
def predict():
"""
Predicts the class label of an input image.
Request format:
{
"image": [[pixel_values_gray]]
}
Response format:
{
"label": predicted_label,
"pred_proba" prediction class probability
"ml-latency-ms": latency_in_milliseconds
(Measures time only for ML operations preprocessing with predict)
}
"""
if 'image' not in request.files:
# Handle if no file is selected
return 'No file selected'
start_time = time.time()
file = request.files['image']
# Get pixels out of file
image_data = Image.open(file)
# Check image shape
if image_data.size != (28, 28):
return "Invalid image shape. Expected (28, 28), take from 'demo images' folder."
# Preprocess the image
processed_image = preprocess_image(image_data)
# Make a prediction, verbose=0 to disable progress bar in logs
prediction = model.predict(processed_image, verbose=0)
# Get the predicted class label
predicted_label = np.argmax(prediction)
proba = prediction[0][predicted_label]
# Calculate latency in milliseconds
latency_ms = (time.time() - start_time) * 1000
# Return the prediction result and latency as dictionary response
response = {
'label': int(predicted_label),
'pred_proba': float(proba),
'ml-latency-ms': round(latency_ms, 4)
}
# dictionary is not a JSON: https://www.quora.com/What-is-the-difference-between-JSON-and-a-dictionary
# flask.jsonify vs json.dumps https://sentry.io/answers/difference-between-json-dumps-and-flask-jsonify/
# The flask.jsonify() function returns a Response object with Serializable JSON and content_type=application/json.
return jsonify(response)
# Helper function to preprocess the image
def preprocess_image(image_data):
"""Preprocess image for Model Inference
:param image_data: Raw image
:return: image: Preprocessed Image
"""
# Resize the image to match the input shape of the model
image = np.array(image_data).reshape(1, 28, 28)
# Normalize the pixel values
image = image.astype('float32') / 255.0
return image
# API route for health check
@app.route('/health', methods=['GET'])
def health():
"""
Health check API to ensure the application is running.
Returns "OK" if the application is healthy.
Demo Usage: "curl http://localhost:5000/health" or using alias "curl http://127.0.0.1:5000/health"
"""
return 'OK'
# API route for version
@app.route('/version', methods=['GET'])
def version():
"""
Returns the version of the application.
Demo Usage: "curl http://127.0.0.1:5000/version" or using alias "curl http://127.0.0.1:5000/version"
"""
return '1.0'
@app.route("/")
def hello_world():
return render_template("index.html")
# return "<p>Hello, Team!</p>"
# Start the Flask application
if __name__ == '__main__':
app.run(debug=True)