File size: 1,813 Bytes
e98653e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
"""main file that does everything"""
from utils import interact
from option import args, setup, cleanup
from data import Data
from model import Model
from loss import Loss
from optim import Optimizer
from train import Trainer
def main_worker(rank, args):
args.rank = rank
args = setup(args)
loaders = Data(args).get_loader()
model = Model(args)
model.parallelize()
optimizer = Optimizer(args, model)
criterion = Loss(args, model=model, optimizer=optimizer)
trainer = Trainer(args, model, criterion, optimizer, loaders)
if args.stay:
interact(local=locals())
exit()
if args.demo:
trainer.evaluate(epoch=args.start_epoch, mode='demo')
exit()
for epoch in range(1, args.start_epoch):
if args.do_validate:
if epoch % args.validate_every == 0:
trainer.fill_evaluation(epoch, 'val')
if args.do_test:
if epoch % args.test_every == 0:
trainer.fill_evaluation(epoch, 'test')
for epoch in range(args.start_epoch, args.end_epoch+1):
if args.do_train:
trainer.train(epoch)
if args.do_validate:
if epoch % args.validate_every == 0:
if trainer.epoch != epoch:
trainer.load(epoch)
trainer.validate(epoch)
if args.do_test:
if epoch % args.test_every == 0:
if trainer.epoch != epoch:
trainer.load(epoch)
trainer.test(epoch)
if args.rank == 0 or not args.launched:
print('')
trainer.imsaver.join_background()
cleanup(args)
def main():
main_worker(args.rank, args)
if __name__ == "__main__":
main() |