Spaces:
Runtime error
Runtime error
File size: 4,726 Bytes
f7fb909 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
from argparse import ArgumentParser
from itertools import groupby
import os
import cv2
import torch
import torch.nn as nn
from torchvision import transforms
import utils_
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output
class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
super(CRNN, self).__init__()
assert imgH % 16 == 0, "imgH has to be a multiple of 16"
ks = [3, 3, 3, 3, 3, 3, 2]
ps = [1, 1, 1, 1, 1, 1, 0]
ss = [1, 1, 1, 1, 1, 1, 1]
nm = [64, 128, 256, 256, 512, 512, 512]
cnn = nn.Sequential()
def convRelu(i, batchNormalization=False):
nIn = nc if i == 0 else nm[i - 1]
nOut = nm[i]
cnn.add_module(
"conv{0}".format(i), nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])
)
if batchNormalization:
cnn.add_module("batchnorm{0}".format(i), nn.BatchNorm2d(nOut))
if leakyRelu:
cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2, inplace=True))
else:
cnn.add_module("relu{0}".format(i), nn.ReLU(True))
convRelu(0)
cnn.add_module("pooling{0}".format(0), nn.MaxPool2d(2, 2)) # 64x16x64
convRelu(1)
cnn.add_module("pooling{0}".format(1), nn.MaxPool2d(2, 2)) # 128x8x32
convRelu(2, True)
convRelu(3)
cnn.add_module(
"pooling{0}".format(2), nn.MaxPool2d((2, 2), (2, 1), (0, 1))
) # 256x4x16
convRelu(4, True)
convRelu(5)
cnn.add_module(
"pooling{0}".format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1))
) # 512x2x16
convRelu(6, True) # 512x1x16
self.cnn = cnn
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass)
)
def forward(self, input):
# conv features
conv = self.cnn(input)
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1) # [w, b, c]
# rnn features
output = self.rnn(conv)
return output
VOCAB = [
"BLANK",
"Z",
"B",
"4",
"X",
"R",
"2",
"U",
"D",
"G",
"Q",
"S",
"A",
"N",
"K",
"0",
"C",
"J",
"P",
"Y",
"H",
"7",
"W",
"V",
"5",
"F",
"L",
"8",
"1",
"I",
"T",
"M",
"3",
"O",
"9",
"E",
"6",
]
def add_text(image, text, pos):
xmin, ymin, xmax, ymax = pos
image = cv2.putText(
image,
text,
(xmin, ymin - 15),
cv2.FONT_HERSHEY_COMPLEX,
0.85,
(0, 0, 255),
2,
cv2.LINE_AA,
)
return image
def greedy_decode(preds):
# collapse best path (using itertools.groupby), map to chars, join char list to string
best_chars_collapsed = [k for k, _ in groupby(preds) if k != "BLANK"]
res = "".join(best_chars_collapsed)
return res
def read_image(file):
img = cv2.imread(file)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def idx2char(preds):
return [VOCAB[idx] for idx in preds]
def post_process(preds):
# preds shape (seq_len, num_class)
_, preds = torch.max(preds, dim=1)
return idx2char(preds.tolist())
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Grayscale(),
transforms.Resize((32, 128)),
transforms.Normalize(0.5, 0.5),
]
)
model = CRNN(32, 1, 37, 512)
state = torch.load("./out/ocr_point08.pt")
model.load_state_dict(state["model"])
def recognize(image):
model.eval()
preds = model(transform(image).unsqueeze(0))
text = post_process(preds[:, 0, :])
text = greedy_decode(text)
return text
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--image",
default=None,
type=str,
help="path to image on which prediction will be made",
)
args = parser.parse_args()
assert os.path.exists(args.image), f"given path {args.image} does not exists"
im = read_image(args.image)
text = recognize(im)
|