oguzakif's picture
init repo
d4b77ac
from trainer import Trainer
from importlib import import_module
import math
import torch
from torch import optim
from torch.optim import lr_scheduler
import numpy as np
import os
from shutil import copyfile
import glob
from models.utils.flow_losses import smoothness_loss, second_order_loss
from models.utils.fbConsistencyCheck import image_warp
from models.utils.fbConsistencyCheck import ternary_loss2
import torch.nn.functional as F
import cv2
import cvbase
from data.util.flow_utils import region_fill as rf
import imageio
import torch.nn as nn
from skimage.feature import canny
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from models.utils.bce_edge_loss import edgeLoss, EdgeAcc
class Network(Trainer):
def init_model(self):
self.edgeMeasure = EdgeAcc()
model_package = import_module('models.{}'.format(self.opt['model']))
model = model_package.Model(self.opt)
optimizer = optim.Adam(model.parameters(), lr=float(self.opt['train']['lr']),
betas=(float(self.opt['train']['BETA1']), float(float(self.opt['train']['BETA2']))))
if self.rank <= 0:
self.logger.info(
'Optimizer is Adam, BETA1: {}, BETA2: {}'.format(float(self.opt['train']['BETA1']),
float(self.opt['train']['BETA2'])))
step_size = int(math.ceil(self.opt['train']['UPDATE_INTERVAL'] / self.trainSize))
if self.rank <= 0:
self.logger.info('Step size for optimizer is {} epoch'.format(step_size))
scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=self.opt['train']['lr_decay'])
return model, optimizer, scheduler
def resume_training(self):
gen_state = torch.load(self.opt['path']['gen_state'],
map_location=lambda storage, loc: storage.cuda(self.opt['device']))
opt_state = torch.load(self.opt['path']['opt_state'],
map_location=lambda storage, loc: storage.cuda(self.opt['device']))
if self.rank <= 0:
self.logger.info('Resume state is activated')
self.logger.info('Resume training from epoch: {}, iter: {}'.format(
opt_state['epoch'], opt_state['iteration']
))
if self.opt['finetune'] == False:
start_epoch = opt_state['epoch']
current_step = opt_state['iteration']
self.optimizer.load_state_dict(opt_state['optimizer_state_dict'])
self.scheduler.load_state_dict(opt_state['scheduler_state_dict'])
else:
start_epoch = 0
current_step = 0
self.model.load_state_dict(gen_state['model_state_dict'])
if self.rank <= 0:
self.logger.info('Resume training mode, optimizer, scheduler and model have been uploaded')
return start_epoch, current_step
def _trainEpoch(self, epoch):
for idx, train_data in enumerate(self.trainLoader):
self.currentStep += 1
if self.currentStep > self.totalIterations:
if self.rank <= 0:
self.logger.info('Train process has been finished')
break
if self.opt['train']['WARMUP'] is not None and self.currentStep <= self.opt['train']['WARMUP'] // self.opt[
'world_size']:
target_lr = self.opt['train']['lr'] * self.currentStep / (
self.opt['train']['WARMUP'])
self.adjust_learning_rate(self.optimizer, target_lr)
flows = train_data['flows']
diffused_flows = train_data['diffused_flows']
target_edge = train_data['edges']
current_frame = train_data['current_frame']
current_frame = current_frame.to(self.opt['device'])
shift_frame = train_data['shift_frame']
shift_frame = shift_frame.to(self.opt['device'])
masks = train_data['masks']
flows = flows.to(self.opt['device'])
masks = masks.to(self.opt['device'])
diffused_flows = diffused_flows.to(self.opt['device'])
target_edge = target_edge.to(self.opt['device'])
if len(masks.shape) == 5:
b, c, t, h, w = masks.shape
target_flow = flows[:, :, t // 2]
target_mask = masks[:, :, t // 2]
else:
assert len(masks.shape) == 4 and len(flows.shape) == 4
target_flow = flows
target_mask = masks
filled_flow = self.model(diffused_flows, masks)
filled_flow, filled_edge = filled_flow
combined_flow = target_flow * (1 - target_mask) + filled_flow * target_mask
combined_edge = target_edge * (1 - target_mask) + filled_edge * target_mask
edge_loss = (edgeLoss(filled_edge, target_edge) + 5 * edgeLoss(combined_edge, target_edge))
# loss calculations
L1Loss_masked = self.maskedLoss(combined_flow * target_mask,
target_flow * target_mask) / torch.mean(target_mask)
L1Loss_valid = self.validLoss(filled_flow * (1 - target_mask),
target_flow * (1 - target_mask)) / torch.mean(1 - target_mask)
smoothLoss = smoothness_loss(combined_flow, target_mask)
smoothLoss2 = second_order_loss(combined_flow, target_mask)
ternary_loss = self.ternary_loss(combined_flow, target_flow, target_mask, current_frame, shift_frame,
scale_factor=1)
m_losses = (L1Loss_masked + L1Loss_valid) * self.opt['L1M']
sm1_loss = smoothLoss * self.opt['sm']
sm2_loss = smoothLoss2 * self.opt['sm2']
t_loss = self.opt['ternary'] * ternary_loss
e_loss = edge_loss * self.opt['edge_loss']
loss = m_losses + sm1_loss + sm2_loss + t_loss + e_loss
self.optimizer.zero_grad()
loss.backward()
if self.opt['gc']: # gradient clip
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10,
norm_type=2)
self.optimizer.step()
if self.opt['use_tb_logger'] and self.rank <= 0 and self.currentStep % 8 == 0:
print('Mask: {:.03f}, sm: {:.03f}, sm2: {:.03f}, ternary: {:.03f}, edge: {:03f}'.format(
m_losses.item(),
sm1_loss.item(),
sm2_loss.item(),
t_loss.item(),
e_loss.item()
))
self.tb_logger.add_scalar('{}/recon'.format('train'), m_losses.item(),
self.currentStep)
self.tb_logger.add_scalar('{}/sm'.format('train'), sm1_loss.item(), self.currentStep)
self.tb_logger.add_scalar('{}/sm2'.format('train'), sm2_loss.item(),
self.currentStep)
self.tb_logger.add_scalar('{}/ternary'.format('train'),
t_loss.item(),
self.currentStep)
self.tb_logger.add_scalar('{}/edge'.format('train'), e_loss.item(),
self.currentStep)
if self.currentStep % self.opt['logger']['PRINT_FREQ'] == 0 and self.rank <= 0:
compLog = np.array(combined_flow.detach().permute(0, 2, 3, 1).cpu())
flowsLog = np.array(target_flow.detach().permute(0, 2, 3, 1).cpu())
logs = self.calculate_metrics(compLog, flowsLog)
prec, recall = self.edgeMeasure(filled_edge.detach(), target_edge.detach())
logs['prec'] = prec
logs['recall'] = recall
self._printLog(logs, epoch, loss)
def ternary_loss(self, comp, flow, mask, current_frame, shift_frame, scale_factor):
if scale_factor != 1:
current_frame = F.interpolate(current_frame, scale_factor=1 / scale_factor, mode='bilinear')
shift_frame = F.interpolate(shift_frame, scale_factor=1 / scale_factor, mode='bilinear')
warped_sc = image_warp(shift_frame, flow)
noc_mask = torch.exp(-50. * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2)).unsqueeze(1)
warped_comp_sc = image_warp(shift_frame, comp)
loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask)
return loss
def calculate_metrics(self, results, gts):
B, H, W, C = results.shape
psnr_values, ssim_values, L1errors, L2errors = [], [], [], []
for i in range(B):
result, gt = results[i], gts[i]
result_rgb = cvbase.flow2rgb(result)
gt_rgb = cvbase.flow2rgb(gt)
psnr_value = psnr(result_rgb, gt_rgb)
ssim_value = ssim(result_rgb, gt_rgb, multichannel=True)
residual = result - gt
L1error = np.mean(np.abs(residual))
L2error = np.sum(residual ** 2) ** 0.5 / (H * W * C)
psnr_values.append(psnr_value)
ssim_values.append(ssim_value)
L1errors.append(L1error)
L2errors.append(L2error)
psnr_value = np.mean(psnr_values)
ssim_value = np.mean(ssim_values)
L1_value = np.mean(L1errors)
L2_value = np.mean(L2errors)
return {'l1': L1_value, 'l2': L2_value, 'psnr': psnr_value, 'ssim': ssim_value}
def _printLog(self, logs, epoch, loss):
if self.countDown % self.opt['record_iter'] == 0:
self.total_psnr = 0
self.total_ssim = 0
self.total_l1 = 0
self.total_l2 = 0
self.total_loss = 0
self.total_prec = 0
self.total_recall = 0
self.countDown = 0
self.countDown += 1
message = '[epoch:{:3d}, iter:{:7d}, lr:('.format(epoch, self.currentStep)
for v in self.get_lr():
message += '{:.3e}, '.format(v)
message += ')] '
self.total_psnr += logs['psnr']
self.total_ssim += logs['ssim']
self.total_l1 += logs['l1']
self.total_l2 += logs['l2']
self.total_prec += logs['prec'].item()
self.total_recall += logs['recall'].item()
self.total_loss += loss.item()
mean_psnr = self.total_psnr / self.countDown
mean_ssim = self.total_ssim / self.countDown
mean_l1 = self.total_l1 / self.countDown
mean_l2 = self.total_l2 / self.countDown
mean_prec = self.total_prec / self.countDown
mean_recall = self.total_recall / self.countDown
mean_loss = self.total_loss / self.countDown
message += '{:s}: {:.4e} '.format('mean_loss', mean_loss)
message += '{:s}: {:} '.format('mean_psnr', mean_psnr)
message += '{:s}: {:} '.format('mean_ssim', mean_ssim)
message += '{:s}: {:} '.format('mean_l1', mean_l1)
message += '{:s}: {:} '.format('mean_l2', mean_l2)
message += '{:s}: {:} '.format('mean_prec', mean_prec)
message += '{:s}: {:} '.format('mean_recall', mean_recall)
if self.opt['use_tb_logger']:
self.tb_logger.add_scalar('train/mean_psnr', mean_psnr, self.currentStep)
self.tb_logger.add_scalar('train/mean_ssim', mean_ssim, self.currentStep)
self.tb_logger.add_scalar('train/mean_l1', mean_l1, self.currentStep)
self.tb_logger.add_scalar('train/mean_l2', mean_l2, self.currentStep)
self.tb_logger.add_scalar('train/mean_loss', mean_loss, self.currentStep)
self.tb_logger.add_scalar('train/mean_prec', mean_prec, self.currentStep)
self.tb_logger.add_scalar('train/mean_recall', mean_recall, self.currentStep)
self.logger.info(message)
if self.currentStep % self.opt['logger']['SAVE_CHECKPOINT_FREQ'] == 0:
self.save_checkpoint(epoch, 'l1', logs['l1'])
def save_checkpoint(self, epoch, metric, number):
if isinstance(self.model, torch.nn.DataParallel) or isinstance(self.model,
torch.nn.parallel.DistributedDataParallel):
model_state = self.model.module.state_dict()
else:
model_state = self.model.state_dict()
gen_state = {
'model_state_dict': model_state
}
opt_state = {
'epoch': epoch,
'iteration': self.currentStep,
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
}
gen_name = os.path.join(self.opt['path']['TRAINING_STATE'],
'gen_{}_{}.pth.tar'.format(epoch, self.currentStep))
opt_name = os.path.join(self.opt['path']['TRAINING_STATE'],
'opt_{}_{}.pth.tar'.format(epoch, self.currentStep))
torch.save(gen_state, gen_name)
torch.save(opt_state, opt_name)
def _validate(self, epoch):
data_path = self.valInfo['data_root']
mask_path = self.valInfo['mask_root']
self.model.eval()
test_list = os.listdir(data_path)
test_list = test_list[:10] # only inference 10 videos
width, height = self.valInfo['flow_width'], self.valInfo['flow_height']
flow_interval = self.opt['flow_interval'] # The sampling interval for flow completion
psnr, ssim, l1, l2, prec, recall = {}, {}, {}, {}, {}, {}
pivot, sequenceLen = 20, self.opt['num_flows']
for i in range(len(test_list)):
videoName = test_list[i]
if self.rank <= 0:
self.logger.info(f'Video {videoName} is being processed')
for direction in ['forward_flo', 'backward_flo']:
flow_dir = os.path.join(data_path, videoName, direction)
mask_dir = os.path.join(mask_path, videoName)
flows = self.read_flows(flow_dir, width, height, pivot, sequenceLen, flow_interval)
masks = self.read_masks(mask_dir, width, height, pivot, sequenceLen, flow_interval)
if flows == [] or masks == []:
if self.rank <= 0:
print('Video {} doesn\'t have enough {} flows'.format(videoName, direction))
continue
if self.rank <= 0:
self.logger.info('Flows have been read')
diffused_flows = self.diffusion_filling(flows, masks)
flows = np.stack(flows, axis=0)
masks = np.stack(masks, axis=0)
diffused_flows = np.stack(diffused_flows, axis=0)
target_flow = flows[self.opt['num_flows'] // 2]
target_edge = self.load_edge(target_flow)
target_edge = target_edge[:, :, np.newaxis]
diffused_flows = torch.from_numpy(np.transpose(diffused_flows, (3, 0, 1, 2))).unsqueeze(
0).float()
masks = torch.from_numpy(np.transpose(masks, (3, 0, 1, 2))).unsqueeze(0).float()
target_flow = torch.from_numpy(np.transpose(target_flow, (2, 0, 1))).unsqueeze(
0).float()
target_edge = torch.from_numpy(np.transpose(target_edge, (2, 0, 1))).unsqueeze(0).float()
diffused_flows = diffused_flows.to(self.opt['device'])
masks = masks.to(self.opt['device'])
target_flow = target_flow.to(self.opt['device'])
target_edge = target_edge.to(self.opt['device'])
target_mask = masks[:, :, sequenceLen // 2]
if diffused_flows.shape[2] == 1 and len(diffused_flows.shape) == 5:
assert masks.shape[2] == 1
diffused_flows = diffused_flows.squeeze(2)
masks = masks.squeeze(2)
with torch.no_grad():
filled_flow = self.model(diffused_flows, masks, None)
filled_flow, filled_edge = filled_flow
if len(diffused_flows.shape) == 5:
target_diffused_flow = diffused_flows[:, :, sequenceLen // 2]
else:
target_diffused_flow = diffused_flows
combined_flow = target_flow * (1 - target_mask) + filled_flow * target_mask
# calculate metrics
psnr_avg, ssim_avg, l1_avg, l2_avg = self.metrics_calc(combined_flow, target_flow)
prec_avg, recall_avg = self.edgeMeasure(filled_edge, target_edge)
psnr[videoName] = psnr_avg
ssim[videoName] = ssim_avg
l1[videoName] = l1_avg
l2[videoName] = l2_avg
prec[videoName] = prec_avg.item()
recall[videoName] = recall_avg.item()
# visualize frames and report the phase performance
if self.rank <= 0:
if self.opt['use_tb_logger']:
self.tb_logger.add_scalar('test/{}/l1'.format(videoName), l1_avg,
self.currentStep)
self.tb_logger.add_scalar('test/{}/l2'.format(videoName), l2_avg, self.currentStep)
self.tb_logger.add_scalar('test/{}/psnr'.format(videoName), psnr_avg, self.currentStep)
self.tb_logger.add_scalar('test/{}/ssim'.format(videoName), ssim_avg, self.currentStep)
self.tb_logger.add_scalar('test/{}/prec'.format(videoName), prec_avg, self.currentStep)
self.tb_logger.add_scalar('test/{}/recall'.format(videoName), recall_avg, self.currentStep)
self.vis_flows(combined_flow, target_flow, target_diffused_flow, videoName,
epoch) # view the difference between diffused flows and the completed flows
mean_psnr = np.mean([psnr[k] for k in psnr.keys()])
mean_ssim = np.mean([ssim[k] for k in ssim.keys()])
mean_l1 = np.mean([l1[k] for k in l1.keys()])
mean_l2 = np.mean([l2[k] for k in l2.keys()])
mean_prec = np.mean([prec[k] for k in prec.keys()])
mean_recall = np.mean([recall[k] for k in recall.keys()])
self.logger.info(
'[epoch:{:3d}, vid:{}/{}], mean_l1: {:.4e}, mean_l2: {:.4e}, mean_psnr: {:}, mean_ssim: {:}, prec: {:}, recall: {:}'.format(
epoch, i, len(test_list), mean_l1, mean_l2, mean_psnr, mean_ssim, mean_prec, mean_recall))
# give the overall performance
if self.rank <= 0:
mean_psnr = np.mean([psnr[k] for k in psnr.keys()])
mean_ssim = np.mean([ssim[k] for k in ssim.keys()])
mean_l1 = np.mean([l1[k] for k in l1.keys()])
mean_l2 = np.mean([l2[k] for k in l2.keys()])
mean_prec = np.mean([prec[k] for k in prec.keys()])
mean_recall = np.mean([recall[k] for k in recall.keys()])
self.logger.info(
'[epoch:{:3d}], mean_l1: {:.4e} mean_l2: {:.4e} mean_psnr: {:} mean_ssim: {:}, prec: {:}, recall: {:}'.format(
epoch, mean_l1, mean_l2, mean_psnr, mean_ssim, mean_prec, mean_recall))
valid_l1 = mean_l1 + 100
self.save_checkpoint(epoch, 'l1', valid_l1)
self.model.train()
def load_edge(self, flow):
flow_rgb = cvbase.flow2rgb(flow)
flow_gray = cv2.cvtColor(flow_rgb, cv2.COLOR_RGB2GRAY)
return canny(flow_gray, sigma=self.opt['datasets']['dataInfo']['edge']['sigma'], mask=None,
low_threshold=self.opt['datasets']['dataInfo']['edge']['low_threshold'],
high_threshold=self.opt['datasets']['dataInfo']['edge']['high_threshold']).astype(
np.float)
def read_flows(self, flow_dir, width, height, pivot, sequenceLen, sample_interval):
flow_paths = glob.glob(os.path.join(flow_dir, '*.flo'))
flows = []
half_seq = sequenceLen // 2
for i in range(-half_seq, half_seq + 1):
index = pivot + sample_interval * i
if index < 0:
index = 0
if index >= len(flow_paths):
index = len(flow_paths) - 1
flow_path = os.path.join(flow_dir, '{:05d}.flo'.format(index))
flow = cvbase.read_flow(flow_path)
pre_height, pre_width = flow.shape[:2]
flow = cv2.resize(flow, (width, height), cv2.INTER_LINEAR)
flow[:, :, 0] = flow[:, :, 0] / pre_width * width
flow[:, :, 1] = flow[:, :, 1] / pre_height * height
flows.append(flow)
return flows
def metrics_calc(self, result, frames):
psnr_avg, ssim_avg, l1_avg, l2_avg = 0, 0, 0, 0
result = np.array(result.permute(0, 2, 3, 1).cpu()) # [b, h, w, c]
gt = np.array(frames.permute(0, 2, 3, 1).cpu()) # [b, h, w, c]
logs = self.calculate_metrics(result, gt)
psnr_avg += logs['psnr']
ssim_avg += logs['ssim']
l1_avg += logs['l1']
l2_avg += logs['l2']
return psnr_avg, ssim_avg, l1_avg, l2_avg
def read_frames(self, frame_dir, width, height, pivot, sequenceLen):
frame_paths = sorted(glob.glob(os.path.join(frame_dir, '*.jpg')))
frames = []
if len(frame_paths) <= 30:
return frames
for i in range(pivot, pivot + sequenceLen):
frame_path = os.path.join(frame_dir, '{:05d}.jpg'.format(i))
frame = imageio.imread(frame_path)
frame = cv2.resize(frame, (width, height), cv2.INTER_LINEAR)
frames.append(frame)
return frames
def load_edges(self, frames, width, height):
edges = []
for i in range(len(frames)):
frame = frames[i]
frame_gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
edge = canny(frame_gray, sigma=self.valInfo['sigma'], mask=None,
low_threshold=self.valInfo['low_threshold'],
high_threshold=self.valInfo['high_threshold']).astype(np.float) # [h, w, 1]
edge_t = self.to_tensor(edge, width, height, mode='nearest')
edges.append(edge_t)
return edges
def to_tensor(self, frame, width, height, mode='bilinear'):
if len(frame.shape) == 2:
frame = frame[:, :, np.newaxis]
frame_t = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2).float() # [b, c, h, w]
if width != 0 and height != 0:
frame_t = F.interpolate(frame_t, size=(height, width), mode=mode)
return frame_t
def to_numpy(self, tensor):
tensor = tensor.cpu()
tensor = tensor[0]
array = np.array(tensor.permute(1, 2, 0))
return array
def read_masks(self, mask_dir, width, height, pivot, sequenceLen, sample_interval):
mask_path = sorted(glob.glob(os.path.join(mask_dir, '*.png')))
masks = []
half_seq = sequenceLen // 2
for i in range(-half_seq, half_seq + 1):
index = pivot + i * sample_interval
if index < 0:
index = 0
if index >= len(mask_path):
index = len(mask_path) - 1
mask = cv2.imread(mask_path[index], 0)
mask = mask / 255.
mask = cv2.resize(mask, (width, height), cv2.INTER_NEAREST)
mask[mask > 0] = 1
if len(mask.shape) == 2:
mask = mask[:, :, np.newaxis]
assert len(mask.shape) == 3, 'Invalid mask shape: {}'.format(mask.shape)
masks.append(mask)
return masks
def diffusion_filling(self, flows, masks):
filled_flows = []
for i in range(len(flows)):
flow, mask = flows[i], masks[i][:, :, 0]
flow_filled = np.zeros(flow.shape)
flow_filled[:, :, 0] = rf.regionfill(flow[:, :, 0], mask)
flow_filled[:, :, 1] = rf.regionfill(flow[:, :, 1], mask)
filled_flows.append(flow_filled)
return filled_flows
def vis_flows(self, result, target_flow, diffused_flow, video_name, epoch):
"""
Vis the filled frames, the GT and the masked frames with the following format
| | | |
| Ours | GT | diffused_flows |
| | | |
Args:
result: contains generated flow tensors with shape [1, 2, h, w]
target_flow: contains GT flow tensors with shape [1, 2, h, w]
diffused_flow: contains diffused flow tensor with shape [1, 2, h, w]
video_name: video name
epoch: epoch
Returns: No returns, but will save the flows for every flow
"""
out_root = self.opt['path']['VAL_IMAGES']
out_dir = os.path.join(out_root, str(epoch), video_name)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
black_column_pixels = 20
result = self.to_numpy(result)
target_flow = self.to_numpy(target_flow)
diffused_flow = self.to_numpy(diffused_flow)
result = cvbase.flow2rgb(result)
target_flow = cvbase.flow2rgb(target_flow)
diffused_flow = cvbase.flow2rgb(diffused_flow)
height, width = result.shape[:2]
canvas = np.zeros((height, width * 3 + black_column_pixels * 2, 3))
canvas[:, 0:width, :] = result
canvas[:, width + black_column_pixels: 2 * width + black_column_pixels, :] = target_flow
canvas[:, 2 * (width + black_column_pixels):, :] = diffused_flow
imageio.imwrite(os.path.join(out_dir, 'result_compare.png'), canvas)