abhishekrs4's picture
code formatting
44066b7
raw history blame
No virus
2.52 kB
import os
import cv2
import json
import torch
import logging
import numpy as np
from flask import Flask, jsonify, request
from flask.wrappers import Response
from iam_line_recognition.model_main import CRNN
from iam_line_recognition.utils import ctc_decode
from iam_line_recognition.dataset import HWRecogIAMDataset
app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
file_model_local = "artifacts/crnn_H_32_W_768_E_196.pth"
file_model_cont = "/data/models/crnn_H_32_W_768_E_196.pth"
device = "cpu"
num_classes = len(HWRecogIAMDataset.LABEL_2_CHAR) + 1
image_height = 32
mean_arr = np.array([[0.485, 0.456, 0.406]])
std_arr = np.array([[0.229, 0.224, 0.225]])
hw_recog_model = CRNN(num_classes, image_height)
try:
logging.info(f"loading model from {file_model_local}")
hw_recog_model.load_state_dict(torch.load(file_model_local, map_location=device))
except:
logging.info(f"loading model from {file_model_cont}")
hw_recog_model.load_state_dict(torch.load(file_model_cont, map_location=device))
hw_recog_model.to(device)
hw_recog_model.eval()
def predict_hw(img_test: np.ndarray) -> str:
img_test = np.expand_dims(img_test, 0)
img_test = img_test.astype(np.float32) / 255.0
img_test = (img_test - mean_arr) / std_arr
img_test = np.transpose(img_test, axes=[0, 3, 1, 2])
img_tensor = torch.tensor(img_test).float()
img_tensor = img_tensor.to(device, dtype=torch.float)
log_probs = hw_recog_model(img_tensor)
pred_labels = ctc_decode(log_probs.detach())
str_pred = [HWRecogIAMDataset.LABEL_2_CHAR[i] for i in pred_labels[0]]
str_pred = "".join(str_pred)
return str_pred
@app.route("/predict", methods=["POST"])
def predict() -> Response:
logging.info("IAM Handwriting recognition app")
img_file = request.files["image_file"]
try:
img_arr = np.fromstring(img_file.read(), np.uint8)
except:
img_arr = np.fromstring(img_file.getvalue(), np.uint8)
img_dec = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
img_dec = cv2.cvtColor(img_dec, cv2.COLOR_BGR2RGB)
img_dec = cv2.resize(img_dec, (768, 32), interpolation=cv2.INTER_LINEAR)
str_pred = predict_hw(img_dec)
dict_pred = {
"file_name": img_file.filename,
"prediction": str_pred,
}
logging.info(dict_pred)
try:
json_pred = jsonify(dict_pred)
except TypeError as e:
json_pred = jsonify({"error": str(e)})
return json_pred
if __name__ == "__main__":
app.run(host="0.0.0.0", debug=True, port=7860)