PatchFusion / estimator /models /baseline_pretrain.py
Zhyever
refactor
1f418ff
raw
history blame
No virus
20.9 kB
# MIT License
# Copyright (c) 2022 Intelligent Systems Lab Org
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# File author: Zhenyu Li
import itertools
import math
import copy
import random
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from mmengine import print_log
from estimator.registry import MODELS
from estimator.models import build_model
from estimator.models.utils import get_activation
from zoedepth.models.zoedepth import ZoeDepth
import matplotlib.pyplot as plt
from estimator.models.utils import get_activation, generatemask, RunningAverageMap
from zoedepth.models.base_models.midas import Resize as ResizeZoe
from depth_anything.transform import Resize as ResizeDA
@MODELS.register_module()
class BaselinePretrain(nn.Module):
def __init__(self,
coarse_branch,
fine_branch,
sigloss,
min_depth,
max_depth,
image_raw_shape=(2160, 3840),
patch_process_shape=(384, 512),
patch_split_num=(4, 4),
target='coarse',
coarse_branch_zoe=None):
"""ZoeDepth model
"""
super().__init__()
self.patch_process_shape = patch_process_shape
self.tile_cfg = self.prepare_tile_cfg(image_raw_shape, patch_split_num)
self.min_depth = min_depth
self.max_depth = max_depth
self.coarse_branch_cfg = coarse_branch
self.fine_branch_cfg = fine_branch
if target == 'coarse':
if self.coarse_branch_cfg.type == 'ZoeDepth':
self.coarse_branch = ZoeDepth.build(**coarse_branch)
print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.coarse_branch.core.prep.resizer)), logger='current')
self.resizer = ResizeZoe(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal")
elif self.coarse_branch_cfg.type == 'DA-ZoeDepth':
self.coarse_branch = ZoeDepth.build(**coarse_branch)
print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.coarse_branch.core.prep.resizer)), logger='current')
self.resizer = ResizeDA(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=14, resize_method="minimal")
if target == 'fine':
if self.fine_branch_cfg.type == 'ZoeDepth':
self.fine_branch = ZoeDepth.build(**fine_branch)
print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.fine_branch.core.prep.resizer)), logger='current')
self.resizer = ResizeZoe(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal")
elif self.fine_branch_cfg.type == 'DA-ZoeDepth':
self.fine_branch = ZoeDepth.build(**fine_branch)
print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.fine_branch.core.prep.resizer)), logger='current')
self.resizer = ResizeDA(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=14, resize_method="minimal")
self.sigloss = build_model(sigloss)
self.target = target
def prepare_tile_cfg(self, image_raw_shape, patch_split_num):
# information for process
patch_split_num = patch_split_num
patch_reensemble_shape = (self.patch_process_shape[0] * patch_split_num[0], self.patch_process_shape[1] * patch_split_num[1])
patch_raw_shape = (image_raw_shape[0] // patch_split_num[0], image_raw_shape[1] // patch_split_num[1])
image_raw_shape = image_raw_shape
raw_h_split_point = []
raw_w_split_point = []
for i in range(patch_split_num[0]):
raw_h_split_point.append(int(patch_raw_shape[0] * i))
for i in range(patch_split_num[1]):
raw_w_split_point.append(int(patch_raw_shape[1] * i))
tile_cfg = {
'patch_split_num': patch_split_num,
'patch_reensemble_shape': patch_reensemble_shape,
'patch_raw_shape': patch_raw_shape,
'image_raw_shape': image_raw_shape,
'raw_h_split_point': raw_h_split_point,
'raw_w_split_point': raw_w_split_point}
return tile_cfg
def load_dict(self, dict):
if hasattr(self, 'coarse_branch') and hasattr(self, 'fine_branch') == False:
return self.coarse_branch.load_state_dict(dict, strict=True)
elif hasattr(self, 'fine_branch') and hasattr(self, 'coarse_branch') == False:
return self.fine_branch.load_state_dict(dict, strict=True)
else:
raise NotImplementedError('Not support loading coarse and fine together')
def get_save_dict(self):
model_state_dict = {}
if hasattr(self, 'coarse_branch') and hasattr(self, 'fine_branch') == False:
model_state_dict.update(self.coarse_branch.state_dict())
elif hasattr(self, 'fine_branch') and hasattr(self, 'coarse_branch') == False:
model_state_dict.update(self.fine_branch.state_dict())
else:
raise NotImplementedError('Not support training coarse and fine together')
return model_state_dict
def infer_forward(self, imgs_crop):
output_dict = self.fine_branch(imgs_crop)
return output_dict['metric_depth']
@torch.no_grad()
def random_tile(
self,
image_hr,
tile_temp=None,
blur_mask=None,
avg_depth_map=None,
tile_cfg=None,
process_num=4,):
## setting
height, width = tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1]
h_start_list = [random.randint(0, tile_cfg['image_raw_shape'][0] - height - 1) for _ in range(process_num)]
w_start_list = [random.randint(0, tile_cfg['image_raw_shape'][1] - width - 1)]
## prepare data
imgs_crop = []
bboxs = []
for h_start in h_start_list:
for w_start in w_start_list:
crop_image = image_hr[:, h_start: h_start+height, w_start: w_start+width]
crop_image_resized = self.resizer(crop_image.unsqueeze(dim=0)).squeeze(dim=0) # resize to patch_process_shape
bbox = torch.tensor([w_start, h_start, w_start+width, h_start+height])
imgs_crop.append(crop_image_resized)
bboxs.append(bbox)
imgs_crop = torch.stack(imgs_crop, dim=0)
bboxs = torch.stack(bboxs, dim=0)
imgs_crop = imgs_crop.to(image_hr.device)
bboxs = bboxs.to(image_hr.device).int()
bboxs_feat_factor = torch.tensor([
1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1],
1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0],
1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1],
1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0]], device=bboxs.device).unsqueeze(dim=0)
bboxs_feat = bboxs * bboxs_feat_factor
inds = torch.arange(bboxs.shape[0]).to(bboxs.device).unsqueeze(dim=-1)
bboxs_feat = torch.cat((inds, bboxs_feat), dim=-1)
if tile_temp is not None:
coarse_postprocess_dict = self.coarse_postprocess_test(bboxs=bboxs, bboxs_feat=bboxs_feat, **tile_temp)
prediction_list = []
if tile_temp is not None:
coarse_temp_dict = {}
for k, v in coarse_postprocess_dict.items():
if k == 'coarse_feats_roi':
coarse_temp_dict[k] = [f for f in v]
else:
coarse_temp_dict[k] = v
bbox_feat_forward = bboxs_feat
bbox_feat_forward[:, 0] = 0
prediction = self.infer_forward(imgs_crop, bbox_feat_forward, tile_temp, coarse_temp_dict)
else:
prediction = self.infer_forward(imgs_crop)
prediction_list.append(prediction)
predictions = torch.cat(prediction_list, dim=0)
predictions = F.interpolate(predictions, tile_cfg['patch_raw_shape'])
patch_select_idx = 0
for h_start in h_start_list:
for w_start in w_start_list:
temp_depth = predictions[patch_select_idx]
count_map = torch.zeros(tile_cfg['image_raw_shape'], device=temp_depth.device)
pred_depth = torch.zeros(tile_cfg['image_raw_shape'], device=temp_depth.device)
count_map[h_start: h_start+tile_cfg['patch_raw_shape'][0], w_start: w_start+tile_cfg['patch_raw_shape'][1]] = blur_mask
pred_depth[h_start: h_start+tile_cfg['patch_raw_shape'][0], w_start: w_start+tile_cfg['patch_raw_shape'][1]] = temp_depth * blur_mask
avg_depth_map.update(pred_depth, count_map)
patch_select_idx += 1
return avg_depth_map
@torch.no_grad()
def regular_tile(
self,
offset,
offset_process,
image_hr,
init_flag=False,
tile_temp=None,
blur_mask=None,
avg_depth_map=None,
tile_cfg=None,
process_num=4,):
## setting
height, width = tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1]
offset_h, offset_w = offset[0], offset[1]
assert offset_w >= 0 and offset_h >= 0
tile_num_h = (tile_cfg['image_raw_shape'][0] - offset_h) // height
tile_num_w = (tile_cfg['image_raw_shape'][1] - offset_w) // width
h_start_list = [height * h + offset_h for h in range(tile_num_h)]
w_start_list = [width * w + offset_w for w in range(tile_num_w)]
height_process, width_process = self.patch_process_shape[0], self.patch_process_shape[1]
offset_h_process, offset_w_process = offset_process[0], offset_process[1]
assert offset_h_process >= 0 and offset_w_process >= 0
tile_num_h_process = (tile_cfg['patch_reensemble_shape'][0] - offset_h_process) // height_process
tile_num_w_process = (tile_cfg['patch_reensemble_shape'][1] - offset_w_process) // width_process
h_start_list_process = [height_process * h + offset_h_process for h in range(tile_num_h_process)]
w_start_list_process = [width_process * w + offset_w_process for w in range(tile_num_w_process)]
## prepare data
imgs_crop = []
bboxs = []
iter_priors = []
for h_start in h_start_list:
for w_start in w_start_list:
crop_image = image_hr[:, h_start: h_start+height, w_start: w_start+width]
crop_image_resized = self.resizer(crop_image.unsqueeze(dim=0)).squeeze(dim=0) # resize to patch_process_shape
bbox = torch.tensor([w_start, h_start, w_start+width, h_start+height])
imgs_crop.append(crop_image_resized)
bboxs.append(bbox)
imgs_crop = torch.stack(imgs_crop, dim=0)
bboxs = torch.stack(bboxs, dim=0)
imgs_crop = imgs_crop.to(image_hr.device)
bboxs = bboxs.to(image_hr.device).int()
bboxs = bboxs.squeeze() # HACK: during inference, 1, 16, 4 -> 16, 4
if len(bboxs.shape) == 1:
bboxs = bboxs.unsqueeze(dim=0)
bboxs_feat_factor = torch.tensor([
1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1],
1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0],
1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1],
1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0]], device=bboxs.device).unsqueeze(dim=0)
bboxs_feat = bboxs * bboxs_feat_factor
inds = torch.arange(bboxs.shape[0]).to(bboxs.device).unsqueeze(dim=-1)
bboxs_feat = torch.cat((inds, bboxs_feat), dim=-1)
# post_process
if tile_temp is not None:
# coarse_prediction_roi, coarse_features_patch_area, crop_coarse_prediction_collection = self.coarse_postprocess_test(bboxs=bboxs, bboxs_feat=bboxs_feat, **tile_temp)
coarse_postprocess_dict = self.coarse_postprocess_test(bboxs=bboxs, bboxs_feat=bboxs_feat, **tile_temp)
count_map = torch.zeros(tile_cfg['patch_reensemble_shape'], device=image_hr.device)
pred_depth = torch.zeros(tile_cfg['patch_reensemble_shape'], device=image_hr.device)
prediction_list = []
split_rebatch_image = torch.split(imgs_crop, process_num, dim=0)
for idx, rebatch_image in enumerate(split_rebatch_image):
if tile_temp is not None:
coarse_temp_dict = {}
for k, v in coarse_postprocess_dict.items():
if k == 'coarse_feats_roi':
coarse_temp_dict[k] = [f[idx*process_num:(idx+1)*process_num, :, :, :] for f in v]
else:
coarse_temp_dict[k] = v[idx*process_num:(idx+1)*process_num, :, :, :]
bbox_feat_forward = bboxs_feat[idx*process_num:(idx+1)*process_num, :]
bbox_feat_forward[:, 0] = 0
prediction = self.infer_forward(rebatch_image, bbox_feat_forward, tile_temp, coarse_temp_dict)
else:
prediction = self.infer_forward(rebatch_image)
prediction_list.append(prediction)
predictions = torch.cat(prediction_list, dim=0)
patch_select_idx = 0
for h_start in h_start_list_process:
for w_start in w_start_list_process:
temp_depth = predictions[patch_select_idx]
if init_flag:
count_map[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = blur_mask
pred_depth[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = temp_depth * blur_mask
else:
count_map = torch.zeros(tile_cfg['patch_reensemble_shape'], device=temp_depth.device)
pred_depth = torch.zeros(tile_cfg['patch_reensemble_shape'], device=temp_depth.device)
count_map[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = blur_mask
pred_depth[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = temp_depth * blur_mask
avg_depth_map.update(pred_depth, count_map)
patch_select_idx += 1
if init_flag:
avg_depth_map = RunningAverageMap(pred_depth, count_map)
return avg_depth_map
def forward(
self,
mode,
image_lr,
image_hr,
depth_gt,
crop_depths=None,
crops_image_hr=None,
bboxs=None,
tile_cfg=None,
cai_mode='m1',
process_num=4,
**kwargs):
if mode == 'train':
loss_dict = {}
if self.target == 'coarse':
model_output_dict = self.coarse_branch(image_lr)
depth_prediction = model_output_dict['metric_depth']
loss_dict['coarse_loss'] = self.sigloss(depth_prediction, depth_gt, self.min_depth, self.max_depth)
loss_dict['total_loss'] = loss_dict['coarse_loss']
return loss_dict, {'rgb': image_lr, 'depth_pred': depth_prediction, 'depth_gt': depth_gt}
elif self.target == 'fine':
model_output_dict = self.fine_branch(crops_image_hr) # 1/2 res, 1/4 res, 1/8 res, 1/16 res
depth_prediction = model_output_dict['metric_depth']
loss_dict['fine_loss'] = self.sigloss(depth_prediction, crop_depths, self.min_depth, self.max_depth)
loss_dict['total_loss'] = loss_dict['fine_loss']
return loss_dict, {'rgb': image_lr, 'depth_pred': depth_prediction, 'depth_gt': crop_depths}
else:
raise NotImplementedError
else:
if self.target == 'coarse':
model_output_dict = self.coarse_branch(image_lr)
depth_prediction = model_output_dict['metric_depth']
elif self.target == 'fine':
if tile_cfg is None:
tile_cfg = self.tile_cfg
else:
tile_cfg = self.prepare_tile_cfg(tile_cfg['image_raw_shape'], tile_cfg['patch_split_num'])
assert image_hr.shape[0] == 1
blur_mask = generatemask((self.patch_process_shape[0], self.patch_process_shape[1])) + 1e-3
blur_mask = torch.tensor(blur_mask, device=image_hr.device)
avg_depth_map = self.regular_tile(
offset=[0, 0],
offset_process=[0, 0],
image_hr=image_hr[0],
init_flag=True,
tile_temp=None,
blur_mask=blur_mask,
tile_cfg=tile_cfg,
process_num=process_num)
if cai_mode == 'm2' or cai_mode[0] == 'r':
avg_depth_map = self.regular_tile(
offset=[0, tile_cfg['patch_raw_shape'][1]//2],
offset_process=[0, self.patch_process_shape[1]//2],
image_hr=image_hr[0], init_flag=False, tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num)
avg_depth_map = self.regular_tile(
offset=[tile_cfg['patch_raw_shape'][0]//2, 0],
offset_process=[self.patch_process_shape[0]//2, 0],
image_hr=image_hr[0], init_flag=False, tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num)
avg_depth_map = self.regular_tile(
offset=[tile_cfg['patch_raw_shape'][0]//2, tile_cfg['patch_raw_shape'][1]//2],
offset_process=[self.patch_process_shape[0]//2, self.patch_process_shape[1]//2],
init_flag=False, image_hr=image_hr[0], tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num)
if cai_mode[0] == 'r':
blur_mask = generatemask((tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1])) + 1e-3
blur_mask = torch.tensor(blur_mask, device=image_hr.device)
avg_depth_map.resize(tile_cfg['image_raw_shape'])
patch_num = int(cai_mode[1:])
for i in range(patch_num):
avg_depth_map = self.random_tile(
image_hr=image_hr[0], tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num)
depth = avg_depth_map.average_map
depth = depth.unsqueeze(dim=0).unsqueeze(dim=0)
return depth, {}
else:
raise NotImplementedError
return depth_prediction, {'rgb': image_lr, 'depth_pred': depth_prediction, 'depth_gt': depth_gt}