Spaces:
Runtime error
Runtime error
from argparse import ArgumentParser | |
from itertools import groupby | |
import os | |
import cv2 | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
import utils_ | |
class BidirectionalLSTM(nn.Module): | |
def __init__(self, nIn, nHidden, nOut): | |
super(BidirectionalLSTM, self).__init__() | |
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) | |
self.embedding = nn.Linear(nHidden * 2, nOut) | |
def forward(self, input): | |
recurrent, _ = self.rnn(input) | |
T, b, h = recurrent.size() | |
t_rec = recurrent.view(T * b, h) | |
output = self.embedding(t_rec) # [T * b, nOut] | |
output = output.view(T, b, -1) | |
return output | |
class CRNN(nn.Module): | |
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False): | |
super(CRNN, self).__init__() | |
assert imgH % 16 == 0, "imgH has to be a multiple of 16" | |
ks = [3, 3, 3, 3, 3, 3, 2] | |
ps = [1, 1, 1, 1, 1, 1, 0] | |
ss = [1, 1, 1, 1, 1, 1, 1] | |
nm = [64, 128, 256, 256, 512, 512, 512] | |
cnn = nn.Sequential() | |
def convRelu(i, batchNormalization=False): | |
nIn = nc if i == 0 else nm[i - 1] | |
nOut = nm[i] | |
cnn.add_module( | |
"conv{0}".format(i), nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]) | |
) | |
if batchNormalization: | |
cnn.add_module("batchnorm{0}".format(i), nn.BatchNorm2d(nOut)) | |
if leakyRelu: | |
cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2, inplace=True)) | |
else: | |
cnn.add_module("relu{0}".format(i), nn.ReLU(True)) | |
convRelu(0) | |
cnn.add_module("pooling{0}".format(0), nn.MaxPool2d(2, 2)) # 64x16x64 | |
convRelu(1) | |
cnn.add_module("pooling{0}".format(1), nn.MaxPool2d(2, 2)) # 128x8x32 | |
convRelu(2, True) | |
convRelu(3) | |
cnn.add_module( | |
"pooling{0}".format(2), nn.MaxPool2d((2, 2), (2, 1), (0, 1)) | |
) # 256x4x16 | |
convRelu(4, True) | |
convRelu(5) | |
cnn.add_module( | |
"pooling{0}".format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1)) | |
) # 512x2x16 | |
convRelu(6, True) # 512x1x16 | |
self.cnn = cnn | |
self.rnn = nn.Sequential( | |
BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass) | |
) | |
def forward(self, input): | |
# conv features | |
conv = self.cnn(input) | |
b, c, h, w = conv.size() | |
assert h == 1, "the height of conv must be 1" | |
conv = conv.squeeze(2) | |
conv = conv.permute(2, 0, 1) # [w, b, c] | |
# rnn features | |
output = self.rnn(conv) | |
return output | |
VOCAB = [ | |
"BLANK", | |
"Z", | |
"B", | |
"4", | |
"X", | |
"R", | |
"2", | |
"U", | |
"D", | |
"G", | |
"Q", | |
"S", | |
"A", | |
"N", | |
"K", | |
"0", | |
"C", | |
"J", | |
"P", | |
"Y", | |
"H", | |
"7", | |
"W", | |
"V", | |
"5", | |
"F", | |
"L", | |
"8", | |
"1", | |
"I", | |
"T", | |
"M", | |
"3", | |
"O", | |
"9", | |
"E", | |
"6", | |
] | |
def add_text(image, text, pos): | |
xmin, ymin, xmax, ymax = pos | |
image = cv2.putText( | |
image, | |
text, | |
(xmin, ymin - 15), | |
cv2.FONT_HERSHEY_COMPLEX, | |
0.85, | |
(0, 0, 255), | |
2, | |
cv2.LINE_AA, | |
) | |
return image | |
def greedy_decode(preds): | |
# collapse best path (using itertools.groupby), map to chars, join char list to string | |
best_chars_collapsed = [k for k, _ in groupby(preds) if k != "BLANK"] | |
res = "".join(best_chars_collapsed) | |
return res | |
def read_image(file): | |
img = cv2.imread(file) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
return img | |
def idx2char(preds): | |
return [VOCAB[idx] for idx in preds] | |
def post_process(preds): | |
# preds shape (seq_len, num_class) | |
_, preds = torch.max(preds, dim=1) | |
return idx2char(preds.tolist()) | |
transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Grayscale(), | |
transforms.Resize((32, 128)), | |
transforms.Normalize(0.5, 0.5), | |
] | |
) | |
model = CRNN(32, 1, 37, 512) | |
state = torch.load("./out/ocr_point08.pt") | |
model.load_state_dict(state["model"]) | |
def recognize(image): | |
model.eval() | |
preds = model(transform(image).unsqueeze(0)) | |
text = post_process(preds[:, 0, :]) | |
text = greedy_decode(text) | |
return text | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--image", | |
default=None, | |
type=str, | |
help="path to image on which prediction will be made", | |
) | |
args = parser.parse_args() | |
assert os.path.exists(args.image), f"given path {args.image} does not exists" | |
im = read_image(args.image) | |
text = recognize(im) | |