Mayanand's picture
Create recognition.py
f7fb909
from argparse import ArgumentParser
from itertools import groupby
import os
import cv2
import torch
import torch.nn as nn
from torchvision import transforms
import utils_
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output
class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
super(CRNN, self).__init__()
assert imgH % 16 == 0, "imgH has to be a multiple of 16"
ks = [3, 3, 3, 3, 3, 3, 2]
ps = [1, 1, 1, 1, 1, 1, 0]
ss = [1, 1, 1, 1, 1, 1, 1]
nm = [64, 128, 256, 256, 512, 512, 512]
cnn = nn.Sequential()
def convRelu(i, batchNormalization=False):
nIn = nc if i == 0 else nm[i - 1]
nOut = nm[i]
cnn.add_module(
"conv{0}".format(i), nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])
)
if batchNormalization:
cnn.add_module("batchnorm{0}".format(i), nn.BatchNorm2d(nOut))
if leakyRelu:
cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2, inplace=True))
else:
cnn.add_module("relu{0}".format(i), nn.ReLU(True))
convRelu(0)
cnn.add_module("pooling{0}".format(0), nn.MaxPool2d(2, 2)) # 64x16x64
convRelu(1)
cnn.add_module("pooling{0}".format(1), nn.MaxPool2d(2, 2)) # 128x8x32
convRelu(2, True)
convRelu(3)
cnn.add_module(
"pooling{0}".format(2), nn.MaxPool2d((2, 2), (2, 1), (0, 1))
) # 256x4x16
convRelu(4, True)
convRelu(5)
cnn.add_module(
"pooling{0}".format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1))
) # 512x2x16
convRelu(6, True) # 512x1x16
self.cnn = cnn
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass)
)
def forward(self, input):
# conv features
conv = self.cnn(input)
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1) # [w, b, c]
# rnn features
output = self.rnn(conv)
return output
VOCAB = [
"BLANK",
"Z",
"B",
"4",
"X",
"R",
"2",
"U",
"D",
"G",
"Q",
"S",
"A",
"N",
"K",
"0",
"C",
"J",
"P",
"Y",
"H",
"7",
"W",
"V",
"5",
"F",
"L",
"8",
"1",
"I",
"T",
"M",
"3",
"O",
"9",
"E",
"6",
]
def add_text(image, text, pos):
xmin, ymin, xmax, ymax = pos
image = cv2.putText(
image,
text,
(xmin, ymin - 15),
cv2.FONT_HERSHEY_COMPLEX,
0.85,
(0, 0, 255),
2,
cv2.LINE_AA,
)
return image
def greedy_decode(preds):
# collapse best path (using itertools.groupby), map to chars, join char list to string
best_chars_collapsed = [k for k, _ in groupby(preds) if k != "BLANK"]
res = "".join(best_chars_collapsed)
return res
def read_image(file):
img = cv2.imread(file)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def idx2char(preds):
return [VOCAB[idx] for idx in preds]
def post_process(preds):
# preds shape (seq_len, num_class)
_, preds = torch.max(preds, dim=1)
return idx2char(preds.tolist())
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Grayscale(),
transforms.Resize((32, 128)),
transforms.Normalize(0.5, 0.5),
]
)
model = CRNN(32, 1, 37, 512)
state = torch.load("./out/ocr_point08.pt")
model.load_state_dict(state["model"])
def recognize(image):
model.eval()
preds = model(transform(image).unsqueeze(0))
text = post_process(preds[:, 0, :])
text = greedy_decode(text)
return text
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--image",
default=None,
type=str,
help="path to image on which prediction will be made",
)
args = parser.parse_args()
assert os.path.exists(args.image), f"given path {args.image} does not exists"
im = read_image(args.image)
text = recognize(im)