|
from .segmenter import CRIS, CISEN, Clip_hash_model, zh_clip, poi_clip, Clip_model, CISEN_vit, CISEN_rsvit, CISEN_new, CISEN_rsvit_classification, CISEN_lclip
|
|
from .segmenter import *
|
|
from loguru import logger
|
|
from transformers import AlignProcessor, AlignModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_CISEN(args, stage):
|
|
model = CISEN_new(args)
|
|
backbone = []
|
|
head = []
|
|
ADP = []
|
|
ADP_t = []
|
|
fuse = []
|
|
name = []
|
|
for k, v in model.named_parameters():
|
|
if k.startswith('backbone') and 'backbone.positional_embedding' not in k:
|
|
|
|
v.requires_grad = False
|
|
backbone.append(v)
|
|
elif k.startswith('ADP'):
|
|
|
|
ADP.append(v)
|
|
elif k.startswith('FPN'):
|
|
fuse.append(v)
|
|
elif k.startswith('gap'):
|
|
fuse.append(v)
|
|
elif k.startswith('ADP_t'):
|
|
ADP_t.append(v)
|
|
else:
|
|
head.append(v)
|
|
name.append(k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if stage == '1st':
|
|
param_list = [{
|
|
'params': ADP,
|
|
'initial_lr': args.base_lr
|
|
},{
|
|
'params': head,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
elif stage == '2nd':
|
|
param_list = [{
|
|
'params': fuse,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
elif stage == '4th':
|
|
param_list = [{
|
|
'params': fuse,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
elif stage == '5th':
|
|
param_list = [{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'params': fuse,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
else:
|
|
print('stage should be either 1st or 2nd')
|
|
return model, param_list
|
|
|
|
def build_CISEN_lclip(args, stage):
|
|
model = CISEN_lclip(args)
|
|
backbone = []
|
|
head = []
|
|
ADP = []
|
|
ADP_t = []
|
|
fuse = []
|
|
name = []
|
|
for k, v in model.named_parameters():
|
|
|
|
if k.startswith('backbone'):
|
|
v.requires_grad = False
|
|
backbone.append(v)
|
|
elif k.startswith('ADP'):
|
|
|
|
ADP.append(v)
|
|
elif k.startswith('FPN'):
|
|
fuse.append(v)
|
|
elif k.startswith('gap'):
|
|
fuse.append(v)
|
|
elif k.startswith('ADP_t'):
|
|
ADP_t.append(v)
|
|
else:
|
|
head.append(v)
|
|
name.append(k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if stage == '1st':
|
|
param_list = [{
|
|
'params': ADP,
|
|
'initial_lr': args.base_lr
|
|
},{
|
|
'params': head,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
elif stage == '2nd':
|
|
param_list = [{
|
|
'params': fuse,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
elif stage == '4th':
|
|
param_list = [{
|
|
'params': fuse,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
elif stage == '5th':
|
|
param_list = [{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'params': fuse,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
else:
|
|
print('stage should be either 1st or 2nd')
|
|
return model, param_list
|
|
|
|
def build_CISEN_vit(args, stage):
|
|
model = CISEN_rsvit(args)
|
|
backbone = []
|
|
head = []
|
|
ADP = []
|
|
ADP_t = []
|
|
fuse = []
|
|
name = []
|
|
for k, v in model.named_parameters():
|
|
|
|
if k.startswith('backbone'):
|
|
v.requires_grad = False
|
|
backbone.append(v)
|
|
elif k.startswith('ADP'):
|
|
v.requires_grad = False
|
|
ADP.append(v)
|
|
elif k.startswith('FPN'):
|
|
|
|
fuse.append(v)
|
|
elif k.startswith('ms_adaptor'):
|
|
|
|
fuse.append(v)
|
|
else:
|
|
head.append(v)
|
|
name.append(k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if stage == '1st':
|
|
param_list = [{
|
|
'params': ADP,
|
|
'initial_lr': args.base_lr
|
|
},{
|
|
'params': head,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
elif stage == '2nd':
|
|
param_list = [{
|
|
'params': fuse,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
elif stage == '4th':
|
|
param_list = [{
|
|
'params': fuse,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
elif stage == '5th':
|
|
param_list = [{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'params': fuse,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
else:
|
|
print('stage should be either 1st or 2nd')
|
|
return model, param_list
|
|
|
|
def build_CISEN_vit_classification(args, stage):
|
|
model = CISEN_rsvit_classification(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
def build_segmenter(args):
|
|
model = CRIS(args)
|
|
backbone = []
|
|
head = []
|
|
for k, v in model.named_parameters():
|
|
if k.startswith('backbone') and 'positional_embedding' not in k:
|
|
backbone.append(v)
|
|
elif k.startswith('Label_encoder') and "token_embedding" not in k:
|
|
v.requires_grad = False
|
|
else:
|
|
head.append(v)
|
|
|
|
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
|
|
param_list = [{
|
|
'params': backbone,
|
|
'initial_lr': args.lr_multi * float(args.base_lr)
|
|
}, {
|
|
'params': head,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
return model, param_list
|
|
|
|
def build_hash(args):
|
|
model = Clip_hash_model(args)
|
|
backbone = []
|
|
head = []
|
|
for k, v in model.named_parameters():
|
|
if k.startswith('backbone') and 'positional_embedding' not in k:
|
|
backbone.append(v)
|
|
else:
|
|
head.append(v)
|
|
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
|
|
param_list = [{
|
|
'params': backbone,
|
|
'initial_lr': args.lr_multi * args.base_lr
|
|
}, {
|
|
'params': head,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
return model, param_list
|
|
|
|
def build_zh_segmenter(args):
|
|
model = zh_clip(args)
|
|
backbone = []
|
|
head = []
|
|
for k, v in model.named_parameters():
|
|
if k.startswith('backbone') and 'positional_embedding' not in k:
|
|
backbone.append(v)
|
|
else:
|
|
head.append(v)
|
|
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
|
|
param_list = [{
|
|
'params': backbone,
|
|
'initial_lr': args.lr_multi * args.base_lr
|
|
}, {
|
|
'params': head,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
return model, param_list
|
|
|
|
def build_poi_segmenter(args):
|
|
model = poi_clip(args)
|
|
backbone = []
|
|
head = []
|
|
for k, v in model.named_parameters():
|
|
if k.startswith('backbone') and 'positional_embedding' not in k:
|
|
backbone.append(v)
|
|
else:
|
|
head.append(v)
|
|
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
|
|
param_list = [{
|
|
'params': backbone,
|
|
'initial_lr': args.lr_multi * args.base_lr
|
|
}, {
|
|
'params': head,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
return model, param_list
|
|
|
|
def build_clip(args):
|
|
model = Clip_model(args)
|
|
backbone = []
|
|
head = []
|
|
for k, v in model.named_parameters():
|
|
if k.startswith('backbone') and 'positional_embedding' not in k:
|
|
backbone.append(v)
|
|
else:
|
|
head.append(v)
|
|
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
|
|
param_list = [{
|
|
'params': backbone,
|
|
'initial_lr': args.lr_multi * args.base_lr
|
|
}, {
|
|
'params': head,
|
|
'initial_lr': args.base_lr
|
|
}]
|
|
return model, param_list |