MonocularDepth / BTS.py
ohjho
tested BTS model and added to the app
b240372
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
from torch import optim
import os
import math
import cv2
import albumentations as A
# tensorboard is needed for run_train_step() which is commented out here
# from torch.utils.tensorboard import SummaryWriter
activation_fn = nn.ELU()
MAX_DEPTH = 81
DEPTH_OFFSET = 0.1 # This is used for ensuring depth prediction gets into positive range
USE_APEX = False # Enable if you have GPU with Tensor Cores, otherwise doesnt really bring any benefits.
APEX_OPT_LEVEL = "O2"
BATCH_NORM_MOMENTUM = 0.005
ENABLE_BIAS = True
device = torch.device("cpu")
if torch.cuda.is_available() :
device = torch.device("cuda")
print(f'--- BTS will use device: {device}')
if USE_APEX:
import apex
class UpscaleLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super(UpscaleLayer, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=ENABLE_BIAS)
self.bn = nn.BatchNorm2d(out_channels, momentum=BATCH_NORM_MOMENTUM)
def forward(self, input):
input = nn.functional.interpolate(input, scale_factor=2, mode="nearest")
input = activation_fn(self.conv(input))
input = self.bn(input)
return input
class UpscaleBlock(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels):
super(UpscaleBlock, self).__init__()
self.uplayer = UpscaleLayer(in_channels, out_channels)
self.conv = nn.Conv2d(out_channels+skip_channels, out_channels, 3, padding=1, bias=ENABLE_BIAS)
self.bn2 = nn.BatchNorm2d(out_channels, BATCH_NORM_MOMENTUM)
def forward(self, input_j):
input, skip = input_j
input = self.uplayer(input)
cat = torch.cat((input, skip), 1)
input = activation_fn(self.conv(cat))
input = self.bn2(input)
return input, cat
class UpscaleNetwork(nn.Module):
def __init__(self, filters=[512, 256]):
super(UpscaleNetwork, self,).__init__()
self.upscale_block1 = UpscaleBlock(2208, 384, filters[0]) # H16
self.upscale_block2 = UpscaleBlock(filters[0], 192, filters[1]) # H8
def forward(self, raw_input):
input, h2, h4, h8, h16 = raw_input
input, _ = self.upscale_block1((input, h16))
input, cat = self.upscale_block2((input, h8))
return input, cat
class AtrousBlock(nn.Module):
def __init__(self, input_filters, filters, dilation, apply_initial_bn=True):
super(AtrousBlock, self).__init__()
self.initial_bn = nn.BatchNorm2d(input_filters, BATCH_NORM_MOMENTUM)
self.apply_initial_bn = apply_initial_bn
self.conv1 = nn.Conv2d(input_filters, filters*2, 1, 1, 0, bias=False)
self.norm1 = nn.BatchNorm2d(filters*2, BATCH_NORM_MOMENTUM)
self.atrous_conv = nn.Conv2d(filters*2, filters, 3, 1, dilation, dilation, bias=False)
self.norm2 = nn.BatchNorm2d(filters, BATCH_NORM_MOMENTUM)
def forward(self, input):
if self.apply_initial_bn:
input = self.initial_bn(input)
input = self.conv1(input.relu())
input = self.norm1(input)
input = self.atrous_conv(input.relu())
input = self.norm2(input)
return input
class ASSPBlock(nn.Module):
def __init__(self, input_filters=256, cat_filters=448, atrous_filters=128):
super(ASSPBlock, self).__init__()
self.atrous_conv_r3 = AtrousBlock(input_filters, atrous_filters, 3, apply_initial_bn=False)
self.atrous_conv_r6 = AtrousBlock(cat_filters + atrous_filters, atrous_filters, 6)
self.atrous_conv_r12 = AtrousBlock(cat_filters + atrous_filters*2, atrous_filters, 12)
self.atrous_conv_r18 = AtrousBlock(cat_filters + atrous_filters*3, atrous_filters, 18)
self.atrous_conv_r24 = AtrousBlock(cat_filters + atrous_filters*4, atrous_filters, 24)
self.conv = nn.Conv2d(5 * atrous_filters + cat_filters, atrous_filters, 3, 1, 1, bias=ENABLE_BIAS)
def forward(self, input):
input, cat = input
layer1_out = self.atrous_conv_r3(input)
concat1 = torch.cat((cat, layer1_out), 1)
layer2_out = self.atrous_conv_r6(concat1)
concat2 = torch.cat((concat1, layer2_out), 1)
layer3_out = self.atrous_conv_r12(concat2)
concat3 = torch.cat((concat2, layer3_out), 1)
layer4_out = self.atrous_conv_r18(concat3)
concat4 = torch.cat((concat3, layer4_out), 1)
layer5_out = self.atrous_conv_r24(concat4)
concat5 = torch.cat((concat4, layer5_out), 1)
features = activation_fn(self.conv(concat5))
return features
# Code of this layer is taken from official PyTorch implementation
class LPGLayer(nn.Module):
def __init__(self, scale):
super(LPGLayer, self).__init__()
self.scale = scale
self.u = torch.arange(self.scale).reshape([1, 1, self.scale]).float()
self.v = torch.arange(int(self.scale)).reshape([1, self.scale, 1]).float()
def forward(self, plane_eq):
plane_eq_expanded = torch.repeat_interleave(plane_eq, int(self.scale), 2)
plane_eq_expanded = torch.repeat_interleave(plane_eq_expanded, int(self.scale), 3)
n1 = plane_eq_expanded[:, 0, :, :]
n2 = plane_eq_expanded[:, 1, :, :]
n3 = plane_eq_expanded[:, 2, :, :]
n4 = plane_eq_expanded[:, 3, :, :]
u = self.u.repeat(plane_eq.size(0), plane_eq.size(2) * int(self.scale), plane_eq.size(3)).to(device)
u = (u - (self.scale - 1) * 0.5) / self.scale
v = self.v.repeat(plane_eq.size(0), plane_eq.size(2), plane_eq.size(3) * int(self.scale)).to(device)
v = (v - (self.scale - 1) * 0.5) / self.scale
d = n4 / (n1 * u + n2 * v + n3)
d = d.unsqueeze(1)
return d
class Reduction(nn.Module):
def __init__(self, scale, input_filters, is_final=False):
super(Reduction, self).__init__()
reduction_count = int(math.log(input_filters, 2)) - 2
self.reductions = torch.nn.Sequential()
for i in range(reduction_count):
if i != reduction_count-1:
self.reductions.add_module("1x1_reduc_%d_%d" % (scale, i), nn.Sequential(
nn.Conv2d(int(input_filters / math.pow(2, i)), int(input_filters / math.pow(2, i + 1)), 1, 1, 0, bias=ENABLE_BIAS),
activation_fn))
else:
if not is_final:
self.reductions.add_module("1x1_reduc_%d_%d" % (scale, i), nn.Sequential(
nn.Conv2d(int(input_filters / math.pow(2, i)), int(input_filters / math.pow(2, i + 1)), 1, 1, 0, bias=ENABLE_BIAS)))
else:
self.reductions.add_module("1x1_reduc_%d_%d" % (scale, i), nn.Sequential(
nn.Conv2d(int(input_filters / math.pow(2, i)), 1, 1, 1, 0, bias=ENABLE_BIAS), nn.Sigmoid()))
def forward(self, ip):
return self.reductions(ip)
class LPGBlock(nn.Module):
def __init__(self, scale, input_filters=128):
super(LPGBlock, self).__init__()
self.scale = scale
self.reduction = Reduction(scale, input_filters)
self.conv = nn.Conv2d(4, 3, 1, 1, 0)
self.LPGLayer = LPGLayer(scale)
def forward(self, input):
input = self.reduction(input)
plane_parameters = torch.zeros_like(input)
input = self.conv(input)
theta = input[:, 0, :, :].sigmoid() * 3.1415926535 / 6
phi = input[:, 1, :, :].sigmoid() * 3.1415926535 * 2
dist = input[:, 2, :, :].sigmoid() * MAX_DEPTH
plane_parameters[:, 0, :, :] = torch.sin(theta) * torch.cos(phi)
plane_parameters[:, 1, :, :] = torch.sin(theta) * torch.sin(phi)
plane_parameters[:, 2, :, :] = torch.cos(theta)
plane_parameters[:, 3, :, :] = dist
plane_parameters[:, 0:3, :, :] = F.normalize(plane_parameters.clone()[:, 0:3, :, :], 2, 1)
depth = self.LPGLayer(plane_parameters.float())
return depth
class bts_encoder(nn.Module):
def __init__(self):
super(bts_encoder, self).__init__()
self.dense_op_h2 = None
self.dense_op_h4 = None
self.dense_op_h8 = None
self.dense_op_h16 = None
self.dense_features = None
self.dense_feature_extractor = self.initialize_dense_feature_extractor()
self.freeze_batch_norm()
self.initialize_hooks()
def freeze_batch_norm(self):
for module in self.dense_feature_extractor.modules():
if isinstance(module, torch.nn.modules.BatchNorm2d):
module.track_running_stats = True
module.eval()
module.affine = True
module.requires_grad = True
def initialize_dense_feature_extractor(self):
dfe = torchvision.models.densenet161(True, True)
dfe.features.denseblock1.requires_grad = False
dfe.features.denseblock2.requires_grad = False
dfe.features.conv0.requires_grad = False
return dfe
def set_h2(self, module, input_, output):
self.dense_op_h2 = output
def set_h4(self, module, input_, output):
self.dense_op_h4 = output
def set_h8(self, module, input_, output):
self.dense_op_h8 = output
def set_h16(self, module, input_, output):
self.dense_op_h16 = output
def set_dense_features(self, module, input_, output):
self.dense_features = output
def initialize_hooks(self):
self.dense_feature_extractor.features.relu0.register_forward_hook(self.set_h2)
self.dense_feature_extractor.features.pool0.register_forward_hook(self.set_h4)
self.dense_feature_extractor.features.transition1.register_forward_hook(self.set_h8)
self.dense_feature_extractor.features.transition2.register_forward_hook(self.set_h16)
self.dense_feature_extractor.features.norm5.register_forward_hook(self.set_dense_features)
def forward(self, ip):
_ = self.dense_feature_extractor(ip)
joint_input = (self.dense_features.relu(), self.dense_op_h2, self.dense_op_h4, self.dense_op_h8, self.dense_op_h16)
return joint_input
class bts_decoder(nn.Module):
def __init__(self):
super(bts_decoder, self).__init__()
self.UpscaleNet = UpscaleNetwork()
self.DenseASSPNet = ASSPBlock()
self.upscale_block3 = UpscaleBlock(64, 96, 128) # H4
self.upscale_block4 = UpscaleBlock(128, 96, 128) # H2
self.LPGBlock8 = LPGBlock(8, 128)
self.LPGBlock4 = LPGBlock(4, 64) # 64 Filter
self.LPGBlock2 = LPGBlock(2, 64) # 64 Filter
self.upconv_h4 = UpscaleLayer(128, 64)
self.upconv_h2 = UpscaleLayer(64, 32) # 64 Filter
self.upconv_h = UpscaleLayer(64, 32) # 32 filter
self.conv_h4 = nn.Conv2d(161, 64, 3, 1, 1, bias=ENABLE_BIAS) # 64 Filter
self.conv_h2 = nn.Conv2d(129, 64, 3, 1, 1, bias=ENABLE_BIAS) # 64 Filter
self.conv_h1 = nn.Conv2d(36, 32, 3, 1, 1, bias=ENABLE_BIAS)
self.reduction1x1 = Reduction(1, 32, True)
self.final_conv = nn.Conv2d(32, 1, 3, 1, 1, bias=ENABLE_BIAS)
def forward(self, joint_input, focal):
(dense_features, dense_op_h2, dense_op_h4, dense_op_h8, dense_op_h16) = joint_input
upscaled_out = self.UpscaleNet(joint_input)
dense_assp_out = self.DenseASSPNet(upscaled_out)
upconv_h4 = self.upconv_h4(dense_assp_out)
depth_8x8 = self.LPGBlock8(dense_assp_out) / MAX_DEPTH
depth_8x8_ds = nn.functional.interpolate(depth_8x8, scale_factor=1 / 4, mode="nearest")
depth_concat_4x4 = torch.cat((depth_8x8_ds, dense_op_h4, upconv_h4), 1)
conv_h4 = activation_fn(self.conv_h4(depth_concat_4x4))
upconv_h2 = self.upconv_h2(conv_h4)
depth_4x4 = self.LPGBlock4(conv_h4) / MAX_DEPTH
depth_4x4_ds = nn.functional.interpolate(depth_4x4, scale_factor=1 / 2, mode="nearest")
depth_concat_2x2 = torch.cat((depth_4x4_ds, dense_op_h2, upconv_h2), 1)
conv_h2 = activation_fn(self.conv_h2(depth_concat_2x2))
upconv_h = self.upconv_h(conv_h2)
depth_1x1 = self.reduction1x1(upconv_h)
depth_2x2 = self.LPGBlock2(conv_h2) / MAX_DEPTH
depth_concat = torch.cat((upconv_h, depth_1x1, depth_2x2, depth_4x4, depth_8x8), 1)
depth = activation_fn(self.conv_h1(depth_concat))
depth = self.final_conv(depth).sigmoid() * MAX_DEPTH + DEPTH_OFFSET
depth *= focal.view(-1, 1, 1, 1) / 715.0873
return depth, depth_2x2, depth_4x4, depth_8x8
class bts_model(nn.Module):
def __init__(self):
super(bts_model, self).__init__()
self.encoder = bts_encoder()
self.decoder = bts_decoder()
def forward(self, input, focal=715.0873):
joint_input = self.encoder(input)
return self.decoder(joint_input, focal)
class SilogLoss(nn.Module):
def __init__(self):
super(SilogLoss, self).__init__()
def forward(self, ip, target, ratio=10, ratio2=0.85):
ip = ip.reshape(-1)
target = target.reshape(-1)
mask = (target > 1) & (target < 81)
masked_ip = torch.masked_select(ip.float(), mask)
masked_op = torch.masked_select(target, mask)
log_diff = torch.log(masked_ip * ratio) - torch.log(masked_op * ratio)
log_diff_masked = log_diff
silog1 = torch.mean(log_diff_masked ** 2)
silog2 = ratio2 * (torch.mean(log_diff_masked) ** 2)
silog_loss = torch.sqrt(silog1 - silog2) * ratio
return silog_loss
class BtsController:
def __init__(self, log_directory='run_1', logs_folder='tensorboard', backprop_frequency=1):
self.bts = bts_model().float().to(device)
self.optimizer = torch.optim.AdamW([{'params': self.bts.encoder.parameters(), 'weight_decay': 1e-2},
{'params': self.bts.decoder.parameters(), 'weight_decay': 0}],
lr=1e-4, eps=1e-6)
if USE_APEX:
self.bts, self.optimizer = apex.amp.initialize(self.bts, self.optimizer, opt_level=APEX_OPT_LEVEL)
self.bts = torch.nn.DataParallel(self.bts)
self.backprop_frequency = backprop_frequency
log_path = os.path.join(logs_folder, log_directory)
# self.writer = SummaryWriter(log_path)
self.criterion = SilogLoss()
self.learning_rate_scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, 0.95)
self.current_epoch = 0
self.last_loss = 0
self.current_step = 0
def eval(self):
self.bts = self.bts.eval()
def train(self):
self.bts = self.bts.train()
def predict(self, input, is_channels_first=True, focal=715.0873, normalize=False):
if normalize:
input = A.Compose([A.Normalize()])(**{"image": input})["image"]
if is_channels_first:
tensor_input = torch.tensor(input).unsqueeze(-1).to(device).float().transpose(0, 3).transpose(2,
3).transpose(
1, 2)
else:
tensor_input = torch.tensor(input).unsqueeze(-1).to(device).float().transpose(0, 3).transpose(1,
2).transpose(
2, 3)
shape_changed = False
org_shape = tensor_input.shape[2:]
if org_shape[0] % 32 != 0 or org_shape[1] % 32 != 0:
shape_changed = True
new_shape_y = round(org_shape[0] / 32) * 32
new_shape_x = round(org_shape[1] / 32) * 32
tensor_input = F.interpolate(tensor_input, (new_shape_y, new_shape_x), mode="bilinear")
model_output = self.bts(tensor_input, torch.tensor(focal).unsqueeze(0))[0][0].detach().unsqueeze(0)
if shape_changed:
model_output = F.interpolate(model_output, (org_shape[0], org_shape[1]), mode="nearest")
return model_output.cpu().squeeze()
@staticmethod
def depth_map_to_rgbimg(depth_map):
depth_map = np.asarray(np.squeeze((255 - torch.clamp_max(depth_map * 4, 250)).byte().numpy()), np.uint8)
depth_map = np.asarray(cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB), np.uint8)
return depth_map
@staticmethod
def depth_map_to_grayimg(depth_map):
depth_map = np.asarray(np.squeeze((255 - torch.clamp_max(depth_map * 4, 250)).byte().numpy()), np.uint8)
return depth_map
@staticmethod
def normalize_img(image):
transformation = A.Compose([
A.Normalize()
])
return transformation(**{"image": image})["image"]
# def run_train_step(self, tensor_input, tensor_output, tensor_focal):
# tensor_input, tensor_output = tensor_input.to(device), tensor_output.to(device)
# # Get Models prediction and calculate loss
# model_output, depth2, depth4, depth8 = self.bts(tensor_input, tensor_focal)
#
# loss = self.criterion(model_output, tensor_output) * 1/self.backprop_frequency
#
# if USE_APEX:
# with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss:
# scaled_loss.backward()
# else:
# loss.backward()
#
# if self.current_step % self.backprop_frequency == 0: # Make update once every x steps
# torch.nn.utils.clip_grad_norm_(self.bts.parameters(), 5)
# self.optimizer.step()
# self.optimizer.zero_grad()
#
# if self.current_step % 100 == 0:
# self.writer.add_scalar("Loss", loss.item() * self.backprop_frequency / tensor_input.shape[0], self.current_step)
#
# if self.current_step % 1000 == 0:
# img = tensor_input[0].detach().transpose(0, 2).transpose(0, 1).cpu().numpy().astype(np.uint8)
# self.writer.add_image("Input", img, self.current_step, None, "HWC")
#
# visual_result = (255-torch.clamp_max(torchvision.utils.make_grid([tensor_output[0], model_output[0]]) * 5, 250)).byte()
#
# self.writer.add_image("Output/Prediction", visual_result, self.current_step)
# depths = [depth2[0], depth4[0], depth8[0]]
# depths = [depth*MAX_DEPTH for depth in depths]
# depth_visual = (255-torch.clamp_max(torchvision.utils.make_grid(depths) * 5, 250)).byte()
#
# self.writer.add_image("Depths", depth_visual, self.current_step)
#
# self.current_step += 1
def save_model(self, path):
save_dict = {
'epoch': self.current_epoch,
'model_state_dict': self.bts.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
"scheduler_state_dict": self.learning_rate_scheduler.state_dict(),
'loss': self.last_loss,
"last_step": self.current_step
}
if USE_APEX:
save_dict["amp"] = apex.amp.state_dict()
save_dict["opt_level"] = APEX_OPT_LEVEL
torch.save(save_dict, path)
def load_model(self, path):
dict = torch.load(path, map_location = device)
if USE_APEX:
saved_opt_level = dict["opt_level"]
self.bts, self.optimizer = apex.amp.initialize(self.bts, self.optimizer, opt_level=saved_opt_level)
apex.amp.load_state_dict(dict["amp"])
self.current_epoch = dict["epoch"]
self.bts.load_state_dict(dict["model_state_dict"])
self.bts = self.bts.float().to(device)
self.optimizer.load_state_dict(dict["optimizer_state_dict"])
self.learning_rate_scheduler.load_state_dict(dict["scheduler_state_dict"])
self.last_loss = dict["loss"]
self.current_step = dict["last_step"]
return dict