OMG_Seg / app /configs /m2_convl.py
HarborYuan's picture
adapt for omg
1f8df14
from torch.nn import GroupNorm, ReLU
from mmdet.models import MSDeformAttnPixelDecoder, CrossEntropyLoss, DiceLoss, FocalLoss
from mmdet.models.task_modules.assigners import HungarianAssigner, ClassificationCost, CrossEntropyLossCost, DiceCost
from mmdet.models.task_modules.samplers import MaskPseudoSampler
from seg.models.detectors import Mask2formerVideo
from seg.models.fusion_head import OMGFusionHead
from seg.models.heads import Mask2FormerVideoHead
from seg.models.backbones import OpenCLIPBackbone
num_things_classes = 80
num_stuff_classes = 53
ov_model_name = 'convnext_large_d_320'
ov_datasets_name = 'CocoPanopticOVDataset'
model = dict(
type=Mask2formerVideo,
data_preprocessor=None, # to fill
backbone=dict(
type=OpenCLIPBackbone,
model_name='convnext_large_d_320',
fix=True,
init_cfg=dict(
type='clip_pretrain',
checkpoint='laion2b_s29b_b131k_ft_soup'
)
),
panoptic_head=dict(
init_cfg=dict(
type='Pretrained',
checkpoint='./models/omg_seg_convl.pth',
prefix='panoptic_head.'
),
type=Mask2FormerVideoHead,
sphere_cls=True,
ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}',
logit=None,
enable_box_query=True,
in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside
strides=[4, 8, 16, 32],
feat_channels=256,
out_channels=256,
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
num_queries=300,
num_transformer_feat_level=3,
pixel_decoder=dict(
type=MSDeformAttnPixelDecoder,
num_outs=3,
norm_cfg=dict(type=GroupNorm, num_groups=32),
act_cfg=dict(type=ReLU),
encoder=dict( # DeformableDetrTransformerEncoder
num_layers=6,
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
dropout=0.0,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type=ReLU, inplace=True)))),
positional_encoding=dict(num_feats=128, normalize=True)),
enforce_decoder_input_project=False,
positional_encoding=dict(num_feats=128, normalize=True),
transformer_decoder=dict( # Mask2FormerTransformerDecoder
return_intermediate=True,
num_layers=9,
layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=True),
cross_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type='ReLU', inplace=True))),
init_cfg=None),
loss_cls=dict(
type=CrossEntropyLoss,
use_sigmoid=False,
loss_weight=2.0,
reduction='mean',
class_weight=None # [1.0] * num_classes + [0.1]
),
loss_mask=dict(
type=CrossEntropyLoss,
use_sigmoid=True,
reduction='mean',
loss_weight=5.0),
loss_dice=dict(
type=DiceLoss,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=5.0),
loss_iou=dict(
type=FocalLoss,
use_sigmoid=True,
loss_weight=2.0,
reduction='mean'
)
),
panoptic_fusion_head=dict(
type=OMGFusionHead,
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
loss_panoptic=None,
init_cfg=None
),
train_cfg=dict(
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type=HungarianAssigner,
match_costs=[
dict(type=ClassificationCost, weight=2.0),
dict(
type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True),
dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0)
]),
sampler=dict(type=MaskPseudoSampler)),
test_cfg=dict(
panoptic_on=True,
semantic_on=False,
instance_on=True,
# max_per_image is for instance segmentation.
max_per_image=100,
iou_thr=0.8,
# In Mask2Former's panoptic postprocessing,
# it will filter mask area where score is less than 0.5 .
filter_low_score=True,
object_mask_thr=0.,
),
init_cfg=None
)