|
from .segmenter import CRIS |
|
|
|
from .segmenter_verbonly import CRIS_PosOnly |
|
from .segmenter_verbonly_fin import CRIS_PosOnly_rev |
|
from .segmenter_verbonly_hardneg import CRIS_VerbOnly |
|
from loguru import logger |
|
|
|
def build_segmenter_pos_rev(args): |
|
model = CRIS_PosOnly_rev(args) |
|
backbone = [] |
|
head = [] |
|
for k, v in model.named_parameters(): |
|
if k.startswith('backbone') and 'positional_embedding' not in k: |
|
backbone.append(v) |
|
else: |
|
head.append(v) |
|
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head))) |
|
param_list = [{ |
|
'params': backbone, |
|
'initial_lr': args.lr_multi * args.base_lr |
|
}, { |
|
'params': head, |
|
'initial_lr': args.base_lr |
|
}] |
|
return model, param_list |
|
|
|
def build_segmenter_pos(args): |
|
model = CRIS_PosOnly(args) |
|
backbone = [] |
|
head = [] |
|
for k, v in model.named_parameters(): |
|
if k.startswith('backbone') and 'positional_embedding' not in k: |
|
backbone.append(v) |
|
else: |
|
head.append(v) |
|
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head))) |
|
param_list = [{ |
|
'params': backbone, |
|
'initial_lr': args.lr_multi * args.base_lr |
|
}, { |
|
'params': head, |
|
'initial_lr': args.base_lr |
|
}] |
|
return model, param_list |
|
|
|
|
|
def build_segmenter(args): |
|
model = CRIS_VerbOnly(args) |
|
backbone = [] |
|
head = [] |
|
for k, v in model.named_parameters(): |
|
if k.startswith('backbone') and 'positional_embedding' not in k: |
|
backbone.append(v) |
|
else: |
|
head.append(v) |
|
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head))) |
|
param_list = [{ |
|
'params': backbone, |
|
'initial_lr': args.lr_multi * args.base_lr |
|
}, { |
|
'params': head, |
|
'initial_lr': args.base_lr |
|
}] |
|
return model, param_list |
|
|
|
|
|
|
|
def build_segmenter_original(args): |
|
model = CRIS(args) |
|
backbone = [] |
|
head = [] |
|
for k, v in model.named_parameters(): |
|
if k.startswith('backbone') and 'positional_embedding' not in k: |
|
backbone.append(v) |
|
else: |
|
head.append(v) |
|
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head))) |
|
param_list = [{ |
|
'params': backbone, |
|
'initial_lr': args.lr_multi * args.base_lr |
|
}, { |
|
'params': head, |
|
'initial_lr': args.base_lr |
|
}] |
|
return model, param_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|