| r""" Logging during training/testing """ |
| import datetime |
| import logging |
| import os |
|
|
| from tensorboardX import SummaryWriter |
| import torch |
|
|
| class AverageMeter: |
| r""" Stores loss, evaluation results """ |
| def __init__(self, dataset, device='cuda'): |
| self.benchmark = dataset.benchmark |
| if self.benchmark == 'pascal': |
| self.class_ids_interest = dataset.class_ids |
| self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device) |
| self.nclass = 20 |
| elif self.benchmark == 'fss': |
| self.class_ids_interest = dataset.class_ids |
| self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device) |
| self.nclass = 1000 |
| elif self.benchmark == 'deepglobe': |
| self.class_ids_interest = dataset.class_ids |
| self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device) |
| self.nclass = 6 |
| elif self.benchmark == 'isic': |
| self.class_ids_interest = dataset.class_ids |
| self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device) |
| self.nclass = 3 |
| elif self.benchmark == 'lung': |
| self.class_ids_interest = dataset.class_ids |
| self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device) |
| self.nclass = 1 |
| elif self.benchmark == 'suim': |
| self.class_ids_interest = dataset.class_ids |
| self.class_ids_interest = torch.tensor(self.class_ids_interest).to(device) |
| self.nclass = 7 |
| else: |
| raise Exception('Unknown dataset: %s' % dataset) |
|
|
| self.intersection_buf = torch.zeros([2, self.nclass]).float().to(device) |
| self.union_buf = torch.zeros([2, self.nclass]).float().to(device) |
| self.ones = torch.ones_like(self.union_buf) |
| self.loss_buf = [] |
|
|
| def update(self, inter_b, union_b, class_id, loss): |
| self.intersection_buf.index_add_(1, class_id, inter_b.float()) |
| self.union_buf.index_add_(1, class_id, union_b.float()) |
| if loss is None: |
| loss = torch.tensor(0.0) |
| self.loss_buf.append(loss) |
|
|
| def compute_iou(self): |
| iou = self.intersection_buf.float() / \ |
| torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0] |
| iou = iou.index_select(1, self.class_ids_interest) |
| miou = iou[1].mean() * 100 |
|
|
| fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) / |
| self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100 |
|
|
| return miou, fb_iou |
|
|
| def write_result(self, split, epoch): |
| iou,fb_iou = self.compute_iou() |
|
|
| loss_buf = torch.stack(self.loss_buf) |
| msg = '\n*** %s ' % split |
| msg += '[@Epoch %02d] ' % epoch |
| msg += 'Avg L: %6.5f ' % loss_buf.mean() |
| msg += 'mIoU: %5.2f ' % iou |
| msg += 'FB-IoU: %5.2f ' % fb_iou |
|
|
| msg += '***\n' |
| Logger.info(msg) |
|
|
| def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20): |
| if batch_idx % write_batch_idx == 0: |
| msg = '[Epoch: %02d] ' % epoch if epoch != -1 else '' |
| msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) |
| iou,fb_iou = self.compute_iou() |
| if epoch != -1: |
| loss_buf = torch.stack(self.loss_buf) |
| msg += 'L: %6.5f ' % loss_buf[-1] |
| msg += 'Avg L: %6.5f ' % loss_buf.mean() |
| msg += 'mIoU: %5.2f | ' % iou |
| msg += 'FB-IoU: %5.2f' % fb_iou |
| Logger.info(msg) |
|
|
|
|
| class Logger: |
| r""" Writes evaluation results of training/testing """ |
| @classmethod |
| def initialize(cls, args, training): |
| logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S') |
| logpath = args.logpath if training else args.logpath + '_TEST_' + logtime |
| if logpath == '': logpath = logtime |
|
|
| cls.logpath = os.path.join('logs', logpath + '.log') |
| cls.benchmark = args.benchmark |
| print("logdir: ",cls.logpath) |
| os.makedirs(cls.logpath) |
|
|
| logging.basicConfig(filemode='w', |
| filename=os.path.join(cls.logpath, 'log.txt'), |
| level=logging.INFO, |
| format='%(message)s', |
| datefmt='%m-%d %H:%M:%S') |
|
|
| |
| console = logging.StreamHandler() |
| console.setLevel(logging.INFO) |
| formatter = logging.Formatter('%(message)s') |
| console.setFormatter(formatter) |
| logging.getLogger('').addHandler(console) |
|
|
| |
| cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs')) |
|
|
| |
| logging.info('\n:=========== Adapt Before Comparison - A New Perspective on Cross-Domain Few-Shot Segmentation ===========') |
| for arg_key in args.__dict__: |
| logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key]))) |
| logging.info(':================================================\n') |
|
|
| @classmethod |
| def info(cls, msg): |
| r""" Writes log message to log.txt """ |
| logging.info(msg) |
|
|
| @classmethod |
| def save_model_miou(cls, model, epoch, val_miou): |
| torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt')) |
| cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou)) |
|
|
| @classmethod |
| def log_params(cls, model): |
| backbone_param = 0 |
| learner_param = 0 |
| for k in model.state_dict().keys(): |
| n_param = model.state_dict()[k].view(-1).size(0) |
| if k.split('.')[0] in 'backbone': |
| if k.split('.')[1] in ['classifier', 'fc']: |
| continue |
| backbone_param += n_param |
| else: |
| learner_param += n_param |
| Logger.info('Backbone # param.: %d' % backbone_param) |
| Logger.info('Learnable # param.: %d' % learner_param) |
| Logger.info('Total # param.: %d' % (backbone_param + learner_param)) |
|
|