ammarali32 commited on
Commit
561a469
1 Parent(s): c39ef6e

Upload hf_utils.py

Browse files
Files changed (1) hide show
  1. hf_utils.py +84 -0
hf_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from transformers.models.deformable_detr.modeling_deformable_detr import DeformableDetrMLPPredictionHead
3
+ import torch.nn as nn
4
+ import torch
5
+ def PairDetr(model, num_queries, num_classes):
6
+ in_features = model.class_embed[0].in_features
7
+ model.model.query_position_embeddings = nn.Embedding(num_queries, 512)
8
+ class_embed = nn.Linear(in_features, num_classes)
9
+ bbox_embed = DeformableDetrMLPPredictionHead(
10
+ input_dim=256, hidden_dim=256, output_dim=8, num_layers=3
11
+ )
12
+ model.class_embed = nn.ModuleList([class_embed for _ in range(6)])
13
+ model.bbox_embed = nn.ModuleList([bbox_embed for _ in range(6)])
14
+ return model
15
+
16
+ def inverse_sigmoid(x, eps=1e-5):
17
+ x = x.clamp(min=0, max=1)
18
+ x1 = x.clamp(min=eps)
19
+ x2 = (1 - x).clamp(min=eps)
20
+ return torch.log(x1 / x2)
21
+
22
+ def forward(model,
23
+ pixel_values,
24
+ pixel_mask=None,
25
+ decoder_attention_mask=None,
26
+ encoder_outputs=None,
27
+ inputs_embeds=None,
28
+ decoder_inputs_embeds=None,
29
+ labels=None,
30
+ output_attentions=None,
31
+ output_hidden_states=None,
32
+ return_dict=None,) -> torch.Tensor:
33
+ return_dict = return_dict if return_dict is not None else model.config.use_return_dict
34
+
35
+ outputs = model.model(
36
+ pixel_values,
37
+ pixel_mask=pixel_mask,
38
+ decoder_attention_mask=decoder_attention_mask,
39
+ encoder_outputs=encoder_outputs,
40
+ inputs_embeds=inputs_embeds,
41
+ decoder_inputs_embeds=decoder_inputs_embeds,
42
+ output_attentions=output_attentions,
43
+ output_hidden_states=output_hidden_states,
44
+ return_dict=return_dict,
45
+ )
46
+
47
+ hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2]
48
+ init_reference = outputs.init_reference_points if return_dict else outputs[0]
49
+ inter_references = outputs.intermediate_reference_points if return_dict else outputs[3]
50
+ outputs_classes = []
51
+ outputs_coords = []
52
+ cons = inverse_sigmoid(init_reference)
53
+ for level in range(hidden_states.shape[1]):
54
+ if level == 0:
55
+ reference = init_reference
56
+ else:
57
+ reference = inter_references[:, level - 1]
58
+ reference = inverse_sigmoid(reference)
59
+ outputs_class = model.class_embed[level](hidden_states[:, level])
60
+ delta_bbox = model.bbox_embed[level](hidden_states[:, level])
61
+ if reference.shape[-1] == 4:
62
+ delta_bbox[..., :4] += reference
63
+ outputs_coord_logits = delta_bbox
64
+ elif reference.shape[-1] == 2:
65
+ delta_bbox[..., :2] += reference
66
+ delta_bbox[..., 4:6] += cons
67
+ outputs_coord_logits = delta_bbox
68
+ else:
69
+ raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}")
70
+ outputs_coord = outputs_coord_logits.sigmoid()
71
+ outputs_classes.append(outputs_class)
72
+ outputs_coords.append(outputs_coord)
73
+ outputs_class = torch.stack(outputs_classes, dim=1)
74
+ outputs_coord = torch.stack(outputs_coords, dim=1)
75
+
76
+ logits = outputs_class[:, -1]
77
+ pred_boxes = outputs_coord[:, -1]
78
+
79
+ dict_outputs = {
80
+ "logits":logits,
81
+ "pred_boxes": pred_boxes,
82
+ "init_reference_points": outputs.init_reference_points,
83
+ }
84
+ return dict_outputs