|
import torch |
|
from .yowo import YOWO |
|
from .loss import build_criterion |
|
|
|
|
|
|
|
def build_yowo(args, |
|
d_cfg, |
|
m_cfg, |
|
device, |
|
num_classes=3, |
|
trainable=False, |
|
resume=None): |
|
print('==============================') |
|
print('Build {} ...'.format(args.version.upper())) |
|
|
|
|
|
model = YOWO( |
|
cfg = m_cfg, |
|
device = device, |
|
num_classes = num_classes, |
|
conf_thresh = 0.15, |
|
nms_thresh = 0.5, |
|
topk = 40, |
|
trainable = trainable, |
|
multi_hot = d_cfg['multi_hot'], |
|
) |
|
|
|
if trainable: |
|
|
|
if args.freeze_backbone_2d: |
|
print('Freeze 2D Backbone ...') |
|
for m in model.backbone_2d.parameters(): |
|
m.requires_grad = False |
|
if args.freeze_backbone_3d: |
|
print('Freeze 3D Backbone ...') |
|
for m in model.backbone_3d.parameters(): |
|
m.requires_grad = False |
|
|
|
|
|
if resume is not None: |
|
print('keep training: ', resume) |
|
checkpoint = torch.load(resume, map_location='cpu') |
|
|
|
checkpoint_state_dict = checkpoint.pop("model") |
|
model.load_state_dict(checkpoint_state_dict) |
|
|
|
|
|
criterion = build_criterion( |
|
args, d_cfg['train_size'], num_classes, d_cfg['multi_hot']) |
|
|
|
else: |
|
criterion = None |
|
|
|
return model, criterion |
|
|