Spaces:
Sleeping
Sleeping
Removed extraneous print statements
Browse files
models/GroundingDINO/transformer.py
CHANGED
@@ -237,7 +237,6 @@ class Transformer(nn.Module):
|
|
237 |
|
238 |
"""
|
239 |
# prepare input for encoder
|
240 |
-
print("inside transformer forward")
|
241 |
src_flatten = []
|
242 |
mask_flatten = []
|
243 |
lvl_pos_embed_flatten = []
|
@@ -274,7 +273,6 @@ class Transformer(nn.Module):
|
|
274 |
#########################################################
|
275 |
# Begin Encoder
|
276 |
#########################################################
|
277 |
-
print("begin transformer encoder")
|
278 |
memory, memory_text = self.encoder(
|
279 |
src_flatten,
|
280 |
pos=lvl_pos_embed_flatten,
|
@@ -288,7 +286,6 @@ class Transformer(nn.Module):
|
|
288 |
position_ids=text_dict["position_ids"],
|
289 |
text_self_attention_masks=text_dict["text_self_attention_masks"],
|
290 |
)
|
291 |
-
print("got encoder output")
|
292 |
#########################################################
|
293 |
# End Encoder
|
294 |
# - memory: bs, \sum{hw}, c
|
@@ -303,11 +300,9 @@ class Transformer(nn.Module):
|
|
303 |
# import ipdb; ipdb.set_trace()
|
304 |
|
305 |
if self.two_stage_type == "standard": # 把encoder的输出作为proposal
|
306 |
-
print("standard two stage")
|
307 |
output_memory, output_proposals = gen_encoder_output_proposals(
|
308 |
memory, mask_flatten, spatial_shapes
|
309 |
)
|
310 |
-
print("got output proposals")
|
311 |
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
312 |
|
313 |
if text_dict is not None:
|
@@ -324,29 +319,22 @@ class Transformer(nn.Module):
|
|
324 |
topk = self.num_queries
|
325 |
|
326 |
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
|
327 |
-
print("got topk proposals")
|
328 |
# gather boxes
|
329 |
-
print("gather 1")
|
330 |
refpoint_embed_undetach = torch.gather(
|
331 |
enc_outputs_coord_unselected,
|
332 |
1,
|
333 |
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
|
334 |
) # unsigmoid
|
335 |
-
print("gathered 1")
|
336 |
refpoint_embed_ = refpoint_embed_undetach.detach()
|
337 |
-
print("gather 2")
|
338 |
init_box_proposal = torch.gather(
|
339 |
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
340 |
).sigmoid() # sigmoid
|
341 |
-
print("gathered 2")
|
342 |
-
print("gather 3")
|
343 |
# gather tgt
|
344 |
tgt_undetach = torch.gather(
|
345 |
output_memory,
|
346 |
1,
|
347 |
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model),
|
348 |
)
|
349 |
-
print("gathered 3")
|
350 |
if self.embed_init_tgt:
|
351 |
tgt_ = (
|
352 |
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
|
@@ -401,7 +389,6 @@ class Transformer(nn.Module):
|
|
401 |
# memory torch.Size([2, 16320, 256])
|
402 |
|
403 |
# import pdb;pdb.set_trace()
|
404 |
-
print("going through decoder")
|
405 |
hs, references = self.decoder(
|
406 |
tgt=tgt.transpose(0, 1),
|
407 |
memory=memory.transpose(0, 1),
|
@@ -416,7 +403,6 @@ class Transformer(nn.Module):
|
|
416 |
text_attention_mask=~text_dict["text_token_mask"],
|
417 |
# we ~ the mask . False means use the token; True means pad the token
|
418 |
)
|
419 |
-
print("got decoder output")
|
420 |
#########################################################
|
421 |
# End Decoder
|
422 |
# hs: n_dec, bs, nq, d_model
|
@@ -560,7 +546,6 @@ class TransformerEncoder(nn.Module):
|
|
560 |
"""
|
561 |
|
562 |
output = src
|
563 |
-
print("inside transformer encoder")
|
564 |
# preparation and reshape
|
565 |
if self.num_layers > 0:
|
566 |
reference_points = self.get_reference_points(
|
@@ -591,10 +576,8 @@ class TransformerEncoder(nn.Module):
|
|
591 |
# if output.isnan().any() or memory_text.isnan().any():
|
592 |
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
593 |
# import ipdb; ipdb.set_trace()
|
594 |
-
print("layer_id: " + str(layer_id))
|
595 |
if self.fusion_layers:
|
596 |
if self.use_checkpoint:
|
597 |
-
print("using checkpoint")
|
598 |
output, memory_text = checkpoint.checkpoint(
|
599 |
self.fusion_layers[layer_id],
|
600 |
output,
|
@@ -602,30 +585,24 @@ class TransformerEncoder(nn.Module):
|
|
602 |
key_padding_mask,
|
603 |
text_attention_mask,
|
604 |
)
|
605 |
-
print("got checkpoint output")
|
606 |
else:
|
607 |
-
print("not using checkpoint")
|
608 |
output, memory_text = self.fusion_layers[layer_id](
|
609 |
v=output,
|
610 |
l=memory_text,
|
611 |
attention_mask_v=key_padding_mask,
|
612 |
attention_mask_l=text_attention_mask,
|
613 |
)
|
614 |
-
print("got fusion output")
|
615 |
|
616 |
if self.text_layers:
|
617 |
-
print("getting text layers")
|
618 |
memory_text = self.text_layers[layer_id](
|
619 |
src=memory_text.transpose(0, 1),
|
620 |
src_mask=~text_self_attention_masks, # note we use ~ for mask here
|
621 |
src_key_padding_mask=text_attention_mask,
|
622 |
pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
|
623 |
).transpose(0, 1)
|
624 |
-
print("got text output")
|
625 |
|
626 |
# main process
|
627 |
if self.use_transformer_ckpt:
|
628 |
-
print("use transformer ckpt")
|
629 |
output = checkpoint.checkpoint(
|
630 |
layer,
|
631 |
output,
|
@@ -635,9 +612,7 @@ class TransformerEncoder(nn.Module):
|
|
635 |
level_start_index,
|
636 |
key_padding_mask,
|
637 |
)
|
638 |
-
print("got output")
|
639 |
else:
|
640 |
-
print("not use transformer ckpt")
|
641 |
output = layer(
|
642 |
src=output,
|
643 |
pos=pos,
|
@@ -646,7 +621,6 @@ class TransformerEncoder(nn.Module):
|
|
646 |
level_start_index=level_start_index,
|
647 |
key_padding_mask=key_padding_mask,
|
648 |
)
|
649 |
-
print("got output")
|
650 |
|
651 |
return output, memory_text
|
652 |
|
@@ -847,7 +821,6 @@ class DeformableTransformerEncoderLayer(nn.Module):
|
|
847 |
):
|
848 |
# self attention
|
849 |
# import ipdb; ipdb.set_trace()
|
850 |
-
print("deformable self-attention")
|
851 |
src2 = self.self_attn(
|
852 |
query=self.with_pos_embed(src, pos),
|
853 |
reference_points=reference_points,
|
|
|
237 |
|
238 |
"""
|
239 |
# prepare input for encoder
|
|
|
240 |
src_flatten = []
|
241 |
mask_flatten = []
|
242 |
lvl_pos_embed_flatten = []
|
|
|
273 |
#########################################################
|
274 |
# Begin Encoder
|
275 |
#########################################################
|
|
|
276 |
memory, memory_text = self.encoder(
|
277 |
src_flatten,
|
278 |
pos=lvl_pos_embed_flatten,
|
|
|
286 |
position_ids=text_dict["position_ids"],
|
287 |
text_self_attention_masks=text_dict["text_self_attention_masks"],
|
288 |
)
|
|
|
289 |
#########################################################
|
290 |
# End Encoder
|
291 |
# - memory: bs, \sum{hw}, c
|
|
|
300 |
# import ipdb; ipdb.set_trace()
|
301 |
|
302 |
if self.two_stage_type == "standard": # 把encoder的输出作为proposal
|
|
|
303 |
output_memory, output_proposals = gen_encoder_output_proposals(
|
304 |
memory, mask_flatten, spatial_shapes
|
305 |
)
|
|
|
306 |
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
307 |
|
308 |
if text_dict is not None:
|
|
|
319 |
topk = self.num_queries
|
320 |
|
321 |
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
|
|
|
322 |
# gather boxes
|
|
|
323 |
refpoint_embed_undetach = torch.gather(
|
324 |
enc_outputs_coord_unselected,
|
325 |
1,
|
326 |
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
|
327 |
) # unsigmoid
|
|
|
328 |
refpoint_embed_ = refpoint_embed_undetach.detach()
|
|
|
329 |
init_box_proposal = torch.gather(
|
330 |
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
331 |
).sigmoid() # sigmoid
|
|
|
|
|
332 |
# gather tgt
|
333 |
tgt_undetach = torch.gather(
|
334 |
output_memory,
|
335 |
1,
|
336 |
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model),
|
337 |
)
|
|
|
338 |
if self.embed_init_tgt:
|
339 |
tgt_ = (
|
340 |
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
|
|
|
389 |
# memory torch.Size([2, 16320, 256])
|
390 |
|
391 |
# import pdb;pdb.set_trace()
|
|
|
392 |
hs, references = self.decoder(
|
393 |
tgt=tgt.transpose(0, 1),
|
394 |
memory=memory.transpose(0, 1),
|
|
|
403 |
text_attention_mask=~text_dict["text_token_mask"],
|
404 |
# we ~ the mask . False means use the token; True means pad the token
|
405 |
)
|
|
|
406 |
#########################################################
|
407 |
# End Decoder
|
408 |
# hs: n_dec, bs, nq, d_model
|
|
|
546 |
"""
|
547 |
|
548 |
output = src
|
|
|
549 |
# preparation and reshape
|
550 |
if self.num_layers > 0:
|
551 |
reference_points = self.get_reference_points(
|
|
|
576 |
# if output.isnan().any() or memory_text.isnan().any():
|
577 |
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
578 |
# import ipdb; ipdb.set_trace()
|
|
|
579 |
if self.fusion_layers:
|
580 |
if self.use_checkpoint:
|
|
|
581 |
output, memory_text = checkpoint.checkpoint(
|
582 |
self.fusion_layers[layer_id],
|
583 |
output,
|
|
|
585 |
key_padding_mask,
|
586 |
text_attention_mask,
|
587 |
)
|
|
|
588 |
else:
|
|
|
589 |
output, memory_text = self.fusion_layers[layer_id](
|
590 |
v=output,
|
591 |
l=memory_text,
|
592 |
attention_mask_v=key_padding_mask,
|
593 |
attention_mask_l=text_attention_mask,
|
594 |
)
|
|
|
595 |
|
596 |
if self.text_layers:
|
|
|
597 |
memory_text = self.text_layers[layer_id](
|
598 |
src=memory_text.transpose(0, 1),
|
599 |
src_mask=~text_self_attention_masks, # note we use ~ for mask here
|
600 |
src_key_padding_mask=text_attention_mask,
|
601 |
pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
|
602 |
).transpose(0, 1)
|
|
|
603 |
|
604 |
# main process
|
605 |
if self.use_transformer_ckpt:
|
|
|
606 |
output = checkpoint.checkpoint(
|
607 |
layer,
|
608 |
output,
|
|
|
612 |
level_start_index,
|
613 |
key_padding_mask,
|
614 |
)
|
|
|
615 |
else:
|
|
|
616 |
output = layer(
|
617 |
src=output,
|
618 |
pos=pos,
|
|
|
621 |
level_start_index=level_start_index,
|
622 |
key_padding_mask=key_padding_mask,
|
623 |
)
|
|
|
624 |
|
625 |
return output, memory_text
|
626 |
|
|
|
821 |
):
|
822 |
# self attention
|
823 |
# import ipdb; ipdb.set_trace()
|
|
|
824 |
src2 = self.self_attn(
|
825 |
query=self.with_pos_embed(src, pos),
|
826 |
reference_points=reference_points,
|