Spaces:
Build error
Build error
import argparse | |
import logging | |
import os | |
import random | |
import torch | |
from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase | |
from fastai.distributed import * | |
from fastai.vision import * | |
from torch.backends import cudnn | |
from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy | |
from dataset import ImageDataset, TextDataset | |
from losses import MultiLosses | |
from utils import Config, Logger, MyDataParallel, MyConcatDataset | |
def _set_random_seed(seed): | |
if seed is not None: | |
random.seed(seed) | |
torch.manual_seed(seed) | |
cudnn.deterministic = True | |
logging.warning('You have chosen to seed training. ' | |
'This will slow down your training!') | |
def _get_training_phases(config, n): | |
lr = np.array(config.optimizer_lr) | |
periods = config.optimizer_scheduler_periods | |
sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))] | |
phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i]) | |
for i in range(len(periods))] | |
return phases | |
def _get_dataset(ds_type, paths, is_training, config, **kwargs): | |
kwargs.update({ | |
'img_h': config.dataset_image_height, | |
'img_w': config.dataset_image_width, | |
'max_length': config.dataset_max_length, | |
'case_sensitive': config.dataset_case_sensitive, | |
'charset_path': config.dataset_charset_path, | |
'data_aug': config.dataset_data_aug, | |
'deteriorate_ratio': config.dataset_deteriorate_ratio, | |
'is_training': is_training, | |
'multiscales': config.dataset_multiscales, | |
'one_hot_y': config.dataset_one_hot_y, | |
}) | |
datasets = [ds_type(p, **kwargs) for p in paths] | |
if len(datasets) > 1: return MyConcatDataset(datasets) | |
else: return datasets[0] | |
def _get_language_databaunch(config): | |
kwargs = { | |
'max_length': config.dataset_max_length, | |
'case_sensitive': config.dataset_case_sensitive, | |
'charset_path': config.dataset_charset_path, | |
'smooth_label': config.dataset_smooth_label, | |
'smooth_factor': config.dataset_smooth_factor, | |
'one_hot_y': config.dataset_one_hot_y, | |
'use_sm': config.dataset_use_sm, | |
} | |
train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs) | |
valid_ds = TextDataset(config.dataset_test_roots[0], is_training=False, **kwargs) | |
data = DataBunch.create( | |
path=train_ds.path, | |
train_ds=train_ds, | |
valid_ds=valid_ds, | |
bs=config.dataset_train_batch_size, | |
val_bs=config.dataset_test_batch_size, | |
num_workers=config.dataset_num_workers, | |
pin_memory=config.dataset_pin_memory) | |
logging.info(f'{len(data.train_ds)} training items found.') | |
if not data.empty_val: | |
logging.info(f'{len(data.valid_ds)} valid items found.') | |
return data | |
def _get_databaunch(config): | |
# An awkward way to reduce loadding data time during test | |
if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots | |
train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config) | |
valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config) | |
data = ImageDataBunch.create( | |
train_ds=train_ds, | |
valid_ds=valid_ds, | |
bs=config.dataset_train_batch_size, | |
val_bs=config.dataset_test_batch_size, | |
num_workers=config.dataset_num_workers, | |
pin_memory=config.dataset_pin_memory).normalize(imagenet_stats) | |
ar_tfm = lambda x: ((x[0], x[1]), x[1]) # auto-regression only for dtd | |
data.add_tfm(ar_tfm) | |
logging.info(f'{len(data.train_ds)} training items found.') | |
if not data.empty_val: | |
logging.info(f'{len(data.valid_ds)} valid items found.') | |
return data | |
def _get_model(config): | |
import importlib | |
names = config.model_name.split('.') | |
module_name, class_name = '.'.join(names[:-1]), names[-1] | |
cls = getattr(importlib.import_module(module_name), class_name) | |
model = cls(config) | |
logging.info(model) | |
return model | |
def _get_learner(config, data, model, local_rank=None): | |
strict = ifnone(config.model_strict, True) | |
if config.global_stage == 'pretrain-language': | |
metrics = [TopKTextAccuracy( | |
k=ifnone(config.model_k, 5), | |
charset_path=config.dataset_charset_path, | |
max_length=config.dataset_max_length + 1, | |
case_sensitive=config.dataset_eval_case_sensisitves, | |
model_eval=config.model_eval)] | |
else: | |
metrics = [TextAccuracy( | |
charset_path=config.dataset_charset_path, | |
max_length=config.dataset_max_length + 1, | |
case_sensitive=config.dataset_eval_case_sensisitves, | |
model_eval=config.model_eval)] | |
opt_type = getattr(torch.optim, config.optimizer_type) | |
learner = Learner(data, model, silent=True, model_dir='.', | |
true_wd=config.optimizer_true_wd, | |
wd=config.optimizer_wd, | |
bn_wd=config.optimizer_bn_wd, | |
path=config.global_workdir, | |
metrics=metrics, | |
opt_func=partial(opt_type, **config.optimizer_args or dict()), | |
loss_func=MultiLosses(one_hot=config.dataset_one_hot_y)) | |
learner.split(lambda m: children(m)) | |
if config.global_phase == 'train': | |
num_replicas = 1 if local_rank is None else torch.distributed.get_world_size() | |
phases = _get_training_phases(config, len(learner.data.train_dl)//num_replicas) | |
learner.callback_fns += [ | |
partial(GeneralScheduler, phases=phases), | |
partial(GradientClipping, clip=config.optimizer_clip_grad), | |
partial(IterationCallback, name=config.global_name, | |
show_iters=config.training_show_iters, | |
eval_iters=config.training_eval_iters, | |
save_iters=config.training_save_iters, | |
start_iters=config.training_start_iters, | |
stats_iters=config.training_stats_iters)] | |
else: | |
learner.callbacks += [ | |
DumpPrediction(learn=learner, | |
dataset='-'.join([Path(p).name for p in config.dataset_test_roots]),charset_path=config.dataset_charset_path, | |
model_eval=config.model_eval, | |
debug=config.global_debug, | |
image_only=config.global_image_only)] | |
learner.rank = local_rank | |
if local_rank is not None: | |
logging.info(f'Set model to distributed with rank {local_rank}.') | |
learner.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(learner.model) | |
learner.model.to(local_rank) | |
learner = learner.to_distributed(local_rank) | |
if torch.cuda.device_count() > 1 and local_rank is None: | |
logging.info(f'Use {torch.cuda.device_count()} GPUs.') | |
learner.model = MyDataParallel(learner.model) | |
if config.model_checkpoint: | |
if Path(config.model_checkpoint).exists(): | |
with open(config.model_checkpoint, 'rb') as f: | |
buffer = io.BytesIO(f.read()) | |
learner.load(buffer, strict=strict) | |
else: | |
from distutils.dir_util import copy_tree | |
src = Path('/data/fangsc/model')/config.global_name | |
trg = Path('/output')/config.global_name | |
if src.exists(): copy_tree(str(src), str(trg)) | |
learner.load(config.model_checkpoint, strict=strict) | |
logging.info(f'Read model from {config.model_checkpoint}') | |
elif config.global_phase == 'test': | |
learner.load(f'best-{config.global_name}', strict=strict) | |
logging.info(f'Read model from best-{config.global_name}') | |
if learner.opt_func.func.__name__ == 'Adadelta': # fastai bug, fix after 1.0.60 | |
learner.fit(epochs=0, lr=config.optimizer_lr) | |
learner.opt.mom = 0. | |
return learner | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config', type=str, required=True, | |
help='path to config file') | |
parser.add_argument('--phase', type=str, default=None, choices=['train', 'test']) | |
parser.add_argument('--name', type=str, default=None) | |
parser.add_argument('--checkpoint', type=str, default=None) | |
parser.add_argument('--test_root', type=str, default=None) | |
parser.add_argument("--local_rank", type=int, default=None) | |
parser.add_argument('--debug', action='store_true', default=None) | |
parser.add_argument('--image_only', action='store_true', default=None) | |
parser.add_argument('--model_strict', action='store_false', default=None) | |
parser.add_argument('--model_eval', type=str, default=None, | |
choices=['alignment', 'vision', 'language']) | |
args = parser.parse_args() | |
config = Config(args.config) | |
if args.name is not None: config.global_name = args.name | |
if args.phase is not None: config.global_phase = args.phase | |
if args.test_root is not None: config.dataset_test_roots = [args.test_root] | |
if args.checkpoint is not None: config.model_checkpoint = args.checkpoint | |
if args.debug is not None: config.global_debug = args.debug | |
if args.image_only is not None: config.global_image_only = args.image_only | |
if args.model_eval is not None: config.model_eval = args.model_eval | |
if args.model_strict is not None: config.model_strict = args.model_strict | |
Logger.init(config.global_workdir, config.global_name, config.global_phase) | |
Logger.enable_file() | |
_set_random_seed(config.global_seed) | |
logging.info(config) | |
if args.local_rank is not None: | |
logging.info(f'Init distribution training at device {args.local_rank}.') | |
torch.cuda.set_device(args.local_rank) | |
torch.distributed.init_process_group(backend='nccl', init_method='env://') | |
logging.info('Construct dataset.') | |
if config.global_stage == 'pretrain-language': data = _get_language_databaunch(config) | |
else: data = _get_databaunch(config) | |
logging.info('Construct model.') | |
model = _get_model(config) | |
logging.info('Construct learner.') | |
learner = _get_learner(config, data, model, args.local_rank) | |
if config.global_phase == 'train': | |
logging.info('Start training.') | |
learner.fit(epochs=config.training_epochs, | |
lr=config.optimizer_lr) | |
else: | |
logging.info('Start validate') | |
last_metrics = learner.validate() | |
log_str = f'eval loss = {last_metrics[0]:6.3f}, ' \ | |
f'ccr = {last_metrics[1]:6.3f}, cwr = {last_metrics[2]:6.3f}, ' \ | |
f'ted = {last_metrics[3]:6.3f}, ned = {last_metrics[4]:6.0f}, ' \ | |
f'ted/w = {last_metrics[5]:6.3f}, ' | |
logging.info(log_str) | |
if __name__ == '__main__': | |
main() | |