|
from collections import namedtuple |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .detectors import build_detector |
|
|
|
try: |
|
import kornia |
|
except: |
|
pass |
|
|
|
|
|
|
|
|
|
def build_network(model_cfg, num_class, dataset): |
|
model = build_detector( |
|
model_cfg=model_cfg, num_class=num_class, dataset=dataset |
|
) |
|
return model |
|
|
|
|
|
def load_data_to_gpu(batch_dict): |
|
for key, val in batch_dict.items(): |
|
if key == 'camera_imgs': |
|
batch_dict[key] = val.cuda() |
|
elif not isinstance(val, np.ndarray): |
|
continue |
|
elif key in ['frame_id', 'metadata', 'calib', 'image_paths','ori_shape','img_process_infos']: |
|
continue |
|
elif key in ['images']: |
|
batch_dict[key] = kornia.image_to_tensor(val).float().cuda().contiguous() |
|
elif key in ['image_shape']: |
|
batch_dict[key] = torch.from_numpy(val).int().cuda() |
|
else: |
|
batch_dict[key] = torch.from_numpy(val).float().cuda() |
|
|
|
|
|
def model_fn_decorator(): |
|
ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict']) |
|
|
|
def model_func(model, batch_dict): |
|
load_data_to_gpu(batch_dict) |
|
ret_dict, tb_dict, disp_dict = model(batch_dict) |
|
|
|
loss = ret_dict['loss'].mean() |
|
if hasattr(model, 'update_global_step'): |
|
model.update_global_step() |
|
else: |
|
model.module.update_global_step() |
|
|
|
return ModelReturn(loss, tb_dict, disp_dict) |
|
|
|
return model_func |
|
|