|
import os |
|
import sys |
|
import time |
|
import random |
|
import string |
|
import argparse |
|
from tqdm import tqdm |
|
|
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
import torch.nn.init as init |
|
import torch.optim as optim |
|
import torch.utils.data |
|
import numpy as np |
|
|
|
from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager |
|
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset |
|
from model import Model |
|
from test import validation |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
def train(opt): |
|
""" dataset preparation """ |
|
if not opt.data_filtering_off: |
|
print('Filtering the images containing characters which are not in opt.character') |
|
print('Filtering the images whose label is longer than opt.batch_max_length') |
|
|
|
|
|
opt.select_data = opt.select_data.split('-') |
|
opt.batch_ratio = opt.batch_ratio.split('-') |
|
train_dataset = Batch_Balanced_Dataset(opt) |
|
|
|
log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') |
|
AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.val_imgW, keep_ratio_with_pad=opt.PAD) |
|
valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) |
|
valid_loader = torch.utils.data.DataLoader( |
|
valid_dataset, batch_size=opt.val_batch_size, |
|
shuffle=True, |
|
num_workers=int(opt.workers), |
|
collate_fn=AlignCollate_valid, pin_memory=True) |
|
log.write(valid_dataset_log) |
|
print('-' * 80) |
|
log.write('-' * 80 + '\n') |
|
log.close() |
|
|
|
""" model configuration """ |
|
if 'CTC' in opt.Prediction: |
|
if opt.baiduCTC: |
|
converter = CTCLabelConverterForBaiduWarpctc(opt.character) |
|
else: |
|
converter = CTCLabelConverter(opt.character) |
|
else: |
|
converter = AttnLabelConverter(opt.character) |
|
opt.num_class = len(converter.character) |
|
|
|
if opt.rgb: |
|
opt.input_channel = 3 |
|
model = Model(opt) |
|
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, |
|
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, |
|
opt.SequenceModeling, opt.Prediction) |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
if 'localization_fc2' in name: |
|
print(f'Skip {name} as it is already initialized') |
|
continue |
|
try: |
|
if 'bias' in name: |
|
init.constant_(param, 0.0) |
|
elif 'weight' in name: |
|
init.kaiming_normal_(param) |
|
except Exception as e: |
|
if 'weight' in name: |
|
param.data.fill_(1) |
|
continue |
|
|
|
|
|
model = torch.nn.DataParallel(model).to(device) |
|
model.train() |
|
if opt.saved_model != '': |
|
print(f'loading pretrained model from {opt.saved_model}') |
|
if opt.FT: |
|
model.load_state_dict(torch.load(opt.saved_model), strict=False) |
|
else: |
|
model.load_state_dict(torch.load(opt.saved_model)) |
|
print("Model:") |
|
print(model) |
|
|
|
""" setup loss """ |
|
if 'CTC' in opt.Prediction: |
|
if opt.baiduCTC: |
|
|
|
from warpctc_pytorch import CTCLoss |
|
criterion = CTCLoss() |
|
else: |
|
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) |
|
else: |
|
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) |
|
|
|
loss_avg = Averager() |
|
|
|
|
|
filtered_parameters = [] |
|
params_num = [] |
|
for p in filter(lambda p: p.requires_grad, model.parameters()): |
|
filtered_parameters.append(p) |
|
params_num.append(np.prod(p.size())) |
|
print('Trainable params num : ', sum(params_num)) |
|
|
|
|
|
|
|
if opt.adam: |
|
optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) |
|
else: |
|
optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) |
|
print("Optimizer:") |
|
print(optimizer) |
|
|
|
""" final options """ |
|
|
|
with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: |
|
opt_log = '------------ Options -------------\n' |
|
args = vars(opt) |
|
for k, v in args.items(): |
|
opt_log += f'{str(k)}: {str(v)}\n' |
|
opt_log += '---------------------------------------\n' |
|
print(opt_log) |
|
opt_file.write(opt_log) |
|
|
|
""" start training """ |
|
start_iter = 0 |
|
if opt.saved_model != '': |
|
try: |
|
start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) |
|
print(f'continue to train, start_iter: {start_iter}') |
|
except: |
|
pass |
|
|
|
start_time = time.time() |
|
best_accuracy = -1 |
|
best_norm_ED = -1 |
|
iteration = start_iter |
|
bar = tqdm(total=opt.valInterval) |
|
|
|
while(True): |
|
|
|
image_tensors, labels = train_dataset.get_batch() |
|
image = image_tensors.to(device) |
|
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) |
|
batch_size = image.size(0) |
|
|
|
if 'CTC' in opt.Prediction: |
|
preds = model(image, text) |
|
preds_size = torch.IntTensor([preds.size(1)] * batch_size) |
|
if opt.baiduCTC: |
|
preds = preds.permute(1, 0, 2) |
|
cost = criterion(preds, text, preds_size, length) / batch_size |
|
else: |
|
preds = preds.log_softmax(2).permute(1, 0, 2) |
|
cost = criterion(preds, text, preds_size, length) |
|
|
|
else: |
|
preds = model(image, text[:, :-1]) |
|
target = text[:, 1:] |
|
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) |
|
|
|
model.zero_grad() |
|
cost.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) |
|
optimizer.step() |
|
|
|
loss_avg.add(cost) |
|
|
|
|
|
|
|
bar.update(1) |
|
if (iteration + 1) % opt.valInterval == 0: |
|
bar.refresh() |
|
elapsed_time = time.time() - start_time |
|
|
|
with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: |
|
model.eval() |
|
with torch.no_grad(): |
|
valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( |
|
model, criterion, valid_loader, converter, opt) |
|
model.train() |
|
|
|
|
|
loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' |
|
loss_avg.reset() |
|
|
|
current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' |
|
|
|
|
|
if current_accuracy > best_accuracy: |
|
best_accuracy = current_accuracy |
|
torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') |
|
if current_norm_ED > best_norm_ED: |
|
best_norm_ED = current_norm_ED |
|
torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') |
|
best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' |
|
|
|
loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' |
|
print(loss_model_log) |
|
log.write(loss_model_log + '\n') |
|
|
|
|
|
dashed_line = '-' * 80 |
|
head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' |
|
predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' |
|
for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): |
|
if 'Attn' in opt.Prediction: |
|
gt = gt[:gt.find('[s]')] |
|
pred = pred[:pred.find('[s]')] |
|
|
|
predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' |
|
predicted_result_log += f'{dashed_line}' |
|
print(predicted_result_log) |
|
log.write(predicted_result_log + '\n') |
|
bar.reset() |
|
|
|
|
|
if (iteration + 1) % 1e+5 == 0: |
|
torch.save( |
|
model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') |
|
|
|
if (iteration + 1) == opt.num_iter: |
|
print('end the training') |
|
sys.exit() |
|
iteration += 1 |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--exp_name', help='Where to store logs and models') |
|
parser.add_argument('--db_type', choices=['lmdb', 'xmlmdb', 'raw'], help='type of database') |
|
parser.add_argument('--train_data', required=True, help='path to training dataset') |
|
parser.add_argument('--valid_data', required=True, help='path to validation dataset') |
|
parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') |
|
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) |
|
parser.add_argument('--batch_size', type=int, default=192, help='input batch size') |
|
parser.add_argument('--val_batch_size', type=int, default=192, help='input batch size') |
|
parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for') |
|
parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation') |
|
parser.add_argument('--saved_model', default='', help="path to model to continue training") |
|
parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') |
|
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') |
|
parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') |
|
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') |
|
parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') |
|
parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') |
|
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') |
|
parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') |
|
""" Data processing """ |
|
parser.add_argument('--select_data', type=str, default='MJ-ST', |
|
help='select training data (default is MJ-ST, which means MJ and ST used as training data)') |
|
parser.add_argument('--batch_ratio', type=str, default='0.5-0.5', |
|
help='assign ratio for each selected data in the batch') |
|
parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', |
|
help='total data usage ratio, this ratio is multiplied to total number of data.') |
|
parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') |
|
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') |
|
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') |
|
parser.add_argument('--val_imgW', type=int, default=100, help='the width of the input image') |
|
parser.add_argument('--rgb', action='store_true', help='use rgb input') |
|
parser.add_argument('--character', type=str, |
|
default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') |
|
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') |
|
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') |
|
parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') |
|
""" Model Architecture """ |
|
parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') |
|
parser.add_argument('--FeatureExtraction', type=str, required=True, |
|
help='FeatureExtraction stage. VGG|RCNN|ResNet') |
|
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') |
|
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') |
|
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') |
|
parser.add_argument('--input_channel', type=int, default=1, |
|
help='the number of input channel of Feature extractor') |
|
parser.add_argument('--output_channel', type=int, default=512, |
|
help='the number of output channel of Feature extractor') |
|
parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') |
|
|
|
opt = parser.parse_args() |
|
|
|
if not opt.exp_name: |
|
opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' |
|
opt.exp_name += f'-Seed{opt.manualSeed}' |
|
|
|
|
|
os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True) |
|
|
|
""" vocab / character number configuration """ |
|
if opt.sensitive: |
|
|
|
opt.character = string.printable[:-6] |
|
|
|
""" Seed and GPU setting """ |
|
|
|
random.seed(opt.manualSeed) |
|
np.random.seed(opt.manualSeed) |
|
torch.manual_seed(opt.manualSeed) |
|
torch.cuda.manual_seed(opt.manualSeed) |
|
|
|
cudnn.benchmark = True |
|
cudnn.deterministic = True |
|
opt.num_gpu = torch.cuda.device_count() |
|
|
|
if opt.num_gpu > 1: |
|
print('------ Use multi-GPU setting ------') |
|
print('if you stuck too long time with multi-GPU setting, try to set --workers 0') |
|
|
|
opt.workers = opt.workers * opt.num_gpu |
|
opt.batch_size = opt.batch_size * opt.num_gpu |
|
opt.val_batch_size = opt.batch_size * opt.num_gpu |
|
|
|
""" previous version |
|
print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) |
|
opt.batch_size = opt.batch_size * opt.num_gpu |
|
print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') |
|
If you dont care about it, just commnet out these line.) |
|
opt.num_iter = int(opt.num_iter / opt.num_gpu) |
|
""" |
|
|
|
train(opt) |
|
|