import copy import torch.nn as nn from detectron2.modeling.backbone import ResNet, BasicStem from detectron2.layers import ShapeSpec from detectron2.config import LazyCall as L from detrex.modeling.matcher import HungarianMatcher from detrex.modeling.neck import ChannelMapper from detrex.layers import PositionEmbeddingSine from projects.dino.modeling import ( DINO, DINOTransformerEncoder, DINOTransformerDecoder, DINOTransformer, DINOCriterion, ) model = L(DINO)( backbone=L(ResNet)( stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), stages=L(ResNet.make_default_stages)( depth=50, stride_in_1x1=False, norm="FrozenBN", ), out_features=["res3", "res4", "res5"], freeze_at=1, ), position_embedding=L(PositionEmbeddingSine)( num_pos_feats=128, temperature=10000, normalize=True, offset=-0.5, ), neck=L(ChannelMapper)( input_shapes={ "res3": ShapeSpec(channels=512), "res4": ShapeSpec(channels=1024), "res5": ShapeSpec(channels=2048), }, in_features=["res3", "res4", "res5"], out_channels=256, num_outs=4, kernel_size=1, norm_layer=L(nn.GroupNorm)(num_groups=32, num_channels=256), ), transformer=L(DINOTransformer)( encoder=L(DINOTransformerEncoder)( embed_dim=256, num_heads=8, feedforward_dim=2048, attn_dropout=0.0, ffn_dropout=0.0, num_layers=6, post_norm=False, num_feature_levels="${..num_feature_levels}", use_checkpoint=False ), decoder=L(DINOTransformerDecoder)( embed_dim=256, num_heads=8, feedforward_dim=2048, attn_dropout=0.0, ffn_dropout=0.0, num_layers=6, return_intermediate=True, num_feature_levels="${..num_feature_levels}", use_checkpoint=False, ), num_feature_levels=4, two_stage_num_proposals="${..num_queries}", ), embed_dim=256, num_classes=80, num_queries=900, aux_loss=True, criterion=L(DINOCriterion)( num_classes="${..num_classes}", matcher=L(HungarianMatcher)( cost_class=2.0, cost_bbox=5.0, cost_giou=2.0, cost_class_type="focal_loss_cost", alpha=0.25, gamma=2.0, ), weight_dict={ "loss_class": 1, "loss_bbox": 5.0, "loss_giou": 2.0, "loss_class_dn": 1, "loss_bbox_dn": 5.0, "loss_giou_dn": 2.0, }, loss_class_type="focal_loss", alpha=0.25, gamma=2.0, two_stage_binary_cls=False, ), dn_number=100, label_noise_ratio=0.5, box_noise_scale=1.0, pixel_mean=[123.675, 116.280, 103.530], pixel_std=[58.395, 57.120, 57.375], vis_period=0, input_format="RGB", device="cuda", ) # set aux loss weight dict base_weight_dict = copy.deepcopy(model.criterion.weight_dict) if model.aux_loss: weight_dict = model.criterion.weight_dict aux_weight_dict = {} aux_weight_dict.update({k + "_enc": v for k, v in base_weight_dict.items()}) for i in range(model.transformer.decoder.num_layers - 1): aux_weight_dict.update({k + f"_{i}": v for k, v in base_weight_dict.items()}) weight_dict.update(aux_weight_dict) model.criterion.weight_dict = weight_dict