import os import cv2 import time import yaml import torch import datetime from tensorboardX import SummaryWriter import torchvision.transforms as tvf import torch.nn as nn import torch.nn.functional as F import numpy as np from nets.geom import getK, getWarp, _grid_positions, getWarpNoValidate from nets.loss import make_detector_loss from nets.score import extract_kpts from nets.sampler import NghSampler2 from nets.reliability_loss import ReliabilityLoss from datasets.noise_simulator import NoiseSimulator from nets.l2net import Quad_L2Net class SingleTrainer: def __init__(self, config, device, loader, job_name, start_cnt): self.config = config self.device = device self.loader = loader # tensorboard writer construction os.makedirs("./runs/", exist_ok=True) if job_name != "": self.log_dir = f"runs/{job_name}" else: self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}' self.writer = SummaryWriter(self.log_dir) with open(f"{self.log_dir}/config.yaml", "w") as f: yaml.dump(config, f) if ( config["network"]["input_type"] == "gray" or config["network"]["input_type"] == "raw-gray" ): self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device) elif ( config["network"]["input_type"] == "rgb" or config["network"]["input_type"] == "raw-demosaic" ): self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device) elif config["network"]["input_type"] == "raw": self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device) else: raise NotImplementedError() # noise maker self.noise_maker = NoiseSimulator(device) # load model self.cnt = 0 if start_cnt != 0: self.model.load_state_dict( torch.load(f"{self.log_dir}/model_{start_cnt:06d}.pth") ) self.cnt = start_cnt + 1 # sampler sampler = NghSampler2( ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16, subd_neg=-8, maxpool_pos=True, ).to(device) self.reliability_loss = ReliabilityLoss(sampler, base=0.3, nq=20).to(device) # reliability map conv self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda() # optimizer and scheduler if self.config["training"]["optimizer"] == "SGD": self.optimizer = torch.optim.SGD( [ { "params": self.model.parameters(), "initial_lr": self.config["training"]["lr"], } ], lr=self.config["training"]["lr"], momentum=self.config["training"]["momentum"], weight_decay=self.config["training"]["weight_decay"], ) elif self.config["training"]["optimizer"] == "Adam": self.optimizer = torch.optim.Adam( [ { "params": self.model.parameters(), "initial_lr": self.config["training"]["lr"], } ], lr=self.config["training"]["lr"], weight_decay=self.config["training"]["weight_decay"], ) else: raise NotImplementedError() self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=self.config["training"]["lr_step"], gamma=self.config["training"]["lr_gamma"], last_epoch=start_cnt, ) for param_tensor in self.model.state_dict(): print(param_tensor, "\t", self.model.state_dict()[param_tensor].size()) def save(self, iter_num): torch.save(self.model.state_dict(), f"{self.log_dir}/model_{iter_num:06d}.pth") def load(self, path): self.model.load_state_dict(torch.load(path)) def train(self): self.model.train() for epoch in range(2): for batch_idx, inputs in enumerate(self.loader): self.optimizer.zero_grad() t = time.time() # preprocess and add noise img0_ori, noise_img0_ori = self.preprocess_noise_pair( inputs["img0"], self.cnt ) img1_ori, noise_img1_ori = self.preprocess_noise_pair( inputs["img1"], self.cnt ) img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device) img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device) if self.config["network"]["input_type"] == "rgb": # 3-channel rgb RGB_mean = [0.485, 0.456, 0.406] RGB_std = [0.229, 0.224, 0.225] norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std) img0 = norm_RGB(img0) img1 = norm_RGB(img1) noise_img0 = norm_RGB(noise_img0) noise_img1 = norm_RGB(noise_img1) elif self.config["network"]["input_type"] == "gray": # 1-channel img0 = torch.mean(img0, dim=1, keepdim=True) img1 = torch.mean(img1, dim=1, keepdim=True) noise_img0 = torch.mean(noise_img0, dim=1, keepdim=True) noise_img1 = torch.mean(noise_img1, dim=1, keepdim=True) norm_gray0 = tvf.Normalize(mean=img0.mean(), std=img0.std()) norm_gray1 = tvf.Normalize(mean=img1.mean(), std=img1.std()) img0 = norm_gray0(img0) img1 = norm_gray1(img1) noise_img0 = norm_gray0(noise_img0) noise_img1 = norm_gray1(noise_img1) elif self.config["network"]["input_type"] == "raw": # 4-channel pass elif self.config["network"]["input_type"] == "raw-demosaic": # 3-channel pass else: raise NotImplementedError() desc0, score_map0, _, _ = self.model(img0) desc1, score_map1, _, _ = self.model(img1) cur_feat_size0 = torch.tensor(score_map0.shape[2:]) cur_feat_size1 = torch.tensor(score_map1.shape[2:]) conf0 = F.softmax(self.model.clf(torch.abs(desc0) ** 2.0), dim=1)[ :, 1:2 ] conf1 = F.softmax(self.model.clf(torch.abs(desc1) ** 2.0), dim=1)[ :, 1:2 ] desc0 = desc0.permute(0, 2, 3, 1) desc1 = desc1.permute(0, 2, 3, 1) score_map0 = score_map0.permute(0, 2, 3, 1) score_map1 = score_map1.permute(0, 2, 3, 1) conf0 = conf0.permute(0, 2, 3, 1) conf1 = conf1.permute(0, 2, 3, 1) r_K0 = getK(inputs["ori_img_size0"], cur_feat_size0, inputs["K0"]).to( self.device ) r_K1 = getK(inputs["ori_img_size1"], cur_feat_size1, inputs["K1"]).to( self.device ) pos0 = _grid_positions( cur_feat_size0[0], cur_feat_size0[1], img0.shape[0] ).to(self.device) pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate( pos0, inputs["rel_pose"].to(self.device), inputs["depth0"].to(self.device), r_K0, inputs["depth1"].to(self.device), r_K1, img0.shape[0], ) pos0, pos1, _ = getWarp( pos0, inputs["rel_pose"].to(self.device), inputs["depth0"].to(self.device), r_K0, inputs["depth1"].to(self.device), r_K1, img0.shape[0], ) reliab_loss = self.reliability_loss( desc0, desc1, conf0, conf1, pos0_for_rel, pos1_for_rel, img0.shape[0], img0.shape[2], img0.shape[3], ) det_structured_loss, det_accuracy = make_detector_loss( pos0, pos1, desc0, desc1, score_map0, score_map1, img0.shape[0], self.config["network"]["use_corr_n"], self.config["network"]["loss_type"], self.config, ) total_loss = det_structured_loss self.writer.add_scalar( "loss/det_loss_normal", det_structured_loss, self.cnt ) total_loss += reliab_loss self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt) self.writer.add_scalar("loss/total_loss", total_loss, self.cnt) self.writer.add_scalar("loss/reliab_loss", reliab_loss, self.cnt) print( "iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter".format( self.cnt, total_loss, det_accuracy, time.time() - t ) ) if det_structured_loss != 0: total_loss.backward() self.optimizer.step() self.lr_scheduler.step() if self.cnt % 100 == 0: indices0, scores0 = extract_kpts( score_map0.permute(0, 3, 1, 2), k=self.config["network"]["det"]["kpt_n"], score_thld=self.config["network"]["det"]["score_thld"], nms_size=self.config["network"]["det"]["nms_size"], eof_size=self.config["network"]["det"]["eof_size"], edge_thld=self.config["network"]["det"]["edge_thld"], ) indices1, scores1 = extract_kpts( score_map1.permute(0, 3, 1, 2), k=self.config["network"]["det"]["kpt_n"], score_thld=self.config["network"]["det"]["score_thld"], nms_size=self.config["network"]["det"]["nms_size"], eof_size=self.config["network"]["det"]["eof_size"], edge_thld=self.config["network"]["det"]["edge_thld"], ) if self.config["network"]["input_type"] == "raw": kpt_img0 = self.showKeyPoints( img0_ori[0][..., :3] * 255.0, indices0[0] ) kpt_img1 = self.showKeyPoints( img1_ori[0][..., :3] * 255.0, indices1[0] ) else: kpt_img0 = self.showKeyPoints(img0_ori[0] * 255.0, indices0[0]) kpt_img1 = self.showKeyPoints(img1_ori[0] * 255.0, indices1[0]) self.writer.add_image( "img0/kpts", kpt_img0, self.cnt, dataformats="HWC" ) self.writer.add_image( "img1/kpts", kpt_img1, self.cnt, dataformats="HWC" ) self.writer.add_image( "img0/score_map", score_map0[0], self.cnt, dataformats="HWC" ) self.writer.add_image( "img1/score_map", score_map1[0], self.cnt, dataformats="HWC" ) self.writer.add_image( "img0/conf", conf0[0], self.cnt, dataformats="HWC" ) self.writer.add_image( "img1/conf", conf1[0], self.cnt, dataformats="HWC" ) if self.cnt % 10000 == 0: self.save(self.cnt) self.cnt += 1 def showKeyPoints(self, img, indices): key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1]) img = img.numpy().astype("uint8") img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0)) return img def preprocess(self, img, iter_idx): if ( not self.config["network"]["noise"] and "raw" not in self.config["network"]["input_type"] ): return img raw = self.noise_maker.rgb2raw(img, batched=True) if self.config["network"]["noise"]: ratio_dec = ( min(self.config["network"]["noise_maxstep"], iter_idx) / self.config["network"]["noise_maxstep"] ) raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) if self.config["network"]["input_type"] == "raw": return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)) if self.config["network"]["input_type"] == "raw-demosaic": return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)) rgb = self.noise_maker.raw2rgb(raw, batched=True) if ( self.config["network"]["input_type"] == "rgb" or self.config["network"]["input_type"] == "gray" ): return torch.tensor(rgb) raise NotImplementedError() def preprocess_noise_pair(self, img, iter_idx): assert self.config["network"]["noise"] raw = self.noise_maker.rgb2raw(img, batched=True) ratio_dec = ( min(self.config["network"]["noise_maxstep"], iter_idx) / self.config["network"]["noise_maxstep"] ) noise_raw = self.noise_maker.raw2noisyRaw( raw, ratio_dec=ratio_dec, batched=True ) if self.config["network"]["input_type"] == "raw": return torch.tensor( self.noise_maker.raw2packedRaw(raw, batched=True) ), torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True)) if self.config["network"]["input_type"] == "raw-demosaic": return torch.tensor( self.noise_maker.raw2demosaicRaw(raw, batched=True) ), torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)) if self.config["network"]["input_type"] == "raw-gray": factor = torch.tensor([0.299, 0.587, 0.114]).double() return torch.matmul( torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), factor, ).unsqueeze(-1), torch.matmul( torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)), factor, ).unsqueeze( -1 ) noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True) if ( self.config["network"]["input_type"] == "rgb" or self.config["network"]["input_type"] == "gray" ): return img, torch.tensor(noise_rgb) raise NotImplementedError()