import os import cv2 import json import torch import logging import numpy as np from flask import Flask, jsonify, request from flask.wrappers import Response from cassava_leaf_disease.models import ResNetImageClassififer app = Flask(__name__) logging.basicConfig(level=logging.INFO) file_model_local = "artifacts/cassava_resnet_34.pt" file_model_cont = "/data/models/cassava_resnet_34.pt" device = "cpu" num_classes = 5 image_height = 320 mean_arr = np.array([[0.485, 0.456, 0.406]]) std_arr = np.array([[0.229, 0.224, 0.225]]) cassava_model = ResNetImageClassififer(num_classes, pretrained=False) try: logging.info(f"loading model from {file_model_local}") cassava_model.load_state_dict(torch.load(file_model_local, map_location=device)) except: logging.info(f"loading model from {file_model_cont}") cassava_model.load_state_dict(torch.load(file_model_cont, map_location=device)) cassava_model.to(device) cassava_model.eval() file_json = "label_mapping.json" file_desc_json = open(file_json) label_mapping = json.load(file_desc_json) logging.info(label_mapping) def predict_cassava_disease(img_test: np.ndarray) -> str: img_test = np.expand_dims(img_test, 0) img_test = img_test.astype(np.float32) / 255.0 img_test = (img_test - mean_arr) / std_arr img_test = np.transpose(img_test, axes=[0, 3, 1, 2]) img_tensor = torch.tensor(img_test).float() img_tensor = img_tensor.to(device, dtype=torch.float) pred_logits = cassava_model(img_tensor) pred_label = torch.argmax(pred_logits, dim=1) pred_label_arr = pred_label.detach().cpu().numpy() pred_label_arr = np.squeeze(pred_label_arr) pred_label_str = label_mapping[str(pred_label_arr)] return pred_label_str @app.route("/predict", methods=["POST"]) def predict() -> Response: logging.info("IAM Handwriting recognition app") img_file = request.files["image_file"] try: img_str = np.fromstring(img_file.read(), np.uint8) except: img_str = np.fromstring(img_file.getvalue(), np.uint8) img_dec = cv2.imdecode(img_str, cv2.IMREAD_COLOR) img_dec = cv2.cvtColor(img_dec, cv2.COLOR_BGR2RGB) img_dec = cv2.resize( img_dec, (image_height, image_height), interpolation=cv2.INTER_LINEAR ) str_pred = predict_cassava_disease(img_dec) dict_pred = { "file_name": img_file.filename, "prediction": str_pred, } logging.info(dict_pred) try: json_pred = jsonify(dict_pred) except TypeError as e: json_pred = jsonify({"error": str(e)}) return json_pred if __name__ == "__main__": app.run(host="0.0.0.0", debug=True, port=7860)