File size: 8,540 Bytes
ff07ed4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Modified from DINO (https://github.com/IDEA-Research/DINO)


import torch
from utils.misc import (NestedTensor, nested_tensor_from_tensor_list,
                       accuracy, get_world_size, interpolate,
                       is_dist_avail_and_initialized, inverse_sigmoid)
from utils import box_ops
import torch.nn.functional as F


def prepare_for_cdn(targets, dn_cfg, num_queries, hidden_dim, dn_enc):
    """
        A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
        forward function and use learnable tgt embedding, so we change this function a little bit.
        :param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
        :param training: if it is training or inference
        :param num_queries: number of queires
        :param num_classes: number of classes
        :param hidden_dim: transformer hidden dim
        :param label_enc: encode labels in dn
        :return:
        """
    device = targets[0]['boxes'].device

    dn_number = dn_cfg['dn_number']
    box_noise_scale = dn_cfg['box_noise_scale']
    tgt_noise_scale = dn_cfg['tgt_noise_scale']
    known = [(torch.ones_like(t['labels'])) for t in targets]
    batch_size = len(known)
    known_num = [sum(k) for k in known]

    if int(max(known_num)) == 0:
        dn_number = 1
    else:
        if dn_number >= 100:
            dn_number = dn_number // (int(max(known_num) * 2))
        elif dn_number < 1:
            dn_number = 1
    if dn_number == 0:
        dn_number = 1


    unmask_bbox = torch.cat(known)
    
    boxes = torch.cat([t['boxes'] for t in targets])
    assert boxes.ndim == 2
    batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
    known_indice = torch.nonzero(unmask_bbox)
    known_indice = known_indice.view(-1)
    known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
    known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)

    single_pad = int(max(known_num))
    pad_size = int(single_pad * 2 * dn_number)
    positive_idx = torch.tensor(range(len(boxes))).long().to(device=device).unsqueeze(0).repeat(dn_number, 1)
    positive_idx += (torch.tensor(range(dn_number)) * len(boxes) * 2).long().to(device=device).unsqueeze(1)
    positive_idx = positive_idx.flatten()
    negative_idx = positive_idx + len(boxes)

    # box queries
    known_bboxs = boxes.repeat(2 * dn_number, 1)
    known_bbox_expand = known_bboxs.clone()
    if box_noise_scale > 0:
        known_bbox_ = torch.zeros_like(known_bboxs)
        known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2
        known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2

        diff = torch.zeros_like(known_bboxs)
        diff[:, :2] = known_bboxs[:, 2:] / 2
        diff[:, 2:] = known_bboxs[:, 2:] / 2

        rand_sign = torch.randint_like(known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
        rand_part = torch.rand_like(known_bboxs)
        rand_part[negative_idx] += 1.0
        rand_part *= rand_sign
        known_bbox_ = known_bbox_ + torch.mul(rand_part,
                                                diff).to(device=device) * box_noise_scale
        known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
        known_bbox_expand[:, :2] = (known_bbox_[:, :2] + known_bbox_[:, 2:]) / 2
        known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2]
    input_bbox_embed = inverse_sigmoid(known_bbox_expand)

    # tgt queries
    if dn_cfg['tgt_embed_type'] == 'labels':
        labels = torch.cat([t['labels'] for t in targets])
        known_labels = labels.repeat(2 * dn_number, 1).view(-1)
        known_labels_expaned = known_labels.clone()
        if tgt_noise_scale > 0:
            p = torch.rand_like(known_labels_expaned.float())
            chosen_indice = torch.nonzero(p < tgt_noise_scale).view(-1)
            new_label = torch.randint_like(chosen_indice, 0, dn_cfg['dn_labelbook_size'])  # randomly put a new one here
            known_labels_expaned.scatter_(0, chosen_indice, new_label)
        m = known_labels_expaned.long().to(device=device)
        input_tgt_embed = dn_enc(m)
    elif dn_cfg['tgt_embed_type'] == 'params':
        poses = torch.cat([t['poses'] for t in targets])
        betas = torch.cat([t['betas'] for t in targets])
        params = torch.cat([poses, betas], dim=-1)
        assert params.ndim == 2
        known_params = params.repeat(2 * dn_number, 1)
        known_params_expaned = known_params.clone()
        if tgt_noise_scale > 0:
            rand_sign = torch.randint_like(known_params, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
            rand_part = torch.rand_like(known_params)
            rand_part[negative_idx] += 1.0
            rand_part *= rand_sign
            known_params_expaned = known_params_expaned + rand_part * tgt_noise_scale
        m = known_params_expaned.to(device=device)
        input_tgt_embed = dn_enc(m)
    

    padding_tgt = torch.zeros((pad_size, hidden_dim), device=device)
    padding_bbox = torch.zeros((pad_size, 4), device=device)

    input_query_tgt = padding_tgt.repeat(batch_size, 1, 1)
    input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)

    map_known_indice = torch.tensor([]).to(device=device)
    if len(known_num):
        map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num])  # [1,2, 1,2,3]
        map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
    if len(known_bid):
        input_query_tgt[(known_bid.long(), map_known_indice)] = input_tgt_embed
        input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed


    # prepare attn_mask
    tgt_size = pad_size + num_queries
    attn_mask = torch.zeros((tgt_size, tgt_size), dtype=bool, device=device)
    # match query cannot see the reconstruct
    attn_mask[pad_size:, :pad_size] = True
    # reconstruct cannot see each other
    for i in range(dn_number):
        if i == 0:
            attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
        if i == dn_number - 1:
            attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True
        else:
            attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
            attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True

    dn_meta = {
        'pad_size': pad_size,
        'num_dn_group': dn_number,
    }

    return input_query_tgt, input_query_bbox, attn_mask, dn_meta


def dn_post_process(pred_poses, pred_betas,
                    pred_boxes, pred_confs,
                    pred_j3ds, pred_j2ds, pred_depths, 
                    pred_verts, pred_transl,
                    dn_meta, aux_loss, _set_aux_loss):
    """
        post process of dn after output from the transformer
        put the dn part in the dn_meta
    """
    assert dn_meta['pad_size'] > 0
    pad_size = dn_meta['pad_size']

    known_poses, pred_poses = pred_poses[:,:,:pad_size], pred_poses[:,:,pad_size:]
    known_betas, pred_betas = pred_betas[:,:,:pad_size], pred_betas[:,:,pad_size:]
    known_boxes, pred_boxes = pred_boxes[:,:,:pad_size], pred_boxes[:,:,pad_size:]
    known_confs, pred_confs = pred_confs[:,:,:pad_size], pred_confs[:,:,pad_size:]
    known_j3ds, pred_j3ds = pred_j3ds[:,:,:pad_size], pred_j3ds[:,:,pad_size:]
    known_j2ds, pred_j2ds = pred_j2ds[:,:,:pad_size], pred_j2ds[:,:,pad_size:]
    known_depths, pred_depths = pred_depths[:,:,:pad_size], pred_depths[:,:,pad_size:]

    known_verts, pred_verts = pred_verts[:,:pad_size], pred_verts[:,pad_size:]
    known_transl, pred_transl = pred_transl[:,:pad_size], pred_transl[:,pad_size:]

    out = {'pred_poses': known_poses[-1], 'pred_betas': known_betas[-1],
                'pred_boxes': known_boxes[-1], 'pred_confs': known_confs[-1], 
               'pred_j3ds': known_j3ds[-1], 'pred_j2ds': known_j2ds[-1],
               'pred_depths': known_depths[-1]}
        
    if aux_loss:
        out['aux_outputs'] = _set_aux_loss(known_poses, known_betas,
                                            known_boxes, known_confs,
                                            known_j3ds, known_j2ds, known_depths)

    dn_meta['output_known'] = out


    return pred_poses, pred_betas,\
           pred_boxes, pred_confs,\
           pred_j3ds, pred_j2ds,\
           pred_depths, pred_verts,\
           pred_transl,