import os import time import numpy as np from PIL import Image from pathlib import Path # Disable tensorflow warnings os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' from tensorflow import keras from flask import Flask, jsonify, request, render_template load_type = 'remote_hub_from_pretrained' """ local; remote_hub_download; remote_hub_from_pretrained; remote_hub_pipeline; - needs config.json and this is not easy to grasp how to do it with custom models https://discuss.huggingface.co/t/how-to-create-a-config-json-after-saving-a-model/10459/4 """ REPO_ID = "1vash/mnist_demo_model" MODEL_DIR = "./artifacts/models" # Load the saved model into memory if load_type == 'local': model = keras.models.load_model(f'{MODEL_DIR}/mnist_model.h5') elif load_type == 'remote_hub_download': from huggingface_hub import hf_hub_download model = keras.models.load_model(hf_hub_download(repo_id=REPO_ID, filename="saved_model.pb")) elif load_type == 'remote_hub_from_pretrained': # https://huggingface.co/docs/hub/keras os.environ['TRANSFORMERS_CACHE'] = str(Path(MODEL_DIR).absolute()) from huggingface_hub import from_pretrained_keras model = from_pretrained_keras(REPO_ID, cache_dir=MODEL_DIR) elif load_type == 'remote_hub_pipeline': from transformers import pipeline model = pipeline("image-classification", model=REPO_ID) else: raise AssertionError('No load type is specified!') # Initialize the Flask application app = Flask(__name__) # API route for prediction @app.route('/predict', methods=['POST']) def predict(): """ Predicts the class label of an input image. Request format: { "image": [[pixel_values_gray]] } Response format: { "label": predicted_label, "pred_proba" prediction class probability "ml-latency-ms": latency_in_milliseconds (Measures time only for ML operations preprocessing with predict) } """ if 'image' not in request.files: # Handle if no file is selected return 'No file selected' start_time = time.time() file = request.files['image'] # Get pixels out of file image_data = Image.open(file) # Check image shape if image_data.size != (28, 28): return "Invalid image shape. Expected (28, 28), take from 'demo images' folder." # Preprocess the image processed_image = preprocess_image(image_data) # Make a prediction, verbose=0 to disable progress bar in logs prediction = model.predict(processed_image, verbose=0) # Get the predicted class label predicted_label = np.argmax(prediction) proba = prediction[0][predicted_label] # Calculate latency in milliseconds latency_ms = (time.time() - start_time) * 1000 # Return the prediction result and latency as dictionary response response = { 'label': int(predicted_label), 'pred_proba': float(proba), 'ml-latency-ms': round(latency_ms, 4) } # dictionary is not a JSON: https://www.quora.com/What-is-the-difference-between-JSON-and-a-dictionary # flask.jsonify vs json.dumps https://sentry.io/answers/difference-between-json-dumps-and-flask-jsonify/ # The flask.jsonify() function returns a Response object with Serializable JSON and content_type=application/json. return jsonify(response) # Helper function to preprocess the image def preprocess_image(image_data): """Preprocess image for Model Inference :param image_data: Raw image :return: image: Preprocessed Image """ # Resize the image to match the input shape of the model image = np.array(image_data).reshape(1, 28, 28) # Normalize the pixel values image = image.astype('float32') / 255.0 return image # API route for health check @app.route('/health', methods=['GET']) def health(): """ Health check API to ensure the application is running. Returns "OK" if the application is healthy. Demo Usage: "curl http://localhost:5000/health" or using alias "curl http://127.0.0.1:5000/health" """ return 'OK' # API route for version @app.route('/version', methods=['GET']) def version(): """ Returns the version of the application. Demo Usage: "curl http://127.0.0.1:5000/version" or using alias "curl http://127.0.0.1:5000/version" """ return '1.0' @app.route("/") def hello_world(): return render_template("index.html") # return "
Hello, Team!
" # Start the Flask application if __name__ == '__main__': app.run(debug=True)