Vincentqyw
fix: roma
c74a070
raw history blame
No virus
15.7 kB
import os
import cv2
import time
import yaml
import torch
import datetime
from tensorboardX import SummaryWriter
import torchvision.transforms as tvf
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from nets.geom import getK, getWarp, _grid_positions, getWarpNoValidate
from nets.loss import make_detector_loss
from nets.score import extract_kpts
from nets.sampler import NghSampler2
from nets.reliability_loss import ReliabilityLoss
from datasets.noise_simulator import NoiseSimulator
from nets.l2net import Quad_L2Net
class SingleTrainer:
def __init__(self, config, device, loader, job_name, start_cnt):
self.config = config
self.device = device
self.loader = loader
# tensorboard writer construction
os.makedirs("./runs/", exist_ok=True)
if job_name != "":
self.log_dir = f"runs/{job_name}"
else:
self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}'
self.writer = SummaryWriter(self.log_dir)
with open(f"{self.log_dir}/config.yaml", "w") as f:
yaml.dump(config, f)
if (
config["network"]["input_type"] == "gray"
or config["network"]["input_type"] == "raw-gray"
):
self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device)
elif (
config["network"]["input_type"] == "rgb"
or config["network"]["input_type"] == "raw-demosaic"
):
self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device)
elif config["network"]["input_type"] == "raw":
self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device)
else:
raise NotImplementedError()
# noise maker
self.noise_maker = NoiseSimulator(device)
# load model
self.cnt = 0
if start_cnt != 0:
self.model.load_state_dict(
torch.load(f"{self.log_dir}/model_{start_cnt:06d}.pth")
)
self.cnt = start_cnt + 1
# sampler
sampler = NghSampler2(
ngh=7,
subq=-8,
subd=1,
pos_d=3,
neg_d=5,
border=16,
subd_neg=-8,
maxpool_pos=True,
).to(device)
self.reliability_loss = ReliabilityLoss(sampler, base=0.3, nq=20).to(device)
# reliability map conv
self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda()
# optimizer and scheduler
if self.config["training"]["optimizer"] == "SGD":
self.optimizer = torch.optim.SGD(
[
{
"params": self.model.parameters(),
"initial_lr": self.config["training"]["lr"],
}
],
lr=self.config["training"]["lr"],
momentum=self.config["training"]["momentum"],
weight_decay=self.config["training"]["weight_decay"],
)
elif self.config["training"]["optimizer"] == "Adam":
self.optimizer = torch.optim.Adam(
[
{
"params": self.model.parameters(),
"initial_lr": self.config["training"]["lr"],
}
],
lr=self.config["training"]["lr"],
weight_decay=self.config["training"]["weight_decay"],
)
else:
raise NotImplementedError()
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer,
step_size=self.config["training"]["lr_step"],
gamma=self.config["training"]["lr_gamma"],
last_epoch=start_cnt,
)
for param_tensor in self.model.state_dict():
print(param_tensor, "\t", self.model.state_dict()[param_tensor].size())
def save(self, iter_num):
torch.save(self.model.state_dict(), f"{self.log_dir}/model_{iter_num:06d}.pth")
def load(self, path):
self.model.load_state_dict(torch.load(path))
def train(self):
self.model.train()
for epoch in range(2):
for batch_idx, inputs in enumerate(self.loader):
self.optimizer.zero_grad()
t = time.time()
# preprocess and add noise
img0_ori, noise_img0_ori = self.preprocess_noise_pair(
inputs["img0"], self.cnt
)
img1_ori, noise_img1_ori = self.preprocess_noise_pair(
inputs["img1"], self.cnt
)
img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device)
img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device)
if self.config["network"]["input_type"] == "rgb":
# 3-channel rgb
RGB_mean = [0.485, 0.456, 0.406]
RGB_std = [0.229, 0.224, 0.225]
norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
img0 = norm_RGB(img0)
img1 = norm_RGB(img1)
noise_img0 = norm_RGB(noise_img0)
noise_img1 = norm_RGB(noise_img1)
elif self.config["network"]["input_type"] == "gray":
# 1-channel
img0 = torch.mean(img0, dim=1, keepdim=True)
img1 = torch.mean(img1, dim=1, keepdim=True)
noise_img0 = torch.mean(noise_img0, dim=1, keepdim=True)
noise_img1 = torch.mean(noise_img1, dim=1, keepdim=True)
norm_gray0 = tvf.Normalize(mean=img0.mean(), std=img0.std())
norm_gray1 = tvf.Normalize(mean=img1.mean(), std=img1.std())
img0 = norm_gray0(img0)
img1 = norm_gray1(img1)
noise_img0 = norm_gray0(noise_img0)
noise_img1 = norm_gray1(noise_img1)
elif self.config["network"]["input_type"] == "raw":
# 4-channel
pass
elif self.config["network"]["input_type"] == "raw-demosaic":
# 3-channel
pass
else:
raise NotImplementedError()
desc0, score_map0, _, _ = self.model(img0)
desc1, score_map1, _, _ = self.model(img1)
cur_feat_size0 = torch.tensor(score_map0.shape[2:])
cur_feat_size1 = torch.tensor(score_map1.shape[2:])
conf0 = F.softmax(self.model.clf(torch.abs(desc0) ** 2.0), dim=1)[
:, 1:2
]
conf1 = F.softmax(self.model.clf(torch.abs(desc1) ** 2.0), dim=1)[
:, 1:2
]
desc0 = desc0.permute(0, 2, 3, 1)
desc1 = desc1.permute(0, 2, 3, 1)
score_map0 = score_map0.permute(0, 2, 3, 1)
score_map1 = score_map1.permute(0, 2, 3, 1)
conf0 = conf0.permute(0, 2, 3, 1)
conf1 = conf1.permute(0, 2, 3, 1)
r_K0 = getK(inputs["ori_img_size0"], cur_feat_size0, inputs["K0"]).to(
self.device
)
r_K1 = getK(inputs["ori_img_size1"], cur_feat_size1, inputs["K1"]).to(
self.device
)
pos0 = _grid_positions(
cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]
).to(self.device)
pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate(
pos0,
inputs["rel_pose"].to(self.device),
inputs["depth0"].to(self.device),
r_K0,
inputs["depth1"].to(self.device),
r_K1,
img0.shape[0],
)
pos0, pos1, _ = getWarp(
pos0,
inputs["rel_pose"].to(self.device),
inputs["depth0"].to(self.device),
r_K0,
inputs["depth1"].to(self.device),
r_K1,
img0.shape[0],
)
reliab_loss = self.reliability_loss(
desc0,
desc1,
conf0,
conf1,
pos0_for_rel,
pos1_for_rel,
img0.shape[0],
img0.shape[2],
img0.shape[3],
)
det_structured_loss, det_accuracy = make_detector_loss(
pos0,
pos1,
desc0,
desc1,
score_map0,
score_map1,
img0.shape[0],
self.config["network"]["use_corr_n"],
self.config["network"]["loss_type"],
self.config,
)
total_loss = det_structured_loss
self.writer.add_scalar(
"loss/det_loss_normal", det_structured_loss, self.cnt
)
total_loss += reliab_loss
self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt)
self.writer.add_scalar("loss/total_loss", total_loss, self.cnt)
self.writer.add_scalar("loss/reliab_loss", reliab_loss, self.cnt)
print(
"iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter".format(
self.cnt, total_loss, det_accuracy, time.time() - t
)
)
if det_structured_loss != 0:
total_loss.backward()
self.optimizer.step()
self.lr_scheduler.step()
if self.cnt % 100 == 0:
indices0, scores0 = extract_kpts(
score_map0.permute(0, 3, 1, 2),
k=self.config["network"]["det"]["kpt_n"],
score_thld=self.config["network"]["det"]["score_thld"],
nms_size=self.config["network"]["det"]["nms_size"],
eof_size=self.config["network"]["det"]["eof_size"],
edge_thld=self.config["network"]["det"]["edge_thld"],
)
indices1, scores1 = extract_kpts(
score_map1.permute(0, 3, 1, 2),
k=self.config["network"]["det"]["kpt_n"],
score_thld=self.config["network"]["det"]["score_thld"],
nms_size=self.config["network"]["det"]["nms_size"],
eof_size=self.config["network"]["det"]["eof_size"],
edge_thld=self.config["network"]["det"]["edge_thld"],
)
if self.config["network"]["input_type"] == "raw":
kpt_img0 = self.showKeyPoints(
img0_ori[0][..., :3] * 255.0, indices0[0]
)
kpt_img1 = self.showKeyPoints(
img1_ori[0][..., :3] * 255.0, indices1[0]
)
else:
kpt_img0 = self.showKeyPoints(img0_ori[0] * 255.0, indices0[0])
kpt_img1 = self.showKeyPoints(img1_ori[0] * 255.0, indices1[0])
self.writer.add_image(
"img0/kpts", kpt_img0, self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img1/kpts", kpt_img1, self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img0/score_map", score_map0[0], self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img1/score_map", score_map1[0], self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img0/conf", conf0[0], self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img1/conf", conf1[0], self.cnt, dataformats="HWC"
)
if self.cnt % 10000 == 0:
self.save(self.cnt)
self.cnt += 1
def showKeyPoints(self, img, indices):
key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1])
img = img.numpy().astype("uint8")
img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0))
return img
def preprocess(self, img, iter_idx):
if (
not self.config["network"]["noise"]
and "raw" not in self.config["network"]["input_type"]
):
return img
raw = self.noise_maker.rgb2raw(img, batched=True)
if self.config["network"]["noise"]:
ratio_dec = (
min(self.config["network"]["noise_maxstep"], iter_idx)
/ self.config["network"]["noise_maxstep"]
)
raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
if self.config["network"]["input_type"] == "raw":
return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True))
if self.config["network"]["input_type"] == "raw-demosaic":
return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True))
rgb = self.noise_maker.raw2rgb(raw, batched=True)
if (
self.config["network"]["input_type"] == "rgb"
or self.config["network"]["input_type"] == "gray"
):
return torch.tensor(rgb)
raise NotImplementedError()
def preprocess_noise_pair(self, img, iter_idx):
assert self.config["network"]["noise"]
raw = self.noise_maker.rgb2raw(img, batched=True)
ratio_dec = (
min(self.config["network"]["noise_maxstep"], iter_idx)
/ self.config["network"]["noise_maxstep"]
)
noise_raw = self.noise_maker.raw2noisyRaw(
raw, ratio_dec=ratio_dec, batched=True
)
if self.config["network"]["input_type"] == "raw":
return torch.tensor(
self.noise_maker.raw2packedRaw(raw, batched=True)
), torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
if self.config["network"]["input_type"] == "raw-demosaic":
return torch.tensor(
self.noise_maker.raw2demosaicRaw(raw, batched=True)
), torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
if self.config["network"]["input_type"] == "raw-gray":
factor = torch.tensor([0.299, 0.587, 0.114]).double()
return torch.matmul(
torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)),
factor,
).unsqueeze(-1), torch.matmul(
torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)),
factor,
).unsqueeze(
-1
)
noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True)
if (
self.config["network"]["input_type"] == "rgb"
or self.config["network"]["input_type"] == "gray"
):
return img, torch.tensor(noise_rgb)
raise NotImplementedError()