|
"""main file that does everything"""
|
|
from utils import interact
|
|
|
|
from option import args, setup, cleanup
|
|
from data import Data
|
|
from model import Model
|
|
from loss import Loss
|
|
from optim import Optimizer
|
|
from train import Trainer
|
|
|
|
def main_worker(rank, args):
|
|
args.rank = rank
|
|
args = setup(args)
|
|
|
|
loaders = Data(args).get_loader()
|
|
model = Model(args)
|
|
model.parallelize()
|
|
optimizer = Optimizer(args, model)
|
|
|
|
criterion = Loss(args, model=model, optimizer=optimizer)
|
|
|
|
trainer = Trainer(args, model, criterion, optimizer, loaders)
|
|
|
|
if args.stay:
|
|
interact(local=locals())
|
|
exit()
|
|
|
|
if args.demo:
|
|
trainer.evaluate(epoch=args.start_epoch, mode='demo')
|
|
exit()
|
|
|
|
for epoch in range(1, args.start_epoch):
|
|
if args.do_validate:
|
|
if epoch % args.validate_every == 0:
|
|
trainer.fill_evaluation(epoch, 'val')
|
|
if args.do_test:
|
|
if epoch % args.test_every == 0:
|
|
trainer.fill_evaluation(epoch, 'test')
|
|
|
|
for epoch in range(args.start_epoch, args.end_epoch+1):
|
|
if args.do_train:
|
|
trainer.train(epoch)
|
|
|
|
if args.do_validate:
|
|
if epoch % args.validate_every == 0:
|
|
if trainer.epoch != epoch:
|
|
trainer.load(epoch)
|
|
trainer.validate(epoch)
|
|
|
|
if args.do_test:
|
|
if epoch % args.test_every == 0:
|
|
if trainer.epoch != epoch:
|
|
trainer.load(epoch)
|
|
trainer.test(epoch)
|
|
|
|
if args.rank == 0 or not args.launched:
|
|
print('')
|
|
|
|
trainer.imsaver.join_background()
|
|
|
|
cleanup(args)
|
|
|
|
def main():
|
|
main_worker(args.rank, args)
|
|
|
|
if __name__ == "__main__":
|
|
main() |