import os |
import sys |
import time |
import random |
import string |
import argparse |
from tqdm import tqdm |
import torch |
import torch.backends.cudnn as cudnn |
import torch.nn.init as init |
import torch.optim as optim |
import torch.utils.data |
import numpy as np |
from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager |
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset |
from model import Model |
from test import validation |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
def train(opt): |
""" dataset preparation """ |
if not opt.data_filtering_off: |
print('Filtering the images containing characters which are not in opt.character') |
print('Filtering the images whose label is longer than opt.batch_max_length') |
opt.select_data = opt.select_data.split('-') |
opt.batch_ratio = opt.batch_ratio.split('-') |
train_dataset = Batch_Balanced_Dataset(opt) |
log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') |
AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.val_imgW, keep_ratio_with_pad=opt.PAD) |
valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) |
valid_loader = torch.utils.data.DataLoader( |
valid_dataset, batch_size=opt.val_batch_size, |
shuffle=True, |
num_workers=int(opt.workers), |
collate_fn=AlignCollate_valid, pin_memory=True) |
log.write(valid_dataset_log) |
print('-' * 80) |
log.write('-' * 80 + '\n') |
log.close() |
""" model configuration """ |
if 'CTC' in opt.Prediction: |
if opt.baiduCTC: |
converter = CTCLabelConverterForBaiduWarpctc(opt.character) |
else: |
converter = CTCLabelConverter(opt.character) |
else: |
converter = AttnLabelConverter(opt.character) |
opt.num_class = len(converter.character) |
if opt.rgb: |
opt.input_channel = 3 |
model = Model(opt) |
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, |
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, |
opt.SequenceModeling, opt.Prediction) |
for name, param in model.named_parameters(): |
if 'localization_fc2' in name: |
print(f'Skip {name} as it is already initialized') |
continue |
try: |
if 'bias' in name: |
init.constant_(param, 0.0) |
elif 'weight' in name: |
init.kaiming_normal_(param) |
except Exception as e: |
if 'weight' in name: |
param.data.fill_(1) |
continue |
model = torch.nn.DataParallel(model).to(device) |
model.train() |
if opt.saved_model != '': |
print(f'loading pretrained model from {opt.saved_model}') |
if opt.FT: |
model.load_state_dict(torch.load(opt.saved_model), strict=False) |
else: |
model.load_state_dict(torch.load(opt.saved_model)) |
print("Model:") |
print(model) |
""" setup loss """ |
if 'CTC' in opt.Prediction: |
if opt.baiduCTC: |
from warpctc_pytorch import CTCLoss |
criterion = CTCLoss() |
else: |
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) |
else: |
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) |
loss_avg = Averager() |
filtered_parameters = [] |
params_num = [] |
for p in filter(lambda p: p.requires_grad, model.parameters()): |
filtered_parameters.append(p) |
params_num.append(np.prod(p.size())) |
print('Trainable params num : ', sum(params_num)) |
if opt.adam: |
optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) |
else: |
optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) |
print("Optimizer:") |
print(optimizer) |
""" final options """ |
with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: |
opt_log = '------------ Options -------------\n' |
args = vars(opt) |
for k, v in args.items(): |
opt_log += f'{str(k)}: {str(v)}\n' |
opt_log += '---------------------------------------\n' |
print(opt_log) |
opt_file.write(opt_log) |
""" start training """ |
start_iter = 0 |
if opt.saved_model != '': |
try: |
start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) |
print(f'continue to train, start_iter: {start_iter}') |
except: |
pass |
start_time = time.time() |
best_accuracy = -1 |
best_norm_ED = -1 |
iteration = start_iter |
bar = tqdm(total=opt.valInterval) |
while(True): |
image_tensors, labels = train_dataset.get_batch() |
image = image_tensors.to(device) |
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) |
batch_size = image.size(0) |
if 'CTC' in opt.Prediction: |
preds = model(image, text) |
preds_size = torch.IntTensor([preds.size(1)] * batch_size) |
if opt.baiduCTC: |
preds = preds.permute(1, 0, 2) |
cost = criterion(preds, text, preds_size, length) / batch_size |
else: |
preds = preds.log_softmax(2).permute(1, 0, 2) |
cost = criterion(preds, text, preds_size, length) |
else: |
preds = model(image, text[:, :-1]) |
target = text[:, 1:] |
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) |
model.zero_grad() |
cost.backward() |
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) |
optimizer.step() |
loss_avg.add(cost) |
bar.update(1) |
if (iteration + 1) % opt.valInterval == 0: |
bar.refresh() |
elapsed_time = time.time() - start_time |
with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: |
model.eval() |
with torch.no_grad(): |
valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( |
model, criterion, valid_loader, converter, opt) |
model.train() |
loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' |
loss_avg.reset() |
current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' |
if current_accuracy > best_accuracy: |
best_accuracy = current_accuracy |
torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') |
if current_norm_ED > best_norm_ED: |
best_norm_ED = current_norm_ED |
torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') |
best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' |
loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' |
print(loss_model_log) |
log.write(loss_model_log + '\n') |
dashed_line = '-' * 80 |
head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' |
predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' |
for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): |
if 'Attn' in opt.Prediction: |
gt = gt[:gt.find('[s]')] |
pred = pred[:pred.find('[s]')] |
predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' |
predicted_result_log += f'{dashed_line}' |
print(predicted_result_log) |
log.write(predicted_result_log + '\n') |
bar.reset() |
if (iteration + 1) % 1e+5 == 0: |
torch.save( |
model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') |
if (iteration + 1) == opt.num_iter: |
print('end the training') |
sys.exit() |
iteration += 1 |
if __name__ == '__main__': |
parser = argparse.ArgumentParser() |
parser.add_argument('--exp_name', help='Where to store logs and models') |
parser.add_argument('--db_type', choices=['lmdb', 'xmlmdb', 'raw'], help='type of database') |
parser.add_argument('--train_data', required=True, help='path to training dataset') |
parser.add_argument('--valid_data', required=True, help='path to validation dataset') |
parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') |
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) |
parser.add_argument('--batch_size', type=int, default=192, help='input batch size') |
parser.add_argument('--val_batch_size', type=int, default=192, help='input batch size') |
parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for') |
parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation') |
parser.add_argument('--saved_model', default='', help="path to model to continue training") |
parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') |
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') |
parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') |
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') |
parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') |
parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') |
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') |
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') |
""" Data processing """ |
parser.add_argument('--select_data', type=str, default='MJ-ST', |
help='select training data (default is MJ-ST, which means MJ and ST used as training data)') |
parser.add_argument('--batch_ratio', type=str, default='0.5-0.5', |
help='assign ratio for each selected data in the batch') |
parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', |
help='total data usage ratio, this ratio is multiplied to total number of data.') |
parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') |
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') |
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') |
parser.add_argument('--val_imgW', type=int, default=100, help='the width of the input image') |
parser.add_argument('--rgb', action='store_true', help='use rgb input') |
parser.add_argument('--character', type=str, |
default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') |
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') |
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') |
parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') |
""" Model Architecture """ |
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') |
parser.add_argument('--FeatureExtraction', type=str, required=True, |
help='FeatureExtraction stage. VGG|RCNN|ResNet') |
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') |
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') |
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') |
parser.add_argument('--input_channel', type=int, default=1, |
help='the number of input channel of Feature extractor') |
parser.add_argument('--output_channel', type=int, default=512, |
help='the number of output channel of Feature extractor') |
parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') |
opt = parser.parse_args() |
if not opt.exp_name: |
opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' |
opt.exp_name += f'-Seed{opt.manualSeed}' |
os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True) |
""" vocab / character number configuration """ |
if opt.sensitive: |
opt.character = string.printable[:-6] |
""" Seed and GPU setting """ |
random.seed(opt.manualSeed) |
np.random.seed(opt.manualSeed) |
torch.manual_seed(opt.manualSeed) |
torch.cuda.manual_seed(opt.manualSeed) |
cudnn.benchmark = True |
cudnn.deterministic = True |
opt.num_gpu = torch.cuda.device_count() |
if opt.num_gpu > 1: |
print('------ Use multi-GPU setting ------') |
print('if you stuck too long time with multi-GPU setting, try to set --workers 0') |
opt.workers = opt.workers * opt.num_gpu |
opt.batch_size = opt.batch_size * opt.num_gpu |
opt.val_batch_size = opt.batch_size * opt.num_gpu |
""" previous version |
print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) |
opt.batch_size = opt.batch_size * opt.num_gpu |
print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') |
If you dont care about it, just commnet out these line.) |
opt.num_iter = int(opt.num_iter / opt.num_gpu) |
""" |
train(opt) |