handwriting-recognition / src /image_processing.py
pmkhanh7890's picture
1st push
6610027
import cv2
import numpy as np
from itertools import groupby
def process_image(image, recognition_input_layer):
# Text detection models expect an image in grayscale format.
# IMPORTANT! This model enables reading only one line at time.
#image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Fetch the shape.
image_height, _ = image.shape
# B,C,H,W = batch size, number of channels, height, width.
_, _, H, W = recognition_input_layer.shape
# Calculate scale ratio between the input shape height and image height to resize the image.
scale_ratio = H / image_height
# Resize the image to expected input sizes.
resized_image = cv2.resize(
image, None, fx=scale_ratio, fy=scale_ratio, interpolation=cv2.INTER_AREA
)
# Pad the image to match input size, without changing aspect ratio.
resized_image = np.pad(
resized_image, ((0, 0), (0, W - resized_image.shape[1])), mode="edge"
)
# Reshape to network input shape.
input_image = resized_image[None, None, :, :]
return input_image
def recognize(image, compiled_model, recognition_input_layer, recognition_output_layer, letters):
input_image = process_image(image, recognition_input_layer)
# Run inference on the model
predictions = compiled_model([input_image])[recognition_output_layer]
# Remove a batch dimension.
predictions = np.squeeze(predictions)
# Run the `argmax` function to pick the symbols with the highest probability.
predictions_indexes = np.argmax(predictions, axis=1)
# Use the `groupby` function to remove concurrent letters, as required by CTC greedy decoding.
output_text_indexes = list(groupby(predictions_indexes))
# Remove grouper objects.
output_text_indexes, _ = np.transpose(output_text_indexes, (1, 0))
# Remove blank symbols.
output_text_indexes = output_text_indexes[output_text_indexes != 0]
# Assign letters to indexes from the output array.
output_text = [letters[letter_index] for letter_index in output_text_indexes]
return output_text