Spaces:
Runtime error
Runtime error
import string | |
import argparse | |
import torch | |
import torch.backends.cudnn as cudnn | |
import torch.utils.data | |
import torch.nn.functional as F | |
from utils import CTCLabelConverter, AttnLabelConverter | |
from dataset import RawDataset, AlignCollate | |
from model import Model | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def demo(opt): | |
""" model configuration """ | |
if 'CTC' in opt.Prediction: | |
converter = CTCLabelConverter(opt.character) | |
else: | |
converter = AttnLabelConverter(opt.character) | |
opt.num_class = len(converter.character) | |
if opt.rgb: | |
opt.input_channel = 3 | |
model = Model(opt) | |
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, | |
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, | |
opt.SequenceModeling, opt.Prediction) | |
model = torch.nn.DataParallel(model).to(device) | |
# load model | |
print('loading pretrained model from %s' % opt.saved_model) | |
model.load_state_dict(torch.load(opt.saved_model, map_location=device)) | |
# prepare data. two demo images from https://github.com/bgshih/crnn#run-demo | |
AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) | |
demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset | |
demo_loader = torch.utils.data.DataLoader( | |
demo_data, batch_size=opt.batch_size, | |
shuffle=False, | |
num_workers=int(opt.workers), | |
collate_fn=AlignCollate_demo, pin_memory=True) | |
# predict | |
model.eval() | |
with torch.no_grad(): | |
for image_tensors, image_path_list in demo_loader: | |
batch_size = image_tensors.size(0) | |
image = image_tensors.to(device) | |
# For max length prediction | |
length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) | |
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) | |
if 'CTC' in opt.Prediction: | |
preds = model(image, text_for_pred) | |
# Select max probabilty (greedy decoding) then decode index to character | |
preds_size = torch.IntTensor([preds.size(1)] * batch_size) | |
_, preds_index = preds.max(2) | |
# preds_index = preds_index.view(-1) | |
preds_str = converter.decode(preds_index, preds_size) | |
else: | |
preds = model(image, text_for_pred, is_train=False) | |
# select max probabilty (greedy decoding) then decode index to character | |
_, preds_index = preds.max(2) | |
preds_str = converter.decode(preds_index, length_for_pred) | |
log = open(f'./log_demo_result.txt', 'a') | |
dashed_line = '-' * 80 | |
head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' | |
print(f'{dashed_line}\n{head}\n{dashed_line}') | |
log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') | |
preds_prob = F.softmax(preds, dim=2) | |
preds_max_prob, _ = preds_prob.max(dim=2) | |
for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): | |
if 'Attn' in opt.Prediction: | |
pred_EOS = pred.find('[s]') | |
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) | |
pred_max_prob = pred_max_prob[:pred_EOS] | |
# calculate confidence score (= multiply of pred_max_prob) | |
confidence_score = pred_max_prob.cumprod(dim=0)[-1] | |
print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') | |
log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') | |
log.close() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--image_folder', required=True, help='path to image_folder which contains text images') | |
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) | |
parser.add_argument('--batch_size', type=int, default=192, help='input batch size') | |
parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation") | |
""" Data processing """ | |
parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') | |
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') | |
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') | |
parser.add_argument('--rgb', action='store_true', help='use rgb input') | |
parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') | |
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') | |
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') | |
""" Model Architecture """ | |
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') | |
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') | |
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') | |
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') | |
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') | |
parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') | |
parser.add_argument('--output_channel', type=int, default=512, | |
help='the number of output channel of Feature extractor') | |
parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') | |
opt = parser.parse_args() | |
""" vocab / character number configuration """ | |
if opt.sensitive: | |
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). | |
cudnn.benchmark = True | |
cudnn.deterministic = True | |
opt.num_gpu = torch.cuda.device_count() | |
demo(opt) | |