ABINet-OCR / main.py
tomofi's picture
Add application file
cb433d6
raw
history blame
10.7 kB
import argparse
import logging
import os
import random
import torch
from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase
from fastai.distributed import *
from fastai.vision import *
from torch.backends import cudnn
from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy
from dataset import ImageDataset, TextDataset
from losses import MultiLosses
from utils import Config, Logger, MyDataParallel, MyConcatDataset
def _set_random_seed(seed):
if seed is not None:
random.seed(seed)
torch.manual_seed(seed)
cudnn.deterministic = True
logging.warning('You have chosen to seed training. '
'This will slow down your training!')
def _get_training_phases(config, n):
lr = np.array(config.optimizer_lr)
periods = config.optimizer_scheduler_periods
sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))]
phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i])
for i in range(len(periods))]
return phases
def _get_dataset(ds_type, paths, is_training, config, **kwargs):
kwargs.update({
'img_h': config.dataset_image_height,
'img_w': config.dataset_image_width,
'max_length': config.dataset_max_length,
'case_sensitive': config.dataset_case_sensitive,
'charset_path': config.dataset_charset_path,
'data_aug': config.dataset_data_aug,
'deteriorate_ratio': config.dataset_deteriorate_ratio,
'is_training': is_training,
'multiscales': config.dataset_multiscales,
'one_hot_y': config.dataset_one_hot_y,
})
datasets = [ds_type(p, **kwargs) for p in paths]
if len(datasets) > 1: return MyConcatDataset(datasets)
else: return datasets[0]
def _get_language_databaunch(config):
kwargs = {
'max_length': config.dataset_max_length,
'case_sensitive': config.dataset_case_sensitive,
'charset_path': config.dataset_charset_path,
'smooth_label': config.dataset_smooth_label,
'smooth_factor': config.dataset_smooth_factor,
'one_hot_y': config.dataset_one_hot_y,
'use_sm': config.dataset_use_sm,
}
train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs)
valid_ds = TextDataset(config.dataset_test_roots[0], is_training=False, **kwargs)
data = DataBunch.create(
path=train_ds.path,
train_ds=train_ds,
valid_ds=valid_ds,
bs=config.dataset_train_batch_size,
val_bs=config.dataset_test_batch_size,
num_workers=config.dataset_num_workers,
pin_memory=config.dataset_pin_memory)
logging.info(f'{len(data.train_ds)} training items found.')
if not data.empty_val:
logging.info(f'{len(data.valid_ds)} valid items found.')
return data
def _get_databaunch(config):
# An awkward way to reduce loadding data time during test
if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots
train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config)
valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config)
data = ImageDataBunch.create(
train_ds=train_ds,
valid_ds=valid_ds,
bs=config.dataset_train_batch_size,
val_bs=config.dataset_test_batch_size,
num_workers=config.dataset_num_workers,
pin_memory=config.dataset_pin_memory).normalize(imagenet_stats)
ar_tfm = lambda x: ((x[0], x[1]), x[1]) # auto-regression only for dtd
data.add_tfm(ar_tfm)
logging.info(f'{len(data.train_ds)} training items found.')
if not data.empty_val:
logging.info(f'{len(data.valid_ds)} valid items found.')
return data
def _get_model(config):
import importlib
names = config.model_name.split('.')
module_name, class_name = '.'.join(names[:-1]), names[-1]
cls = getattr(importlib.import_module(module_name), class_name)
model = cls(config)
logging.info(model)
return model
def _get_learner(config, data, model, local_rank=None):
strict = ifnone(config.model_strict, True)
if config.global_stage == 'pretrain-language':
metrics = [TopKTextAccuracy(
k=ifnone(config.model_k, 5),
charset_path=config.dataset_charset_path,
max_length=config.dataset_max_length + 1,
case_sensitive=config.dataset_eval_case_sensisitves,
model_eval=config.model_eval)]
else:
metrics = [TextAccuracy(
charset_path=config.dataset_charset_path,
max_length=config.dataset_max_length + 1,
case_sensitive=config.dataset_eval_case_sensisitves,
model_eval=config.model_eval)]
opt_type = getattr(torch.optim, config.optimizer_type)
learner = Learner(data, model, silent=True, model_dir='.',
true_wd=config.optimizer_true_wd,
wd=config.optimizer_wd,
bn_wd=config.optimizer_bn_wd,
path=config.global_workdir,
metrics=metrics,
opt_func=partial(opt_type, **config.optimizer_args or dict()),
loss_func=MultiLosses(one_hot=config.dataset_one_hot_y))
learner.split(lambda m: children(m))
if config.global_phase == 'train':
num_replicas = 1 if local_rank is None else torch.distributed.get_world_size()
phases = _get_training_phases(config, len(learner.data.train_dl)//num_replicas)
learner.callback_fns += [
partial(GeneralScheduler, phases=phases),
partial(GradientClipping, clip=config.optimizer_clip_grad),
partial(IterationCallback, name=config.global_name,
show_iters=config.training_show_iters,
eval_iters=config.training_eval_iters,
save_iters=config.training_save_iters,
start_iters=config.training_start_iters,
stats_iters=config.training_stats_iters)]
else:
learner.callbacks += [
DumpPrediction(learn=learner,
dataset='-'.join([Path(p).name for p in config.dataset_test_roots]),charset_path=config.dataset_charset_path,
model_eval=config.model_eval,
debug=config.global_debug,
image_only=config.global_image_only)]
learner.rank = local_rank
if local_rank is not None:
logging.info(f'Set model to distributed with rank {local_rank}.')
learner.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(learner.model)
learner.model.to(local_rank)
learner = learner.to_distributed(local_rank)
if torch.cuda.device_count() > 1 and local_rank is None:
logging.info(f'Use {torch.cuda.device_count()} GPUs.')
learner.model = MyDataParallel(learner.model)
if config.model_checkpoint:
if Path(config.model_checkpoint).exists():
with open(config.model_checkpoint, 'rb') as f:
buffer = io.BytesIO(f.read())
learner.load(buffer, strict=strict)
else:
from distutils.dir_util import copy_tree
src = Path('/data/fangsc/model')/config.global_name
trg = Path('/output')/config.global_name
if src.exists(): copy_tree(str(src), str(trg))
learner.load(config.model_checkpoint, strict=strict)
logging.info(f'Read model from {config.model_checkpoint}')
elif config.global_phase == 'test':
learner.load(f'best-{config.global_name}', strict=strict)
logging.info(f'Read model from best-{config.global_name}')
if learner.opt_func.func.__name__ == 'Adadelta': # fastai bug, fix after 1.0.60
learner.fit(epochs=0, lr=config.optimizer_lr)
learner.opt.mom = 0.
return learner
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True,
help='path to config file')
parser.add_argument('--phase', type=str, default=None, choices=['train', 'test'])
parser.add_argument('--name', type=str, default=None)
parser.add_argument('--checkpoint', type=str, default=None)
parser.add_argument('--test_root', type=str, default=None)
parser.add_argument("--local_rank", type=int, default=None)
parser.add_argument('--debug', action='store_true', default=None)
parser.add_argument('--image_only', action='store_true', default=None)
parser.add_argument('--model_strict', action='store_false', default=None)
parser.add_argument('--model_eval', type=str, default=None,
choices=['alignment', 'vision', 'language'])
args = parser.parse_args()
config = Config(args.config)
if args.name is not None: config.global_name = args.name
if args.phase is not None: config.global_phase = args.phase
if args.test_root is not None: config.dataset_test_roots = [args.test_root]
if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
if args.debug is not None: config.global_debug = args.debug
if args.image_only is not None: config.global_image_only = args.image_only
if args.model_eval is not None: config.model_eval = args.model_eval
if args.model_strict is not None: config.model_strict = args.model_strict
Logger.init(config.global_workdir, config.global_name, config.global_phase)
Logger.enable_file()
_set_random_seed(config.global_seed)
logging.info(config)
if args.local_rank is not None:
logging.info(f'Init distribution training at device {args.local_rank}.')
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
logging.info('Construct dataset.')
if config.global_stage == 'pretrain-language': data = _get_language_databaunch(config)
else: data = _get_databaunch(config)
logging.info('Construct model.')
model = _get_model(config)
logging.info('Construct learner.')
learner = _get_learner(config, data, model, args.local_rank)
if config.global_phase == 'train':
logging.info('Start training.')
learner.fit(epochs=config.training_epochs,
lr=config.optimizer_lr)
else:
logging.info('Start validate')
last_metrics = learner.validate()
log_str = f'eval loss = {last_metrics[0]:6.3f}, ' \
f'ccr = {last_metrics[1]:6.3f}, cwr = {last_metrics[2]:6.3f}, ' \
f'ted = {last_metrics[3]:6.3f}, ned = {last_metrics[4]:6.0f}, ' \
f'ted/w = {last_metrics[5]:6.3f}, '
logging.info(log_str)
if __name__ == '__main__':
main()