zhigangjiang's picture
no message
88b0dcb
raw
history blame
No virus
8.75 kB
import torch.nn
import torch
import torch.nn as nn
import models.modules as modules
import numpy as np
from models.base_model import BaseModule
from models.modules.horizon_net_feature_extractor import HorizonNetFeatureExtractor
from models.modules.patch_feature_extractor import PatchFeatureExtractor
from utils.conversion import uv2depth, get_u, lonlat2depth, get_lon, lonlat2uv
from utils.height import calc_ceil_ratio
from utils.misc import tensor2np
class LGT_Net(BaseModule):
def __init__(self, ckpt_dir=None, backbone='resnet50', dropout=0.0, output_name='LGT',
decoder_name='Transformer', win_size=8, depth=6,
ape=None, rpe=None, corner_heat_map=False, rpe_pos=1):
super().__init__(ckpt_dir)
self.patch_num = 256
self.patch_dim = 1024
self.decoder_name = decoder_name
self.output_name = output_name
self.corner_heat_map = corner_heat_map
self.dropout_d = dropout
if backbone == 'patch':
self.feature_extractor = PatchFeatureExtractor(patch_num=self.patch_num, input_shape=[3, 512, 1024])
else:
# feature extractor
self.feature_extractor = HorizonNetFeatureExtractor(backbone)
if 'Transformer' in self.decoder_name:
# transformer encoder
transformer_dim = self.patch_dim
transformer_layers = depth
transformer_heads = 8
transformer_head_dim = transformer_dim // transformer_heads
transformer_ff_dim = 2048
rpe = None if rpe == 'None' else rpe
self.transformer = getattr(modules, decoder_name)(dim=transformer_dim, depth=transformer_layers,
heads=transformer_heads, dim_head=transformer_head_dim,
mlp_dim=transformer_ff_dim, win_size=win_size,
dropout=self.dropout_d, patch_num=self.patch_num,
ape=ape, rpe=rpe, rpe_pos=rpe_pos)
elif self.decoder_name == 'LSTM':
self.bi_rnn = nn.LSTM(input_size=self.feature_extractor.c_last,
hidden_size=self.patch_dim // 2,
num_layers=2,
dropout=self.dropout_d,
batch_first=False,
bidirectional=True)
self.drop_out = nn.Dropout(self.dropout_d)
else:
raise NotImplementedError("Only support *Transformer and LSTM")
if self.output_name == 'LGT':
# omnidirectional-geometry aware output
self.linear_depth_output = nn.Linear(in_features=self.patch_dim, out_features=1)
self.linear_ratio = nn.Linear(in_features=self.patch_dim, out_features=1)
self.linear_ratio_output = nn.Linear(in_features=self.patch_num, out_features=1)
elif self.output_name == 'LED' or self.output_name == 'Horizon':
# horizon-depth or latitude output
self.linear = nn.Linear(in_features=self.patch_dim, out_features=2)
else:
raise NotImplementedError("Unknown output")
if self.corner_heat_map:
# corners heat map output
self.linear_corner_heat_map_output = nn.Linear(in_features=self.patch_dim, out_features=1)
self.name = f"{self.decoder_name}_{self.output_name}_Net"
def lgt_output(self, x):
"""
:param x: [ b, 256(patch_num), 1024(d)]
:return: {
'depth': [b, 256(patch_num & d)]
'ratio': [b, 1(d)]
}
"""
depth = self.linear_depth_output(x) # [b, 256(patch_num), 1(d)]
depth = depth.view(-1, self.patch_num) # [b, 256(patch_num & d)]
# ratio represent room height
ratio = self.linear_ratio(x) # [b, 256(patch_num), 1(d)]
ratio = ratio.view(-1, self.patch_num) # [b, 256(patch_num & d)]
ratio = self.linear_ratio_output(ratio) # [b, 1(d)]
output = {
'depth': depth,
'ratio': ratio
}
return output
def led_output(self, x):
"""
:param x: [ b, 256(patch_num), 1024(d)]
:return: {
'depth': [b, 256(patch_num)]
'ceil_depth': [b, 256(patch_num)]
'ratio': [b, 1(d)]
}
"""
bon = self.linear(x) # [b, 256(patch_num), 2(d)]
bon = bon.permute(0, 2, 1) # [b, 2(d), 256(patch_num)]
bon = torch.sigmoid(bon)
ceil_v = bon[:, 0, :] * -0.5 + 0.5 # [b, 256(patch_num)]
floor_v = bon[:, 1, :] * 0.5 + 0.5 # [b, 256(patch_num)]
u = get_u(w=self.patch_num, is_np=False, b=ceil_v.shape[0]).to(ceil_v.device)
ceil_boundary = torch.stack((u, ceil_v), axis=-1) # [b, 256(patch_num), 2]
floor_boundary = torch.stack((u, floor_v), axis=-1) # [b, 256(patch_num), 2]
output = {
'depth': uv2depth(floor_boundary), # [b, 256(patch_num)]
'ceil_depth': uv2depth(ceil_boundary), # [b, 256(patch_num)]
}
# print(output['depth'].mean())
if not self.training:
# [b, 1(d)]
output['ratio'] = calc_ceil_ratio([tensor2np(ceil_boundary), tensor2np(floor_boundary)], mode='lsq').reshape(-1, 1)
return output
def horizon_output(self, x):
"""
:param x: [ b, 256(patch_num), 1024(d)]
:return: {
'floor_boundary': [b, 256(patch_num)]
'ceil_boundary': [b, 256(patch_num)]
}
"""
bon = self.linear(x) # [b, 256(patch_num), 2(d)]
bon = bon.permute(0, 2, 1) # [b, 2(d), 256(patch_num)]
output = {
'boundary': bon
}
if not self.training:
lon = get_lon(w=self.patch_num, is_np=False, b=bon.shape[0]).to(bon.device)
floor_lat = torch.clip(bon[:, 0, :], 1e-4, np.pi / 2)
ceil_lat = torch.clip(bon[:, 1, :], -np.pi / 2, -1e-4)
floor_lonlat = torch.stack((lon, floor_lat), axis=-1) # [b, 256(patch_num), 2]
ceil_lonlat = torch.stack((lon, ceil_lat), axis=-1) # [b, 256(patch_num), 2]
output['depth'] = lonlat2depth(floor_lonlat)
output['ratio'] = calc_ceil_ratio([tensor2np(lonlat2uv(ceil_lonlat)),
tensor2np(lonlat2uv(floor_lonlat))], mode='mean').reshape(-1, 1)
return output
def forward(self, x):
"""
:param x: [b, 3(d), 512(h), 1024(w)]
:return: {
'depth': [b, 256(patch_num & d)]
'ratio': [b, 1(d)]
}
"""
# feature extractor
x = self.feature_extractor(x) # [b 1024(d) 256(w)]
if 'Transformer' in self.decoder_name:
# transformer decoder
x = x.permute(0, 2, 1) # [b 256(patch_num) 1024(d)]
x = self.transformer(x) # [b 256(patch_num) 1024(d)]
elif self.decoder_name == 'LSTM':
# lstm decoder
x = x.permute(2, 0, 1) # [256(patch_num), b, 1024(d)]
self.bi_rnn.flatten_parameters()
x, _ = self.bi_rnn(x) # [256(patch_num & seq_len), b, 1024(d)]
x = x.permute(1, 0, 2) # [b, 256(patch_num), 1024(d)]
x = self.drop_out(x)
output = None
if self.output_name == 'LGT':
# plt output
output = self.lgt_output(x)
elif self.output_name == 'LED':
# led output
output = self.led_output(x)
elif self.output_name == 'Horizon':
# led output
output = self.horizon_output(x)
if self.corner_heat_map:
corner_heat_map = self.linear_corner_heat_map_output(x) # [b, 256(patch_num), 1]
corner_heat_map = corner_heat_map.view(-1, self.patch_num)
corner_heat_map = torch.sigmoid(corner_heat_map)
output['corner_heat_map'] = corner_heat_map
return output
if __name__ == '__main__':
from PIL import Image
import numpy as np
from models.other.init_env import init_env
init_env(0, deterministic=True)
net = LGT_Net()
total = sum(p.numel() for p in net.parameters())
trainable = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('parameter total:{:,}, trainable:{:,}'.format(total, trainable))
img = np.array(Image.open("../src/demo.png")).transpose((2, 0, 1))
input = torch.Tensor([img]) # 1 3 512 1024
output = net(input)
print(output['depth'].shape) # 1 256
print(output['ratio'].shape) # 1 1