|
import os |
|
import cv2 |
|
import json |
|
import torch |
|
import logging |
|
|
|
import numpy as np |
|
from flask import Flask, jsonify, request |
|
from flask.wrappers import Response |
|
|
|
from iam_line_recognition.model_main import CRNN |
|
from iam_line_recognition.utils import ctc_decode |
|
from iam_line_recognition.dataset import HWRecogIAMDataset |
|
|
|
app = Flask("IAM_Handwriting_Recognition") |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
file_model_local = "artifacts/crnn_H_32_W_768_E_196.pth" |
|
file_model_cont = "/data/models/crnn_H_32_W_768_E_196.pth" |
|
device = "cpu" |
|
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1 |
|
image_height = 32 |
|
mean_arr = np.array([[0.485, 0.456, 0.406]]) |
|
std_arr = np.array([[0.229, 0.224, 0.225]]) |
|
hw_recog_model = CRNN(num_classes, image_height) |
|
|
|
try: |
|
logging.info(f"loading model from {file_model_local}") |
|
hw_recog_model.load_state_dict( |
|
torch.load(file_model_local, map_location=device) |
|
) |
|
except: |
|
logging.info(f"loading model from {file_model_cont}") |
|
hw_recog_model.load_state_dict( |
|
torch.load(file_model_cont, map_location=device) |
|
) |
|
hw_recog_model.to(device) |
|
hw_recog_model.eval() |
|
|
|
|
|
def predict_hw(img_test: np.ndarray) -> str: |
|
img_test = np.expand_dims(img_test, 0) |
|
img_test = img_test.astype(np.float32) / 255.0 |
|
img_test = (img_test - mean_arr) / std_arr |
|
img_test = np.transpose(img_test, axes=[0, 3, 1, 2]) |
|
img_tensor = torch.tensor(img_test).float() |
|
img_tensor = img_tensor.to(device, dtype=torch.float) |
|
log_probs = hw_recog_model(img_tensor) |
|
pred_labels = ctc_decode(log_probs.detach()) |
|
str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[0]] |
|
str_pred = "".join(str_pred) |
|
return str_pred |
|
|
|
@app.route("/predict", methods=["POST"]) |
|
def predict() -> Response: |
|
logging.info("IAM Handwriting recognition app") |
|
img_file = request.files["image_file"] |
|
try: |
|
img_arr = np.fromstring(img_file.read(), np.uint8) |
|
except: |
|
img_arr = np.fromstring(img_file.getvalue(), np.uint8) |
|
img_dec = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) |
|
img_dec = cv2.cvtColor(img_dec, cv2.COLOR_BGR2RGB) |
|
|
|
img_dec = cv2.resize(img_dec, (768, 32), interpolation = cv2.INTER_LINEAR) |
|
|
|
str_pred = predict_hw(img_dec) |
|
|
|
dict_pred = { |
|
"file_name": img_file.filename, |
|
"prediction": str_pred, |
|
} |
|
try: |
|
json_pred = jsonify(dict_pred) |
|
except TypeError as e: |
|
json_pred = jsonify({"error": str(e)}) |
|
logging.info(json_pred) |
|
return json_pred |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", debug=True, port=7860) |
|
|