|
import torch |
|
|
|
from utils.flow_viz import flow_tensor_to_image |
|
|
|
|
|
class Logger: |
|
def __init__(self, lr_scheduler, |
|
summary_writer, |
|
summary_freq=100, |
|
start_step=0, |
|
): |
|
self.lr_scheduler = lr_scheduler |
|
self.total_steps = start_step |
|
self.running_loss = {} |
|
self.summary_writer = summary_writer |
|
self.summary_freq = summary_freq |
|
|
|
def print_training_status(self, mode='train'): |
|
|
|
print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq)) |
|
|
|
for k in self.running_loss: |
|
self.summary_writer.add_scalar(mode + '/' + k, |
|
self.running_loss[k] / self.summary_freq, self.total_steps) |
|
self.running_loss[k] = 0.0 |
|
|
|
def lr_summary(self): |
|
lr = self.lr_scheduler.get_last_lr()[0] |
|
self.summary_writer.add_scalar('lr', lr, self.total_steps) |
|
|
|
def add_image_summary(self, img1, img2, flow_preds, flow_gt, mode='train', |
|
): |
|
if self.total_steps % self.summary_freq == 0: |
|
img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1) |
|
img_concat = img_concat.type(torch.uint8) |
|
|
|
flow_pred = flow_tensor_to_image(flow_preds[-1][0]) |
|
forward_flow_gt = flow_tensor_to_image(flow_gt[0]) |
|
flow_concat = torch.cat((torch.from_numpy(flow_pred), |
|
torch.from_numpy(forward_flow_gt)), dim=-1) |
|
|
|
concat = torch.cat((img_concat, flow_concat), dim=-2) |
|
|
|
self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps) |
|
|
|
def push(self, metrics, mode='train'): |
|
self.total_steps += 1 |
|
|
|
self.lr_summary() |
|
|
|
for key in metrics: |
|
if key not in self.running_loss: |
|
self.running_loss[key] = 0.0 |
|
|
|
self.running_loss[key] += metrics[key] |
|
|
|
if self.total_steps % self.summary_freq == 0: |
|
self.print_training_status(mode) |
|
self.running_loss = {} |
|
|
|
def write_dict(self, results): |
|
for key in results: |
|
tag = key.split('_')[0] |
|
tag = tag + '/' + key |
|
self.summary_writer.add_scalar(tag, results[key], self.total_steps) |
|
|
|
def close(self): |
|
self.summary_writer.close() |
|
|