cranky-coder08 commited on
Commit
4dfe3b3
·
verified ·
1 Parent(s): 9207dd1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. phivenv/Lib/site-packages/torch/lib/XNNPACK.lib +3 -0
  3. phivenv/Lib/site-packages/torch/lib/torch_cpu.lib +3 -0
  4. phivenv/Lib/site-packages/torch/lib/torch_python.dll +3 -0
  5. phivenv/Lib/site-packages/torch/lib/torch_python.lib +3 -0
  6. phivenv/Lib/site-packages/torch/lib/uv.dll +3 -0
  7. phivenv/Lib/site-packages/torch/linalg/__pycache__/__init__.cpython-39.pyc +3 -0
  8. phivenv/Lib/site-packages/transformers/models/d_fine/__init__.py +29 -0
  9. phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/__init__.cpython-39.pyc +0 -0
  10. phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/configuration_d_fine.cpython-39.pyc +0 -0
  11. phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/modeling_d_fine.cpython-39.pyc +0 -0
  12. phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/modular_d_fine.cpython-39.pyc +0 -0
  13. phivenv/Lib/site-packages/transformers/models/d_fine/configuration_d_fine.py +433 -0
  14. phivenv/Lib/site-packages/transformers/models/d_fine/modeling_d_fine.py +0 -0
  15. phivenv/Lib/site-packages/transformers/models/d_fine/modular_d_fine.py +1221 -0
  16. phivenv/Lib/site-packages/transformers/models/depth_pro/__init__.py +29 -0
  17. phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/__init__.cpython-39.pyc +0 -0
  18. phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/configuration_depth_pro.cpython-39.pyc +0 -0
  19. phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/image_processing_depth_pro.cpython-39.pyc +0 -0
  20. phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/image_processing_depth_pro_fast.cpython-39.pyc +0 -0
  21. phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/modeling_depth_pro.cpython-39.pyc +0 -0
  22. phivenv/Lib/site-packages/transformers/models/depth_pro/configuration_depth_pro.py +205 -0
  23. phivenv/Lib/site-packages/transformers/models/depth_pro/image_processing_depth_pro.py +389 -0
  24. phivenv/Lib/site-packages/transformers/models/depth_pro/image_processing_depth_pro_fast.py +177 -0
  25. phivenv/Lib/site-packages/transformers/models/depth_pro/modeling_depth_pro.py +1132 -0
  26. phivenv/Lib/site-packages/transformers/models/detr/__init__.py +31 -0
  27. phivenv/Lib/site-packages/transformers/models/detr/__pycache__/__init__.cpython-39.pyc +0 -0
  28. phivenv/Lib/site-packages/transformers/models/detr/__pycache__/configuration_detr.cpython-39.pyc +0 -0
  29. phivenv/Lib/site-packages/transformers/models/detr/__pycache__/feature_extraction_detr.cpython-39.pyc +0 -0
  30. phivenv/Lib/site-packages/transformers/models/detr/__pycache__/image_processing_detr.cpython-39.pyc +0 -0
  31. phivenv/Lib/site-packages/transformers/models/detr/__pycache__/image_processing_detr_fast.cpython-39.pyc +0 -0
  32. phivenv/Lib/site-packages/transformers/models/detr/__pycache__/modeling_detr.cpython-39.pyc +0 -0
  33. phivenv/Lib/site-packages/transformers/models/detr/configuration_detr.py +297 -0
  34. phivenv/Lib/site-packages/transformers/models/detr/feature_extraction_detr.py +48 -0
  35. phivenv/Lib/site-packages/transformers/models/detr/image_processing_detr.py +2049 -0
  36. phivenv/Lib/site-packages/transformers/models/detr/image_processing_detr_fast.py +1291 -0
  37. phivenv/Lib/site-packages/transformers/models/detr/modeling_detr.py +1693 -0
  38. phivenv/Lib/site-packages/transformers/models/dia/__init__.py +31 -0
  39. phivenv/Lib/site-packages/transformers/models/dia/__pycache__/__init__.cpython-39.pyc +0 -0
  40. phivenv/Lib/site-packages/transformers/models/dia/__pycache__/configuration_dia.cpython-39.pyc +0 -0
  41. phivenv/Lib/site-packages/transformers/models/dia/__pycache__/feature_extraction_dia.cpython-39.pyc +0 -0
  42. phivenv/Lib/site-packages/transformers/models/dia/__pycache__/generation_dia.cpython-39.pyc +0 -0
  43. phivenv/Lib/site-packages/transformers/models/dia/__pycache__/modeling_dia.cpython-39.pyc +0 -0
  44. phivenv/Lib/site-packages/transformers/models/dia/__pycache__/modular_dia.cpython-39.pyc +0 -0
  45. phivenv/Lib/site-packages/transformers/models/dia/__pycache__/processing_dia.cpython-39.pyc +0 -0
  46. phivenv/Lib/site-packages/transformers/models/dia/__pycache__/tokenization_dia.cpython-39.pyc +0 -0
  47. phivenv/Lib/site-packages/transformers/models/dia/configuration_dia.py +376 -0
  48. phivenv/Lib/site-packages/transformers/models/dia/feature_extraction_dia.py +183 -0
  49. phivenv/Lib/site-packages/transformers/models/dia/generation_dia.py +464 -0
  50. phivenv/Lib/site-packages/transformers/models/dia/modeling_dia.py +958 -0
.gitattributes CHANGED
@@ -122,3 +122,9 @@ phivenv/Lib/site-packages/torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs
122
  phivenv/Lib/site-packages/torch/lib/pthreadpool.lib filter=lfs diff=lfs merge=lfs -text
123
  phivenv/Lib/site-packages/torch/lib/microkernels-prod.lib filter=lfs diff=lfs merge=lfs -text
124
  phivenv/Lib/site-packages/torch/lib/sleef.lib filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
122
  phivenv/Lib/site-packages/torch/lib/pthreadpool.lib filter=lfs diff=lfs merge=lfs -text
123
  phivenv/Lib/site-packages/torch/lib/microkernels-prod.lib filter=lfs diff=lfs merge=lfs -text
124
  phivenv/Lib/site-packages/torch/lib/sleef.lib filter=lfs diff=lfs merge=lfs -text
125
+ phivenv/Lib/site-packages/torch/lib/torch_cpu.lib filter=lfs diff=lfs merge=lfs -text
126
+ phivenv/Lib/site-packages/torch/lib/torch_python.dll filter=lfs diff=lfs merge=lfs -text
127
+ phivenv/Lib/site-packages/torch/lib/torch_python.lib filter=lfs diff=lfs merge=lfs -text
128
+ phivenv/Lib/site-packages/torch/lib/uv.dll filter=lfs diff=lfs merge=lfs -text
129
+ phivenv/Lib/site-packages/torch/lib/XNNPACK.lib filter=lfs diff=lfs merge=lfs -text
130
+ phivenv/Lib/site-packages/torch/linalg/__pycache__/__init__.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
phivenv/Lib/site-packages/torch/lib/XNNPACK.lib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bf5c98f694f4587f5a191739ea8dd565a0696828448828e7491b9c8ca5d6fe2
3
+ size 14049460
phivenv/Lib/site-packages/torch/lib/torch_cpu.lib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08b81393191ac47ebf63e92aad8b65ece890d86dd51eb1e7294f1be3e496f3d7
3
+ size 29046564
phivenv/Lib/site-packages/torch/lib/torch_python.dll ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cd2897d163f1029341f3ae67fcccd5f4f4fe7b4d62ecae8ca767128e3140f73
3
+ size 16310272
phivenv/Lib/site-packages/torch/lib/torch_python.lib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8910178ebb14175b5c8b9ccd27b38b2360297399273b6dd3312bcc733b779529
3
+ size 287836
phivenv/Lib/site-packages/torch/lib/uv.dll ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa569e682fc5fb7a8eb94c6829af9f30a569748dbbc6bce39735d48bc960bcf8
3
+ size 195072
phivenv/Lib/site-packages/torch/linalg/__pycache__/__init__.cpython-39.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aa23740f58a645a003b0df576478da22f00c55c5001053fba842208036eb483
3
+ size 113386
phivenv/Lib/site-packages/transformers/models/d_fine/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING
17
+
18
+ from ...utils import _LazyModule
19
+ from ...utils.import_utils import define_import_structure
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from .configuration_d_fine import *
24
+ from .modeling_d_fine import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (528 Bytes). View file
 
phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/configuration_d_fine.cpython-39.pyc ADDED
Binary file (17.5 kB). View file
 
phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/modeling_d_fine.cpython-39.pyc ADDED
Binary file (72.6 kB). View file
 
phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/modular_d_fine.cpython-39.pyc ADDED
Binary file (43.9 kB). View file
 
phivenv/Lib/site-packages/transformers/models/d_fine/configuration_d_fine.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/d_fine/modular_d_fine.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_d_fine.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ from ...configuration_utils import PretrainedConfig
22
+ from ...utils import logging
23
+ from ...utils.backbone_utils import verify_backbone_config_arguments
24
+ from ..auto import CONFIG_MAPPING
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ # TODO: Attribute map assignment logic should be fixed in modular
31
+ # as well as super() call parsing because otherwise we cannot re-write args after initialization
32
+ class DFineConfig(PretrainedConfig):
33
+ """
34
+ This is the configuration class to store the configuration of a [`DFineModel`]. It is used to instantiate a D-FINE
35
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
+ defaults will yield a similar configuration to that of D-FINE-X-COCO "[ustc-community/dfine-xlarge-coco"](https://huggingface.co/ustc-community/dfine-xlarge-coco").
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+ Args:
41
+ initializer_range (`float`, *optional*, defaults to 0.01):
42
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
43
+ initializer_bias_prior_prob (`float`, *optional*):
44
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
45
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
46
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
47
+ The epsilon used by the layer normalization layers.
48
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
49
+ The epsilon used by the batch normalization layers.
50
+ backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
51
+ The configuration of the backbone model.
52
+ backbone (`str`, *optional*):
53
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
54
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
55
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
56
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
57
+ Whether to use pretrained weights for the backbone.
58
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
59
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
60
+ library.
61
+ freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
62
+ Whether to freeze the batch normalization layers in the backbone.
63
+ backbone_kwargs (`dict`, *optional*):
64
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
65
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
66
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
67
+ Dimension of the layers in hybrid encoder.
68
+ encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
69
+ Multi level features input for encoder.
70
+ feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
71
+ Strides used in each feature map.
72
+ encoder_layers (`int`, *optional*, defaults to 1):
73
+ Total of layers to be used by the encoder.
74
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
75
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
76
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
77
+ Number of attention heads for each attention layer in the Transformer encoder.
78
+ dropout (`float`, *optional*, defaults to 0.0):
79
+ The ratio for all dropout layers.
80
+ activation_dropout (`float`, *optional*, defaults to 0.0):
81
+ The dropout ratio for activations inside the fully connected layer.
82
+ encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
83
+ Indexes of the projected layers to be used in the encoder.
84
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
85
+ The temperature parameter used to create the positional encodings.
86
+ encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
87
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
88
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
89
+ activation_function (`str`, *optional*, defaults to `"silu"`):
90
+ The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
91
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
92
+ eval_size (`tuple[int, int]`, *optional*):
93
+ Height and width used to computes the effective height and width of the position embeddings after taking
94
+ into account the stride.
95
+ normalize_before (`bool`, *optional*, defaults to `False`):
96
+ Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
97
+ feed-forward modules.
98
+ hidden_expansion (`float`, *optional*, defaults to 1.0):
99
+ Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
100
+ d_model (`int`, *optional*, defaults to 256):
101
+ Dimension of the layers exclude hybrid encoder.
102
+ num_queries (`int`, *optional*, defaults to 300):
103
+ Number of object queries.
104
+ decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
105
+ Multi level features dimension for decoder
106
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
107
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
108
+ num_feature_levels (`int`, *optional*, defaults to 3):
109
+ The number of input feature levels.
110
+ decoder_n_points (`int`, *optional*, defaults to 4):
111
+ The number of sampled keys in each feature level for each attention head in the decoder.
112
+ decoder_layers (`int`, *optional*, defaults to 6):
113
+ Number of decoder layers.
114
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
115
+ Number of attention heads for each attention layer in the Transformer decoder.
116
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
117
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
118
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
119
+ attention_dropout (`float`, *optional*, defaults to 0.0):
120
+ The dropout ratio for the attention probabilities.
121
+ num_denoising (`int`, *optional*, defaults to 100):
122
+ The total number of denoising tasks or queries to be used for contrastive denoising.
123
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
124
+ The fraction of denoising labels to which random noise should be added.
125
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
126
+ Scale or magnitude of noise to be added to the bounding boxes.
127
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
128
+ Indicates whether the initial query embeddings for the decoder should be learned during training
129
+ anchor_image_size (`tuple[int, int]`, *optional*):
130
+ Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
131
+ with_box_refine (`bool`, *optional*, defaults to `True`):
132
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
133
+ based on the predictions from the previous layer.
134
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
135
+ Whether the architecture has an encoder decoder structure.
136
+ matcher_alpha (`float`, *optional*, defaults to 0.25):
137
+ Parameter alpha used by the Hungarian Matcher.
138
+ matcher_gamma (`float`, *optional*, defaults to 2.0):
139
+ Parameter gamma used by the Hungarian Matcher.
140
+ matcher_class_cost (`float`, *optional*, defaults to 2.0):
141
+ The relative weight of the class loss used by the Hungarian Matcher.
142
+ matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
143
+ The relative weight of the bounding box loss used by the Hungarian Matcher.
144
+ matcher_giou_cost (`float`, *optional*, defaults to 2.0):
145
+ The relative weight of the giou loss of used by the Hungarian Matcher.
146
+ use_focal_loss (`bool`, *optional*, defaults to `True`):
147
+ Parameter informing if focal focal should be used.
148
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
149
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
150
+ focal_loss_alpha (`float`, *optional*, defaults to 0.75):
151
+ Parameter alpha used to compute the focal loss.
152
+ focal_loss_gamma (`float`, *optional*, defaults to 2.0):
153
+ Parameter gamma used to compute the focal loss.
154
+ weight_loss_vfl (`float`, *optional*, defaults to 1.0):
155
+ Relative weight of the varifocal loss in the object detection loss.
156
+ weight_loss_bbox (`float`, *optional*, defaults to 5.0):
157
+ Relative weight of the L1 bounding box loss in the object detection loss.
158
+ weight_loss_giou (`float`, *optional*, defaults to 2.0):
159
+ Relative weight of the generalized IoU loss in the object detection loss.
160
+ weight_loss_fgl (`float`, *optional*, defaults to 0.15):
161
+ Relative weight of the fine-grained localization loss in the object detection loss.
162
+ weight_loss_ddf (`float`, *optional*, defaults to 1.5):
163
+ Relative weight of the decoupled distillation focal loss in the object detection loss.
164
+ eos_coefficient (`float`, *optional*, defaults to 0.0001):
165
+ Relative classification weight of the 'no-object' class in the object detection loss.
166
+ eval_idx (`int`, *optional*, defaults to -1):
167
+ Index of the decoder layer to use for evaluation. If negative, counts from the end
168
+ (e.g., -1 means use the last layer). This allows for early prediction in the decoder
169
+ stack while still training later layers.
170
+ layer_scale (`float`, *optional*, defaults to `1.0`):
171
+ Scaling factor for the hidden dimension in later decoder layers. Used to adjust the
172
+ model capacity after the evaluation layer.
173
+ max_num_bins (`int`, *optional*, defaults to 32):
174
+ Maximum number of bins for the distribution-guided bounding box refinement.
175
+ Higher values allow for more fine-grained localization but increase computation.
176
+ reg_scale (`float`, *optional*, defaults to 4.0):
177
+ Scale factor for the regression distribution. Controls the range and granularity
178
+ of the bounding box refinement process.
179
+ depth_mult (`float`, *optional*, defaults to 1.0):
180
+ Multiplier for the number of blocks in RepNCSPELAN4 layers. Used to scale the model's
181
+ depth while maintaining its architecture.
182
+ top_prob_values (`int`, *optional*, defaults to 4):
183
+ Number of top probability values to consider from each corner's distribution.
184
+ lqe_hidden_dim (`int`, *optional*, defaults to 64):
185
+ Hidden dimension size for the Location Quality Estimator (LQE) network.
186
+ lqe_layers (`int`, *optional*, defaults to 2):
187
+ Number of layers in the Location Quality Estimator MLP.
188
+ decoder_offset_scale (`float`, *optional*, defaults to 0.5):
189
+ Offset scale used in deformable attention.
190
+ decoder_method (`str`, *optional*, defaults to `"default"`):
191
+ The method to use for the decoder: `"default"` or `"discrete"`.
192
+ up (`float`, *optional*, defaults to 0.5):
193
+ Controls the upper bounds of the Weighting Function.
194
+ """
195
+
196
+ model_type = "d_fine"
197
+ layer_types = ["basic", "bottleneck"]
198
+ attribute_map = {
199
+ "hidden_size": "d_model",
200
+ "num_attention_heads": "encoder_attention_heads",
201
+ }
202
+
203
+ def __init__(
204
+ self,
205
+ initializer_range=0.01,
206
+ initializer_bias_prior_prob=None,
207
+ layer_norm_eps=1e-5,
208
+ batch_norm_eps=1e-5,
209
+ # backbone
210
+ backbone_config=None,
211
+ backbone=None,
212
+ use_pretrained_backbone=False,
213
+ use_timm_backbone=False,
214
+ freeze_backbone_batch_norms=True,
215
+ backbone_kwargs=None,
216
+ # encoder HybridEncoder
217
+ encoder_hidden_dim=256,
218
+ encoder_in_channels=[512, 1024, 2048],
219
+ feat_strides=[8, 16, 32],
220
+ encoder_layers=1,
221
+ encoder_ffn_dim=1024,
222
+ encoder_attention_heads=8,
223
+ dropout=0.0,
224
+ activation_dropout=0.0,
225
+ encode_proj_layers=[2],
226
+ positional_encoding_temperature=10000,
227
+ encoder_activation_function="gelu",
228
+ activation_function="silu",
229
+ eval_size=None,
230
+ normalize_before=False,
231
+ hidden_expansion=1.0,
232
+ # decoder DFineTransformer
233
+ d_model=256,
234
+ num_queries=300,
235
+ decoder_in_channels=[256, 256, 256],
236
+ decoder_ffn_dim=1024,
237
+ num_feature_levels=3,
238
+ decoder_n_points=4,
239
+ decoder_layers=6,
240
+ decoder_attention_heads=8,
241
+ decoder_activation_function="relu",
242
+ attention_dropout=0.0,
243
+ num_denoising=100,
244
+ label_noise_ratio=0.5,
245
+ box_noise_scale=1.0,
246
+ learn_initial_query=False,
247
+ anchor_image_size=None,
248
+ with_box_refine=True,
249
+ is_encoder_decoder=True,
250
+ # Loss
251
+ matcher_alpha=0.25,
252
+ matcher_gamma=2.0,
253
+ matcher_class_cost=2.0,
254
+ matcher_bbox_cost=5.0,
255
+ matcher_giou_cost=2.0,
256
+ use_focal_loss=True,
257
+ auxiliary_loss=True,
258
+ focal_loss_alpha=0.75,
259
+ focal_loss_gamma=2.0,
260
+ weight_loss_vfl=1.0,
261
+ weight_loss_bbox=5.0,
262
+ weight_loss_giou=2.0,
263
+ weight_loss_fgl=0.15,
264
+ weight_loss_ddf=1.5,
265
+ eos_coefficient=1e-4,
266
+ eval_idx=-1,
267
+ layer_scale=1,
268
+ max_num_bins=32,
269
+ reg_scale=4.0,
270
+ depth_mult=1.0,
271
+ top_prob_values=4,
272
+ lqe_hidden_dim=64,
273
+ lqe_layers=2,
274
+ decoder_offset_scale=0.5,
275
+ decoder_method="default",
276
+ up=0.5,
277
+ **kwargs,
278
+ ):
279
+ self.initializer_range = initializer_range
280
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
281
+ self.layer_norm_eps = layer_norm_eps
282
+ self.batch_norm_eps = batch_norm_eps
283
+ # backbone
284
+ if backbone_config is None and backbone is None:
285
+ logger.info(
286
+ "`backbone_config` and `backbone` are `None`. Initializing the config with the default `HGNet-V2` backbone."
287
+ )
288
+ backbone_model_type = "hgnet_v2"
289
+ config_class = CONFIG_MAPPING[backbone_model_type]
290
+ # this will map it to RTDetrResNetConfig
291
+ # note: we can instead create HGNetV2Config
292
+ # and we would need to create HGNetV2Backbone
293
+ backbone_config = config_class(
294
+ num_channels=3,
295
+ embedding_size=64,
296
+ hidden_sizes=[256, 512, 1024, 2048],
297
+ depths=[3, 4, 6, 3],
298
+ layer_type="bottleneck",
299
+ hidden_act="relu",
300
+ downsample_in_first_stage=False,
301
+ downsample_in_bottleneck=False,
302
+ out_features=None,
303
+ out_indices=[2, 3, 4],
304
+ )
305
+ elif isinstance(backbone_config, dict):
306
+ backbone_model_type = backbone_config.pop("model_type")
307
+ config_class = CONFIG_MAPPING[backbone_model_type]
308
+ backbone_config = config_class.from_dict(backbone_config)
309
+
310
+ verify_backbone_config_arguments(
311
+ use_timm_backbone=use_timm_backbone,
312
+ use_pretrained_backbone=use_pretrained_backbone,
313
+ backbone=backbone,
314
+ backbone_config=backbone_config,
315
+ backbone_kwargs=backbone_kwargs,
316
+ )
317
+
318
+ self.backbone_config = backbone_config
319
+ self.backbone = backbone
320
+ self.use_pretrained_backbone = use_pretrained_backbone
321
+ self.use_timm_backbone = use_timm_backbone
322
+ self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
323
+ self.backbone_kwargs = backbone_kwargs
324
+ # encoder
325
+ self.encoder_hidden_dim = encoder_hidden_dim
326
+ self.encoder_in_channels = encoder_in_channels
327
+ self.feat_strides = feat_strides
328
+ self.encoder_attention_heads = encoder_attention_heads
329
+ self.encoder_ffn_dim = encoder_ffn_dim
330
+ self.dropout = dropout
331
+ self.activation_dropout = activation_dropout
332
+ self.encode_proj_layers = encode_proj_layers
333
+ self.encoder_layers = encoder_layers
334
+ self.positional_encoding_temperature = positional_encoding_temperature
335
+ self.eval_size = eval_size
336
+ self.normalize_before = normalize_before
337
+ self.encoder_activation_function = encoder_activation_function
338
+ self.activation_function = activation_function
339
+ self.hidden_expansion = hidden_expansion
340
+ # decoder
341
+ self.d_model = d_model
342
+ self.num_queries = num_queries
343
+ self.decoder_ffn_dim = decoder_ffn_dim
344
+ self.decoder_in_channels = decoder_in_channels
345
+ self.num_feature_levels = num_feature_levels
346
+ self.decoder_n_points = decoder_n_points
347
+ self.decoder_layers = decoder_layers
348
+ self.decoder_attention_heads = decoder_attention_heads
349
+ self.decoder_activation_function = decoder_activation_function
350
+ self.attention_dropout = attention_dropout
351
+ self.num_denoising = num_denoising
352
+ self.label_noise_ratio = label_noise_ratio
353
+ self.box_noise_scale = box_noise_scale
354
+ self.learn_initial_query = learn_initial_query
355
+ self.anchor_image_size = anchor_image_size
356
+ self.auxiliary_loss = auxiliary_loss
357
+ self.with_box_refine = with_box_refine
358
+ # Loss
359
+ self.matcher_alpha = matcher_alpha
360
+ self.matcher_gamma = matcher_gamma
361
+ self.matcher_class_cost = matcher_class_cost
362
+ self.matcher_bbox_cost = matcher_bbox_cost
363
+ self.matcher_giou_cost = matcher_giou_cost
364
+ self.use_focal_loss = use_focal_loss
365
+ self.focal_loss_alpha = focal_loss_alpha
366
+ self.focal_loss_gamma = focal_loss_gamma
367
+ self.weight_loss_vfl = weight_loss_vfl
368
+ self.weight_loss_bbox = weight_loss_bbox
369
+ self.weight_loss_giou = weight_loss_giou
370
+ self.weight_loss_fgl = weight_loss_fgl
371
+ self.weight_loss_ddf = weight_loss_ddf
372
+ self.eos_coefficient = eos_coefficient
373
+ # add the new attributes with the given values or defaults
374
+ self.eval_idx = eval_idx
375
+ self.layer_scale = layer_scale
376
+ self.max_num_bins = max_num_bins
377
+ self.reg_scale = reg_scale
378
+ self.depth_mult = depth_mult
379
+ self.decoder_offset_scale = decoder_offset_scale
380
+ self.decoder_method = decoder_method
381
+ self.top_prob_values = top_prob_values
382
+ self.lqe_hidden_dim = lqe_hidden_dim
383
+ self.lqe_layers = lqe_layers
384
+ self.up = up
385
+
386
+ if isinstance(self.decoder_n_points, list):
387
+ if len(self.decoder_n_points) != self.num_feature_levels:
388
+ raise ValueError(
389
+ f"Length of decoder_n_points list ({len(self.decoder_n_points)}) must match num_feature_levels ({self.num_feature_levels})."
390
+ )
391
+
392
+ head_dim = self.d_model // self.decoder_attention_heads
393
+ if head_dim * self.decoder_attention_heads != self.d_model:
394
+ raise ValueError(
395
+ f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
396
+ )
397
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
398
+
399
+ @property
400
+ def num_attention_heads(self) -> int:
401
+ return self.encoder_attention_heads
402
+
403
+ @property
404
+ def hidden_size(self) -> int:
405
+ return self.d_model
406
+
407
+ @property
408
+ def sub_configs(self):
409
+ return (
410
+ {"backbone_config": type(self.backbone_config)}
411
+ if getattr(self, "backbone_config", None) is not None
412
+ else {}
413
+ )
414
+
415
+ @classmethod
416
+ def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
417
+ """Instantiate a [`DFineConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
418
+ configuration.
419
+
420
+ Args:
421
+ backbone_config ([`PretrainedConfig`]):
422
+ The backbone configuration.
423
+
424
+ Returns:
425
+ [`DFineConfig`]: An instance of a configuration object
426
+ """
427
+ return cls(
428
+ backbone_config=backbone_config,
429
+ **kwargs,
430
+ )
431
+
432
+
433
+ __all__ = ["DFineConfig"]
phivenv/Lib/site-packages/transformers/models/d_fine/modeling_d_fine.py ADDED
The diff for this file is too large to render. See raw diff
 
phivenv/Lib/site-packages/transformers/models/d_fine/modular_d_fine.py ADDED
@@ -0,0 +1,1221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import math
16
+ from typing import Any, Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torch.nn.init as init
21
+ from torch import nn
22
+
23
+ from ...activations import ACT2CLS
24
+ from ...configuration_utils import PretrainedConfig
25
+ from ...image_transforms import corners_to_center_format
26
+ from ...utils import is_torchdynamo_compiling, logging
27
+ from ...utils.backbone_utils import verify_backbone_config_arguments
28
+ from ..auto import CONFIG_MAPPING
29
+ from ..rt_detr.modeling_rt_detr import (
30
+ RTDetrConvNormLayer,
31
+ RTDetrDecoder,
32
+ RTDetrDecoderLayer,
33
+ RTDetrDecoderOutput,
34
+ RTDetrEncoder,
35
+ RTDetrForObjectDetection,
36
+ RTDetrHybridEncoder,
37
+ RTDetrMLPPredictionHead,
38
+ RTDetrModel,
39
+ RTDetrPreTrainedModel,
40
+ RTDetrRepVggBlock,
41
+ inverse_sigmoid,
42
+ )
43
+ from ..rt_detr_v2.modeling_rt_detr_v2 import multi_scale_deformable_attention_v2
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ # TODO: Attribute map assignment logic should be fixed in modular
50
+ # as well as super() call parsing because otherwise we cannot re-write args after initialization
51
+ class DFineConfig(PretrainedConfig):
52
+ """
53
+ This is the configuration class to store the configuration of a [`DFineModel`]. It is used to instantiate a D-FINE
54
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
55
+ defaults will yield a similar configuration to that of D-FINE-X-COCO "[ustc-community/dfine-xlarge-coco"](https://huggingface.co/ustc-community/dfine-xlarge-coco").
56
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
57
+ documentation from [`PretrainedConfig`] for more information.
58
+
59
+ Args:
60
+ initializer_range (`float`, *optional*, defaults to 0.01):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ initializer_bias_prior_prob (`float`, *optional*):
63
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
64
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
65
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
66
+ The epsilon used by the layer normalization layers.
67
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
68
+ The epsilon used by the batch normalization layers.
69
+ backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
70
+ The configuration of the backbone model.
71
+ backbone (`str`, *optional*):
72
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
73
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
74
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
75
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
76
+ Whether to use pretrained weights for the backbone.
77
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
78
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
79
+ library.
80
+ freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
81
+ Whether to freeze the batch normalization layers in the backbone.
82
+ backbone_kwargs (`dict`, *optional*):
83
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
84
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
85
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
86
+ Dimension of the layers in hybrid encoder.
87
+ encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
88
+ Multi level features input for encoder.
89
+ feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
90
+ Strides used in each feature map.
91
+ encoder_layers (`int`, *optional*, defaults to 1):
92
+ Total of layers to be used by the encoder.
93
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
94
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
95
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
96
+ Number of attention heads for each attention layer in the Transformer encoder.
97
+ dropout (`float`, *optional*, defaults to 0.0):
98
+ The ratio for all dropout layers.
99
+ activation_dropout (`float`, *optional*, defaults to 0.0):
100
+ The dropout ratio for activations inside the fully connected layer.
101
+ encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
102
+ Indexes of the projected layers to be used in the encoder.
103
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
104
+ The temperature parameter used to create the positional encodings.
105
+ encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
106
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
107
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
108
+ activation_function (`str`, *optional*, defaults to `"silu"`):
109
+ The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
110
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
111
+ eval_size (`tuple[int, int]`, *optional*):
112
+ Height and width used to computes the effective height and width of the position embeddings after taking
113
+ into account the stride.
114
+ normalize_before (`bool`, *optional*, defaults to `False`):
115
+ Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
116
+ feed-forward modules.
117
+ hidden_expansion (`float`, *optional*, defaults to 1.0):
118
+ Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
119
+ d_model (`int`, *optional*, defaults to 256):
120
+ Dimension of the layers exclude hybrid encoder.
121
+ num_queries (`int`, *optional*, defaults to 300):
122
+ Number of object queries.
123
+ decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
124
+ Multi level features dimension for decoder
125
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
126
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
127
+ num_feature_levels (`int`, *optional*, defaults to 3):
128
+ The number of input feature levels.
129
+ decoder_n_points (`int`, *optional*, defaults to 4):
130
+ The number of sampled keys in each feature level for each attention head in the decoder.
131
+ decoder_layers (`int`, *optional*, defaults to 6):
132
+ Number of decoder layers.
133
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
134
+ Number of attention heads for each attention layer in the Transformer decoder.
135
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
136
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
137
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
138
+ attention_dropout (`float`, *optional*, defaults to 0.0):
139
+ The dropout ratio for the attention probabilities.
140
+ num_denoising (`int`, *optional*, defaults to 100):
141
+ The total number of denoising tasks or queries to be used for contrastive denoising.
142
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
143
+ The fraction of denoising labels to which random noise should be added.
144
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
145
+ Scale or magnitude of noise to be added to the bounding boxes.
146
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
147
+ Indicates whether the initial query embeddings for the decoder should be learned during training
148
+ anchor_image_size (`tuple[int, int]`, *optional*):
149
+ Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
150
+ with_box_refine (`bool`, *optional*, defaults to `True`):
151
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
152
+ based on the predictions from the previous layer.
153
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
154
+ Whether the architecture has an encoder decoder structure.
155
+ matcher_alpha (`float`, *optional*, defaults to 0.25):
156
+ Parameter alpha used by the Hungarian Matcher.
157
+ matcher_gamma (`float`, *optional*, defaults to 2.0):
158
+ Parameter gamma used by the Hungarian Matcher.
159
+ matcher_class_cost (`float`, *optional*, defaults to 2.0):
160
+ The relative weight of the class loss used by the Hungarian Matcher.
161
+ matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
162
+ The relative weight of the bounding box loss used by the Hungarian Matcher.
163
+ matcher_giou_cost (`float`, *optional*, defaults to 2.0):
164
+ The relative weight of the giou loss of used by the Hungarian Matcher.
165
+ use_focal_loss (`bool`, *optional*, defaults to `True`):
166
+ Parameter informing if focal focal should be used.
167
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
168
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
169
+ focal_loss_alpha (`float`, *optional*, defaults to 0.75):
170
+ Parameter alpha used to compute the focal loss.
171
+ focal_loss_gamma (`float`, *optional*, defaults to 2.0):
172
+ Parameter gamma used to compute the focal loss.
173
+ weight_loss_vfl (`float`, *optional*, defaults to 1.0):
174
+ Relative weight of the varifocal loss in the object detection loss.
175
+ weight_loss_bbox (`float`, *optional*, defaults to 5.0):
176
+ Relative weight of the L1 bounding box loss in the object detection loss.
177
+ weight_loss_giou (`float`, *optional*, defaults to 2.0):
178
+ Relative weight of the generalized IoU loss in the object detection loss.
179
+ weight_loss_fgl (`float`, *optional*, defaults to 0.15):
180
+ Relative weight of the fine-grained localization loss in the object detection loss.
181
+ weight_loss_ddf (`float`, *optional*, defaults to 1.5):
182
+ Relative weight of the decoupled distillation focal loss in the object detection loss.
183
+ eos_coefficient (`float`, *optional*, defaults to 0.0001):
184
+ Relative classification weight of the 'no-object' class in the object detection loss.
185
+ eval_idx (`int`, *optional*, defaults to -1):
186
+ Index of the decoder layer to use for evaluation. If negative, counts from the end
187
+ (e.g., -1 means use the last layer). This allows for early prediction in the decoder
188
+ stack while still training later layers.
189
+ layer_scale (`float`, *optional*, defaults to `1.0`):
190
+ Scaling factor for the hidden dimension in later decoder layers. Used to adjust the
191
+ model capacity after the evaluation layer.
192
+ max_num_bins (`int`, *optional*, defaults to 32):
193
+ Maximum number of bins for the distribution-guided bounding box refinement.
194
+ Higher values allow for more fine-grained localization but increase computation.
195
+ reg_scale (`float`, *optional*, defaults to 4.0):
196
+ Scale factor for the regression distribution. Controls the range and granularity
197
+ of the bounding box refinement process.
198
+ depth_mult (`float`, *optional*, defaults to 1.0):
199
+ Multiplier for the number of blocks in RepNCSPELAN4 layers. Used to scale the model's
200
+ depth while maintaining its architecture.
201
+ top_prob_values (`int`, *optional*, defaults to 4):
202
+ Number of top probability values to consider from each corner's distribution.
203
+ lqe_hidden_dim (`int`, *optional*, defaults to 64):
204
+ Hidden dimension size for the Location Quality Estimator (LQE) network.
205
+ lqe_layers (`int`, *optional*, defaults to 2):
206
+ Number of layers in the Location Quality Estimator MLP.
207
+ decoder_offset_scale (`float`, *optional*, defaults to 0.5):
208
+ Offset scale used in deformable attention.
209
+ decoder_method (`str`, *optional*, defaults to `"default"`):
210
+ The method to use for the decoder: `"default"` or `"discrete"`.
211
+ up (`float`, *optional*, defaults to 0.5):
212
+ Controls the upper bounds of the Weighting Function.
213
+ """
214
+
215
+ model_type = "d_fine"
216
+ layer_types = ["basic", "bottleneck"]
217
+ attribute_map = {
218
+ "hidden_size": "d_model",
219
+ "num_attention_heads": "encoder_attention_heads",
220
+ }
221
+
222
+ def __init__(
223
+ self,
224
+ initializer_range=0.01,
225
+ initializer_bias_prior_prob=None,
226
+ layer_norm_eps=1e-5,
227
+ batch_norm_eps=1e-5,
228
+ # backbone
229
+ backbone_config=None,
230
+ backbone=None,
231
+ use_pretrained_backbone=False,
232
+ use_timm_backbone=False,
233
+ freeze_backbone_batch_norms=True,
234
+ backbone_kwargs=None,
235
+ # encoder HybridEncoder
236
+ encoder_hidden_dim=256,
237
+ encoder_in_channels=[512, 1024, 2048],
238
+ feat_strides=[8, 16, 32],
239
+ encoder_layers=1,
240
+ encoder_ffn_dim=1024,
241
+ encoder_attention_heads=8,
242
+ dropout=0.0,
243
+ activation_dropout=0.0,
244
+ encode_proj_layers=[2],
245
+ positional_encoding_temperature=10000,
246
+ encoder_activation_function="gelu",
247
+ activation_function="silu",
248
+ eval_size=None,
249
+ normalize_before=False,
250
+ hidden_expansion=1.0,
251
+ # decoder DFineTransformer
252
+ d_model=256,
253
+ num_queries=300,
254
+ decoder_in_channels=[256, 256, 256],
255
+ decoder_ffn_dim=1024,
256
+ num_feature_levels=3,
257
+ decoder_n_points=4,
258
+ decoder_layers=6,
259
+ decoder_attention_heads=8,
260
+ decoder_activation_function="relu",
261
+ attention_dropout=0.0,
262
+ num_denoising=100,
263
+ label_noise_ratio=0.5,
264
+ box_noise_scale=1.0,
265
+ learn_initial_query=False,
266
+ anchor_image_size=None,
267
+ with_box_refine=True,
268
+ is_encoder_decoder=True,
269
+ # Loss
270
+ matcher_alpha=0.25,
271
+ matcher_gamma=2.0,
272
+ matcher_class_cost=2.0,
273
+ matcher_bbox_cost=5.0,
274
+ matcher_giou_cost=2.0,
275
+ use_focal_loss=True,
276
+ auxiliary_loss=True,
277
+ focal_loss_alpha=0.75,
278
+ focal_loss_gamma=2.0,
279
+ weight_loss_vfl=1.0,
280
+ weight_loss_bbox=5.0,
281
+ weight_loss_giou=2.0,
282
+ weight_loss_fgl=0.15,
283
+ weight_loss_ddf=1.5,
284
+ eos_coefficient=1e-4,
285
+ eval_idx=-1,
286
+ layer_scale=1,
287
+ max_num_bins=32,
288
+ reg_scale=4.0,
289
+ depth_mult=1.0,
290
+ top_prob_values=4,
291
+ lqe_hidden_dim=64,
292
+ lqe_layers=2,
293
+ decoder_offset_scale=0.5,
294
+ decoder_method="default",
295
+ up=0.5,
296
+ **kwargs,
297
+ ):
298
+ self.initializer_range = initializer_range
299
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
300
+ self.layer_norm_eps = layer_norm_eps
301
+ self.batch_norm_eps = batch_norm_eps
302
+ # backbone
303
+ if backbone_config is None and backbone is None:
304
+ logger.info(
305
+ "`backbone_config` and `backbone` are `None`. Initializing the config with the default `HGNet-V2` backbone."
306
+ )
307
+ backbone_model_type = "hgnet_v2"
308
+ config_class = CONFIG_MAPPING[backbone_model_type]
309
+ # this will map it to RTDetrResNetConfig
310
+ # note: we can instead create HGNetV2Config
311
+ # and we would need to create HGNetV2Backbone
312
+ backbone_config = config_class(
313
+ num_channels=3,
314
+ embedding_size=64,
315
+ hidden_sizes=[256, 512, 1024, 2048],
316
+ depths=[3, 4, 6, 3],
317
+ layer_type="bottleneck",
318
+ hidden_act="relu",
319
+ downsample_in_first_stage=False,
320
+ downsample_in_bottleneck=False,
321
+ out_features=None,
322
+ out_indices=[2, 3, 4],
323
+ )
324
+ elif isinstance(backbone_config, dict):
325
+ backbone_model_type = backbone_config.pop("model_type")
326
+ config_class = CONFIG_MAPPING[backbone_model_type]
327
+ backbone_config = config_class.from_dict(backbone_config)
328
+
329
+ verify_backbone_config_arguments(
330
+ use_timm_backbone=use_timm_backbone,
331
+ use_pretrained_backbone=use_pretrained_backbone,
332
+ backbone=backbone,
333
+ backbone_config=backbone_config,
334
+ backbone_kwargs=backbone_kwargs,
335
+ )
336
+
337
+ self.backbone_config = backbone_config
338
+ self.backbone = backbone
339
+ self.use_pretrained_backbone = use_pretrained_backbone
340
+ self.use_timm_backbone = use_timm_backbone
341
+ self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
342
+ self.backbone_kwargs = backbone_kwargs
343
+ # encoder
344
+ self.encoder_hidden_dim = encoder_hidden_dim
345
+ self.encoder_in_channels = encoder_in_channels
346
+ self.feat_strides = feat_strides
347
+ self.encoder_attention_heads = encoder_attention_heads
348
+ self.encoder_ffn_dim = encoder_ffn_dim
349
+ self.dropout = dropout
350
+ self.activation_dropout = activation_dropout
351
+ self.encode_proj_layers = encode_proj_layers
352
+ self.encoder_layers = encoder_layers
353
+ self.positional_encoding_temperature = positional_encoding_temperature
354
+ self.eval_size = eval_size
355
+ self.normalize_before = normalize_before
356
+ self.encoder_activation_function = encoder_activation_function
357
+ self.activation_function = activation_function
358
+ self.hidden_expansion = hidden_expansion
359
+ # decoder
360
+ self.d_model = d_model
361
+ self.num_queries = num_queries
362
+ self.decoder_ffn_dim = decoder_ffn_dim
363
+ self.decoder_in_channels = decoder_in_channels
364
+ self.num_feature_levels = num_feature_levels
365
+ self.decoder_n_points = decoder_n_points
366
+ self.decoder_layers = decoder_layers
367
+ self.decoder_attention_heads = decoder_attention_heads
368
+ self.decoder_activation_function = decoder_activation_function
369
+ self.attention_dropout = attention_dropout
370
+ self.num_denoising = num_denoising
371
+ self.label_noise_ratio = label_noise_ratio
372
+ self.box_noise_scale = box_noise_scale
373
+ self.learn_initial_query = learn_initial_query
374
+ self.anchor_image_size = anchor_image_size
375
+ self.auxiliary_loss = auxiliary_loss
376
+ self.with_box_refine = with_box_refine
377
+ # Loss
378
+ self.matcher_alpha = matcher_alpha
379
+ self.matcher_gamma = matcher_gamma
380
+ self.matcher_class_cost = matcher_class_cost
381
+ self.matcher_bbox_cost = matcher_bbox_cost
382
+ self.matcher_giou_cost = matcher_giou_cost
383
+ self.use_focal_loss = use_focal_loss
384
+ self.focal_loss_alpha = focal_loss_alpha
385
+ self.focal_loss_gamma = focal_loss_gamma
386
+ self.weight_loss_vfl = weight_loss_vfl
387
+ self.weight_loss_bbox = weight_loss_bbox
388
+ self.weight_loss_giou = weight_loss_giou
389
+ self.weight_loss_fgl = weight_loss_fgl
390
+ self.weight_loss_ddf = weight_loss_ddf
391
+ self.eos_coefficient = eos_coefficient
392
+ # add the new attributes with the given values or defaults
393
+ self.eval_idx = eval_idx
394
+ self.layer_scale = layer_scale
395
+ self.max_num_bins = max_num_bins
396
+ self.reg_scale = reg_scale
397
+ self.depth_mult = depth_mult
398
+ self.decoder_offset_scale = decoder_offset_scale
399
+ self.decoder_method = decoder_method
400
+ self.top_prob_values = top_prob_values
401
+ self.lqe_hidden_dim = lqe_hidden_dim
402
+ self.lqe_layers = lqe_layers
403
+ self.up = up
404
+
405
+ if isinstance(self.decoder_n_points, list):
406
+ if len(self.decoder_n_points) != self.num_feature_levels:
407
+ raise ValueError(
408
+ f"Length of decoder_n_points list ({len(self.decoder_n_points)}) must match num_feature_levels ({self.num_feature_levels})."
409
+ )
410
+
411
+ head_dim = self.d_model // self.decoder_attention_heads
412
+ if head_dim * self.decoder_attention_heads != self.d_model:
413
+ raise ValueError(
414
+ f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
415
+ )
416
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
417
+
418
+ @property
419
+ def num_attention_heads(self) -> int:
420
+ return self.encoder_attention_heads
421
+
422
+ @property
423
+ def hidden_size(self) -> int:
424
+ return self.d_model
425
+
426
+ @property
427
+ def sub_configs(self):
428
+ return (
429
+ {"backbone_config": type(self.backbone_config)}
430
+ if getattr(self, "backbone_config", None) is not None
431
+ else {}
432
+ )
433
+
434
+ @classmethod
435
+ def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
436
+ """Instantiate a [`DFineConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
437
+ configuration.
438
+
439
+ Args:
440
+ backbone_config ([`PretrainedConfig`]):
441
+ The backbone configuration.
442
+
443
+ Returns:
444
+ [`DFineConfig`]: An instance of a configuration object
445
+ """
446
+ return cls(
447
+ backbone_config=backbone_config,
448
+ **kwargs,
449
+ )
450
+
451
+
452
+ class DFineMultiscaleDeformableAttention(nn.Module):
453
+ def __init__(self, config: DFineConfig):
454
+ """
455
+ D-Fine version of multiscale deformable attention
456
+ """
457
+ super().__init__()
458
+ self.d_model = config.d_model
459
+ self.n_heads = config.decoder_attention_heads
460
+ self.n_levels = config.num_feature_levels
461
+ self.offset_scale = config.decoder_offset_scale
462
+ self.decoder_method = config.decoder_method
463
+ self.n_points = config.decoder_n_points
464
+
465
+ if isinstance(self.n_points, list):
466
+ num_points_list = self.n_points
467
+ else:
468
+ num_points_list = [self.n_points for _ in range(self.n_levels)]
469
+
470
+ self.num_points_list = num_points_list
471
+ num_points_scale = [1 / n for n in self.num_points_list for _ in range(n)]
472
+ self.register_buffer("num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32))
473
+
474
+ self.total_points = self.n_heads * sum(self.num_points_list)
475
+
476
+ self.sampling_offsets = nn.Linear(self.d_model, self.total_points * 2)
477
+ self.attention_weights = nn.Linear(self.d_model, self.total_points)
478
+
479
+ self.ms_deformable_attn_core = multi_scale_deformable_attention_v2
480
+
481
+ def forward(
482
+ self,
483
+ hidden_states: torch.Tensor,
484
+ attention_mask: Optional[torch.Tensor] = None,
485
+ reference_points=None,
486
+ encoder_hidden_states=None,
487
+ spatial_shapes=None,
488
+ spatial_shapes_list=None,
489
+ ) -> tuple[torch.Tensor, torch.Tensor]:
490
+ batch_size, num_queries, _ = hidden_states.shape
491
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
492
+
493
+ if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
494
+ raise ValueError(
495
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
496
+ )
497
+
498
+ # Reshape for multi-head attention
499
+ value = encoder_hidden_states.reshape(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
500
+ if attention_mask is not None:
501
+ value = value.masked_fill(~attention_mask[..., None], float(0))
502
+
503
+ sampling_offsets: torch.Tensor = self.sampling_offsets(hidden_states)
504
+ sampling_offsets = sampling_offsets.reshape(
505
+ batch_size, num_queries, self.n_heads, sum(self.num_points_list), 2
506
+ )
507
+
508
+ attention_weights = self.attention_weights(hidden_states).reshape(
509
+ batch_size, num_queries, self.n_heads, sum(self.num_points_list)
510
+ )
511
+ attention_weights = F.softmax(attention_weights, dim=-1)
512
+
513
+ if reference_points.shape[-1] == 2:
514
+ offset_normalizer = torch.tensor(spatial_shapes)
515
+ offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.n_levels, 1, 2)
516
+ sampling_locations = (
517
+ reference_points.reshape(batch_size, sequence_length, 1, self.n_levels, 1, 2)
518
+ + sampling_offsets / offset_normalizer
519
+ )
520
+ elif reference_points.shape[-1] == 4:
521
+ # reference_points [8, 480, None, 1, 4]
522
+ # sampling_offsets [8, 480, 8, 12, 2]
523
+ num_points_scale = self.num_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1)
524
+ offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
525
+ sampling_locations = reference_points[:, :, None, :, :2] + offset
526
+ else:
527
+ raise ValueError(
528
+ f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead."
529
+ )
530
+
531
+ output = self.ms_deformable_attn_core(
532
+ value,
533
+ spatial_shapes_list,
534
+ sampling_locations,
535
+ attention_weights,
536
+ self.num_points_list,
537
+ self.decoder_method,
538
+ )
539
+
540
+ return output, attention_weights
541
+
542
+
543
+ class DFineGate(nn.Module):
544
+ def __init__(self, d_model: int):
545
+ super().__init__()
546
+ self.gate = nn.Linear(2 * d_model, 2 * d_model)
547
+ self.norm = nn.LayerNorm(d_model)
548
+
549
+ def forward(self, second_residual: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
550
+ gate_input = torch.cat([second_residual, hidden_states], dim=-1)
551
+ gates = torch.sigmoid(self.gate(gate_input))
552
+ gate1, gate2 = gates.chunk(2, dim=-1)
553
+ hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states)
554
+ return hidden_states
555
+
556
+
557
+ class DFineDecoderLayer(RTDetrDecoderLayer):
558
+ def __init__(self, config: DFineConfig):
559
+ super().__init__(config)
560
+
561
+ # override the encoder attention module with d-fine version
562
+ self.encoder_attn = DFineMultiscaleDeformableAttention(config=config)
563
+ # gate
564
+ self.gateway = DFineGate(config.d_model)
565
+
566
+ del self.encoder_attn_layer_norm
567
+
568
+ def forward(
569
+ self,
570
+ hidden_states: torch.Tensor,
571
+ position_embeddings: Optional[torch.Tensor] = None,
572
+ reference_points=None,
573
+ spatial_shapes=None,
574
+ spatial_shapes_list=None,
575
+ encoder_hidden_states: Optional[torch.Tensor] = None,
576
+ encoder_attention_mask: Optional[torch.Tensor] = None,
577
+ output_attentions: Optional[bool] = False,
578
+ ) -> tuple[torch.Tensor, Any, Any]:
579
+ # Self Attention
580
+ hidden_states_2, self_attn_weights = self.self_attn(
581
+ hidden_states=hidden_states,
582
+ attention_mask=encoder_attention_mask,
583
+ position_embeddings=position_embeddings,
584
+ output_attentions=output_attentions,
585
+ )
586
+
587
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
588
+ hidden_states = hidden_states + hidden_states_2
589
+ hidden_states = self.self_attn_layer_norm(hidden_states)
590
+ residual = hidden_states
591
+
592
+ # Cross-Attention
593
+ cross_attn_weights = None
594
+ hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
595
+ hidden_states_2, cross_attn_weights = self.encoder_attn(
596
+ hidden_states=hidden_states,
597
+ encoder_hidden_states=encoder_hidden_states,
598
+ reference_points=reference_points,
599
+ spatial_shapes=spatial_shapes,
600
+ spatial_shapes_list=spatial_shapes_list,
601
+ )
602
+
603
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
604
+ hidden_states = self.gateway(residual, hidden_states_2)
605
+
606
+ # Fully Connected
607
+ hidden_states_2 = self.activation_fn(self.fc1(hidden_states))
608
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.activation_dropout, training=self.training)
609
+ hidden_states_2 = self.fc2(hidden_states_2)
610
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
611
+ hidden_states = hidden_states + hidden_states_2
612
+ hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504))
613
+
614
+ outputs = (hidden_states,)
615
+
616
+ if output_attentions:
617
+ outputs += (self_attn_weights, cross_attn_weights)
618
+
619
+ return outputs
620
+
621
+
622
+ class DFinePreTrainedModel(RTDetrPreTrainedModel):
623
+ def _init_weights(self, module):
624
+ # initialize linear layer bias value according to a given probability value.
625
+ if isinstance(module, (DFineForObjectDetection, DFineDecoder)):
626
+ if module.class_embed is not None:
627
+ for layer in module.class_embed:
628
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
629
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
630
+ nn.init.xavier_uniform_(layer.weight)
631
+ nn.init.constant_(layer.bias, bias)
632
+
633
+ if module.bbox_embed is not None:
634
+ for layer in module.bbox_embed:
635
+ nn.init.constant_(layer.layers[-1].weight, 0)
636
+ nn.init.constant_(layer.layers[-1].bias, 0)
637
+
638
+ if isinstance(module, DFineMultiscaleDeformableAttention):
639
+ nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
640
+ default_dtype = torch.get_default_dtype()
641
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
642
+ 2.0 * math.pi / module.n_heads
643
+ )
644
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
645
+ grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
646
+ grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1])
647
+ scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1)
648
+ grid_init *= scaling
649
+ with torch.no_grad():
650
+ module.sampling_offsets.bias.data[...] = grid_init.flatten()
651
+
652
+ nn.init.constant_(module.attention_weights.weight.data, 0.0)
653
+ nn.init.constant_(module.attention_weights.bias.data, 0.0)
654
+
655
+ if isinstance(module, DFineModel):
656
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
657
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
658
+ nn.init.xavier_uniform_(module.enc_score_head.weight)
659
+ nn.init.constant_(module.enc_score_head.bias, bias)
660
+
661
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
662
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
663
+ if module.bias is not None:
664
+ module.bias.data.zero_()
665
+
666
+ if isinstance(module, DFineGate):
667
+ bias = float(-math.log((1 - 0.5) / 0.5))
668
+ init.constant_(module.gate.bias, bias)
669
+ init.constant_(module.gate.weight, 0)
670
+
671
+ if isinstance(module, DFineLQE):
672
+ init.constant_(module.reg_conf.layers[-1].bias, 0)
673
+ init.constant_(module.reg_conf.layers[-1].weight, 0)
674
+
675
+ if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
676
+ nn.init.xavier_uniform_(module.weight_embedding.weight)
677
+ if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
678
+ nn.init.xavier_uniform_(module.denoising_class_embed.weight)
679
+
680
+
681
+ class DFineIntegral(nn.Module):
682
+ """
683
+ A static layer that calculates integral results from a distribution.
684
+
685
+ This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
686
+ where Pr(n) is the softmax probability vector representing the discrete
687
+ distribution, and W(n) is the non-uniform Weighting Function.
688
+
689
+ Args:
690
+ max_num_bins (int): Max number of the discrete bins. Default is 32.
691
+ It can be adjusted based on the dataset or task requirements.
692
+ """
693
+
694
+ def __init__(self, config: DFineConfig):
695
+ super().__init__()
696
+ self.max_num_bins = config.max_num_bins
697
+
698
+ def forward(self, pred_corners: torch.Tensor, project: torch.Tensor) -> torch.Tensor:
699
+ batch_size, num_queries, _ = pred_corners.shape
700
+ pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1)
701
+ pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4)
702
+ pred_corners = pred_corners.reshape(batch_size, num_queries, -1)
703
+ return pred_corners
704
+
705
+
706
+ class DFineDecoderOutput(RTDetrDecoderOutput):
707
+ pass
708
+
709
+
710
+ class DFineDecoder(RTDetrDecoder):
711
+ """
712
+ D-FINE Decoder implementing Fine-grained Distribution Refinement (FDR).
713
+
714
+ This decoder refines object detection predictions through iterative updates across multiple layers,
715
+ utilizing attention mechanisms, location quality estimators, and distribution refinement techniques
716
+ to improve bounding box accuracy and robustness.
717
+ """
718
+
719
+ def __init__(self, config: DFineConfig):
720
+ self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx
721
+ super().__init__(config=config)
722
+ self.reg_scale = nn.Parameter(torch.tensor([config.reg_scale]), requires_grad=False)
723
+ self.max_num_bins = config.max_num_bins
724
+ self.d_model = config.d_model
725
+ self.layer_scale = config.layer_scale
726
+ self.pre_bbox_head = DFineMLP(config.hidden_size, config.hidden_size, 4, 3)
727
+ self.integral = DFineIntegral(config)
728
+ self.num_head = config.decoder_attention_heads
729
+ self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False)
730
+ self.lqe_layers = nn.ModuleList([DFineLQE(config) for _ in range(config.decoder_layers)])
731
+ self.layers = nn.ModuleList(
732
+ [DFineDecoderLayer(config) for _ in range(config.decoder_layers)]
733
+ + [DFineDecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)]
734
+ )
735
+
736
+ def forward(
737
+ self,
738
+ encoder_hidden_states: torch.Tensor,
739
+ reference_points: torch.Tensor,
740
+ inputs_embeds: torch.Tensor,
741
+ spatial_shapes,
742
+ level_start_index=None,
743
+ spatial_shapes_list=None,
744
+ output_hidden_states=None,
745
+ encoder_attention_mask=None,
746
+ memory_mask=None,
747
+ output_attentions=None,
748
+ return_dict=None,
749
+ ) -> DFineDecoderOutput:
750
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
751
+ output_hidden_states = (
752
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
753
+ )
754
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
755
+
756
+ if inputs_embeds is not None:
757
+ hidden_states = inputs_embeds
758
+
759
+ # decoder layers
760
+ all_hidden_states = () if output_hidden_states else None
761
+ all_self_attns = () if output_attentions else None
762
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
763
+ intermediate = ()
764
+ intermediate_reference_points = ()
765
+ intermediate_logits = ()
766
+ intermediate_predicted_corners = ()
767
+ initial_reference_points = ()
768
+
769
+ output_detach = pred_corners_undetach = 0
770
+
771
+ project = weighting_function(self.max_num_bins, self.up, self.reg_scale)
772
+ ref_points_detach = F.sigmoid(reference_points)
773
+
774
+ for i, decoder_layer in enumerate(self.layers):
775
+ ref_points_input = ref_points_detach.unsqueeze(2)
776
+ query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10)
777
+
778
+ if output_hidden_states:
779
+ all_hidden_states += (hidden_states,)
780
+
781
+ output = decoder_layer(
782
+ hidden_states=hidden_states,
783
+ position_embeddings=query_pos_embed,
784
+ reference_points=ref_points_input,
785
+ spatial_shapes=spatial_shapes,
786
+ spatial_shapes_list=spatial_shapes_list,
787
+ encoder_hidden_states=encoder_hidden_states,
788
+ encoder_attention_mask=encoder_attention_mask,
789
+ output_attentions=output_attentions,
790
+ )
791
+
792
+ hidden_states = output[0]
793
+
794
+ if i == 0:
795
+ # Initial bounding box predictions with inverse sigmoid refinement
796
+ new_reference_points = F.sigmoid(self.pre_bbox_head(output[0]) + inverse_sigmoid(ref_points_detach))
797
+ ref_points_initial = new_reference_points.detach()
798
+
799
+ # Refine bounding box corners using FDR, integrating previous layer's corrections
800
+ if self.bbox_embed is not None:
801
+ pred_corners = self.bbox_embed[i](hidden_states + output_detach) + pred_corners_undetach
802
+ inter_ref_bbox = distance2bbox(
803
+ ref_points_initial, self.integral(pred_corners, project), self.reg_scale
804
+ )
805
+ pred_corners_undetach = pred_corners
806
+ ref_points_detach = inter_ref_bbox.detach()
807
+
808
+ output_detach = hidden_states.detach()
809
+
810
+ intermediate += (hidden_states,)
811
+
812
+ if self.class_embed is not None and (self.training or i == self.eval_idx):
813
+ scores = self.class_embed[i](hidden_states)
814
+ # Add initial logits and reference points with pre-bbox head
815
+ if i == 0:
816
+ intermediate_logits += (scores,)
817
+ intermediate_reference_points += (new_reference_points,)
818
+ # Lqe does not affect the performance here.
819
+ scores = self.lqe_layers[i](scores, pred_corners)
820
+ intermediate_logits += (scores,)
821
+ intermediate_reference_points += (inter_ref_bbox,)
822
+ initial_reference_points += (ref_points_initial,)
823
+ intermediate_predicted_corners += (pred_corners,)
824
+
825
+ if output_attentions:
826
+ all_self_attns += (output[1],)
827
+
828
+ if encoder_hidden_states is not None:
829
+ all_cross_attentions += (output[2],)
830
+
831
+ # Keep batch_size as first dimension
832
+ intermediate = torch.stack(intermediate)
833
+ if self.class_embed is not None and self.bbox_embed is not None:
834
+ intermediate_logits = torch.stack(intermediate_logits, dim=1)
835
+ intermediate_predicted_corners = torch.stack(intermediate_predicted_corners, dim=1)
836
+ initial_reference_points = torch.stack(initial_reference_points, dim=1)
837
+ intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
838
+
839
+ # add hidden states from the last decoder layer
840
+ if output_hidden_states:
841
+ all_hidden_states += (hidden_states,)
842
+
843
+ if not return_dict:
844
+ return tuple(
845
+ v
846
+ for v in [
847
+ hidden_states,
848
+ intermediate,
849
+ intermediate_logits,
850
+ intermediate_reference_points,
851
+ intermediate_predicted_corners,
852
+ initial_reference_points,
853
+ all_hidden_states,
854
+ all_self_attns,
855
+ all_cross_attentions,
856
+ ]
857
+ if v is not None
858
+ )
859
+
860
+ return DFineDecoderOutput(
861
+ last_hidden_state=hidden_states,
862
+ intermediate_hidden_states=intermediate,
863
+ intermediate_logits=intermediate_logits,
864
+ intermediate_reference_points=intermediate_reference_points,
865
+ intermediate_predicted_corners=intermediate_predicted_corners,
866
+ initial_reference_points=initial_reference_points,
867
+ hidden_states=all_hidden_states,
868
+ attentions=all_self_attns,
869
+ cross_attentions=all_cross_attentions,
870
+ )
871
+
872
+
873
+ class DFineModel(RTDetrModel):
874
+ def __init__(self, config: DFineConfig):
875
+ super().__init__(config)
876
+ del self.decoder_input_proj
877
+ self.encoder = DFineHybridEncoder(config=config)
878
+ num_backbone_outs = len(config.decoder_in_channels)
879
+ decoder_input_proj = []
880
+ in_channels = config.decoder_in_channels[-1]
881
+ for _ in range(num_backbone_outs):
882
+ if config.hidden_size == config.decoder_in_channels[-1]:
883
+ decoder_input_proj.append(nn.Identity())
884
+ else:
885
+ conv = nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False)
886
+ batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps)
887
+ decoder_input_proj.append(nn.Sequential(conv, batchnorm))
888
+ for _ in range(config.num_feature_levels - num_backbone_outs):
889
+ if config.hidden_size == config.decoder_in_channels[-1]:
890
+ decoder_input_proj.append(nn.Identity())
891
+ else:
892
+ conv = nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False)
893
+ batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps)
894
+ decoder_input_proj.append(nn.Sequential(conv, batchnorm))
895
+ self.decoder_input_proj = nn.ModuleList(decoder_input_proj)
896
+ self.decoder = DFineDecoder(config)
897
+
898
+
899
+ class DFineForObjectDetection(RTDetrForObjectDetection, DFinePreTrainedModel):
900
+ def __init__(self, config: DFineConfig):
901
+ DFinePreTrainedModel.__init__(self, config)
902
+
903
+ # D-FINE encoder-decoder model
904
+ self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx
905
+ self.model = DFineModel(config)
906
+ scaled_dim = round(config.layer_scale * config.hidden_size)
907
+ num_pred = config.decoder_layers
908
+ self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)])
909
+ self.bbox_embed = nn.ModuleList(
910
+ [
911
+ DFineMLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3)
912
+ for _ in range(self.eval_idx + 1)
913
+ ]
914
+ + [
915
+ DFineMLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3)
916
+ for _ in range(config.decoder_layers - self.eval_idx - 1)
917
+ ]
918
+ )
919
+
920
+ # here self.model.decoder.bbox_embed is null, but not self.bbox_embed
921
+ self.model.decoder.class_embed = self.class_embed
922
+ self.model.decoder.bbox_embed = self.bbox_embed
923
+
924
+ # Initialize weights and apply final processing
925
+ self.post_init()
926
+
927
+ def forward(**super_kwargs):
928
+ r"""
929
+ Example:
930
+
931
+ ```python
932
+ >>> import torch
933
+ >>> from transformers.image_utils import load_image
934
+ >>> from transformers import AutoImageProcessor, DFineForObjectDetection
935
+
936
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
937
+ >>> image = load_image(url)
938
+
939
+ >>> image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-xlarge-coco")
940
+ >>> model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-xlarge-coco")
941
+
942
+ >>> # prepare image for the model
943
+ >>> inputs = image_processor(images=image, return_tensors="pt")
944
+
945
+ >>> # forward pass
946
+ >>> outputs = model(**inputs)
947
+
948
+ >>> logits = outputs.logits
949
+ >>> list(logits.shape)
950
+ [1, 300, 80]
951
+
952
+ >>> boxes = outputs.pred_boxes
953
+ >>> list(boxes.shape)
954
+ [1, 300, 4]
955
+
956
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
957
+ >>> target_sizes = torch.tensor([image.size[::-1]])
958
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)
959
+ >>> result = results[0] # first image in batch
960
+
961
+ >>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
962
+ ... box = [round(i, 2) for i in box.tolist()]
963
+ ... print(
964
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
965
+ ... f"{round(score.item(), 3)} at location {box}"
966
+ ... )
967
+ Detected cat with confidence 0.958 at location [344.49, 23.4, 639.84, 374.27]
968
+ Detected cat with confidence 0.956 at location [11.71, 53.52, 316.64, 472.33]
969
+ Detected remote with confidence 0.947 at location [40.46, 73.7, 175.62, 117.57]
970
+ Detected sofa with confidence 0.918 at location [0.59, 1.88, 640.25, 474.74]
971
+ ```
972
+ """
973
+ super().forward(**super_kwargs)
974
+
975
+
976
+ def weighting_function(max_num_bins: int, up: torch.Tensor, reg_scale: int) -> torch.Tensor:
977
+ """
978
+ Generates the non-uniform Weighting Function W(n) for bounding box regression.
979
+
980
+ Args:
981
+ max_num_bins (int): Max number of the discrete bins.
982
+ up (Tensor): Controls upper bounds of the sequence,
983
+ where maximum offset is ±up * H / W.
984
+ reg_scale (float): Controls the curvature of the Weighting Function.
985
+ Larger values result in flatter weights near the central axis W(max_num_bins/2)=0
986
+ and steeper weights at both ends.
987
+ Returns:
988
+ Tensor: Sequence of Weighting Function.
989
+ """
990
+ upper_bound1 = abs(up[0]) * abs(reg_scale)
991
+ upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
992
+ step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2))
993
+ left_values = [-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)]
994
+ right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)]
995
+ values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2]
996
+ values = torch.cat(values, 0)
997
+ return values
998
+
999
+
1000
+ class DFineMLPPredictionHead(RTDetrMLPPredictionHead):
1001
+ pass
1002
+
1003
+
1004
+ def distance2bbox(points, distance: torch.Tensor, reg_scale: float) -> torch.Tensor:
1005
+ """
1006
+ Decodes edge-distances into bounding box coordinates.
1007
+
1008
+ Args:
1009
+ points (`torch.Tensor`):
1010
+ (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height]
1011
+ distance (`torch.Tensor`):
1012
+ (batch_size, num_boxes, 4) or (num_boxes, 4), representing distances from the point to the left, top, right, and bottom boundaries.
1013
+ reg_scale (`float`):
1014
+ Controls the curvature of the Weighting Function.
1015
+ Returns:
1016
+ `torch.Tensor`: Bounding boxes in (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height]
1017
+ """
1018
+ reg_scale = abs(reg_scale)
1019
+ top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale)
1020
+ top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale)
1021
+ bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale)
1022
+ bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale)
1023
+
1024
+ bboxes = torch.stack([top_left_x, top_left_y, bottom_right_x, bottom_right_y], -1)
1025
+
1026
+ return corners_to_center_format(bboxes)
1027
+
1028
+
1029
+ class DFineMLP(nn.Module):
1030
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"):
1031
+ super().__init__()
1032
+ self.num_layers = num_layers
1033
+ hidden_dims = [hidden_dim] * (num_layers - 1)
1034
+ input_dims = [input_dim] + hidden_dims
1035
+ output_dims = hidden_dims + [output_dim]
1036
+ self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims))
1037
+ self.act = ACT2CLS[act]()
1038
+
1039
+ def forward(self, stat_features: torch.Tensor) -> torch.Tensor:
1040
+ for i, layer in enumerate(self.layers):
1041
+ stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features)
1042
+ return stat_features
1043
+
1044
+
1045
+ class DFineLQE(nn.Module):
1046
+ def __init__(self, config: DFineConfig):
1047
+ super().__init__()
1048
+ self.top_prob_values = config.top_prob_values
1049
+ self.max_num_bins = config.max_num_bins
1050
+ self.reg_conf = DFineMLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers)
1051
+
1052
+ def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor:
1053
+ batch_size, length, _ = pred_corners.size()
1054
+ prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1)
1055
+ prob_topk, _ = prob.topk(self.top_prob_values, dim=-1)
1056
+ stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
1057
+ quality_score = self.reg_conf(stat.reshape(batch_size, length, -1))
1058
+ scores = scores + quality_score
1059
+ return scores
1060
+
1061
+
1062
+ class DFineConvNormLayer(RTDetrConvNormLayer):
1063
+ def __init__(
1064
+ self,
1065
+ config: DFineConfig,
1066
+ in_channels: int,
1067
+ out_channels: int,
1068
+ kernel_size: int,
1069
+ stride: int,
1070
+ groups: int = 1,
1071
+ padding: Optional[int] = None,
1072
+ activation: Optional[str] = None,
1073
+ ):
1074
+ super().__init__(config, in_channels, out_channels, kernel_size, stride, padding=None, activation=activation)
1075
+ self.conv = nn.Conv2d(
1076
+ in_channels,
1077
+ out_channels,
1078
+ kernel_size,
1079
+ stride,
1080
+ groups=groups,
1081
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
1082
+ bias=False,
1083
+ )
1084
+
1085
+
1086
+ class DFineRepVggBlock(RTDetrRepVggBlock):
1087
+ def __init__(self, config: DFineConfig, in_channels: int, out_channels: int):
1088
+ super().__init__(config)
1089
+ hidden_channels = in_channels
1090
+ self.conv1 = DFineConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1)
1091
+ self.conv2 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0)
1092
+
1093
+
1094
+ class DFineCSPRepLayer(nn.Module):
1095
+ """
1096
+ Cross Stage Partial (CSP) network layer with RepVGG blocks.
1097
+ """
1098
+
1099
+ def __init__(
1100
+ self, config: DFineConfig, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0
1101
+ ):
1102
+ super().__init__()
1103
+ in_channels = in_channels
1104
+ out_channels = out_channels
1105
+ activation = config.activation_function
1106
+
1107
+ hidden_channels = int(out_channels * expansion)
1108
+ self.conv1 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
1109
+ self.conv2 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
1110
+ self.bottlenecks = nn.ModuleList(
1111
+ [DFineRepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)]
1112
+ )
1113
+ if hidden_channels != out_channels:
1114
+ self.conv3 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
1115
+ else:
1116
+ self.conv3 = nn.Identity()
1117
+
1118
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
1119
+ hidden_state_1 = self.conv1(hidden_state)
1120
+ for bottleneck in self.bottlenecks:
1121
+ hidden_state_1 = bottleneck(hidden_state_1)
1122
+ hidden_state_2 = self.conv2(hidden_state)
1123
+ hidden_state_3 = self.conv3(hidden_state_1 + hidden_state_2)
1124
+ return hidden_state_3
1125
+
1126
+
1127
+ class DFineRepNCSPELAN4(nn.Module):
1128
+ def __init__(self, config: DFineConfig, act: str = "silu", numb_blocks: int = 3):
1129
+ super().__init__()
1130
+ conv1_dim = config.encoder_hidden_dim * 2
1131
+ conv2_dim = config.encoder_hidden_dim
1132
+ conv3_dim = config.encoder_hidden_dim * 2
1133
+ conv4_dim = round(config.hidden_expansion * config.encoder_hidden_dim // 2)
1134
+ self.conv_dim = conv3_dim // 2
1135
+ self.conv1 = DFineConvNormLayer(config, conv1_dim, conv3_dim, 1, 1, activation=act)
1136
+ self.csp_rep1 = DFineCSPRepLayer(config, conv3_dim // 2, conv4_dim, num_blocks=numb_blocks)
1137
+ self.conv2 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
1138
+ self.csp_rep2 = DFineCSPRepLayer(config, conv4_dim, conv4_dim, num_blocks=numb_blocks)
1139
+ self.conv3 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
1140
+ self.conv4 = DFineConvNormLayer(config, conv3_dim + (2 * conv4_dim), conv2_dim, 1, 1, activation=act)
1141
+
1142
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
1143
+ # Split initial features into two branches after first convolution
1144
+ split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1))
1145
+
1146
+ # Process branches sequentially
1147
+ branch1 = self.csp_rep1(split_features[-1])
1148
+ branch1 = self.conv2(branch1)
1149
+ branch2 = self.csp_rep2(branch1)
1150
+ branch2 = self.conv3(branch2)
1151
+
1152
+ split_features.extend([branch1, branch2])
1153
+ merged_features = torch.cat(split_features, 1)
1154
+ merged_features = self.conv4(merged_features)
1155
+ return merged_features
1156
+
1157
+
1158
+ class DFineSCDown(nn.Module):
1159
+ def __init__(self, config: DFineConfig, kernel_size: int, stride: int):
1160
+ super().__init__()
1161
+ self.conv1 = DFineConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1)
1162
+ self.conv2 = DFineConvNormLayer(
1163
+ config,
1164
+ config.encoder_hidden_dim,
1165
+ config.encoder_hidden_dim,
1166
+ kernel_size,
1167
+ stride,
1168
+ config.encoder_hidden_dim,
1169
+ )
1170
+
1171
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
1172
+ input_features = self.conv1(input_features)
1173
+ input_features = self.conv2(input_features)
1174
+ return input_features
1175
+
1176
+
1177
+ class DFineEncoder(RTDetrEncoder):
1178
+ pass
1179
+
1180
+
1181
+ class DFineHybridEncoder(RTDetrHybridEncoder):
1182
+ def __init__(self, config: DFineConfig):
1183
+ nn.Module.__init__(self)
1184
+ self.config = config
1185
+ self.in_channels = config.encoder_in_channels
1186
+ self.num_fpn_stages = len(self.in_channels) - 1
1187
+ self.feat_strides = config.feat_strides
1188
+ self.encoder_hidden_dim = config.encoder_hidden_dim
1189
+ self.encode_proj_layers = config.encode_proj_layers
1190
+ self.positional_encoding_temperature = config.positional_encoding_temperature
1191
+ self.eval_size = config.eval_size
1192
+ self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
1193
+ self.out_strides = self.feat_strides
1194
+
1195
+ # encoder transformer
1196
+ self.encoder = nn.ModuleList([DFineEncoder(config) for _ in range(len(self.encode_proj_layers))])
1197
+ # top-down fpn
1198
+ self.lateral_convs = nn.ModuleList()
1199
+ self.fpn_blocks = nn.ModuleList()
1200
+ for _ in range(len(self.in_channels) - 1, 0, -1):
1201
+ lateral_layer = DFineConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1)
1202
+ self.lateral_convs.append(lateral_layer)
1203
+ num_blocks = round(3 * config.depth_mult)
1204
+ fpn_layer = DFineRepNCSPELAN4(config, numb_blocks=num_blocks)
1205
+ self.fpn_blocks.append(fpn_layer)
1206
+
1207
+ # bottom-up pan
1208
+ self.downsample_convs = nn.ModuleList()
1209
+ self.pan_blocks = nn.ModuleList()
1210
+ for _ in range(len(self.in_channels) - 1):
1211
+ self.downsample_convs.append(DFineSCDown(config, 3, 2))
1212
+ num_blocks = round(3 * config.depth_mult)
1213
+ self.pan_blocks.append(DFineRepNCSPELAN4(config, numb_blocks=num_blocks))
1214
+
1215
+
1216
+ __all__ = [
1217
+ "DFineConfig",
1218
+ "DFineModel",
1219
+ "DFinePreTrainedModel",
1220
+ "DFineForObjectDetection",
1221
+ ]
phivenv/Lib/site-packages/transformers/models/depth_pro/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_depth_pro import *
22
+ from .image_processing_depth_pro import *
23
+ from .image_processing_depth_pro_fast import *
24
+ from .modeling_depth_pro import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (618 Bytes). View file
 
phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/configuration_depth_pro.cpython-39.pyc ADDED
Binary file (7.61 kB). View file
 
phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/image_processing_depth_pro.cpython-39.pyc ADDED
Binary file (14.6 kB). View file
 
phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/image_processing_depth_pro_fast.cpython-39.pyc ADDED
Binary file (4.79 kB). View file
 
phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/modeling_depth_pro.cpython-39.pyc ADDED
Binary file (27.7 kB). View file
 
phivenv/Lib/site-packages/transformers/models/depth_pro/configuration_depth_pro.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """DepthPro model configuration"""
16
+
17
+ from copy import deepcopy
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+ from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class DepthProConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`DepthProModel`]. It is used to instantiate a
30
+ DepthPro model according to the specified arguments, defining the model architecture. Instantiating a configuration
31
+ with the defaults will yield a similar configuration to that of the DepthPro
32
+ [apple/DepthPro](https://huggingface.co/apple/DepthPro) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+ Args:
38
+ fusion_hidden_size (`int`, *optional*, defaults to 256):
39
+ The number of channels before fusion.
40
+ patch_size (`int`, *optional*, defaults to 384):
41
+ The size (resolution) of each patch. This is also the image_size for backbone model.
42
+ initializer_range (`float`, *optional*, defaults to 0.02):
43
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
44
+ intermediate_hook_ids (`list[int]`, *optional*, defaults to `[11, 5]`):
45
+ Indices of the intermediate hidden states from the patch encoder to use for fusion.
46
+ intermediate_feature_dims (`list[int]`, *optional*, defaults to `[256, 256]`):
47
+ Hidden state dimensions during upsampling for each intermediate hidden state in `intermediate_hook_ids`.
48
+ scaled_images_ratios (`list[float]`, *optional*, defaults to `[0.25, 0.5, 1]`):
49
+ Ratios of scaled images to be used by the patch encoder.
50
+ scaled_images_overlap_ratios (`list[float]`, *optional*, defaults to `[0.0, 0.5, 0.25]`):
51
+ Overlap ratios between patches for each scaled image in `scaled_images_ratios`.
52
+ scaled_images_feature_dims (`list[int]`, *optional*, defaults to `[1024, 1024, 512]`):
53
+ Hidden state dimensions during upsampling for each scaled image in `scaled_images_ratios`.
54
+ merge_padding_value (`int`, *optional*, defaults to 3):
55
+ When merging smaller patches back to the image size, overlapping sections of this size are removed.
56
+ use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
57
+ Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
58
+ use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
59
+ Whether to use bias in the pre-activate residual units of the fusion blocks.
60
+ use_fov_model (`bool`, *optional*, defaults to `False`):
61
+ Whether to use `DepthProFovModel` to generate the field of view.
62
+ num_fov_head_layers (`int`, *optional*, defaults to 2):
63
+ Number of convolution layers in the head of `DepthProFovModel`.
64
+ image_model_config (`Union[dict[str, Any], PretrainedConfig]`, *optional*):
65
+ The configuration of the image encoder model, which is loaded using the [`AutoModel`] API.
66
+ By default, Dinov2 model is used as backbone.
67
+ patch_model_config (`Union[dict[str, Any], PretrainedConfig]`, *optional*):
68
+ The configuration of the patch encoder model, which is loaded using the [`AutoModel`] API.
69
+ By default, Dinov2 model is used as backbone.
70
+ fov_model_config (`Union[dict[str, Any], PretrainedConfig]`, *optional*):
71
+ The configuration of the fov encoder model, which is loaded using the [`AutoModel`] API.
72
+ By default, Dinov2 model is used as backbone.
73
+
74
+ Example:
75
+
76
+ ```python
77
+ >>> from transformers import DepthProConfig, DepthProModel
78
+
79
+ >>> # Initializing a DepthPro apple/DepthPro style configuration
80
+ >>> configuration = DepthProConfig()
81
+
82
+ >>> # Initializing a model (with random weights) from the apple/DepthPro style configuration
83
+ >>> model = DepthProModel(configuration)
84
+
85
+ >>> # Accessing the model configuration
86
+ >>> configuration = model.config
87
+ ```"""
88
+
89
+ model_type = "depth_pro"
90
+ sub_configs = {"image_model_config": AutoConfig, "patch_model_config": AutoConfig, "fov_model_config": AutoConfig}
91
+
92
+ def __init__(
93
+ self,
94
+ fusion_hidden_size=256,
95
+ patch_size=384,
96
+ initializer_range=0.02,
97
+ intermediate_hook_ids=[11, 5],
98
+ intermediate_feature_dims=[256, 256],
99
+ scaled_images_ratios=[0.25, 0.5, 1],
100
+ scaled_images_overlap_ratios=[0.0, 0.5, 0.25],
101
+ scaled_images_feature_dims=[1024, 1024, 512],
102
+ merge_padding_value=3,
103
+ use_batch_norm_in_fusion_residual=False,
104
+ use_bias_in_fusion_residual=True,
105
+ use_fov_model=False,
106
+ num_fov_head_layers=2,
107
+ image_model_config=None,
108
+ patch_model_config=None,
109
+ fov_model_config=None,
110
+ **kwargs,
111
+ ):
112
+ super().__init__(**kwargs)
113
+
114
+ # scaled_images_ratios is sorted
115
+ if scaled_images_ratios != sorted(scaled_images_ratios):
116
+ raise ValueError(
117
+ f"Values in scaled_images_ratios={scaled_images_ratios} should be sorted from low to high"
118
+ )
119
+
120
+ # scaled_images_ratios, scaled_images_overlap_ratios, scaled_images_feature_dims should be consistent
121
+ if not (len(scaled_images_ratios) == len(scaled_images_overlap_ratios) == len(scaled_images_feature_dims)):
122
+ raise ValueError(
123
+ f"len(scaled_images_ratios)={len(scaled_images_ratios)} and "
124
+ f"len(scaled_images_overlap_ratios)={len(scaled_images_overlap_ratios)} and "
125
+ f"len(scaled_images_feature_dims)={len(scaled_images_feature_dims)}, "
126
+ f"should match in config."
127
+ )
128
+
129
+ # intermediate_hook_ids, intermediate_feature_dims should be consistent
130
+ if not (len(intermediate_hook_ids) == len(intermediate_feature_dims)):
131
+ raise ValueError(
132
+ f"len(intermediate_hook_ids)={len(intermediate_hook_ids)} and "
133
+ f"len(intermediate_feature_dims)={len(intermediate_feature_dims)}, "
134
+ f"should match in config."
135
+ )
136
+
137
+ # fusion_hidden_size should be consistent with num_fov_head_layers
138
+ if fusion_hidden_size // 2**num_fov_head_layers == 0:
139
+ raise ValueError(
140
+ f"fusion_hidden_size={fusion_hidden_size} should be consistent with num_fov_head_layers={num_fov_head_layers} "
141
+ "i.e fusion_hidden_size // 2**num_fov_head_layers > 0"
142
+ )
143
+
144
+ self.fusion_hidden_size = fusion_hidden_size
145
+ self.patch_size = patch_size
146
+ self.initializer_range = initializer_range
147
+ self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
148
+ self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
149
+ self.use_fov_model = use_fov_model
150
+ self.num_fov_head_layers = num_fov_head_layers
151
+ self.intermediate_hook_ids = intermediate_hook_ids
152
+ self.intermediate_feature_dims = intermediate_feature_dims
153
+ self.scaled_images_ratios = scaled_images_ratios
154
+ self.scaled_images_overlap_ratios = scaled_images_overlap_ratios
155
+ self.scaled_images_feature_dims = scaled_images_feature_dims
156
+ self.merge_padding_value = merge_padding_value
157
+ self.image_model_config = image_model_config
158
+ self.patch_model_config = patch_model_config
159
+ self.fov_model_config = fov_model_config
160
+
161
+ for sub_config_key in self.sub_configs:
162
+ sub_config = getattr(self, sub_config_key)
163
+
164
+ if sub_config is None:
165
+ sub_config = CONFIG_MAPPING["dinov2"](image_size=patch_size)
166
+ logger.info(
167
+ f"`{sub_config_key}` is `None`. Initializing `{sub_config_key}` with the `Dinov2Config` "
168
+ f"with default values except `{sub_config_key}.image_size` is set to `config.patch_size`."
169
+ )
170
+ elif isinstance(sub_config, dict):
171
+ sub_config = deepcopy(sub_config)
172
+ if "model_type" not in sub_config:
173
+ raise KeyError(
174
+ f"The `model_type` key is missing in the `{sub_config_key}` dictionary. Please provide the model type."
175
+ )
176
+ elif sub_config["model_type"] not in CONFIG_MAPPING:
177
+ raise ValueError(
178
+ f"The model type `{sub_config['model_type']}` in `{sub_config_key}` is not supported. Please provide a valid model type."
179
+ )
180
+ image_size = sub_config.get("image_size")
181
+ if image_size != patch_size:
182
+ logger.info(
183
+ f"The `image_size` in `{sub_config_key}` is set to `{image_size}`, "
184
+ f"but it does not match the required `patch_size` of `{patch_size}`. "
185
+ f"Updating `image_size` to `{patch_size}` for consistency. "
186
+ f"Ensure that `image_size` aligns with `patch_size` in the configuration."
187
+ )
188
+ sub_config.update({"image_size": patch_size})
189
+ sub_config = CONFIG_MAPPING[sub_config["model_type"]](**sub_config)
190
+ elif isinstance(sub_config, PretrainedConfig):
191
+ sub_config = sub_config
192
+ image_size = getattr(sub_config, "image_size", None)
193
+ if image_size != patch_size:
194
+ raise ValueError(
195
+ f"`config.{sub_config_key}.image_size={image_size}` should match `config.patch_size={patch_size}`."
196
+ )
197
+ else:
198
+ raise TypeError(
199
+ f"Invalid type for `sub_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(sub_config)}."
200
+ )
201
+
202
+ setattr(self, sub_config_key, sub_config)
203
+
204
+
205
+ __all__ = ["DepthProConfig"]
phivenv/Lib/site-packages/transformers/models/depth_pro/image_processing_depth_pro.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for DepthPro."""
16
+
17
+ from typing import TYPE_CHECKING, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...utils.import_utils import requires
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from .modeling_depth_pro import DepthProDepthEstimatorOutput
26
+
27
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
28
+ from ...image_transforms import to_channel_dimension_format
29
+ from ...image_utils import (
30
+ IMAGENET_STANDARD_MEAN,
31
+ IMAGENET_STANDARD_STD,
32
+ ChannelDimension,
33
+ ImageInput,
34
+ PILImageResampling,
35
+ infer_channel_dimension_format,
36
+ is_scaled_image,
37
+ is_torch_available,
38
+ make_list_of_images,
39
+ to_numpy_array,
40
+ valid_images,
41
+ )
42
+ from ...utils import TensorType, filter_out_non_signature_kwargs, is_torchvision_available, logging, requires_backends
43
+
44
+
45
+ if is_torch_available():
46
+ import torch
47
+
48
+ if is_torchvision_available():
49
+ from ...image_utils import pil_torch_interpolation_mapping
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ @requires(backends=("torchvision", "torch"))
56
+ class DepthProImageProcessor(BaseImageProcessor):
57
+ r"""
58
+ Constructs a DepthPro image processor.
59
+
60
+ Args:
61
+ do_resize (`bool`, *optional*, defaults to `True`):
62
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
63
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
64
+ size (`dict`, *optional*, defaults to `{"height": 1536, "width": 1536}`):
65
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
66
+ method.
67
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
68
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
69
+ `preprocess` method.
70
+ do_rescale (`bool`, *optional*, defaults to `True`):
71
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
72
+ parameter in the `preprocess` method.
73
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
74
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
75
+ `preprocess` method.
76
+ do_normalize (`bool`, *optional*, defaults to `True`):
77
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
78
+ method.
79
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
80
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
81
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
82
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
83
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
84
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
85
+ """
86
+
87
+ model_input_names = ["pixel_values"]
88
+
89
+ def __init__(
90
+ self,
91
+ do_resize: bool = True,
92
+ size: Optional[dict[str, int]] = None,
93
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
94
+ do_rescale: bool = True,
95
+ rescale_factor: Union[int, float] = 1 / 255,
96
+ do_normalize: bool = True,
97
+ image_mean: Optional[Union[float, list[float]]] = None,
98
+ image_std: Optional[Union[float, list[float]]] = None,
99
+ **kwargs,
100
+ ):
101
+ super().__init__(**kwargs)
102
+ size = size if size is not None else {"height": 1536, "width": 1536}
103
+ size = get_size_dict(size)
104
+ self.do_resize = do_resize
105
+ self.do_rescale = do_rescale
106
+ self.do_normalize = do_normalize
107
+ self.size = size
108
+ self.resample = resample
109
+ self.rescale_factor = rescale_factor
110
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
111
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
112
+
113
+ def resize(
114
+ self,
115
+ image: np.ndarray,
116
+ size: dict[str, int],
117
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
118
+ data_format: Optional[Union[str, ChannelDimension]] = None,
119
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
120
+ **kwargs,
121
+ ) -> np.ndarray:
122
+ """
123
+ Resize an image to `(size["height"], size["width"])`.
124
+
125
+ Args:
126
+ image (`np.ndarray`):
127
+ Image to resize.
128
+ size (`dict[str, int]`):
129
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
130
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
131
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
132
+ data_format (`ChannelDimension` or `str`, *optional*):
133
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
134
+ image is used. Can be one of:
135
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
136
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
137
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
138
+ input_data_format (`ChannelDimension` or `str`, *optional*):
139
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
140
+ from the input image. Can be one of:
141
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
142
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
143
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
144
+
145
+ Returns:
146
+ `np.ndarray`: The resized images.
147
+ """
148
+ requires_backends(self, "torch")
149
+
150
+ size = get_size_dict(size)
151
+ if "height" not in size or "width" not in size:
152
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
153
+ output_size = (size["height"], size["width"])
154
+
155
+ # we use torch interpolation instead of image.resize because DepthProImageProcessor
156
+ # rescales, then normalizes, which may cause some values to become negative, before resizing the image.
157
+ # image.resize expects all values to be in range [0, 1] or [0, 255] and throws an exception otherwise,
158
+ # however pytorch interpolation works with negative values.
159
+ # relevant issue here: https://github.com/huggingface/transformers/issues/34920
160
+ # input should be (B, C, H, W)
161
+ image_tensor = torch.from_numpy(image).unsqueeze(0)
162
+ resized_image = torch.nn.functional.interpolate(
163
+ input=image_tensor,
164
+ size=output_size,
165
+ mode=pil_torch_interpolation_mapping[resample].value,
166
+ )
167
+ resized_image = resized_image.squeeze(0).numpy()
168
+ return resized_image
169
+
170
+ def _validate_input_arguments(
171
+ self,
172
+ do_resize: bool,
173
+ size: dict[str, int],
174
+ resample: PILImageResampling,
175
+ do_rescale: bool,
176
+ rescale_factor: float,
177
+ do_normalize: bool,
178
+ image_mean: Union[float, list[float]],
179
+ image_std: Union[float, list[float]],
180
+ data_format: Union[str, ChannelDimension],
181
+ ):
182
+ if do_resize and None in (size, resample):
183
+ raise ValueError("Size and resample must be specified if do_resize is True.")
184
+
185
+ if do_rescale and rescale_factor is None:
186
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
187
+
188
+ if do_normalize and None in (image_mean, image_std):
189
+ raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
190
+
191
+ @filter_out_non_signature_kwargs()
192
+ def preprocess(
193
+ self,
194
+ images: ImageInput,
195
+ do_resize: Optional[bool] = None,
196
+ size: Optional[dict[str, int]] = None,
197
+ resample: Optional[PILImageResampling] = None,
198
+ do_rescale: Optional[bool] = None,
199
+ rescale_factor: Optional[float] = None,
200
+ do_normalize: Optional[bool] = None,
201
+ image_mean: Optional[Union[float, list[float]]] = None,
202
+ image_std: Optional[Union[float, list[float]]] = None,
203
+ return_tensors: Optional[Union[str, TensorType]] = None,
204
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
205
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
206
+ ):
207
+ """
208
+ Preprocess an image or batch of images.
209
+
210
+ Args:
211
+ images (`ImageInput`):
212
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
213
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
214
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
215
+ Whether to resize the image.
216
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
217
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
218
+ resizing.
219
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
220
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
221
+ an effect if `do_resize` is set to `True`.
222
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
223
+ Whether to rescale the image values between [0 - 1].
224
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
225
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
226
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
227
+ Whether to normalize the image.
228
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
229
+ Image mean to use if `do_normalize` is set to `True`.
230
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
231
+ Image standard deviation to use if `do_normalize` is set to `True`.
232
+ return_tensors (`str` or `TensorType`, *optional*):
233
+ The type of tensors to return. Can be one of:
234
+ - Unset: Return a list of `np.ndarray`.
235
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
236
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
237
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
238
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
239
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
240
+ The channel dimension format for the output image. Can be one of:
241
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
242
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
243
+ - Unset: Use the channel dimension format of the input image.
244
+ input_data_format (`ChannelDimension` or `str`, *optional*):
245
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
246
+ from the input image. Can be one of:
247
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
248
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
249
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
250
+ """
251
+ do_resize = do_resize if do_resize is not None else self.do_resize
252
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
253
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
254
+ resample = resample if resample is not None else self.resample
255
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
256
+ image_mean = image_mean if image_mean is not None else self.image_mean
257
+ image_std = image_std if image_std is not None else self.image_std
258
+
259
+ size = size if size is not None else self.size
260
+
261
+ images = make_list_of_images(images)
262
+
263
+ if not valid_images(images):
264
+ raise ValueError(
265
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
266
+ "torch.Tensor, tf.Tensor or jax.ndarray."
267
+ )
268
+ self._validate_input_arguments(
269
+ do_resize=do_resize,
270
+ size=size,
271
+ resample=resample,
272
+ do_rescale=do_rescale,
273
+ rescale_factor=rescale_factor,
274
+ do_normalize=do_normalize,
275
+ image_mean=image_mean,
276
+ image_std=image_std,
277
+ data_format=data_format,
278
+ )
279
+
280
+ # All transformations expect numpy arrays.
281
+ images = [to_numpy_array(image) for image in images]
282
+
283
+ if is_scaled_image(images[0]) and do_rescale:
284
+ logger.warning_once(
285
+ "It looks like you are trying to rescale already rescaled images. If the input"
286
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
287
+ )
288
+
289
+ if input_data_format is None:
290
+ # We assume that all images have the same channel dimension format.
291
+ input_data_format = infer_channel_dimension_format(images[0])
292
+
293
+ all_images = []
294
+ for image in images:
295
+ if do_rescale:
296
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
297
+
298
+ if do_normalize:
299
+ image = self.normalize(
300
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
301
+ )
302
+
303
+ # depth-pro rescales and normalizes the image before resizing it
304
+ # uses torch interpolation which requires ChannelDimension.FIRST
305
+ if do_resize:
306
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
307
+ image = self.resize(image=image, size=size, resample=resample)
308
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=ChannelDimension.FIRST)
309
+ else:
310
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
311
+
312
+ all_images.append(image)
313
+
314
+ data = {"pixel_values": all_images}
315
+ return BatchFeature(data=data, tensor_type=return_tensors)
316
+
317
+ def post_process_depth_estimation(
318
+ self,
319
+ outputs: "DepthProDepthEstimatorOutput",
320
+ target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None,
321
+ ) -> list[dict[str, TensorType]]:
322
+ """
323
+ Post-processes the raw depth predictions from the model to generate
324
+ final depth predictions which is caliberated using the field of view if provided
325
+ and resized to specified target sizes if provided.
326
+
327
+ Args:
328
+ outputs ([`DepthProDepthEstimatorOutput`]):
329
+ Raw outputs of the model.
330
+ target_sizes (`Optional[Union[TensorType, list[tuple[int, int]], None]]`, *optional*, defaults to `None`):
331
+ Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
332
+ or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
333
+ is performed.
334
+
335
+ Returns:
336
+ `list[dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
337
+ predictions, and field of view (degrees) and focal length (pixels) if `field_of_view` is given in `outputs`.
338
+
339
+ Raises:
340
+ `ValueError`:
341
+ If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
342
+ """
343
+ requires_backends(self, "torch")
344
+
345
+ predicted_depth = outputs.predicted_depth
346
+ fov = outputs.field_of_view
347
+
348
+ batch_size = len(predicted_depth)
349
+
350
+ if target_sizes is not None and batch_size != len(target_sizes):
351
+ raise ValueError(
352
+ "Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
353
+ )
354
+
355
+ results = []
356
+ fov = [None] * batch_size if fov is None else fov
357
+ target_sizes = [None] * batch_size if target_sizes is None else target_sizes
358
+ for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
359
+ focal_length = None
360
+ if target_size is not None:
361
+ # scale image w.r.t fov
362
+ if fov_value is not None:
363
+ width = target_size[1]
364
+ focal_length = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
365
+ depth = depth * width / focal_length
366
+
367
+ # interpolate
368
+ depth = torch.nn.functional.interpolate(
369
+ # input should be (B, C, H, W)
370
+ input=depth.unsqueeze(0).unsqueeze(1),
371
+ size=target_size,
372
+ mode=pil_torch_interpolation_mapping[self.resample].value,
373
+ ).squeeze()
374
+
375
+ # inverse the depth
376
+ depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
377
+
378
+ results.append(
379
+ {
380
+ "predicted_depth": depth,
381
+ "field_of_view": fov_value,
382
+ "focal_length": focal_length,
383
+ }
384
+ )
385
+
386
+ return results
387
+
388
+
389
+ __all__ = ["DepthProImageProcessor"]
phivenv/Lib/site-packages/transformers/models/depth_pro/image_processing_depth_pro_fast.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for DepthPro."""
16
+
17
+ from typing import TYPE_CHECKING, Optional, Union
18
+
19
+ from ...image_processing_base import BatchFeature
20
+ from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
21
+ from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling, SizeDict
22
+ from ...utils import (
23
+ TensorType,
24
+ auto_docstring,
25
+ is_torch_available,
26
+ is_torchvision_available,
27
+ is_torchvision_v2_available,
28
+ logging,
29
+ requires_backends,
30
+ )
31
+ from ...utils.import_utils import requires
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from .modeling_depth_pro import DepthProDepthEstimatorOutput
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ if is_torch_available():
41
+ import torch
42
+
43
+
44
+ if is_torchvision_available():
45
+ from ...image_utils import pil_torch_interpolation_mapping
46
+
47
+ if is_torchvision_v2_available():
48
+ from torchvision.transforms.v2 import functional as F
49
+ else:
50
+ from torchvision.transforms import functional as F
51
+
52
+
53
+ @auto_docstring
54
+ @requires(backends=("torchvision", "torch"))
55
+ class DepthProImageProcessorFast(BaseImageProcessorFast):
56
+ resample = PILImageResampling.BILINEAR
57
+ image_mean = IMAGENET_STANDARD_MEAN
58
+ image_std = IMAGENET_STANDARD_STD
59
+ size = {"height": 1536, "width": 1536}
60
+ do_resize = True
61
+ do_rescale = True
62
+ do_normalize = True
63
+
64
+ # DepthPro resizes image after rescaling and normalizing,
65
+ # which makes it different from BaseImageProcessorFast._preprocess
66
+ def _preprocess(
67
+ self,
68
+ images: list["torch.Tensor"],
69
+ do_resize: bool,
70
+ size: SizeDict,
71
+ interpolation: Optional["F.InterpolationMode"],
72
+ do_center_crop: bool,
73
+ crop_size: SizeDict,
74
+ do_rescale: bool,
75
+ rescale_factor: float,
76
+ do_normalize: bool,
77
+ image_mean: Optional[Union[float, list[float]]],
78
+ image_std: Optional[Union[float, list[float]]],
79
+ disable_grouping: Optional[bool],
80
+ return_tensors: Optional[Union[str, TensorType]],
81
+ ) -> BatchFeature:
82
+ # Group images by size for batched scaling
83
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
84
+ processed_images_grouped = {}
85
+ for shape, stacked_images in grouped_images.items():
86
+ # Fused rescale and normalize
87
+ stacked_images = self.rescale_and_normalize(
88
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
89
+ )
90
+ if do_resize:
91
+ stacked_images = self.resize(
92
+ image=stacked_images,
93
+ size=size,
94
+ interpolation=interpolation,
95
+ antialias=False,
96
+ )
97
+ processed_images_grouped[shape] = stacked_images
98
+
99
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
100
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
101
+
102
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
103
+
104
+ # Copied from transformers.models.depth_pro.image_processing_depth_pro.DepthProImageProcessor.post_process_depth_estimation
105
+ def post_process_depth_estimation(
106
+ self,
107
+ outputs: "DepthProDepthEstimatorOutput",
108
+ target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None,
109
+ ) -> list[dict[str, TensorType]]:
110
+ """
111
+ Post-processes the raw depth predictions from the model to generate
112
+ final depth predictions which is caliberated using the field of view if provided
113
+ and resized to specified target sizes if provided.
114
+
115
+ Args:
116
+ outputs ([`DepthProDepthEstimatorOutput`]):
117
+ Raw outputs of the model.
118
+ target_sizes (`Optional[Union[TensorType, list[tuple[int, int]], None]]`, *optional*, defaults to `None`):
119
+ Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
120
+ or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
121
+ is performed.
122
+
123
+ Returns:
124
+ `list[dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
125
+ predictions, and field of view (degrees) and focal length (pixels) if `field_of_view` is given in `outputs`.
126
+
127
+ Raises:
128
+ `ValueError`:
129
+ If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
130
+ """
131
+ requires_backends(self, "torch")
132
+
133
+ predicted_depth = outputs.predicted_depth
134
+ fov = outputs.field_of_view
135
+
136
+ batch_size = len(predicted_depth)
137
+
138
+ if target_sizes is not None and batch_size != len(target_sizes):
139
+ raise ValueError(
140
+ "Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
141
+ )
142
+
143
+ results = []
144
+ fov = [None] * batch_size if fov is None else fov
145
+ target_sizes = [None] * batch_size if target_sizes is None else target_sizes
146
+ for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
147
+ focal_length = None
148
+ if target_size is not None:
149
+ # scale image w.r.t fov
150
+ if fov_value is not None:
151
+ width = target_size[1]
152
+ focal_length = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
153
+ depth = depth * width / focal_length
154
+
155
+ # interpolate
156
+ depth = torch.nn.functional.interpolate(
157
+ # input should be (B, C, H, W)
158
+ input=depth.unsqueeze(0).unsqueeze(1),
159
+ size=target_size,
160
+ mode=pil_torch_interpolation_mapping[self.resample].value,
161
+ ).squeeze()
162
+
163
+ # inverse the depth
164
+ depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
165
+
166
+ results.append(
167
+ {
168
+ "predicted_depth": depth,
169
+ "field_of_view": fov_value,
170
+ "focal_length": focal_length,
171
+ }
172
+ )
173
+
174
+ return results
175
+
176
+
177
+ __all__ = ["DepthProImageProcessorFast"]
phivenv/Lib/site-packages/transformers/models/depth_pro/modeling_depth_pro.py ADDED
@@ -0,0 +1,1132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Apple Research Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DepthPro model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch import nn
24
+
25
+ from ...modeling_utils import PreTrainedModel
26
+ from ...utils import ModelOutput, auto_docstring, logging, torch_int
27
+ from ..auto import AutoModel
28
+ from .configuration_depth_pro import DepthProConfig
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ @dataclass
35
+ @auto_docstring(
36
+ custom_intro="""
37
+ Base class for DepthPro's outputs.
38
+ """
39
+ )
40
+ class DepthProOutput(ModelOutput):
41
+ r"""
42
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`):
43
+ Sequence of hidden-states at the output of the last layer of the model.
44
+ features (`Union[torch.FloatTensor, List[torch.FloatTensor]]`, *optional*):
45
+ Features from encoders. Can be a single feature or a list of features.
46
+ """
47
+
48
+ last_hidden_state: Optional[torch.FloatTensor] = None
49
+ features: Union[torch.FloatTensor, list[torch.FloatTensor]] = None
50
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
51
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
52
+
53
+
54
+ @dataclass
55
+ @auto_docstring(
56
+ custom_intro="""
57
+ Base class for DepthProForDepthEstimation's output.
58
+ """
59
+ )
60
+ class DepthProDepthEstimatorOutput(ModelOutput):
61
+ r"""
62
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
63
+ Classification (or regression if config.num_labels==1) loss.
64
+ field_of_view (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned when `use_fov_model` is provided):
65
+ Field of View Scaler.
66
+ """
67
+
68
+ loss: Optional[torch.FloatTensor] = None
69
+ predicted_depth: Optional[torch.FloatTensor] = None
70
+ field_of_view: Optional[torch.FloatTensor] = None
71
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
72
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
73
+
74
+
75
+ def split_to_patches(pixel_values: torch.Tensor, patch_size: int, overlap_ratio: float) -> torch.Tensor:
76
+ """Creates Patches from Batch."""
77
+ batch_size, num_channels, height, width = pixel_values.shape
78
+
79
+ if height == width == patch_size:
80
+ # create patches only if scaled image is not already equal to patch size
81
+ return pixel_values
82
+
83
+ stride = torch_int(patch_size * (1 - overlap_ratio))
84
+
85
+ patches = F.unfold(pixel_values, kernel_size=(patch_size, patch_size), stride=(stride, stride))
86
+ patches = patches.permute(2, 0, 1)
87
+ patches = patches.reshape(-1, num_channels, patch_size, patch_size)
88
+
89
+ return patches
90
+
91
+
92
+ def reshape_features(hidden_states: torch.Tensor) -> torch.Tensor:
93
+ """Discard class token and reshape 1D feature map to a 2D grid."""
94
+ n_samples, seq_len, hidden_size = hidden_states.shape
95
+ size = torch_int(seq_len**0.5)
96
+
97
+ hidden_states = hidden_states[:, -(size**2) :, :] # remove special tokens if there are any
98
+ hidden_states = hidden_states.reshape(n_samples, size, size, hidden_size)
99
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
100
+
101
+ return hidden_states
102
+
103
+
104
+ def merge_patches(patches: torch.Tensor, batch_size: int, padding: int) -> torch.Tensor:
105
+ """Merges smaller patches into image-like feature map."""
106
+ n_patches, hidden_size, out_size, out_size = patches.shape
107
+ n_patches_per_batch = n_patches // batch_size
108
+ sqrt_n_patches_per_batch = torch_int(n_patches_per_batch**0.5)
109
+ new_out_size = sqrt_n_patches_per_batch * out_size
110
+
111
+ if n_patches == batch_size:
112
+ # merge only if the patches were created from scaled image
113
+ # patches are not created when scaled image size is equal to patch size
114
+ return patches
115
+
116
+ if n_patches_per_batch < 4:
117
+ # for each batch, at least 4 small patches are required to
118
+ # recreate a large square patch from merging them and later padding is applied
119
+ # 3 x (8x8) patches becomes 1 x ( 8x8 ) patch (extra patch ignored, no padding)
120
+ # 4 x (8x8) patches becomes 1 x (16x16) patch (padding later)
121
+ # 5 x (8x8) patches becomes 1 x (16x16) patch (extra patch ignored, padding later)
122
+ # 9 x (8x8) patches becomes 1 x (24x24) patch (padding later)
123
+ # thus the following code only rearranges the patches and removes extra ones
124
+ padding = 0
125
+
126
+ # make sure padding is not large enough to remove more than half of the patch
127
+ padding = min(out_size // 4, padding)
128
+
129
+ if padding == 0:
130
+ # faster when no padding is required
131
+ merged = patches.reshape(n_patches_per_batch, batch_size, hidden_size, out_size, out_size)
132
+ merged = merged.permute(1, 2, 0, 3, 4)
133
+ merged = merged[:, :, : sqrt_n_patches_per_batch**2, :, :]
134
+ merged = merged.reshape(
135
+ batch_size, hidden_size, sqrt_n_patches_per_batch, sqrt_n_patches_per_batch, out_size, out_size
136
+ )
137
+ merged = merged.permute(0, 1, 2, 4, 3, 5)
138
+ merged = merged.reshape(batch_size, hidden_size, new_out_size, new_out_size)
139
+ else:
140
+ # padding example:
141
+ # let out_size = 8, new_out_size = 32, padding = 2
142
+ # each patch is separated by "|"
143
+ # and padding is applied to the merging edges of each patch
144
+ # 00 01 02 03 04 05 06 07 | 08 09 10 11 12 13 14 15 | 16 17 18 19 20 21 22 23 | 24 25 26 27 28 29 30 31
145
+ # 00 01 02 03 04 05 -- -- | -- -- 10 11 12 13 -- -- | -- -- 18 19 20 21 -- -- | -- -- 26 27 28 29 30 31
146
+ i = 0
147
+ boxes = []
148
+ for h in range(sqrt_n_patches_per_batch):
149
+ boxes_in_row = []
150
+ for w in range(sqrt_n_patches_per_batch):
151
+ box = patches[batch_size * i : batch_size * (i + 1)]
152
+
153
+ # collect paddings
154
+ paddings = [0, 0, 0, 0]
155
+ if h != 0:
156
+ # remove pad from height if box is not at top border
157
+ paddings[0] = padding
158
+ if w != 0:
159
+ # remove pad from width if box is not at left border
160
+ paddings[2] = padding
161
+ if h != sqrt_n_patches_per_batch - 1:
162
+ # remove pad from height if box is not at bottom border
163
+ paddings[1] = padding
164
+ if w != sqrt_n_patches_per_batch - 1:
165
+ # remove pad from width if box is not at right border
166
+ paddings[3] = padding
167
+
168
+ # remove paddings
169
+ _, _, box_h, box_w = box.shape
170
+ pad_top, pad_bottom, pad_left, pad_right = paddings
171
+ box = box[:, :, pad_top : box_h - pad_bottom, pad_left : box_w - pad_right]
172
+
173
+ boxes_in_row.append(box)
174
+ i += 1
175
+ boxes_in_row = torch.cat(boxes_in_row, dim=-1)
176
+ boxes.append(boxes_in_row)
177
+ merged = torch.cat(boxes, dim=-2)
178
+
179
+ return merged
180
+
181
+
182
+ def reconstruct_feature_maps(
183
+ hidden_state: torch.Tensor, batch_size: int, padding: int, output_size: tuple[float, float]
184
+ ) -> torch.Tensor:
185
+ """
186
+ Reconstructs feature maps from the hidden state produced by any of the encoder. Converts the hidden state of shape
187
+ `(n_patches_per_batch * batch_size, seq_len, hidden_size)` to feature maps of shape
188
+ `(batch_size, hidden_size, output_size[0], output_size[1])`.
189
+
190
+ Args:
191
+ hidden_state (torch.Tensor): Input tensor of shape `(n_patches_per_batch * batch_size, seq_len, hidden_size)`
192
+ representing the encoded patches.
193
+ batch_size (int): The number of samples in a batch.
194
+ padding (int): The amount of padding to be removed when merging patches.
195
+ output_size (tuple[float, float]): The desired output size for the feature maps, specified as `(height, width)`.
196
+
197
+ Returns:
198
+ torch.Tensor: Reconstructed feature maps of shape `(batch_size, hidden_size, output_size[0], output_size[1])`.
199
+ """
200
+ # reshape back to image like
201
+ features = reshape_features(hidden_state)
202
+
203
+ # merge all patches in a batch to create one large patch per batch
204
+ features = merge_patches(
205
+ features,
206
+ batch_size=batch_size,
207
+ padding=padding,
208
+ )
209
+
210
+ # interpolate patches to base size
211
+ features = F.interpolate(
212
+ features,
213
+ size=output_size,
214
+ mode="bilinear",
215
+ align_corners=False,
216
+ )
217
+
218
+ return features
219
+
220
+
221
+ class DepthProPatchEncoder(nn.Module):
222
+ def __init__(self, config: DepthProConfig):
223
+ super().__init__()
224
+ self.config = config
225
+
226
+ self.intermediate_hook_ids = config.intermediate_hook_ids
227
+ self.intermediate_feature_dims = config.intermediate_feature_dims
228
+ self.scaled_images_ratios = config.scaled_images_ratios
229
+ self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
230
+ self.scaled_images_feature_dims = config.scaled_images_feature_dims
231
+ self.merge_padding_value = config.merge_padding_value
232
+
233
+ self.n_scaled_images = len(config.scaled_images_ratios)
234
+ self.n_intermediate_hooks = len(config.intermediate_hook_ids)
235
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
236
+
237
+ self.model = AutoModel.from_config(config.patch_model_config)
238
+
239
+ def forward(
240
+ self,
241
+ pixel_values: torch.Tensor,
242
+ head_mask: Optional[torch.Tensor] = None,
243
+ ) -> list[torch.Tensor]:
244
+ batch_size, num_channels, height, width = pixel_values.shape
245
+
246
+ if min(self.scaled_images_ratios) * min(height, width) < self.config.patch_size:
247
+ raise ValueError(
248
+ f"Image size {height}x{width} is too small to be scaled "
249
+ f"with scaled_images_ratios={self.scaled_images_ratios} "
250
+ f"when patch_size={self.config.patch_size}."
251
+ )
252
+
253
+ # STEP 1: create 3-level image
254
+
255
+ scaled_images = []
256
+ for ratio in self.scaled_images_ratios:
257
+ scaled_images.append(
258
+ F.interpolate(
259
+ pixel_values,
260
+ scale_factor=ratio,
261
+ mode="bilinear",
262
+ align_corners=False,
263
+ )
264
+ )
265
+
266
+ # STEP 2: create patches
267
+
268
+ for i in range(self.n_scaled_images):
269
+ scaled_images[i] = split_to_patches(
270
+ scaled_images[i],
271
+ patch_size=self.config.patch_size,
272
+ overlap_ratio=self.scaled_images_overlap_ratios[i],
273
+ )
274
+ n_patches_per_scaled_image = [len(i) for i in scaled_images]
275
+ patches = torch.cat(scaled_images[::-1], dim=0) # -1 as patch encoder expects high res patches first
276
+
277
+ # STEP 3: apply patch encoder
278
+
279
+ encodings = self.model(
280
+ # each patch is processed as a separate batch
281
+ patches,
282
+ head_mask=head_mask,
283
+ # required for intermediate features
284
+ output_hidden_states=self.n_intermediate_hooks > 0,
285
+ )
286
+
287
+ scaled_images_last_hidden_state = torch.split_with_sizes(encodings[0], n_patches_per_scaled_image[::-1])
288
+ # -1 (reverse list) as patch encoder returns high res patches first, we need low res first
289
+ scaled_images_last_hidden_state = scaled_images_last_hidden_state[::-1]
290
+
291
+ # calculate base height and width
292
+ # base height and width are the dimensions of the lowest resolution features
293
+ exponent_value = torch_int(math.log2(width / self.out_size))
294
+ base_height = height // 2**exponent_value
295
+ base_width = width // 2**exponent_value
296
+
297
+ # STEP 4: get patch features (high_res, med_res, low_res) - (3-5) in diagram
298
+
299
+ scaled_images_features = []
300
+ for i in range(self.n_scaled_images):
301
+ hidden_state = scaled_images_last_hidden_state[i]
302
+ batch_size = batch_size
303
+ padding = torch_int(self.merge_padding_value * (1 / self.scaled_images_ratios[i]))
304
+ output_height = base_height * 2**i
305
+ output_width = base_width * 2**i
306
+ features = reconstruct_feature_maps(
307
+ hidden_state,
308
+ batch_size=batch_size,
309
+ padding=padding,
310
+ output_size=(output_height, output_width),
311
+ )
312
+ scaled_images_features.append(features)
313
+
314
+ # STEP 5: get intermediate features - (1-2) in diagram
315
+
316
+ intermediate_features = []
317
+ for i in range(self.n_intermediate_hooks):
318
+ # +1 to correct index position as hidden_states contain embedding output as well
319
+ hidden_state = encodings[2][self.intermediate_hook_ids[i] + 1]
320
+ padding = torch_int(self.merge_padding_value * (1 / self.scaled_images_ratios[-1]))
321
+ output_height = base_height * 2 ** (self.n_scaled_images - 1)
322
+ output_width = base_width * 2 ** (self.n_scaled_images - 1)
323
+ features = reconstruct_feature_maps(
324
+ hidden_state,
325
+ batch_size=batch_size,
326
+ padding=padding,
327
+ output_size=(output_height, output_width),
328
+ )
329
+ intermediate_features.append(features)
330
+
331
+ # STEP 7: combine all features
332
+ features = [*scaled_images_features, *intermediate_features]
333
+
334
+ return features
335
+
336
+
337
+ class DepthProImageEncoder(nn.Module):
338
+ def __init__(self, config: DepthProConfig):
339
+ super().__init__()
340
+ self.config = config
341
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
342
+
343
+ self.model = AutoModel.from_config(config.image_model_config)
344
+
345
+ def forward(
346
+ self,
347
+ pixel_values: torch.Tensor,
348
+ head_mask: Optional[torch.Tensor] = None,
349
+ output_attentions: bool = False,
350
+ output_hidden_states: bool = False,
351
+ return_dict: bool = True,
352
+ ) -> Union[tuple, DepthProOutput]:
353
+ batch_size, num_channels, height, width = pixel_values.shape
354
+
355
+ # scale the image for image_encoder
356
+ size = self.config.image_model_config.image_size
357
+ pixel_values = F.interpolate(
358
+ pixel_values,
359
+ size=(size, size),
360
+ mode="bilinear",
361
+ align_corners=False,
362
+ )
363
+ encodings = self.model(
364
+ pixel_values=pixel_values,
365
+ head_mask=head_mask,
366
+ output_attentions=output_attentions,
367
+ output_hidden_states=output_hidden_states,
368
+ )
369
+
370
+ # calculate base height and width
371
+ # base height and width are the dimensions of the lowest resolution features
372
+ exponent_value = torch_int(math.log2(width / self.out_size))
373
+ base_height = height // 2**exponent_value
374
+ base_width = width // 2**exponent_value
375
+
376
+ features = reconstruct_feature_maps(
377
+ encodings[0],
378
+ batch_size=batch_size,
379
+ padding=0,
380
+ output_size=(base_height, base_width),
381
+ )
382
+
383
+ if not return_dict:
384
+ return (encodings[0], features) + encodings[2:] # ignore last_hidden_state and poooler output
385
+
386
+ return DepthProOutput(
387
+ last_hidden_state=encodings.last_hidden_state,
388
+ features=features,
389
+ hidden_states=encodings.hidden_states,
390
+ attentions=encodings.attentions,
391
+ )
392
+
393
+
394
+ class DepthProEncoder(nn.Module):
395
+ def __init__(self, config: DepthProConfig):
396
+ super().__init__()
397
+ self.config = config
398
+ self.intermediate_hook_ids = config.intermediate_hook_ids
399
+ self.intermediate_feature_dims = config.intermediate_feature_dims
400
+ self.scaled_images_ratios = config.scaled_images_ratios
401
+ self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
402
+ self.scaled_images_feature_dims = config.scaled_images_feature_dims
403
+ self.merge_padding_value = config.merge_padding_value
404
+
405
+ self.n_scaled_images = len(self.scaled_images_ratios)
406
+ self.n_intermediate_hooks = len(self.intermediate_hook_ids)
407
+
408
+ self.patch_encoder = DepthProPatchEncoder(config)
409
+ self.image_encoder = DepthProImageEncoder(config)
410
+
411
+ def forward(
412
+ self,
413
+ pixel_values: torch.Tensor,
414
+ head_mask: Optional[torch.Tensor] = None,
415
+ output_attentions: bool = False,
416
+ output_hidden_states: bool = False,
417
+ return_dict: bool = True,
418
+ ) -> Union[tuple, DepthProOutput]:
419
+ batch_size, num_channels, height, width = pixel_values.shape
420
+
421
+ patch_features = self.patch_encoder(
422
+ pixel_values,
423
+ head_mask=head_mask,
424
+ )
425
+ image_encodings = self.image_encoder(
426
+ pixel_values,
427
+ head_mask=head_mask,
428
+ output_attentions=output_attentions,
429
+ output_hidden_states=output_hidden_states,
430
+ return_dict=return_dict,
431
+ )
432
+ image_features = image_encodings[1] # index 1 contains features
433
+
434
+ features = [image_features, *patch_features]
435
+
436
+ if not return_dict:
437
+ return (image_encodings[0], features) + image_encodings[2:]
438
+
439
+ return DepthProOutput(
440
+ last_hidden_state=image_encodings.last_hidden_state,
441
+ features=features,
442
+ hidden_states=image_encodings.hidden_states,
443
+ attentions=image_encodings.attentions,
444
+ )
445
+
446
+
447
+ class DepthProFeatureUpsampleBlock(nn.Module):
448
+ def __init__(
449
+ self,
450
+ config: DepthProConfig,
451
+ input_dims: int,
452
+ intermediate_dims: int,
453
+ output_dims: int,
454
+ n_upsample_layers: int,
455
+ use_proj: bool = True,
456
+ bias: bool = False,
457
+ ):
458
+ super().__init__()
459
+ self.config = config
460
+ self.layers = nn.ModuleList()
461
+
462
+ # create first projection layer
463
+ if use_proj:
464
+ proj = nn.Conv2d(
465
+ in_channels=input_dims,
466
+ out_channels=intermediate_dims,
467
+ kernel_size=1,
468
+ stride=1,
469
+ padding=0,
470
+ bias=bias,
471
+ )
472
+ self.layers.append(proj)
473
+
474
+ # create following upsample layers
475
+ for i in range(n_upsample_layers):
476
+ in_channels = intermediate_dims if i == 0 else output_dims
477
+ layer = nn.ConvTranspose2d(
478
+ in_channels=in_channels,
479
+ out_channels=output_dims,
480
+ kernel_size=2,
481
+ stride=2,
482
+ padding=0,
483
+ bias=bias,
484
+ )
485
+ self.layers.append(layer)
486
+
487
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
488
+ for layer in self.layers:
489
+ features = layer(features)
490
+ return features
491
+
492
+
493
+ class DepthProFeatureUpsample(nn.Module):
494
+ def __init__(self, config: DepthProConfig):
495
+ super().__init__()
496
+ self.config = config
497
+ self.n_scaled_images = len(self.config.scaled_images_ratios)
498
+ self.n_intermediate_hooks = len(self.config.intermediate_hook_ids)
499
+
500
+ # for image_features
501
+ self.image_block = DepthProFeatureUpsampleBlock(
502
+ config=config,
503
+ input_dims=config.image_model_config.hidden_size,
504
+ intermediate_dims=config.image_model_config.hidden_size,
505
+ output_dims=config.scaled_images_feature_dims[0],
506
+ n_upsample_layers=1,
507
+ use_proj=False,
508
+ bias=True,
509
+ )
510
+
511
+ # for scaled_images_features
512
+ self.scaled_images = nn.ModuleList()
513
+ for i, feature_dims in enumerate(config.scaled_images_feature_dims):
514
+ block = DepthProFeatureUpsampleBlock(
515
+ config=config,
516
+ input_dims=config.patch_model_config.hidden_size,
517
+ intermediate_dims=feature_dims,
518
+ output_dims=feature_dims,
519
+ n_upsample_layers=1,
520
+ )
521
+ self.scaled_images.append(block)
522
+
523
+ # for intermediate_features
524
+ self.intermediate = nn.ModuleList()
525
+ for i, feature_dims in enumerate(config.intermediate_feature_dims):
526
+ intermediate_dims = config.fusion_hidden_size if i == 0 else feature_dims
527
+ block = DepthProFeatureUpsampleBlock(
528
+ config=config,
529
+ input_dims=config.patch_model_config.hidden_size,
530
+ intermediate_dims=intermediate_dims,
531
+ output_dims=feature_dims,
532
+ n_upsample_layers=2 + i,
533
+ )
534
+ self.intermediate.append(block)
535
+
536
+ def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
537
+ features[0] = self.image_block(features[0])
538
+
539
+ for i in range(self.n_scaled_images):
540
+ features[i + 1] = self.scaled_images[i](features[i + 1])
541
+
542
+ for i in range(self.n_intermediate_hooks):
543
+ features[self.n_scaled_images + i + 1] = self.intermediate[i](features[self.n_scaled_images + i + 1])
544
+
545
+ return features
546
+
547
+
548
+ class DepthProFeatureProjection(nn.Module):
549
+ def __init__(self, config: DepthProConfig):
550
+ super().__init__()
551
+ self.config = config
552
+
553
+ combined_feature_dims = config.scaled_images_feature_dims + config.intermediate_feature_dims
554
+ self.projections = nn.ModuleList()
555
+ for i, in_channels in enumerate(combined_feature_dims):
556
+ if i == len(combined_feature_dims) - 1 and in_channels == config.fusion_hidden_size:
557
+ # projection for last layer can be ignored if input and output channels already match
558
+ self.projections.append(nn.Identity())
559
+ else:
560
+ self.projections.append(
561
+ nn.Conv2d(
562
+ in_channels=in_channels,
563
+ out_channels=config.fusion_hidden_size,
564
+ kernel_size=3,
565
+ stride=1,
566
+ padding=1,
567
+ bias=False,
568
+ )
569
+ )
570
+
571
+ def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
572
+ projected_features = []
573
+ for i, projection in enumerate(self.projections):
574
+ upsampled_feature = projection(features[i])
575
+ projected_features.append(upsampled_feature)
576
+ return projected_features
577
+
578
+
579
+ class DepthProNeck(nn.Module):
580
+ def __init__(self, config: DepthProConfig):
581
+ super().__init__()
582
+ self.config = config
583
+
584
+ self.feature_upsample = DepthProFeatureUpsample(config)
585
+ self.fuse_image_with_low_res = nn.Conv2d(
586
+ in_channels=config.scaled_images_feature_dims[0] * 2,
587
+ out_channels=config.scaled_images_feature_dims[0],
588
+ kernel_size=1,
589
+ stride=1,
590
+ padding=0,
591
+ bias=True,
592
+ )
593
+ self.feature_projection = DepthProFeatureProjection(config)
594
+
595
+ def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
596
+ features = self.feature_upsample(features)
597
+ # global features = low res features + image features
598
+ global_features = torch.cat((features[1], features[0]), dim=1)
599
+ global_features = self.fuse_image_with_low_res(global_features)
600
+ features = [global_features, *features[2:]]
601
+ features = self.feature_projection(features)
602
+ return features
603
+
604
+
605
+ # General docstring
606
+
607
+
608
+ @auto_docstring
609
+ class DepthProPreTrainedModel(PreTrainedModel):
610
+ config: DepthProConfig
611
+ base_model_prefix = "depth_pro"
612
+ main_input_name = "pixel_values"
613
+ supports_gradient_checkpointing = True
614
+ _supports_sdpa = True
615
+ _no_split_modules = ["DepthProPreActResidualLayer"]
616
+ _keys_to_ignore_on_load_unexpected = ["fov_model.*"]
617
+
618
+ def _init_weights(self, module):
619
+ """Initialize the weights"""
620
+ if isinstance(module, nn.Linear):
621
+ # Slightly different from the TF version which uses truncated_normal for initialization
622
+ # cf https://github.com/pytorch/pytorch/pull/5617
623
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
624
+ if module.bias is not None:
625
+ module.bias.data.zero_()
626
+ elif isinstance(module, nn.LayerNorm):
627
+ module.bias.data.zero_()
628
+ module.weight.data.fill_(1.0)
629
+ elif isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
630
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
631
+ if module.bias is not None:
632
+ module.bias.data.zero_()
633
+
634
+
635
+ @auto_docstring
636
+ class DepthProModel(DepthProPreTrainedModel):
637
+ def __init__(self, config):
638
+ super().__init__(config)
639
+ self.config = config
640
+ self.encoder = DepthProEncoder(config)
641
+ self.neck = DepthProNeck(config)
642
+ # Initialize weights and apply final processing
643
+ self.post_init()
644
+
645
+ def get_input_embeddings(self):
646
+ return self.encoder.image_encoder.model.get_input_embeddings()
647
+
648
+ @auto_docstring
649
+ def forward(
650
+ self,
651
+ pixel_values: torch.FloatTensor,
652
+ head_mask: Optional[torch.FloatTensor] = None,
653
+ output_attentions: Optional[bool] = None,
654
+ output_hidden_states: Optional[bool] = None,
655
+ return_dict: Optional[bool] = None,
656
+ ) -> Union[tuple, DepthProOutput]:
657
+ r"""
658
+ Examples:
659
+
660
+ ```python
661
+ >>> import torch
662
+ >>> from PIL import Image
663
+ >>> import requests
664
+ >>> from transformers import AutoProcessor, DepthProModel
665
+
666
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
667
+ >>> image = Image.open(requests.get(url, stream=True).raw)
668
+
669
+ >>> checkpoint = "apple/DepthPro-hf"
670
+ >>> processor = AutoProcessor.from_pretrained(checkpoint)
671
+ >>> model = DepthProModel.from_pretrained(checkpoint)
672
+
673
+ >>> # prepare image for the model
674
+ >>> inputs = processor(images=image, return_tensors="pt")
675
+
676
+ >>> with torch.no_grad():
677
+ ... output = model(**inputs)
678
+
679
+ >>> output.last_hidden_state.shape
680
+ torch.Size([1, 35, 577, 1024])
681
+ ```"""
682
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
683
+ output_hidden_states = (
684
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
685
+ )
686
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
687
+
688
+ encodings = self.encoder(
689
+ pixel_values,
690
+ head_mask=head_mask,
691
+ output_attentions=output_attentions,
692
+ output_hidden_states=output_hidden_states,
693
+ return_dict=return_dict,
694
+ )
695
+ features = encodings[1] # index 1 contains features
696
+ features = self.neck(features)
697
+
698
+ if not return_dict:
699
+ return (encodings[0], features) + encodings[2:]
700
+
701
+ return DepthProOutput(
702
+ last_hidden_state=encodings.last_hidden_state,
703
+ features=features,
704
+ hidden_states=encodings.hidden_states,
705
+ attentions=encodings.attentions,
706
+ )
707
+
708
+
709
+ # Copied from transformers.models.dpt.modeling_dpt.DPTPreActResidualLayer DPT->DepthPro
710
+ class DepthProPreActResidualLayer(nn.Module):
711
+ """
712
+ ResidualConvUnit, pre-activate residual unit.
713
+
714
+ Args:
715
+ config (`[DepthProConfig]`):
716
+ Model configuration class defining the model architecture.
717
+ """
718
+
719
+ def __init__(self, config: DepthProConfig):
720
+ super().__init__()
721
+
722
+ self.use_batch_norm = config.use_batch_norm_in_fusion_residual
723
+ use_bias_in_fusion_residual = (
724
+ config.use_bias_in_fusion_residual
725
+ if config.use_bias_in_fusion_residual is not None
726
+ else not self.use_batch_norm
727
+ )
728
+
729
+ self.activation1 = nn.ReLU()
730
+ self.convolution1 = nn.Conv2d(
731
+ config.fusion_hidden_size,
732
+ config.fusion_hidden_size,
733
+ kernel_size=3,
734
+ stride=1,
735
+ padding=1,
736
+ bias=use_bias_in_fusion_residual,
737
+ )
738
+
739
+ self.activation2 = nn.ReLU()
740
+ self.convolution2 = nn.Conv2d(
741
+ config.fusion_hidden_size,
742
+ config.fusion_hidden_size,
743
+ kernel_size=3,
744
+ stride=1,
745
+ padding=1,
746
+ bias=use_bias_in_fusion_residual,
747
+ )
748
+
749
+ if self.use_batch_norm:
750
+ self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
751
+ self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
752
+
753
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
754
+ residual = hidden_state
755
+ hidden_state = self.activation1(hidden_state)
756
+
757
+ hidden_state = self.convolution1(hidden_state)
758
+
759
+ if self.use_batch_norm:
760
+ hidden_state = self.batch_norm1(hidden_state)
761
+
762
+ hidden_state = self.activation2(hidden_state)
763
+ hidden_state = self.convolution2(hidden_state)
764
+
765
+ if self.use_batch_norm:
766
+ hidden_state = self.batch_norm2(hidden_state)
767
+
768
+ return hidden_state + residual
769
+
770
+
771
+ # Modified from transformers.models.dpt.modeling_dpt.DPTFeatureFusionLayer
772
+ # except it uses deconv and skip_add and needs no interpolation
773
+ class DepthProFeatureFusionLayer(nn.Module):
774
+ def __init__(self, config: DepthProConfig, use_deconv: bool = True):
775
+ super().__init__()
776
+ self.config = config
777
+ self.use_deconv = use_deconv
778
+
779
+ self.residual_layer1 = DepthProPreActResidualLayer(config)
780
+ self.residual_layer2 = DepthProPreActResidualLayer(config)
781
+
782
+ if self.use_deconv:
783
+ self.deconv = nn.ConvTranspose2d(
784
+ in_channels=config.fusion_hidden_size,
785
+ out_channels=config.fusion_hidden_size,
786
+ kernel_size=2,
787
+ stride=2,
788
+ padding=0,
789
+ bias=False,
790
+ )
791
+
792
+ self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
793
+
794
+ def forward(self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
795
+ if residual is not None:
796
+ residual = self.residual_layer1(residual)
797
+ hidden_state = hidden_state + residual
798
+
799
+ hidden_state = self.residual_layer2(hidden_state)
800
+ if self.use_deconv:
801
+ hidden_state = self.deconv(hidden_state)
802
+ hidden_state = self.projection(hidden_state)
803
+
804
+ return hidden_state
805
+
806
+
807
+ # Modified from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage with DPT->DepthPro
808
+ # with deconv and reversed layers
809
+ class DepthProFeatureFusionStage(nn.Module):
810
+ def __init__(self, config):
811
+ super().__init__()
812
+ self.config = config
813
+
814
+ self.num_layers = len(config.intermediate_hook_ids) + len(config.scaled_images_ratios)
815
+ self.intermediate = nn.ModuleList()
816
+ for _ in range(self.num_layers - 1):
817
+ self.intermediate.append(DepthProFeatureFusionLayer(config))
818
+
819
+ # final layer does not require deconvolution
820
+ self.final = DepthProFeatureFusionLayer(config, use_deconv=False)
821
+
822
+ def forward(self, hidden_states: list[torch.Tensor]) -> list[torch.Tensor]:
823
+ if self.num_layers != len(hidden_states):
824
+ raise ValueError(
825
+ f"num_layers={self.num_layers} in DepthProFeatureFusionStage"
826
+ f"does not match len(hidden_states)={len(hidden_states)}"
827
+ )
828
+
829
+ fused_hidden_states = []
830
+ fused_hidden_state = None
831
+ for hidden_state, layer in zip(hidden_states[:-1], self.intermediate):
832
+ if fused_hidden_state is None:
833
+ # first layer only uses the last hidden_state
834
+ fused_hidden_state = layer(hidden_state)
835
+ else:
836
+ fused_hidden_state = layer(fused_hidden_state, hidden_state)
837
+ fused_hidden_states.append(fused_hidden_state)
838
+
839
+ hidden_state = hidden_states[-1]
840
+ fused_hidden_state = self.final(fused_hidden_state, hidden_state)
841
+ fused_hidden_states.append(fused_hidden_state)
842
+
843
+ return fused_hidden_states
844
+
845
+
846
+ class DepthProFovEncoder(nn.Module):
847
+ def __init__(self, config: DepthProConfig):
848
+ super().__init__()
849
+ self.config = config
850
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
851
+
852
+ self.model = AutoModel.from_config(config.fov_model_config)
853
+ self.neck = nn.Linear(config.fov_model_config.hidden_size, config.fusion_hidden_size // 2)
854
+
855
+ def forward(
856
+ self,
857
+ pixel_values: torch.Tensor,
858
+ head_mask: Optional[torch.Tensor] = None,
859
+ ) -> torch.Tensor:
860
+ batch_size, num_channels, height, width = pixel_values.shape
861
+
862
+ # scale the image for fov_encoder
863
+ size = self.config.fov_model_config.image_size
864
+ pixel_values = F.interpolate(
865
+ pixel_values,
866
+ size=(size, size),
867
+ mode="bilinear",
868
+ align_corners=False,
869
+ )
870
+ encodings = self.model(
871
+ pixel_values=pixel_values,
872
+ head_mask=head_mask,
873
+ )
874
+ hidden_state = encodings[0]
875
+ hidden_state = self.neck(hidden_state)
876
+
877
+ # calculate base height and width
878
+ # base height and width are the dimensions of the lowest resolution features
879
+ exponent_value = torch_int(math.log2(width / self.out_size))
880
+ base_height = height // 2**exponent_value
881
+ base_width = width // 2**exponent_value
882
+
883
+ features = reconstruct_feature_maps(
884
+ hidden_state,
885
+ batch_size=batch_size,
886
+ padding=0,
887
+ output_size=(base_height, base_width),
888
+ )
889
+
890
+ return features
891
+
892
+
893
+ class DepthProFovHead(nn.Module):
894
+ def __init__(self, config: DepthProConfig):
895
+ super().__init__()
896
+ self.config = config
897
+ self.fusion_hidden_size = config.fusion_hidden_size
898
+ self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
899
+
900
+ # create initial head layers
901
+ self.layers = nn.ModuleList()
902
+ for i in range(config.num_fov_head_layers):
903
+ self.layers.append(
904
+ nn.Conv2d(
905
+ math.ceil(self.fusion_hidden_size / 2 ** (i + 1)),
906
+ math.ceil(self.fusion_hidden_size / 2 ** (i + 2)),
907
+ kernel_size=3,
908
+ stride=2,
909
+ padding=1,
910
+ )
911
+ )
912
+ self.layers.append(nn.ReLU(True))
913
+ # calculate expected shapes to finally generate a scalar output from final head layer
914
+ final_in_channels = math.ceil(self.fusion_hidden_size / 2 ** (config.num_fov_head_layers + 1))
915
+ final_kernel_size = torch_int((self.out_size - 1) / 2**config.num_fov_head_layers + 1)
916
+ self.layers.append(
917
+ nn.Conv2d(
918
+ in_channels=final_in_channels, out_channels=1, kernel_size=final_kernel_size, stride=1, padding=0
919
+ )
920
+ )
921
+
922
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
923
+ features = F.interpolate(
924
+ features,
925
+ size=(self.out_size, self.out_size),
926
+ mode="bilinear",
927
+ align_corners=False,
928
+ )
929
+ for layer in self.layers:
930
+ features = layer(features)
931
+ return features
932
+
933
+
934
+ class DepthProFovModel(nn.Module):
935
+ def __init__(self, config: DepthProConfig):
936
+ super().__init__()
937
+ self.config = config
938
+ self.fusion_hidden_size = config.fusion_hidden_size
939
+
940
+ self.fov_encoder = DepthProFovEncoder(config)
941
+ self.conv = nn.Conv2d(
942
+ self.fusion_hidden_size, self.fusion_hidden_size // 2, kernel_size=3, stride=2, padding=1
943
+ )
944
+ self.activation = nn.ReLU(inplace=True)
945
+ self.head = DepthProFovHead(config)
946
+
947
+ def forward(
948
+ self,
949
+ pixel_values: torch.Tensor,
950
+ global_features: torch.Tensor,
951
+ head_mask: Optional[torch.Tensor] = None,
952
+ ) -> torch.Tensor:
953
+ fov_features = self.fov_encoder(pixel_values, head_mask)
954
+
955
+ global_features = self.conv(global_features)
956
+ global_features = self.activation(global_features)
957
+
958
+ fov_features = fov_features + global_features
959
+ fov_output = self.head(fov_features)
960
+ fov_output = fov_output.flatten()
961
+
962
+ return fov_output
963
+
964
+
965
+ class DepthProDepthEstimationHead(nn.Module):
966
+ """
967
+ The DepthProDepthEstimationHead module serves as the output head for depth estimation tasks.
968
+ This module comprises a sequence of convolutional and transposed convolutional layers
969
+ that process the feature map from the fusion to produce a single-channel depth map.
970
+ Key operations include dimensionality reduction and upsampling to match the input resolution.
971
+ """
972
+
973
+ def __init__(self, config):
974
+ super().__init__()
975
+ self.config = config
976
+
977
+ features = config.fusion_hidden_size
978
+ self.layers = nn.ModuleList(
979
+ [
980
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
981
+ nn.ConvTranspose2d(
982
+ in_channels=features // 2,
983
+ out_channels=features // 2,
984
+ kernel_size=2,
985
+ stride=2,
986
+ padding=0,
987
+ bias=True,
988
+ ),
989
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
990
+ nn.ReLU(True),
991
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
992
+ nn.ReLU(),
993
+ ]
994
+ )
995
+
996
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
997
+ for layer in self.layers:
998
+ hidden_states = layer(hidden_states)
999
+
1000
+ predicted_depth = hidden_states.squeeze(dim=1)
1001
+ return predicted_depth
1002
+
1003
+
1004
+ @auto_docstring(
1005
+ custom_intro="""
1006
+ DepthPro Model with a depth estimation head on top (consisting of 3 convolutional layers).
1007
+ """
1008
+ )
1009
+ class DepthProForDepthEstimation(DepthProPreTrainedModel):
1010
+ def __init__(self, config, use_fov_model=None):
1011
+ r"""
1012
+ use_fov_model (bool, *optional*):
1013
+ Whether to use the field of view model.
1014
+ """
1015
+ super().__init__(config)
1016
+ self.config = config
1017
+ self.use_fov_model = use_fov_model if use_fov_model is not None else self.config.use_fov_model
1018
+
1019
+ # dinov2 (vit) like encoders
1020
+ self.depth_pro = DepthProModel(config)
1021
+
1022
+ # dpt (vit) like fusion stage
1023
+ self.fusion_stage = DepthProFeatureFusionStage(config)
1024
+
1025
+ # depth estimation head
1026
+ self.head = DepthProDepthEstimationHead(config)
1027
+
1028
+ # dinov2 (vit) like encoder
1029
+ self.fov_model = DepthProFovModel(config) if self.use_fov_model else None
1030
+
1031
+ # Initialize weights and apply final processing
1032
+ self.post_init()
1033
+
1034
+ @auto_docstring
1035
+ def forward(
1036
+ self,
1037
+ pixel_values: torch.FloatTensor,
1038
+ head_mask: Optional[torch.FloatTensor] = None,
1039
+ labels: Optional[torch.LongTensor] = None,
1040
+ output_attentions: Optional[bool] = None,
1041
+ output_hidden_states: Optional[bool] = None,
1042
+ return_dict: Optional[bool] = None,
1043
+ ) -> Union[tuple[torch.Tensor], DepthProDepthEstimatorOutput]:
1044
+ r"""
1045
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
1046
+ Ground truth depth estimation maps for computing the loss.
1047
+
1048
+ Examples:
1049
+
1050
+ ```python
1051
+ >>> from transformers import AutoImageProcessor, DepthProForDepthEstimation
1052
+ >>> import torch
1053
+ >>> from PIL import Image
1054
+ >>> import requests
1055
+
1056
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1057
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1058
+
1059
+ >>> checkpoint = "apple/DepthPro-hf"
1060
+ >>> processor = AutoImageProcessor.from_pretrained(checkpoint)
1061
+ >>> model = DepthProForDepthEstimation.from_pretrained(checkpoint)
1062
+
1063
+ >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1064
+ >>> model.to(device)
1065
+
1066
+ >>> # prepare image for the model
1067
+ >>> inputs = processor(images=image, return_tensors="pt").to(device)
1068
+
1069
+ >>> with torch.no_grad():
1070
+ ... outputs = model(**inputs)
1071
+
1072
+ >>> # interpolate to original size
1073
+ >>> post_processed_output = processor.post_process_depth_estimation(
1074
+ ... outputs, target_sizes=[(image.height, image.width)],
1075
+ ... )
1076
+
1077
+ >>> # get the field of view (fov) predictions
1078
+ >>> field_of_view = post_processed_output[0]["field_of_view"]
1079
+ >>> focal_length = post_processed_output[0]["focal_length"]
1080
+
1081
+ >>> # visualize the prediction
1082
+ >>> predicted_depth = post_processed_output[0]["predicted_depth"]
1083
+ >>> depth = predicted_depth * 255 / predicted_depth.max()
1084
+ >>> depth = depth.detach().cpu().numpy()
1085
+ >>> depth = Image.fromarray(depth.astype("uint8"))
1086
+ ```"""
1087
+ loss = None
1088
+ if labels is not None:
1089
+ raise NotImplementedError("Training is not implemented yet")
1090
+
1091
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1092
+ output_hidden_states = (
1093
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1094
+ )
1095
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1096
+
1097
+ depth_pro_outputs = self.depth_pro(
1098
+ pixel_values=pixel_values,
1099
+ head_mask=head_mask,
1100
+ output_attentions=output_attentions,
1101
+ output_hidden_states=output_hidden_states,
1102
+ return_dict=True,
1103
+ )
1104
+ features = depth_pro_outputs.features
1105
+ fused_hidden_states = self.fusion_stage(features)
1106
+ predicted_depth = self.head(fused_hidden_states[-1])
1107
+
1108
+ if self.use_fov_model:
1109
+ # frozen features from encoder are used
1110
+ features_for_fov = features[0].detach()
1111
+ fov = self.fov_model(
1112
+ pixel_values=pixel_values,
1113
+ global_features=features_for_fov,
1114
+ head_mask=head_mask,
1115
+ )
1116
+ else:
1117
+ fov = None
1118
+
1119
+ if not return_dict:
1120
+ outputs = [loss, predicted_depth, fov, depth_pro_outputs.hidden_states, depth_pro_outputs.attentions]
1121
+ return tuple(v for v in outputs if v is not None)
1122
+
1123
+ return DepthProDepthEstimatorOutput(
1124
+ loss=loss,
1125
+ predicted_depth=predicted_depth,
1126
+ field_of_view=fov,
1127
+ hidden_states=depth_pro_outputs.hidden_states,
1128
+ attentions=depth_pro_outputs.attentions,
1129
+ )
1130
+
1131
+
1132
+ __all__ = ["DepthProPreTrainedModel", "DepthProModel", "DepthProForDepthEstimation"]
phivenv/Lib/site-packages/transformers/models/detr/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ...utils import _LazyModule
18
+ from ...utils.import_utils import define_import_structure
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from .configuration_detr import *
23
+ from .feature_extraction_detr import *
24
+ from .image_processing_detr import *
25
+ from .image_processing_detr_fast import *
26
+ from .modeling_detr import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (628 Bytes). View file
 
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/configuration_detr.cpython-39.pyc ADDED
Binary file (11.9 kB). View file
 
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/feature_extraction_detr.cpython-39.pyc ADDED
Binary file (1.4 kB). View file
 
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/image_processing_detr.cpython-39.pyc ADDED
Binary file (71.3 kB). View file
 
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/image_processing_detr_fast.cpython-39.pyc ADDED
Binary file (43.1 kB). View file
 
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/modeling_detr.cpython-39.pyc ADDED
Binary file (53.3 kB). View file
 
phivenv/Lib/site-packages/transformers/models/detr/configuration_detr.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Facebook AI Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """DETR model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from collections.abc import Mapping
19
+
20
+ from packaging import version
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...onnx import OnnxConfig
24
+ from ...utils import logging
25
+ from ...utils.backbone_utils import verify_backbone_config_arguments
26
+ from ..auto import CONFIG_MAPPING
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class DetrConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`DetrModel`]. It is used to instantiate a DETR
35
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
+ defaults will yield a similar configuration to that of the DETR
37
+ [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Args:
43
+ use_timm_backbone (`bool`, *optional*, defaults to `True`):
44
+ Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
45
+ API.
46
+ backbone_config (`PretrainedConfig` or `dict`, *optional*):
47
+ The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
48
+ case it will default to `ResNetConfig()`.
49
+ num_channels (`int`, *optional*, defaults to 3):
50
+ The number of input channels.
51
+ num_queries (`int`, *optional*, defaults to 100):
52
+ Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
53
+ detect in a single image. For COCO, we recommend 100 queries.
54
+ d_model (`int`, *optional*, defaults to 256):
55
+ This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
56
+ encoder_layers (`int`, *optional*, defaults to 6):
57
+ Number of encoder layers.
58
+ decoder_layers (`int`, *optional*, defaults to 6):
59
+ Number of decoder layers.
60
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
61
+ Number of attention heads for each attention layer in the Transformer encoder.
62
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
63
+ Number of attention heads for each attention layer in the Transformer decoder.
64
+ decoder_ffn_dim (`int`, *optional*, defaults to 2048):
65
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
66
+ encoder_ffn_dim (`int`, *optional*, defaults to 2048):
67
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
68
+ activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
69
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
70
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
71
+ dropout (`float`, *optional*, defaults to 0.1):
72
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
73
+ attention_dropout (`float`, *optional*, defaults to 0.0):
74
+ The dropout ratio for the attention probabilities.
75
+ activation_dropout (`float`, *optional*, defaults to 0.0):
76
+ The dropout ratio for activations inside the fully connected layer.
77
+ init_std (`float`, *optional*, defaults to 0.02):
78
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
79
+ init_xavier_std (`float`, *optional*, defaults to 1):
80
+ The scaling factor used for the Xavier initialization gain in the HM Attention map module.
81
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
82
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
83
+ for more details.
84
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
85
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
86
+ for more details.
87
+ auxiliary_loss (`bool`, *optional*, defaults to `False`):
88
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
89
+ position_embedding_type (`str`, *optional*, defaults to `"sine"`):
90
+ Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
91
+ backbone (`str`, *optional*, defaults to `"resnet50"`):
92
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
93
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
94
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
95
+ use_pretrained_backbone (`bool`, *optional*, `True`):
96
+ Whether to use pretrained weights for the backbone.
97
+ backbone_kwargs (`dict`, *optional*):
98
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
99
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
100
+ dilation (`bool`, *optional*, defaults to `False`):
101
+ Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
102
+ `use_timm_backbone` = `True`.
103
+ class_cost (`float`, *optional*, defaults to 1):
104
+ Relative weight of the classification error in the Hungarian matching cost.
105
+ bbox_cost (`float`, *optional*, defaults to 5):
106
+ Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
107
+ giou_cost (`float`, *optional*, defaults to 2):
108
+ Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
109
+ mask_loss_coefficient (`float`, *optional*, defaults to 1):
110
+ Relative weight of the Focal loss in the panoptic segmentation loss.
111
+ dice_loss_coefficient (`float`, *optional*, defaults to 1):
112
+ Relative weight of the DICE/F-1 loss in the panoptic segmentation loss.
113
+ bbox_loss_coefficient (`float`, *optional*, defaults to 5):
114
+ Relative weight of the L1 bounding box loss in the object detection loss.
115
+ giou_loss_coefficient (`float`, *optional*, defaults to 2):
116
+ Relative weight of the generalized IoU loss in the object detection loss.
117
+ eos_coefficient (`float`, *optional*, defaults to 0.1):
118
+ Relative classification weight of the 'no-object' class in the object detection loss.
119
+
120
+ Examples:
121
+
122
+ ```python
123
+ >>> from transformers import DetrConfig, DetrModel
124
+
125
+ >>> # Initializing a DETR facebook/detr-resnet-50 style configuration
126
+ >>> configuration = DetrConfig()
127
+
128
+ >>> # Initializing a model (with random weights) from the facebook/detr-resnet-50 style configuration
129
+ >>> model = DetrModel(configuration)
130
+
131
+ >>> # Accessing the model configuration
132
+ >>> configuration = model.config
133
+ ```"""
134
+
135
+ model_type = "detr"
136
+ keys_to_ignore_at_inference = ["past_key_values"]
137
+ attribute_map = {
138
+ "hidden_size": "d_model",
139
+ "num_attention_heads": "encoder_attention_heads",
140
+ }
141
+
142
+ def __init__(
143
+ self,
144
+ use_timm_backbone=True,
145
+ backbone_config=None,
146
+ num_channels=3,
147
+ num_queries=100,
148
+ encoder_layers=6,
149
+ encoder_ffn_dim=2048,
150
+ encoder_attention_heads=8,
151
+ decoder_layers=6,
152
+ decoder_ffn_dim=2048,
153
+ decoder_attention_heads=8,
154
+ encoder_layerdrop=0.0,
155
+ decoder_layerdrop=0.0,
156
+ is_encoder_decoder=True,
157
+ activation_function="relu",
158
+ d_model=256,
159
+ dropout=0.1,
160
+ attention_dropout=0.0,
161
+ activation_dropout=0.0,
162
+ init_std=0.02,
163
+ init_xavier_std=1.0,
164
+ auxiliary_loss=False,
165
+ position_embedding_type="sine",
166
+ backbone="resnet50",
167
+ use_pretrained_backbone=True,
168
+ backbone_kwargs=None,
169
+ dilation=False,
170
+ class_cost=1,
171
+ bbox_cost=5,
172
+ giou_cost=2,
173
+ mask_loss_coefficient=1,
174
+ dice_loss_coefficient=1,
175
+ bbox_loss_coefficient=5,
176
+ giou_loss_coefficient=2,
177
+ eos_coefficient=0.1,
178
+ **kwargs,
179
+ ):
180
+ # We default to values which were previously hard-coded in the model. This enables configurability of the config
181
+ # while keeping the default behavior the same.
182
+ if use_timm_backbone and backbone_kwargs is None:
183
+ backbone_kwargs = {}
184
+ if dilation:
185
+ backbone_kwargs["output_stride"] = 16
186
+ backbone_kwargs["out_indices"] = [1, 2, 3, 4]
187
+ backbone_kwargs["in_chans"] = num_channels
188
+ # Backwards compatibility
189
+ elif not use_timm_backbone and backbone in (None, "resnet50"):
190
+ if backbone_config is None:
191
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
192
+ backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
193
+ elif isinstance(backbone_config, dict):
194
+ backbone_model_type = backbone_config.get("model_type")
195
+ config_class = CONFIG_MAPPING[backbone_model_type]
196
+ backbone_config = config_class.from_dict(backbone_config)
197
+ backbone = None
198
+ # set timm attributes to None
199
+ dilation = None
200
+
201
+ verify_backbone_config_arguments(
202
+ use_timm_backbone=use_timm_backbone,
203
+ use_pretrained_backbone=use_pretrained_backbone,
204
+ backbone=backbone,
205
+ backbone_config=backbone_config,
206
+ backbone_kwargs=backbone_kwargs,
207
+ )
208
+
209
+ self.use_timm_backbone = use_timm_backbone
210
+ self.backbone_config = backbone_config
211
+ self.num_channels = num_channels
212
+ self.num_queries = num_queries
213
+ self.d_model = d_model
214
+ self.encoder_ffn_dim = encoder_ffn_dim
215
+ self.encoder_layers = encoder_layers
216
+ self.encoder_attention_heads = encoder_attention_heads
217
+ self.decoder_ffn_dim = decoder_ffn_dim
218
+ self.decoder_layers = decoder_layers
219
+ self.decoder_attention_heads = decoder_attention_heads
220
+ self.dropout = dropout
221
+ self.attention_dropout = attention_dropout
222
+ self.activation_dropout = activation_dropout
223
+ self.activation_function = activation_function
224
+ self.init_std = init_std
225
+ self.init_xavier_std = init_xavier_std
226
+ self.encoder_layerdrop = encoder_layerdrop
227
+ self.decoder_layerdrop = decoder_layerdrop
228
+ self.num_hidden_layers = encoder_layers
229
+ self.auxiliary_loss = auxiliary_loss
230
+ self.position_embedding_type = position_embedding_type
231
+ self.backbone = backbone
232
+ self.use_pretrained_backbone = use_pretrained_backbone
233
+ self.backbone_kwargs = backbone_kwargs
234
+ self.dilation = dilation
235
+ # Hungarian matcher
236
+ self.class_cost = class_cost
237
+ self.bbox_cost = bbox_cost
238
+ self.giou_cost = giou_cost
239
+ # Loss coefficients
240
+ self.mask_loss_coefficient = mask_loss_coefficient
241
+ self.dice_loss_coefficient = dice_loss_coefficient
242
+ self.bbox_loss_coefficient = bbox_loss_coefficient
243
+ self.giou_loss_coefficient = giou_loss_coefficient
244
+ self.eos_coefficient = eos_coefficient
245
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
246
+
247
+ @property
248
+ def num_attention_heads(self) -> int:
249
+ return self.encoder_attention_heads
250
+
251
+ @property
252
+ def hidden_size(self) -> int:
253
+ return self.d_model
254
+
255
+ @property
256
+ def sub_configs(self):
257
+ return (
258
+ {"backbone_config": type(self.backbone_config)}
259
+ if getattr(self, "backbone_config", None) is not None
260
+ else {}
261
+ )
262
+
263
+ @classmethod
264
+ def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
265
+ """Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
266
+
267
+ Args:
268
+ backbone_config ([`PretrainedConfig`]):
269
+ The backbone configuration.
270
+ Returns:
271
+ [`DetrConfig`]: An instance of a configuration object
272
+ """
273
+ return cls(backbone_config=backbone_config, **kwargs)
274
+
275
+
276
+ class DetrOnnxConfig(OnnxConfig):
277
+ torch_onnx_minimum_version = version.parse("1.11")
278
+
279
+ @property
280
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
281
+ return OrderedDict(
282
+ [
283
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
284
+ ("pixel_mask", {0: "batch"}),
285
+ ]
286
+ )
287
+
288
+ @property
289
+ def atol_for_validation(self) -> float:
290
+ return 1e-5
291
+
292
+ @property
293
+ def default_onnx_opset(self) -> int:
294
+ return 12
295
+
296
+
297
+ __all__ = ["DetrConfig", "DetrOnnxConfig"]
phivenv/Lib/site-packages/transformers/models/detr/feature_extraction_detr.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for DETR."""
16
+
17
+ import warnings
18
+
19
+ from ...image_transforms import rgb_to_id as _rgb_to_id
20
+ from ...utils import logging
21
+ from ...utils.import_utils import requires
22
+ from .image_processing_detr import DetrImageProcessor
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ def rgb_to_id(x):
29
+ warnings.warn(
30
+ "rgb_to_id has moved and will not be importable from this module from v5. "
31
+ "Please import from transformers.image_transforms instead.",
32
+ FutureWarning,
33
+ )
34
+ return _rgb_to_id(x)
35
+
36
+
37
+ @requires(backends=("vision",))
38
+ class DetrFeatureExtractor(DetrImageProcessor):
39
+ def __init__(self, *args, **kwargs) -> None:
40
+ warnings.warn(
41
+ "The class DetrFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
42
+ " Please use DetrImageProcessor instead.",
43
+ FutureWarning,
44
+ )
45
+ super().__init__(*args, **kwargs)
46
+
47
+
48
+ __all__ = ["DetrFeatureExtractor"]
phivenv/Lib/site-packages/transformers/models/detr/image_processing_detr.py ADDED
@@ -0,0 +1,2049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for DETR."""
16
+
17
+ import io
18
+ import pathlib
19
+ from collections import defaultdict
20
+ from collections.abc import Iterable
21
+ from typing import Any, Callable, Optional, Union
22
+
23
+ import numpy as np
24
+
25
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
26
+ from ...image_transforms import (
27
+ PaddingMode,
28
+ center_to_corners_format,
29
+ corners_to_center_format,
30
+ id_to_rgb,
31
+ pad,
32
+ rescale,
33
+ resize,
34
+ rgb_to_id,
35
+ to_channel_dimension_format,
36
+ )
37
+ from ...image_utils import (
38
+ IMAGENET_DEFAULT_MEAN,
39
+ IMAGENET_DEFAULT_STD,
40
+ AnnotationFormat,
41
+ AnnotationType,
42
+ ChannelDimension,
43
+ ImageInput,
44
+ PILImageResampling,
45
+ get_image_size,
46
+ infer_channel_dimension_format,
47
+ is_scaled_image,
48
+ make_list_of_images,
49
+ to_numpy_array,
50
+ valid_images,
51
+ validate_annotations,
52
+ validate_kwargs,
53
+ validate_preprocess_arguments,
54
+ )
55
+ from ...utils import (
56
+ TensorType,
57
+ is_flax_available,
58
+ is_jax_tensor,
59
+ is_scipy_available,
60
+ is_tf_available,
61
+ is_tf_tensor,
62
+ is_torch_available,
63
+ is_torch_tensor,
64
+ is_vision_available,
65
+ logging,
66
+ )
67
+ from ...utils.import_utils import requires
68
+
69
+
70
+ if is_torch_available():
71
+ import torch
72
+ from torch import nn
73
+
74
+
75
+ if is_vision_available():
76
+ import PIL
77
+
78
+
79
+ if is_scipy_available():
80
+ import scipy.special
81
+ import scipy.stats
82
+
83
+
84
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
85
+
86
+ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
87
+
88
+
89
+ # From the original repo: https://github.com/facebookresearch/detr/blob/3af9fa878e73b6894ce3596450a8d9b89d918ca9/datasets/transforms.py#L76
90
+ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
91
+ """
92
+ Computes the output image size given the input image size and the desired output size.
93
+
94
+ Args:
95
+ image_size (`tuple[int, int]`):
96
+ The input image size.
97
+ size (`int`):
98
+ The desired output size.
99
+ max_size (`int`, *optional*):
100
+ The maximum allowed output size.
101
+ """
102
+ height, width = image_size
103
+ raw_size = None
104
+ if max_size is not None:
105
+ min_original_size = float(min((height, width)))
106
+ max_original_size = float(max((height, width)))
107
+ if max_original_size / min_original_size * size > max_size:
108
+ raw_size = max_size * min_original_size / max_original_size
109
+ size = int(round(raw_size))
110
+
111
+ if (height <= width and height == size) or (width <= height and width == size):
112
+ oh, ow = height, width
113
+ elif width < height:
114
+ ow = size
115
+ if max_size is not None and raw_size is not None:
116
+ oh = int(raw_size * height / width)
117
+ else:
118
+ oh = int(size * height / width)
119
+ else:
120
+ oh = size
121
+ if max_size is not None and raw_size is not None:
122
+ ow = int(raw_size * width / height)
123
+ else:
124
+ ow = int(size * width / height)
125
+
126
+ return (oh, ow)
127
+
128
+
129
+ def get_image_size_for_max_height_width(
130
+ input_image: np.ndarray,
131
+ max_height: int,
132
+ max_width: int,
133
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
134
+ ) -> tuple[int, int]:
135
+ """
136
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
137
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
138
+ to at least one of the edges be equal to max_height or max_width.
139
+
140
+ For example:
141
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
142
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
143
+
144
+ Args:
145
+ input_image (`np.ndarray`):
146
+ The image to resize.
147
+ max_height (`int`):
148
+ The maximum allowed height.
149
+ max_width (`int`):
150
+ The maximum allowed width.
151
+ input_data_format (`ChannelDimension` or `str`, *optional*):
152
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
153
+ """
154
+ image_size = get_image_size(input_image, input_data_format)
155
+ height, width = image_size
156
+ height_scale = max_height / height
157
+ width_scale = max_width / width
158
+ min_scale = min(height_scale, width_scale)
159
+ new_height = int(height * min_scale)
160
+ new_width = int(width * min_scale)
161
+ return new_height, new_width
162
+
163
+
164
+ def get_resize_output_image_size(
165
+ input_image: np.ndarray,
166
+ size: Union[int, tuple[int, int], list[int]],
167
+ max_size: Optional[int] = None,
168
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
169
+ ) -> tuple[int, int]:
170
+ """
171
+ Computes the output image size given the input image size and the desired output size. If the desired output size
172
+ is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
173
+ image size is computed by keeping the aspect ratio of the input image size.
174
+
175
+ Args:
176
+ input_image (`np.ndarray`):
177
+ The image to resize.
178
+ size (`int` or `tuple[int, int]` or `list[int]`):
179
+ The desired output size.
180
+ max_size (`int`, *optional*):
181
+ The maximum allowed output size.
182
+ input_data_format (`ChannelDimension` or `str`, *optional*):
183
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
184
+ """
185
+ image_size = get_image_size(input_image, input_data_format)
186
+ if isinstance(size, (list, tuple)):
187
+ return size
188
+
189
+ return get_size_with_aspect_ratio(image_size, size, max_size)
190
+
191
+
192
+ def get_numpy_to_framework_fn(arr) -> Callable:
193
+ """
194
+ Returns a function that converts a numpy array to the framework of the input array.
195
+
196
+ Args:
197
+ arr (`np.ndarray`): The array to convert.
198
+ """
199
+ if isinstance(arr, np.ndarray):
200
+ return np.array
201
+ if is_tf_available() and is_tf_tensor(arr):
202
+ import tensorflow as tf
203
+
204
+ return tf.convert_to_tensor
205
+ if is_torch_available() and is_torch_tensor(arr):
206
+ import torch
207
+
208
+ return torch.tensor
209
+ if is_flax_available() and is_jax_tensor(arr):
210
+ import jax.numpy as jnp
211
+
212
+ return jnp.array
213
+ raise ValueError(f"Cannot convert arrays of type {type(arr)}")
214
+
215
+
216
+ def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
217
+ """
218
+ Squeezes an array, but only if the axis specified has dim 1.
219
+ """
220
+ if axis is None:
221
+ return arr.squeeze()
222
+
223
+ try:
224
+ return arr.squeeze(axis=axis)
225
+ except ValueError:
226
+ return arr
227
+
228
+
229
+ def normalize_annotation(annotation: dict, image_size: tuple[int, int]) -> dict:
230
+ image_height, image_width = image_size
231
+ norm_annotation = {}
232
+ for key, value in annotation.items():
233
+ if key == "boxes":
234
+ boxes = value
235
+ boxes = corners_to_center_format(boxes)
236
+ boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
237
+ norm_annotation[key] = boxes
238
+ else:
239
+ norm_annotation[key] = value
240
+ return norm_annotation
241
+
242
+
243
+ # Copied from transformers.models.vilt.image_processing_vilt.max_across_indices
244
+ def max_across_indices(values: Iterable[Any]) -> list[Any]:
245
+ """
246
+ Return the maximum value across all indices of an iterable of values.
247
+ """
248
+ return [max(values_i) for values_i in zip(*values)]
249
+
250
+
251
+ # Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
252
+ def get_max_height_width(
253
+ images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
254
+ ) -> list[int]:
255
+ """
256
+ Get the maximum height and width across all images in a batch.
257
+ """
258
+ if input_data_format is None:
259
+ input_data_format = infer_channel_dimension_format(images[0])
260
+
261
+ if input_data_format == ChannelDimension.FIRST:
262
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
263
+ elif input_data_format == ChannelDimension.LAST:
264
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
265
+ else:
266
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
267
+ return (max_height, max_width)
268
+
269
+
270
+ # Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
271
+ def make_pixel_mask(
272
+ image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
273
+ ) -> np.ndarray:
274
+ """
275
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
276
+
277
+ Args:
278
+ image (`np.ndarray`):
279
+ Image to make the pixel mask for.
280
+ output_size (`tuple[int, int]`):
281
+ Output size of the mask.
282
+ """
283
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
284
+ mask = np.zeros(output_size, dtype=np.int64)
285
+ mask[:input_height, :input_width] = 1
286
+ return mask
287
+
288
+
289
+ # inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33
290
+ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray:
291
+ """
292
+ Convert a COCO polygon annotation to a mask.
293
+
294
+ Args:
295
+ segmentations (`list[list[float]]`):
296
+ List of polygons, each polygon represented by a list of x-y coordinates.
297
+ height (`int`):
298
+ Height of the mask.
299
+ width (`int`):
300
+ Width of the mask.
301
+ """
302
+ try:
303
+ from pycocotools import mask as coco_mask
304
+ except ImportError:
305
+ raise ImportError("Pycocotools is not installed in your environment.")
306
+
307
+ masks = []
308
+ for polygons in segmentations:
309
+ rles = coco_mask.frPyObjects(polygons, height, width)
310
+ mask = coco_mask.decode(rles)
311
+ if len(mask.shape) < 3:
312
+ mask = mask[..., None]
313
+ mask = np.asarray(mask, dtype=np.uint8)
314
+ mask = np.any(mask, axis=2)
315
+ masks.append(mask)
316
+ if masks:
317
+ masks = np.stack(masks, axis=0)
318
+ else:
319
+ masks = np.zeros((0, height, width), dtype=np.uint8)
320
+
321
+ return masks
322
+
323
+
324
+ # inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50
325
+ def prepare_coco_detection_annotation(
326
+ image,
327
+ target,
328
+ return_segmentation_masks: bool = False,
329
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
330
+ ):
331
+ """
332
+ Convert the target in COCO format into the format expected by DETR.
333
+ """
334
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
335
+
336
+ image_id = target["image_id"]
337
+ image_id = np.asarray([image_id], dtype=np.int64)
338
+
339
+ # Get all COCO annotations for the given image.
340
+ annotations = target["annotations"]
341
+ annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
342
+
343
+ classes = [obj["category_id"] for obj in annotations]
344
+ classes = np.asarray(classes, dtype=np.int64)
345
+
346
+ # for conversion to coco api
347
+ area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
348
+ iscrowd = np.asarray([obj.get("iscrowd", 0) for obj in annotations], dtype=np.int64)
349
+
350
+ boxes = [obj["bbox"] for obj in annotations]
351
+ # guard against no boxes via resizing
352
+ boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
353
+ boxes[:, 2:] += boxes[:, :2]
354
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
355
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
356
+
357
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
358
+
359
+ new_target = {}
360
+ new_target["image_id"] = image_id
361
+ new_target["class_labels"] = classes[keep]
362
+ new_target["boxes"] = boxes[keep]
363
+ new_target["area"] = area[keep]
364
+ new_target["iscrowd"] = iscrowd[keep]
365
+ new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
366
+
367
+ if annotations and "keypoints" in annotations[0]:
368
+ keypoints = [obj["keypoints"] for obj in annotations]
369
+ # Converting the filtered keypoints list to a numpy array
370
+ keypoints = np.asarray(keypoints, dtype=np.float32)
371
+ # Apply the keep mask here to filter the relevant annotations
372
+ keypoints = keypoints[keep]
373
+ num_keypoints = keypoints.shape[0]
374
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
375
+ new_target["keypoints"] = keypoints
376
+
377
+ if return_segmentation_masks:
378
+ segmentation_masks = [obj["segmentation"] for obj in annotations]
379
+ masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width)
380
+ new_target["masks"] = masks[keep]
381
+
382
+ return new_target
383
+
384
+
385
+ def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
386
+ """
387
+ Compute the bounding boxes around the provided panoptic segmentation masks.
388
+
389
+ Args:
390
+ masks: masks in format `[number_masks, height, width]` where N is the number of masks
391
+
392
+ Returns:
393
+ boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
394
+ """
395
+ if masks.size == 0:
396
+ return np.zeros((0, 4))
397
+
398
+ h, w = masks.shape[-2:]
399
+ y = np.arange(0, h, dtype=np.float32)
400
+ x = np.arange(0, w, dtype=np.float32)
401
+ # see https://github.com/pytorch/pytorch/issues/50276
402
+ y, x = np.meshgrid(y, x, indexing="ij")
403
+
404
+ x_mask = masks * np.expand_dims(x, axis=0)
405
+ x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1)
406
+ x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool)))
407
+ x_min = x.filled(fill_value=1e8)
408
+ x_min = x_min.reshape(x_min.shape[0], -1).min(-1)
409
+
410
+ y_mask = masks * np.expand_dims(y, axis=0)
411
+ y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1)
412
+ y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool)))
413
+ y_min = y.filled(fill_value=1e8)
414
+ y_min = y_min.reshape(y_min.shape[0], -1).min(-1)
415
+
416
+ return np.stack([x_min, y_min, x_max, y_max], 1)
417
+
418
+
419
+ def prepare_coco_panoptic_annotation(
420
+ image: np.ndarray,
421
+ target: dict,
422
+ masks_path: Union[str, pathlib.Path],
423
+ return_masks: bool = True,
424
+ input_data_format: Union[ChannelDimension, str] = None,
425
+ ) -> dict:
426
+ """
427
+ Prepare a coco panoptic annotation for DETR.
428
+ """
429
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
430
+ annotation_path = pathlib.Path(masks_path) / target["file_name"]
431
+
432
+ new_target = {}
433
+ new_target["image_id"] = np.asarray([target["image_id"] if "image_id" in target else target["id"]], dtype=np.int64)
434
+ new_target["size"] = np.asarray([image_height, image_width], dtype=np.int64)
435
+ new_target["orig_size"] = np.asarray([image_height, image_width], dtype=np.int64)
436
+
437
+ if "segments_info" in target:
438
+ masks = np.asarray(PIL.Image.open(annotation_path), dtype=np.uint32)
439
+ masks = rgb_to_id(masks)
440
+
441
+ ids = np.array([segment_info["id"] for segment_info in target["segments_info"]])
442
+ masks = masks == ids[:, None, None]
443
+ masks = masks.astype(np.uint8)
444
+ if return_masks:
445
+ new_target["masks"] = masks
446
+ new_target["boxes"] = masks_to_boxes(masks)
447
+ new_target["class_labels"] = np.array(
448
+ [segment_info["category_id"] for segment_info in target["segments_info"]], dtype=np.int64
449
+ )
450
+ new_target["iscrowd"] = np.asarray(
451
+ [segment_info["iscrowd"] for segment_info in target["segments_info"]], dtype=np.int64
452
+ )
453
+ new_target["area"] = np.asarray(
454
+ [segment_info["area"] for segment_info in target["segments_info"]], dtype=np.float32
455
+ )
456
+
457
+ return new_target
458
+
459
+
460
+ def get_segmentation_image(
461
+ masks: np.ndarray, input_size: tuple, target_size: tuple, stuff_equiv_classes, deduplicate=False
462
+ ):
463
+ h, w = input_size
464
+ final_h, final_w = target_size
465
+
466
+ m_id = scipy.special.softmax(masks.transpose(0, 1), -1)
467
+
468
+ if m_id.shape[-1] == 0:
469
+ # We didn't detect any mask :(
470
+ m_id = np.zeros((h, w), dtype=np.int64)
471
+ else:
472
+ m_id = m_id.argmax(-1).reshape(h, w)
473
+
474
+ if deduplicate:
475
+ # Merge the masks corresponding to the same stuff class
476
+ for equiv in stuff_equiv_classes.values():
477
+ for eq_id in equiv:
478
+ m_id[m_id == eq_id] = equiv[0]
479
+
480
+ seg_img = id_to_rgb(m_id)
481
+ seg_img = resize(seg_img, (final_w, final_h), resample=PILImageResampling.NEAREST)
482
+ return seg_img
483
+
484
+
485
+ def get_mask_area(seg_img: np.ndarray, target_size: tuple[int, int], n_classes: int) -> np.ndarray:
486
+ final_h, final_w = target_size
487
+ np_seg_img = seg_img.astype(np.uint8)
488
+ np_seg_img = np_seg_img.reshape(final_h, final_w, 3)
489
+ m_id = rgb_to_id(np_seg_img)
490
+ area = [(m_id == i).sum() for i in range(n_classes)]
491
+ return area
492
+
493
+
494
+ def score_labels_from_class_probabilities(logits: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
495
+ probs = scipy.special.softmax(logits, axis=-1)
496
+ labels = probs.argmax(-1, keepdims=True)
497
+ scores = np.take_along_axis(probs, labels, axis=-1)
498
+ scores, labels = scores.squeeze(-1), labels.squeeze(-1)
499
+ return scores, labels
500
+
501
+
502
+ def post_process_panoptic_sample(
503
+ out_logits: np.ndarray,
504
+ masks: np.ndarray,
505
+ boxes: np.ndarray,
506
+ processed_size: tuple[int, int],
507
+ target_size: tuple[int, int],
508
+ is_thing_map: dict,
509
+ threshold=0.85,
510
+ ) -> dict:
511
+ """
512
+ Converts the output of [`DetrForSegmentation`] into panoptic segmentation predictions for a single sample.
513
+
514
+ Args:
515
+ out_logits (`torch.Tensor`):
516
+ The logits for this sample.
517
+ masks (`torch.Tensor`):
518
+ The predicted segmentation masks for this sample.
519
+ boxes (`torch.Tensor`):
520
+ The predicted bounding boxes for this sample. The boxes are in the normalized format `(center_x, center_y,
521
+ width, height)` and values between `[0, 1]`, relative to the size the image (disregarding padding).
522
+ processed_size (`tuple[int, int]`):
523
+ The processed size of the image `(height, width)`, as returned by the preprocessing step i.e. the size
524
+ after data augmentation but before batching.
525
+ target_size (`tuple[int, int]`):
526
+ The target size of the image, `(height, width)` corresponding to the requested final size of the
527
+ prediction.
528
+ is_thing_map (`Dict`):
529
+ A dictionary mapping class indices to a boolean value indicating whether the class is a thing or not.
530
+ threshold (`float`, *optional*, defaults to 0.85):
531
+ The threshold used to binarize the segmentation masks.
532
+ """
533
+ # we filter empty queries and detection below threshold
534
+ scores, labels = score_labels_from_class_probabilities(out_logits)
535
+ keep = (labels != out_logits.shape[-1] - 1) & (scores > threshold)
536
+
537
+ cur_scores = scores[keep]
538
+ cur_classes = labels[keep]
539
+ cur_boxes = center_to_corners_format(boxes[keep])
540
+
541
+ if len(cur_boxes) != len(cur_classes):
542
+ raise ValueError("Not as many boxes as there are classes")
543
+
544
+ cur_masks = masks[keep]
545
+ cur_masks = resize(cur_masks[:, None], processed_size, resample=PILImageResampling.BILINEAR)
546
+ cur_masks = safe_squeeze(cur_masks, 1)
547
+ b, h, w = cur_masks.shape
548
+
549
+ # It may be that we have several predicted masks for the same stuff class.
550
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
551
+ cur_masks = cur_masks.reshape(b, -1)
552
+ stuff_equiv_classes = defaultdict(list)
553
+ for k, label in enumerate(cur_classes):
554
+ if not is_thing_map[label]:
555
+ stuff_equiv_classes[label].append(k)
556
+
557
+ seg_img = get_segmentation_image(cur_masks, processed_size, target_size, stuff_equiv_classes, deduplicate=True)
558
+ area = get_mask_area(cur_masks, processed_size, n_classes=len(cur_scores))
559
+
560
+ # We filter out any mask that is too small
561
+ if cur_classes.size() > 0:
562
+ # We know filter empty masks as long as we find some
563
+ filtered_small = np.array([a <= 4 for a in area], dtype=bool)
564
+ while filtered_small.any():
565
+ cur_masks = cur_masks[~filtered_small]
566
+ cur_scores = cur_scores[~filtered_small]
567
+ cur_classes = cur_classes[~filtered_small]
568
+ seg_img = get_segmentation_image(cur_masks, (h, w), target_size, stuff_equiv_classes, deduplicate=True)
569
+ area = get_mask_area(seg_img, target_size, n_classes=len(cur_scores))
570
+ filtered_small = np.array([a <= 4 for a in area], dtype=bool)
571
+ else:
572
+ cur_classes = np.ones((1, 1), dtype=np.int64)
573
+
574
+ segments_info = [
575
+ {"id": i, "isthing": is_thing_map[cat], "category_id": int(cat), "area": a}
576
+ for i, (cat, a) in enumerate(zip(cur_classes, area))
577
+ ]
578
+ del cur_classes
579
+
580
+ with io.BytesIO() as out:
581
+ PIL.Image.fromarray(seg_img).save(out, format="PNG")
582
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
583
+
584
+ return predictions
585
+
586
+
587
+ def resize_annotation(
588
+ annotation: dict[str, Any],
589
+ orig_size: tuple[int, int],
590
+ target_size: tuple[int, int],
591
+ threshold: float = 0.5,
592
+ resample: PILImageResampling = PILImageResampling.NEAREST,
593
+ ):
594
+ """
595
+ Resizes an annotation to a target size.
596
+
597
+ Args:
598
+ annotation (`dict[str, Any]`):
599
+ The annotation dictionary.
600
+ orig_size (`tuple[int, int]`):
601
+ The original size of the input image.
602
+ target_size (`tuple[int, int]`):
603
+ The target size of the image, as returned by the preprocessing `resize` step.
604
+ threshold (`float`, *optional*, defaults to 0.5):
605
+ The threshold used to binarize the segmentation masks.
606
+ resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
607
+ The resampling filter to use when resizing the masks.
608
+ """
609
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
610
+ ratio_height, ratio_width = ratios
611
+
612
+ new_annotation = {}
613
+ new_annotation["size"] = target_size
614
+
615
+ for key, value in annotation.items():
616
+ if key == "boxes":
617
+ boxes = value
618
+ scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
619
+ new_annotation["boxes"] = scaled_boxes
620
+ elif key == "area":
621
+ area = value
622
+ scaled_area = area * (ratio_width * ratio_height)
623
+ new_annotation["area"] = scaled_area
624
+ elif key == "masks":
625
+ masks = value[:, None]
626
+ masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
627
+ masks = masks.astype(np.float32)
628
+ masks = masks[:, 0] > threshold
629
+ new_annotation["masks"] = masks
630
+ elif key == "size":
631
+ new_annotation["size"] = target_size
632
+ else:
633
+ new_annotation[key] = value
634
+
635
+ return new_annotation
636
+
637
+
638
+ # TODO - (Amy) make compatible with other frameworks
639
+ def binary_mask_to_rle(mask):
640
+ """
641
+ Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
642
+
643
+ Args:
644
+ mask (`torch.Tensor` or `numpy.array`):
645
+ A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
646
+ segment_id or class_id.
647
+ Returns:
648
+ `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
649
+ format.
650
+ """
651
+ if is_torch_tensor(mask):
652
+ mask = mask.numpy()
653
+
654
+ pixels = mask.flatten()
655
+ pixels = np.concatenate([[0], pixels, [0]])
656
+ runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
657
+ runs[1::2] -= runs[::2]
658
+ return list(runs)
659
+
660
+
661
+ # TODO - (Amy) make compatible with other frameworks
662
+ def convert_segmentation_to_rle(segmentation):
663
+ """
664
+ Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
665
+
666
+ Args:
667
+ segmentation (`torch.Tensor` or `numpy.array`):
668
+ A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
669
+ Returns:
670
+ `list[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
671
+ """
672
+ segment_ids = torch.unique(segmentation)
673
+
674
+ run_length_encodings = []
675
+ for idx in segment_ids:
676
+ mask = torch.where(segmentation == idx, 1, 0)
677
+ rle = binary_mask_to_rle(mask)
678
+ run_length_encodings.append(rle)
679
+
680
+ return run_length_encodings
681
+
682
+
683
+ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
684
+ """
685
+ Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
686
+ `labels`.
687
+
688
+ Args:
689
+ masks (`torch.Tensor`):
690
+ A tensor of shape `(num_queries, height, width)`.
691
+ scores (`torch.Tensor`):
692
+ A tensor of shape `(num_queries)`.
693
+ labels (`torch.Tensor`):
694
+ A tensor of shape `(num_queries)`.
695
+ object_mask_threshold (`float`):
696
+ A number between 0 and 1 used to binarize the masks.
697
+ Raises:
698
+ `ValueError`: Raised when the first dimension doesn't match in all input tensors.
699
+ Returns:
700
+ `tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
701
+ < `object_mask_threshold`.
702
+ """
703
+ if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
704
+ raise ValueError("mask, scores and labels must have the same shape!")
705
+
706
+ to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
707
+
708
+ return masks[to_keep], scores[to_keep], labels[to_keep]
709
+
710
+
711
+ def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
712
+ # Get the mask associated with the k class
713
+ mask_k = mask_labels == k
714
+ mask_k_area = mask_k.sum()
715
+
716
+ # Compute the area of all the stuff in query k
717
+ original_area = (mask_probs[k] >= mask_threshold).sum()
718
+ mask_exists = mask_k_area > 0 and original_area > 0
719
+
720
+ # Eliminate disconnected tiny segments
721
+ if mask_exists:
722
+ area_ratio = mask_k_area / original_area
723
+ if not area_ratio.item() > overlap_mask_area_threshold:
724
+ mask_exists = False
725
+
726
+ return mask_exists, mask_k
727
+
728
+
729
+ def compute_segments(
730
+ mask_probs,
731
+ pred_scores,
732
+ pred_labels,
733
+ mask_threshold: float = 0.5,
734
+ overlap_mask_area_threshold: float = 0.8,
735
+ label_ids_to_fuse: Optional[set[int]] = None,
736
+ target_size: Optional[tuple[int, int]] = None,
737
+ ):
738
+ height = mask_probs.shape[1] if target_size is None else target_size[0]
739
+ width = mask_probs.shape[2] if target_size is None else target_size[1]
740
+
741
+ segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
742
+ segments: list[dict] = []
743
+
744
+ if target_size is not None:
745
+ mask_probs = nn.functional.interpolate(
746
+ mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
747
+ )[0]
748
+
749
+ current_segment_id = 0
750
+
751
+ # Weigh each mask by its prediction score
752
+ mask_probs *= pred_scores.view(-1, 1, 1)
753
+ mask_labels = mask_probs.argmax(0) # [height, width]
754
+
755
+ # Keep track of instances of each class
756
+ stuff_memory_list: dict[str, int] = {}
757
+ for k in range(pred_labels.shape[0]):
758
+ pred_class = pred_labels[k].item()
759
+ should_fuse = pred_class in label_ids_to_fuse
760
+
761
+ # Check if mask exists and large enough to be a segment
762
+ mask_exists, mask_k = check_segment_validity(
763
+ mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
764
+ )
765
+
766
+ if mask_exists:
767
+ if pred_class in stuff_memory_list:
768
+ current_segment_id = stuff_memory_list[pred_class]
769
+ else:
770
+ current_segment_id += 1
771
+
772
+ # Add current object segment to final segmentation map
773
+ segmentation[mask_k] = current_segment_id
774
+ segment_score = round(pred_scores[k].item(), 6)
775
+ segments.append(
776
+ {
777
+ "id": current_segment_id,
778
+ "label_id": pred_class,
779
+ "was_fused": should_fuse,
780
+ "score": segment_score,
781
+ }
782
+ )
783
+ if should_fuse:
784
+ stuff_memory_list[pred_class] = current_segment_id
785
+
786
+ return segmentation, segments
787
+
788
+
789
+ @requires(backends=("vision",))
790
+ class DetrImageProcessor(BaseImageProcessor):
791
+ r"""
792
+ Constructs a Detr image processor.
793
+
794
+ Args:
795
+ format (`str`, *optional*, defaults to `"coco_detection"`):
796
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
797
+ do_resize (`bool`, *optional*, defaults to `True`):
798
+ Controls whether to resize the image's `(height, width)` dimensions to the specified `size`. Can be
799
+ overridden by the `do_resize` parameter in the `preprocess` method.
800
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`):
801
+ Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
802
+ in the `preprocess` method. Available options are:
803
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
804
+ Do NOT keep the aspect ratio.
805
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
806
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
807
+ less or equal to `longest_edge`.
808
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
809
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
810
+ `max_width`.
811
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
812
+ Resampling filter to use if resizing the image.
813
+ do_rescale (`bool`, *optional*, defaults to `True`):
814
+ Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
815
+ `do_rescale` parameter in the `preprocess` method.
816
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
817
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
818
+ `preprocess` method.
819
+ do_normalize (`bool`, *optional*, defaults to True):
820
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
821
+ `preprocess` method.
822
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
823
+ Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
824
+ channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
825
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
826
+ Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
827
+ for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
828
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
829
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
830
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
831
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
832
+ do_pad (`bool`, *optional*, defaults to `True`):
833
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
834
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
835
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
836
+ Otherwise, the image will be padded to the maximum height and width of the batch.
837
+ pad_size (`dict[str, int]`, *optional*):
838
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
839
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
840
+ height and width in the batch.
841
+ """
842
+
843
+ model_input_names = ["pixel_values", "pixel_mask"]
844
+
845
+ def __init__(
846
+ self,
847
+ format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
848
+ do_resize: bool = True,
849
+ size: Optional[dict[str, int]] = None,
850
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
851
+ do_rescale: bool = True,
852
+ rescale_factor: Union[int, float] = 1 / 255,
853
+ do_normalize: bool = True,
854
+ image_mean: Optional[Union[float, list[float]]] = None,
855
+ image_std: Optional[Union[float, list[float]]] = None,
856
+ do_convert_annotations: Optional[bool] = None,
857
+ do_pad: bool = True,
858
+ pad_size: Optional[dict[str, int]] = None,
859
+ **kwargs,
860
+ ) -> None:
861
+ if "pad_and_return_pixel_mask" in kwargs:
862
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
863
+
864
+ if "max_size" in kwargs:
865
+ logger.warning_once(
866
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
867
+ "Please specify in `size['longest_edge'] instead`.",
868
+ )
869
+ max_size = kwargs.pop("max_size")
870
+ else:
871
+ max_size = None if size is None else 1333
872
+
873
+ size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
874
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
875
+
876
+ # Backwards compatibility
877
+ if do_convert_annotations is None:
878
+ do_convert_annotations = do_normalize
879
+
880
+ super().__init__(**kwargs)
881
+ self.format = format
882
+ self.do_resize = do_resize
883
+ self.size = size
884
+ self.resample = resample
885
+ self.do_rescale = do_rescale
886
+ self.rescale_factor = rescale_factor
887
+ self.do_normalize = do_normalize
888
+ self.do_convert_annotations = do_convert_annotations
889
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
890
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
891
+ self.do_pad = do_pad
892
+ self.pad_size = pad_size
893
+ self._valid_processor_keys = [
894
+ "images",
895
+ "annotations",
896
+ "return_segmentation_masks",
897
+ "masks_path",
898
+ "do_resize",
899
+ "size",
900
+ "resample",
901
+ "do_rescale",
902
+ "rescale_factor",
903
+ "do_normalize",
904
+ "do_convert_annotations",
905
+ "image_mean",
906
+ "image_std",
907
+ "do_pad",
908
+ "pad_size",
909
+ "format",
910
+ "return_tensors",
911
+ "data_format",
912
+ "input_data_format",
913
+ ]
914
+
915
+ @classmethod
916
+ def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
917
+ """
918
+ Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
919
+ created using from_dict and kwargs e.g. `DetrImageProcessor.from_pretrained(checkpoint, size=600,
920
+ max_size=800)`
921
+ """
922
+ image_processor_dict = image_processor_dict.copy()
923
+ if "max_size" in kwargs:
924
+ image_processor_dict["max_size"] = kwargs.pop("max_size")
925
+ if "pad_and_return_pixel_mask" in kwargs:
926
+ image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
927
+ return super().from_dict(image_processor_dict, **kwargs)
928
+
929
+ def prepare_annotation(
930
+ self,
931
+ image: np.ndarray,
932
+ target: dict,
933
+ format: Optional[AnnotationFormat] = None,
934
+ return_segmentation_masks: Optional[bool] = None,
935
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
936
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
937
+ ) -> dict:
938
+ """
939
+ Prepare an annotation for feeding into DETR model.
940
+ """
941
+ format = format if format is not None else self.format
942
+
943
+ if format == AnnotationFormat.COCO_DETECTION:
944
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
945
+ target = prepare_coco_detection_annotation(
946
+ image, target, return_segmentation_masks, input_data_format=input_data_format
947
+ )
948
+ elif format == AnnotationFormat.COCO_PANOPTIC:
949
+ return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
950
+ target = prepare_coco_panoptic_annotation(
951
+ image,
952
+ target,
953
+ masks_path=masks_path,
954
+ return_masks=return_segmentation_masks,
955
+ input_data_format=input_data_format,
956
+ )
957
+ else:
958
+ raise ValueError(f"Format {format} is not supported.")
959
+ return target
960
+
961
+ def resize(
962
+ self,
963
+ image: np.ndarray,
964
+ size: dict[str, int],
965
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
966
+ data_format: Optional[ChannelDimension] = None,
967
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
968
+ **kwargs,
969
+ ) -> np.ndarray:
970
+ """
971
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
972
+ int, smaller edge of the image will be matched to this number.
973
+
974
+ Args:
975
+ image (`np.ndarray`):
976
+ Image to resize.
977
+ size (`dict[str, int]`):
978
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
979
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
980
+ Do NOT keep the aspect ratio.
981
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
982
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
983
+ less or equal to `longest_edge`.
984
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
985
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
986
+ `max_width`.
987
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
988
+ Resampling filter to use if resizing the image.
989
+ data_format (`str` or `ChannelDimension`, *optional*):
990
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
991
+ image is used.
992
+ input_data_format (`ChannelDimension` or `str`, *optional*):
993
+ The channel dimension format of the input image. If not provided, it will be inferred.
994
+ """
995
+ if "max_size" in kwargs:
996
+ logger.warning_once(
997
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
998
+ "Please specify in `size['longest_edge'] instead`.",
999
+ )
1000
+ max_size = kwargs.pop("max_size")
1001
+ else:
1002
+ max_size = None
1003
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
1004
+ if "shortest_edge" in size and "longest_edge" in size:
1005
+ new_size = get_resize_output_image_size(
1006
+ image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
1007
+ )
1008
+ elif "max_height" in size and "max_width" in size:
1009
+ new_size = get_image_size_for_max_height_width(
1010
+ image, size["max_height"], size["max_width"], input_data_format=input_data_format
1011
+ )
1012
+ elif "height" in size and "width" in size:
1013
+ new_size = (size["height"], size["width"])
1014
+ else:
1015
+ raise ValueError(
1016
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
1017
+ f" {size.keys()}."
1018
+ )
1019
+ image = resize(
1020
+ image,
1021
+ size=new_size,
1022
+ resample=resample,
1023
+ data_format=data_format,
1024
+ input_data_format=input_data_format,
1025
+ **kwargs,
1026
+ )
1027
+ return image
1028
+
1029
+ def resize_annotation(
1030
+ self,
1031
+ annotation,
1032
+ orig_size,
1033
+ size,
1034
+ resample: PILImageResampling = PILImageResampling.NEAREST,
1035
+ ) -> dict:
1036
+ """
1037
+ Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
1038
+ to this number.
1039
+ """
1040
+ return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
1041
+
1042
+ # TODO (Amy) - update to use `rescale_factor` instead of `scale`
1043
+ def rescale(
1044
+ self,
1045
+ image: np.ndarray,
1046
+ rescale_factor: float,
1047
+ data_format: Optional[Union[str, ChannelDimension]] = None,
1048
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1049
+ ) -> np.ndarray:
1050
+ """
1051
+ Rescale the image by the given factor. image = image * rescale_factor.
1052
+
1053
+ Args:
1054
+ image (`np.ndarray`):
1055
+ Image to rescale.
1056
+ rescale_factor (`float`):
1057
+ The value to use for rescaling.
1058
+ data_format (`str` or `ChannelDimension`, *optional*):
1059
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
1060
+ image is used. Can be one of:
1061
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1062
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1063
+ input_data_format (`str` or `ChannelDimension`, *optional*):
1064
+ The channel dimension format for the input image. If unset, is inferred from the input image. Can be
1065
+ one of:
1066
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1067
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1068
+ """
1069
+ return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
1070
+
1071
+ def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
1072
+ """
1073
+ Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
1074
+ `[center_x, center_y, width, height]` format and from absolute to relative pixel values.
1075
+ """
1076
+ return normalize_annotation(annotation, image_size=image_size)
1077
+
1078
+ def _update_annotation_for_padded_image(
1079
+ self,
1080
+ annotation: dict,
1081
+ input_image_size: tuple[int, int],
1082
+ output_image_size: tuple[int, int],
1083
+ padding,
1084
+ update_bboxes,
1085
+ ) -> dict:
1086
+ """
1087
+ Update the annotation for a padded image.
1088
+ """
1089
+ new_annotation = {}
1090
+ new_annotation["size"] = output_image_size
1091
+
1092
+ for key, value in annotation.items():
1093
+ if key == "masks":
1094
+ masks = value
1095
+ masks = pad(
1096
+ masks,
1097
+ padding,
1098
+ mode=PaddingMode.CONSTANT,
1099
+ constant_values=0,
1100
+ input_data_format=ChannelDimension.FIRST,
1101
+ )
1102
+ masks = safe_squeeze(masks, 1)
1103
+ new_annotation["masks"] = masks
1104
+ elif key == "boxes" and update_bboxes:
1105
+ boxes = value
1106
+ boxes *= np.asarray(
1107
+ [
1108
+ input_image_size[1] / output_image_size[1],
1109
+ input_image_size[0] / output_image_size[0],
1110
+ input_image_size[1] / output_image_size[1],
1111
+ input_image_size[0] / output_image_size[0],
1112
+ ]
1113
+ )
1114
+ new_annotation["boxes"] = boxes
1115
+ elif key == "size":
1116
+ new_annotation["size"] = output_image_size
1117
+ else:
1118
+ new_annotation[key] = value
1119
+ return new_annotation
1120
+
1121
+ def _pad_image(
1122
+ self,
1123
+ image: np.ndarray,
1124
+ output_size: tuple[int, int],
1125
+ annotation: Optional[dict[str, Any]] = None,
1126
+ constant_values: Union[float, Iterable[float]] = 0,
1127
+ data_format: Optional[ChannelDimension] = None,
1128
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1129
+ update_bboxes: bool = True,
1130
+ ) -> np.ndarray:
1131
+ """
1132
+ Pad an image with zeros to the given size.
1133
+ """
1134
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
1135
+ output_height, output_width = output_size
1136
+
1137
+ pad_bottom = output_height - input_height
1138
+ pad_right = output_width - input_width
1139
+ padding = ((0, pad_bottom), (0, pad_right))
1140
+ padded_image = pad(
1141
+ image,
1142
+ padding,
1143
+ mode=PaddingMode.CONSTANT,
1144
+ constant_values=constant_values,
1145
+ data_format=data_format,
1146
+ input_data_format=input_data_format,
1147
+ )
1148
+ if annotation is not None:
1149
+ annotation = self._update_annotation_for_padded_image(
1150
+ annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
1151
+ )
1152
+ return padded_image, annotation
1153
+
1154
+ def pad(
1155
+ self,
1156
+ images: list[np.ndarray],
1157
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
1158
+ constant_values: Union[float, Iterable[float]] = 0,
1159
+ return_pixel_mask: bool = True,
1160
+ return_tensors: Optional[Union[str, TensorType]] = None,
1161
+ data_format: Optional[ChannelDimension] = None,
1162
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1163
+ update_bboxes: bool = True,
1164
+ pad_size: Optional[dict[str, int]] = None,
1165
+ ) -> BatchFeature:
1166
+ """
1167
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
1168
+ in the batch and optionally returns their corresponding pixel mask.
1169
+
1170
+ Args:
1171
+ images (list[`np.ndarray`]):
1172
+ Images to pad.
1173
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
1174
+ Annotations to transform according to the padding that is applied to the images.
1175
+ constant_values (`float` or `Iterable[float]`, *optional*):
1176
+ The value to use for the padding if `mode` is `"constant"`.
1177
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
1178
+ Whether to return a pixel mask.
1179
+ return_tensors (`str` or `TensorType`, *optional*):
1180
+ The type of tensors to return. Can be one of:
1181
+ - Unset: Return a list of `np.ndarray`.
1182
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
1183
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
1184
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
1185
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
1186
+ data_format (`str` or `ChannelDimension`, *optional*):
1187
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
1188
+ input_data_format (`ChannelDimension` or `str`, *optional*):
1189
+ The channel dimension format of the input image. If not provided, it will be inferred.
1190
+ update_bboxes (`bool`, *optional*, defaults to `True`):
1191
+ Whether to update the bounding boxes in the annotations to match the padded images. If the
1192
+ bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
1193
+ format, the bounding boxes will not be updated.
1194
+ pad_size (`dict[str, int]`, *optional*):
1195
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
1196
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
1197
+ height and width in the batch.
1198
+ """
1199
+ pad_size = pad_size if pad_size is not None else self.pad_size
1200
+ if pad_size is not None:
1201
+ padded_size = (pad_size["height"], pad_size["width"])
1202
+ else:
1203
+ padded_size = get_max_height_width(images, input_data_format=input_data_format)
1204
+
1205
+ annotation_list = annotations if annotations is not None else [None] * len(images)
1206
+ padded_images = []
1207
+ padded_annotations = []
1208
+ for image, annotation in zip(images, annotation_list):
1209
+ padded_image, padded_annotation = self._pad_image(
1210
+ image,
1211
+ padded_size,
1212
+ annotation,
1213
+ constant_values=constant_values,
1214
+ data_format=data_format,
1215
+ input_data_format=input_data_format,
1216
+ update_bboxes=update_bboxes,
1217
+ )
1218
+ padded_images.append(padded_image)
1219
+ padded_annotations.append(padded_annotation)
1220
+
1221
+ data = {"pixel_values": padded_images}
1222
+
1223
+ if return_pixel_mask:
1224
+ masks = [
1225
+ make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
1226
+ for image in images
1227
+ ]
1228
+ data["pixel_mask"] = masks
1229
+
1230
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
1231
+
1232
+ if annotations is not None:
1233
+ encoded_inputs["labels"] = [
1234
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
1235
+ ]
1236
+
1237
+ return encoded_inputs
1238
+
1239
+ def preprocess(
1240
+ self,
1241
+ images: ImageInput,
1242
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
1243
+ return_segmentation_masks: Optional[bool] = None,
1244
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
1245
+ do_resize: Optional[bool] = None,
1246
+ size: Optional[dict[str, int]] = None,
1247
+ resample=None, # PILImageResampling
1248
+ do_rescale: Optional[bool] = None,
1249
+ rescale_factor: Optional[Union[int, float]] = None,
1250
+ do_normalize: Optional[bool] = None,
1251
+ do_convert_annotations: Optional[bool] = None,
1252
+ image_mean: Optional[Union[float, list[float]]] = None,
1253
+ image_std: Optional[Union[float, list[float]]] = None,
1254
+ do_pad: Optional[bool] = None,
1255
+ format: Optional[Union[str, AnnotationFormat]] = None,
1256
+ return_tensors: Optional[Union[TensorType, str]] = None,
1257
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
1258
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
1259
+ pad_size: Optional[dict[str, int]] = None,
1260
+ **kwargs,
1261
+ ) -> BatchFeature:
1262
+ """
1263
+ Preprocess an image or a batch of images so that it can be used by the model.
1264
+
1265
+ Args:
1266
+ images (`ImageInput`):
1267
+ Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
1268
+ from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
1269
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
1270
+ List of annotations associated with the image or batch of images. If annotation is for object
1271
+ detection, the annotations should be a dictionary with the following keys:
1272
+ - "image_id" (`int`): The image id.
1273
+ - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
1274
+ dictionary. An image can have no annotations, in which case the list should be empty.
1275
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
1276
+ - "image_id" (`int`): The image id.
1277
+ - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
1278
+ An image can have no segments, in which case the list should be empty.
1279
+ - "file_name" (`str`): The file name of the image.
1280
+ return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
1281
+ Whether to return segmentation masks.
1282
+ masks_path (`str` or `pathlib.Path`, *optional*):
1283
+ Path to the directory containing the segmentation masks.
1284
+ do_resize (`bool`, *optional*, defaults to self.do_resize):
1285
+ Whether to resize the image.
1286
+ size (`dict[str, int]`, *optional*, defaults to self.size):
1287
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
1288
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
1289
+ Do NOT keep the aspect ratio.
1290
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
1291
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
1292
+ less or equal to `longest_edge`.
1293
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
1294
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
1295
+ `max_width`.
1296
+ resample (`PILImageResampling`, *optional*, defaults to self.resample):
1297
+ Resampling filter to use when resizing the image.
1298
+ do_rescale (`bool`, *optional*, defaults to self.do_rescale):
1299
+ Whether to rescale the image.
1300
+ rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
1301
+ Rescale factor to use when rescaling the image.
1302
+ do_normalize (`bool`, *optional*, defaults to self.do_normalize):
1303
+ Whether to normalize the image.
1304
+ do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
1305
+ Whether to convert the annotations to the format expected by the model. Converts the bounding
1306
+ boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
1307
+ and in relative coordinates.
1308
+ image_mean (`float` or `list[float]`, *optional*, defaults to self.image_mean):
1309
+ Mean to use when normalizing the image.
1310
+ image_std (`float` or `list[float]`, *optional*, defaults to self.image_std):
1311
+ Standard deviation to use when normalizing the image.
1312
+ do_pad (`bool`, *optional*, defaults to self.do_pad):
1313
+ Whether to pad the image. If `True`, padding will be applied to the bottom and right of
1314
+ the image with zeros. If `pad_size` is provided, the image will be padded to the specified
1315
+ dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
1316
+ format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
1317
+ Format of the annotations.
1318
+ return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
1319
+ Type of tensors to return. If `None`, will return the list of images.
1320
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
1321
+ The channel dimension format for the output image. Can be one of:
1322
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1323
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1324
+ - Unset: Use the channel dimension format of the input image.
1325
+ input_data_format (`ChannelDimension` or `str`, *optional*):
1326
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
1327
+ from the input image. Can be one of:
1328
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
1329
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
1330
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
1331
+ pad_size (`dict[str, int]`, *optional*):
1332
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
1333
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
1334
+ height and width in the batch.
1335
+ """
1336
+ if "pad_and_return_pixel_mask" in kwargs:
1337
+ logger.warning_once(
1338
+ "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
1339
+ "use `do_pad` instead."
1340
+ )
1341
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
1342
+
1343
+ if "max_size" in kwargs:
1344
+ logger.warning_once(
1345
+ "The `max_size` argument is deprecated and will be removed in a future version, use"
1346
+ " `size['longest_edge']` instead."
1347
+ )
1348
+ size = kwargs.pop("max_size")
1349
+
1350
+ do_resize = self.do_resize if do_resize is None else do_resize
1351
+ size = self.size if size is None else size
1352
+ size = get_size_dict(size=size, default_to_square=False)
1353
+ resample = self.resample if resample is None else resample
1354
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
1355
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
1356
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
1357
+ image_mean = self.image_mean if image_mean is None else image_mean
1358
+ image_std = self.image_std if image_std is None else image_std
1359
+ do_convert_annotations = (
1360
+ self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
1361
+ )
1362
+ do_pad = self.do_pad if do_pad is None else do_pad
1363
+ pad_size = self.pad_size if pad_size is None else pad_size
1364
+ format = self.format if format is None else format
1365
+
1366
+ images = make_list_of_images(images)
1367
+
1368
+ if not valid_images(images):
1369
+ raise ValueError(
1370
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
1371
+ "torch.Tensor, tf.Tensor or jax.ndarray."
1372
+ )
1373
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
1374
+
1375
+ # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
1376
+ validate_preprocess_arguments(
1377
+ do_rescale=do_rescale,
1378
+ rescale_factor=rescale_factor,
1379
+ do_normalize=do_normalize,
1380
+ image_mean=image_mean,
1381
+ image_std=image_std,
1382
+ do_resize=do_resize,
1383
+ size=size,
1384
+ resample=resample,
1385
+ )
1386
+
1387
+ if annotations is not None and isinstance(annotations, dict):
1388
+ annotations = [annotations]
1389
+
1390
+ if annotations is not None and len(images) != len(annotations):
1391
+ raise ValueError(
1392
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
1393
+ )
1394
+
1395
+ format = AnnotationFormat(format)
1396
+ if annotations is not None:
1397
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
1398
+
1399
+ if (
1400
+ masks_path is not None
1401
+ and format == AnnotationFormat.COCO_PANOPTIC
1402
+ and not isinstance(masks_path, (pathlib.Path, str))
1403
+ ):
1404
+ raise ValueError(
1405
+ "The path to the directory containing the mask PNG files should be provided as a"
1406
+ f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
1407
+ )
1408
+
1409
+ # All transformations expect numpy arrays
1410
+ images = [to_numpy_array(image) for image in images]
1411
+
1412
+ if do_rescale and is_scaled_image(images[0]):
1413
+ logger.warning_once(
1414
+ "It looks like you are trying to rescale already rescaled images. If the input"
1415
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
1416
+ )
1417
+
1418
+ if input_data_format is None:
1419
+ # We assume that all images have the same channel dimension format.
1420
+ input_data_format = infer_channel_dimension_format(images[0])
1421
+
1422
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
1423
+ if annotations is not None:
1424
+ prepared_images = []
1425
+ prepared_annotations = []
1426
+ for image, target in zip(images, annotations):
1427
+ target = self.prepare_annotation(
1428
+ image,
1429
+ target,
1430
+ format,
1431
+ return_segmentation_masks=return_segmentation_masks,
1432
+ masks_path=masks_path,
1433
+ input_data_format=input_data_format,
1434
+ )
1435
+ prepared_images.append(image)
1436
+ prepared_annotations.append(target)
1437
+ images = prepared_images
1438
+ annotations = prepared_annotations
1439
+ del prepared_images, prepared_annotations
1440
+
1441
+ # transformations
1442
+ if do_resize:
1443
+ if annotations is not None:
1444
+ resized_images, resized_annotations = [], []
1445
+ for image, target in zip(images, annotations):
1446
+ orig_size = get_image_size(image, input_data_format)
1447
+ resized_image = self.resize(
1448
+ image, size=size, resample=resample, input_data_format=input_data_format
1449
+ )
1450
+ resized_annotation = self.resize_annotation(
1451
+ target, orig_size, get_image_size(resized_image, input_data_format)
1452
+ )
1453
+ resized_images.append(resized_image)
1454
+ resized_annotations.append(resized_annotation)
1455
+ images = resized_images
1456
+ annotations = resized_annotations
1457
+ del resized_images, resized_annotations
1458
+ else:
1459
+ images = [
1460
+ self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
1461
+ for image in images
1462
+ ]
1463
+
1464
+ if do_rescale:
1465
+ images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
1466
+
1467
+ if do_normalize:
1468
+ images = [
1469
+ self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
1470
+ ]
1471
+
1472
+ if do_convert_annotations and annotations is not None:
1473
+ annotations = [
1474
+ self.normalize_annotation(annotation, get_image_size(image, input_data_format))
1475
+ for annotation, image in zip(annotations, images)
1476
+ ]
1477
+
1478
+ if do_pad:
1479
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
1480
+ encoded_inputs = self.pad(
1481
+ images,
1482
+ annotations=annotations,
1483
+ return_pixel_mask=True,
1484
+ data_format=data_format,
1485
+ input_data_format=input_data_format,
1486
+ update_bboxes=do_convert_annotations,
1487
+ return_tensors=return_tensors,
1488
+ pad_size=pad_size,
1489
+ )
1490
+ else:
1491
+ images = [
1492
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
1493
+ for image in images
1494
+ ]
1495
+ encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
1496
+ if annotations is not None:
1497
+ encoded_inputs["labels"] = [
1498
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
1499
+ ]
1500
+
1501
+ return encoded_inputs
1502
+
1503
+ # POSTPROCESSING METHODS - TODO: add support for other frameworks
1504
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258
1505
+ def post_process(self, outputs, target_sizes):
1506
+ """
1507
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
1508
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
1509
+
1510
+ Args:
1511
+ outputs ([`DetrObjectDetectionOutput`]):
1512
+ Raw outputs of the model.
1513
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
1514
+ Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
1515
+ original image size (before any data augmentation). For visualization, this should be the image size
1516
+ after data augment, but before padding.
1517
+ Returns:
1518
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
1519
+ in the batch as predicted by the model.
1520
+ """
1521
+ logger.warning_once(
1522
+ "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
1523
+ " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
1524
+ )
1525
+
1526
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
1527
+
1528
+ if len(out_logits) != len(target_sizes):
1529
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
1530
+ if target_sizes.shape[1] != 2:
1531
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
1532
+
1533
+ prob = nn.functional.softmax(out_logits, -1)
1534
+ scores, labels = prob[..., :-1].max(-1)
1535
+
1536
+ # convert to [x0, y0, x1, y1] format
1537
+ boxes = center_to_corners_format(out_bbox)
1538
+ # and from relative [0, 1] to absolute [0, height] coordinates
1539
+ img_h, img_w = target_sizes.unbind(1)
1540
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
1541
+ boxes = boxes * scale_fct[:, None, :]
1542
+
1543
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
1544
+ return results
1545
+
1546
+ def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
1547
+ """
1548
+ Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch.
1549
+
1550
+ Args:
1551
+ outputs ([`DetrSegmentationOutput`]):
1552
+ Raw outputs of the model.
1553
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`):
1554
+ Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
1555
+ threshold (`float`, *optional*, defaults to 0.9):
1556
+ Threshold to use to filter out queries.
1557
+ mask_threshold (`float`, *optional*, defaults to 0.5):
1558
+ Threshold to use when turning the predicted masks into binary values.
1559
+ Returns:
1560
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
1561
+ in the batch as predicted by the model.
1562
+ """
1563
+ logger.warning_once(
1564
+ "`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
1565
+ " `post_process_semantic_segmentation`.",
1566
+ )
1567
+ out_logits, raw_masks = outputs.logits, outputs.pred_masks
1568
+ empty_label = out_logits.shape[-1] - 1
1569
+ preds = []
1570
+
1571
+ def to_tuple(tup):
1572
+ if isinstance(tup, tuple):
1573
+ return tup
1574
+ return tuple(tup.tolist())
1575
+
1576
+ for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
1577
+ # we filter empty queries and detection below threshold
1578
+ cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
1579
+ keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
1580
+ cur_scores = cur_scores[keep]
1581
+ cur_labels = cur_labels[keep]
1582
+ cur_masks = cur_masks[keep]
1583
+ cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
1584
+ cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
1585
+
1586
+ predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
1587
+ preds.append(predictions)
1588
+ return preds
1589
+
1590
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218
1591
+ def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
1592
+ """
1593
+ Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports
1594
+ PyTorch.
1595
+
1596
+ Args:
1597
+ results (`list[Dict]`):
1598
+ Results list obtained by [`~DetrImageProcessor.post_process`], to which "masks" results will be added.
1599
+ outputs ([`DetrSegmentationOutput`]):
1600
+ Raw outputs of the model.
1601
+ orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
1602
+ Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
1603
+ image size (before any data augmentation).
1604
+ max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
1605
+ Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the
1606
+ original image size (before any data augmentation).
1607
+ threshold (`float`, *optional*, defaults to 0.5):
1608
+ Threshold to use when turning the predicted masks into binary values.
1609
+ Returns:
1610
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
1611
+ image in the batch as predicted by the model.
1612
+ """
1613
+ logger.warning_once(
1614
+ "`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
1615
+ " `post_process_instance_segmentation`.",
1616
+ )
1617
+
1618
+ if len(orig_target_sizes) != len(max_target_sizes):
1619
+ raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
1620
+ max_h, max_w = max_target_sizes.max(0)[0].tolist()
1621
+ outputs_masks = outputs.pred_masks.squeeze(2)
1622
+ outputs_masks = nn.functional.interpolate(
1623
+ outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
1624
+ )
1625
+ outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
1626
+
1627
+ for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
1628
+ img_h, img_w = t[0], t[1]
1629
+ results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
1630
+ results[i]["masks"] = nn.functional.interpolate(
1631
+ results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
1632
+ ).byte()
1633
+
1634
+ return results
1635
+
1636
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241
1637
+ def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85):
1638
+ """
1639
+ Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch.
1640
+
1641
+ Args:
1642
+ outputs ([`DetrSegmentationOutput`]):
1643
+ Raw outputs of the model.
1644
+ processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`):
1645
+ Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data
1646
+ augmentation but before batching.
1647
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`, *optional*):
1648
+ Torch Tensor (or list) corresponding to the requested final size `(height, width)` of each prediction.
1649
+ If left to None, it will default to the `processed_sizes`.
1650
+ is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*):
1651
+ Dictionary mapping class indices to either True or False, depending on whether or not they are a thing.
1652
+ If not set, defaults to the `is_thing_map` of COCO panoptic.
1653
+ threshold (`float`, *optional*, defaults to 0.85):
1654
+ Threshold to use to filter out queries.
1655
+ Returns:
1656
+ `list[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
1657
+ an image in the batch as predicted by the model.
1658
+ """
1659
+ logger.warning_once(
1660
+ "`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
1661
+ " `post_process_panoptic_segmentation`.",
1662
+ )
1663
+ if target_sizes is None:
1664
+ target_sizes = processed_sizes
1665
+ if len(processed_sizes) != len(target_sizes):
1666
+ raise ValueError("Make sure to pass in as many processed_sizes as target_sizes")
1667
+
1668
+ if is_thing_map is None:
1669
+ # default to is_thing_map of COCO panoptic
1670
+ is_thing_map = {i: i <= 90 for i in range(201)}
1671
+
1672
+ out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes
1673
+ if not len(out_logits) == len(raw_masks) == len(target_sizes):
1674
+ raise ValueError(
1675
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
1676
+ )
1677
+ empty_label = out_logits.shape[-1] - 1
1678
+ preds = []
1679
+
1680
+ def to_tuple(tup):
1681
+ if isinstance(tup, tuple):
1682
+ return tup
1683
+ return tuple(tup.tolist())
1684
+
1685
+ for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
1686
+ out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
1687
+ ):
1688
+ # we filter empty queries and detection below threshold
1689
+ cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
1690
+ keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
1691
+ cur_scores = cur_scores[keep]
1692
+ cur_labels = cur_labels[keep]
1693
+ cur_masks = cur_masks[keep]
1694
+ cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
1695
+ cur_boxes = center_to_corners_format(cur_boxes[keep])
1696
+
1697
+ h, w = cur_masks.shape[-2:]
1698
+ if len(cur_boxes) != len(cur_labels):
1699
+ raise ValueError("Not as many boxes as there are classes")
1700
+
1701
+ # It may be that we have several predicted masks for the same stuff class.
1702
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
1703
+ cur_masks = cur_masks.flatten(1)
1704
+ stuff_equiv_classes = defaultdict(lambda: [])
1705
+ for k, label in enumerate(cur_labels):
1706
+ if not is_thing_map[label.item()]:
1707
+ stuff_equiv_classes[label.item()].append(k)
1708
+
1709
+ def get_ids_area(masks, scores, dedup=False):
1710
+ # This helper function creates the final panoptic segmentation image
1711
+ # It also returns the area of the masks that appears on the image
1712
+
1713
+ m_id = masks.transpose(0, 1).softmax(-1)
1714
+
1715
+ if m_id.shape[-1] == 0:
1716
+ # We didn't detect any mask :(
1717
+ m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
1718
+ else:
1719
+ m_id = m_id.argmax(-1).view(h, w)
1720
+
1721
+ if dedup:
1722
+ # Merge the masks corresponding to the same stuff class
1723
+ for equiv in stuff_equiv_classes.values():
1724
+ if len(equiv) > 1:
1725
+ for eq_id in equiv:
1726
+ m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
1727
+
1728
+ final_h, final_w = to_tuple(target_size)
1729
+
1730
+ seg_img = PIL.Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy()))
1731
+ seg_img = seg_img.resize(size=(final_w, final_h), resample=PILImageResampling.NEAREST)
1732
+
1733
+ np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
1734
+ np_seg_img = np_seg_img.view(final_h, final_w, 3)
1735
+ np_seg_img = np_seg_img.numpy()
1736
+
1737
+ m_id = torch.from_numpy(rgb_to_id(np_seg_img))
1738
+
1739
+ area = []
1740
+ for i in range(len(scores)):
1741
+ area.append(m_id.eq(i).sum().item())
1742
+ return area, seg_img
1743
+
1744
+ area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
1745
+ if cur_labels.numel() > 0:
1746
+ # We know filter empty masks as long as we find some
1747
+ while True:
1748
+ filtered_small = torch.as_tensor(
1749
+ [area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
1750
+ )
1751
+ if filtered_small.any().item():
1752
+ cur_scores = cur_scores[~filtered_small]
1753
+ cur_labels = cur_labels[~filtered_small]
1754
+ cur_masks = cur_masks[~filtered_small]
1755
+ area, seg_img = get_ids_area(cur_masks, cur_scores)
1756
+ else:
1757
+ break
1758
+
1759
+ else:
1760
+ cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
1761
+
1762
+ segments_info = []
1763
+ for i, a in enumerate(area):
1764
+ cat = cur_labels[i].item()
1765
+ segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
1766
+ del cur_labels
1767
+
1768
+ with io.BytesIO() as out:
1769
+ seg_img.save(out, format="PNG")
1770
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
1771
+ preds.append(predictions)
1772
+ return preds
1773
+
1774
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258
1775
+ def post_process_object_detection(
1776
+ self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None
1777
+ ):
1778
+ """
1779
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
1780
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
1781
+
1782
+ Args:
1783
+ outputs ([`DetrObjectDetectionOutput`]):
1784
+ Raw outputs of the model.
1785
+ threshold (`float`, *optional*):
1786
+ Score threshold to keep object detection predictions.
1787
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
1788
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
1789
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
1790
+ Returns:
1791
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
1792
+ in the batch as predicted by the model.
1793
+ """
1794
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
1795
+
1796
+ if target_sizes is not None:
1797
+ if len(out_logits) != len(target_sizes):
1798
+ raise ValueError(
1799
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1800
+ )
1801
+
1802
+ prob = nn.functional.softmax(out_logits, -1)
1803
+ scores, labels = prob[..., :-1].max(-1)
1804
+
1805
+ # Convert to [x0, y0, x1, y1] format
1806
+ boxes = center_to_corners_format(out_bbox)
1807
+
1808
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
1809
+ if target_sizes is not None:
1810
+ if isinstance(target_sizes, list):
1811
+ img_h = torch.Tensor([i[0] for i in target_sizes])
1812
+ img_w = torch.Tensor([i[1] for i in target_sizes])
1813
+ else:
1814
+ img_h, img_w = target_sizes.unbind(1)
1815
+
1816
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
1817
+ boxes = boxes * scale_fct[:, None, :]
1818
+
1819
+ results = []
1820
+ for s, l, b in zip(scores, labels, boxes):
1821
+ score = s[s > threshold]
1822
+ label = l[s > threshold]
1823
+ box = b[s > threshold]
1824
+ results.append({"scores": score, "labels": label, "boxes": box})
1825
+
1826
+ return results
1827
+
1828
+ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None):
1829
+ """
1830
+ Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch.
1831
+
1832
+ Args:
1833
+ outputs ([`DetrForSegmentation`]):
1834
+ Raw outputs of the model.
1835
+ target_sizes (`list[tuple[int, int]]`, *optional*):
1836
+ A list of tuples (`tuple[int, int]`) containing the target size (height, width) of each image in the
1837
+ batch. If unset, predictions will not be resized.
1838
+ Returns:
1839
+ `list[torch.Tensor]`:
1840
+ A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
1841
+ corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
1842
+ `torch.Tensor` correspond to a semantic class id.
1843
+ """
1844
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
1845
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
1846
+
1847
+ # Remove the null class `[..., :-1]`
1848
+ masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
1849
+ masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
1850
+
1851
+ # Semantic segmentation logits of shape (batch_size, num_classes, height, width)
1852
+ segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
1853
+ batch_size = class_queries_logits.shape[0]
1854
+
1855
+ # Resize logits and compute semantic segmentation maps
1856
+ if target_sizes is not None:
1857
+ if batch_size != len(target_sizes):
1858
+ raise ValueError(
1859
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1860
+ )
1861
+
1862
+ semantic_segmentation = []
1863
+ for idx in range(batch_size):
1864
+ resized_logits = nn.functional.interpolate(
1865
+ segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
1866
+ )
1867
+ semantic_map = resized_logits[0].argmax(dim=0)
1868
+ semantic_segmentation.append(semantic_map)
1869
+ else:
1870
+ semantic_segmentation = segmentation.argmax(dim=1)
1871
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
1872
+
1873
+ return semantic_segmentation
1874
+
1875
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218
1876
+ def post_process_instance_segmentation(
1877
+ self,
1878
+ outputs,
1879
+ threshold: float = 0.5,
1880
+ mask_threshold: float = 0.5,
1881
+ overlap_mask_area_threshold: float = 0.8,
1882
+ target_sizes: Optional[list[tuple[int, int]]] = None,
1883
+ return_coco_annotation: Optional[bool] = False,
1884
+ ) -> list[dict]:
1885
+ """
1886
+ Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
1887
+
1888
+ Args:
1889
+ outputs ([`DetrForSegmentation`]):
1890
+ Raw outputs of the model.
1891
+ threshold (`float`, *optional*, defaults to 0.5):
1892
+ The probability score threshold to keep predicted instance masks.
1893
+ mask_threshold (`float`, *optional*, defaults to 0.5):
1894
+ Threshold to use when turning the predicted masks into binary values.
1895
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
1896
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
1897
+ instance mask.
1898
+ target_sizes (`list[Tuple]`, *optional*):
1899
+ List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
1900
+ final size (height, width) of each prediction. If unset, predictions will not be resized.
1901
+ return_coco_annotation (`bool`, *optional*):
1902
+ Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
1903
+ format.
1904
+ Returns:
1905
+ `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
1906
+ - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
1907
+ `list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
1908
+ `True`. Set to `None` if no mask if found above `threshold`.
1909
+ - **segments_info** -- A dictionary that contains additional information on each segment.
1910
+ - **id** -- An integer representing the `segment_id`.
1911
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
1912
+ - **score** -- Prediction score of segment with `segment_id`.
1913
+ """
1914
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
1915
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
1916
+
1917
+ batch_size = class_queries_logits.shape[0]
1918
+ num_labels = class_queries_logits.shape[-1] - 1
1919
+
1920
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
1921
+
1922
+ # Predicted label and score of each query (batch_size, num_queries)
1923
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
1924
+
1925
+ # Loop over items in batch size
1926
+ results: list[dict[str, TensorType]] = []
1927
+
1928
+ for i in range(batch_size):
1929
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
1930
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
1931
+ )
1932
+
1933
+ # No mask found
1934
+ if mask_probs_item.shape[0] <= 0:
1935
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
1936
+ segmentation = torch.zeros((height, width)) - 1
1937
+ results.append({"segmentation": segmentation, "segments_info": []})
1938
+ continue
1939
+
1940
+ # Get segmentation map and segment information of batch item
1941
+ target_size = target_sizes[i] if target_sizes is not None else None
1942
+ segmentation, segments = compute_segments(
1943
+ mask_probs=mask_probs_item,
1944
+ pred_scores=pred_scores_item,
1945
+ pred_labels=pred_labels_item,
1946
+ mask_threshold=mask_threshold,
1947
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
1948
+ label_ids_to_fuse=[],
1949
+ target_size=target_size,
1950
+ )
1951
+
1952
+ # Return segmentation map in run-length encoding (RLE) format
1953
+ if return_coco_annotation:
1954
+ segmentation = convert_segmentation_to_rle(segmentation)
1955
+
1956
+ results.append({"segmentation": segmentation, "segments_info": segments})
1957
+ return results
1958
+
1959
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241
1960
+ def post_process_panoptic_segmentation(
1961
+ self,
1962
+ outputs,
1963
+ threshold: float = 0.5,
1964
+ mask_threshold: float = 0.5,
1965
+ overlap_mask_area_threshold: float = 0.8,
1966
+ label_ids_to_fuse: Optional[set[int]] = None,
1967
+ target_sizes: Optional[list[tuple[int, int]]] = None,
1968
+ ) -> list[dict]:
1969
+ """
1970
+ Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports
1971
+ PyTorch.
1972
+
1973
+ Args:
1974
+ outputs ([`DetrForSegmentation`]):
1975
+ The outputs from [`DetrForSegmentation`].
1976
+ threshold (`float`, *optional*, defaults to 0.5):
1977
+ The probability score threshold to keep predicted instance masks.
1978
+ mask_threshold (`float`, *optional*, defaults to 0.5):
1979
+ Threshold to use when turning the predicted masks into binary values.
1980
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
1981
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
1982
+ instance mask.
1983
+ label_ids_to_fuse (`Set[int]`, *optional*):
1984
+ The labels in this state will have all their instances be fused together. For instance we could say
1985
+ there can only be one sky in an image, but several persons, so the label ID for sky would be in that
1986
+ set, but not the one for person.
1987
+ target_sizes (`list[Tuple]`, *optional*):
1988
+ List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
1989
+ final size (height, width) of each prediction in batch. If unset, predictions will not be resized.
1990
+ Returns:
1991
+ `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
1992
+ - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or
1993
+ `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to
1994
+ the corresponding `target_sizes` entry.
1995
+ - **segments_info** -- A dictionary that contains additional information on each segment.
1996
+ - **id** -- an integer representing the `segment_id`.
1997
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
1998
+ - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
1999
+ Multiple instances of the same class / label were fused and assigned a single `segment_id`.
2000
+ - **score** -- Prediction score of segment with `segment_id`.
2001
+ """
2002
+
2003
+ if label_ids_to_fuse is None:
2004
+ logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.")
2005
+ label_ids_to_fuse = set()
2006
+
2007
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
2008
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
2009
+
2010
+ batch_size = class_queries_logits.shape[0]
2011
+ num_labels = class_queries_logits.shape[-1] - 1
2012
+
2013
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
2014
+
2015
+ # Predicted label and score of each query (batch_size, num_queries)
2016
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
2017
+
2018
+ # Loop over items in batch size
2019
+ results: list[dict[str, TensorType]] = []
2020
+
2021
+ for i in range(batch_size):
2022
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
2023
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
2024
+ )
2025
+
2026
+ # No mask found
2027
+ if mask_probs_item.shape[0] <= 0:
2028
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
2029
+ segmentation = torch.zeros((height, width)) - 1
2030
+ results.append({"segmentation": segmentation, "segments_info": []})
2031
+ continue
2032
+
2033
+ # Get segmentation map and segment information of batch item
2034
+ target_size = target_sizes[i] if target_sizes is not None else None
2035
+ segmentation, segments = compute_segments(
2036
+ mask_probs=mask_probs_item,
2037
+ pred_scores=pred_scores_item,
2038
+ pred_labels=pred_labels_item,
2039
+ mask_threshold=mask_threshold,
2040
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
2041
+ label_ids_to_fuse=label_ids_to_fuse,
2042
+ target_size=target_size,
2043
+ )
2044
+
2045
+ results.append({"segmentation": segmentation, "segments_info": segments})
2046
+ return results
2047
+
2048
+
2049
+ __all__ = ["DetrImageProcessor"]
phivenv/Lib/site-packages/transformers/models/detr/image_processing_detr_fast.py ADDED
@@ -0,0 +1,1291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for DETR."""
16
+
17
+ import io
18
+ import pathlib
19
+ from collections import defaultdict
20
+ from typing import Any, Optional, Union
21
+
22
+ from ...image_processing_utils import BatchFeature, get_size_dict
23
+ from ...image_processing_utils_fast import (
24
+ BaseImageProcessorFast,
25
+ DefaultFastImageProcessorKwargs,
26
+ SizeDict,
27
+ get_image_size_for_max_height_width,
28
+ get_max_height_width,
29
+ safe_squeeze,
30
+ )
31
+ from ...image_transforms import center_to_corners_format, corners_to_center_format, id_to_rgb
32
+ from ...image_utils import (
33
+ IMAGENET_DEFAULT_MEAN,
34
+ IMAGENET_DEFAULT_STD,
35
+ AnnotationFormat,
36
+ AnnotationType,
37
+ ChannelDimension,
38
+ ImageInput,
39
+ PILImageResampling,
40
+ get_image_size,
41
+ validate_annotations,
42
+ )
43
+ from ...processing_utils import Unpack
44
+ from ...utils import (
45
+ TensorType,
46
+ auto_docstring,
47
+ is_torch_available,
48
+ is_torchvision_available,
49
+ is_torchvision_v2_available,
50
+ is_vision_available,
51
+ logging,
52
+ )
53
+ from ...utils.import_utils import requires
54
+ from .image_processing_detr import (
55
+ compute_segments,
56
+ convert_segmentation_to_rle,
57
+ get_size_with_aspect_ratio,
58
+ remove_low_and_no_objects,
59
+ )
60
+
61
+
62
+ if is_torch_available():
63
+ import torch
64
+ from torch import nn
65
+
66
+ if is_vision_available():
67
+ import PIL
68
+
69
+
70
+ if is_torchvision_v2_available():
71
+ from torchvision.io import read_image
72
+ from torchvision.transforms.v2 import functional as F
73
+
74
+ elif is_torchvision_available():
75
+ from torchvision.io import read_image
76
+ from torchvision.transforms import functional as F
77
+
78
+
79
+ logger = logging.get_logger(__name__)
80
+
81
+ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
82
+
83
+
84
+ # inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33
85
+ def convert_coco_poly_to_mask(segmentations, height: int, width: int, device: torch.device) -> torch.Tensor:
86
+ """
87
+ Convert a COCO polygon annotation to a mask.
88
+
89
+ Args:
90
+ segmentations (`list[list[float]]`):
91
+ List of polygons, each polygon represented by a list of x-y coordinates.
92
+ height (`int`):
93
+ Height of the mask.
94
+ width (`int`):
95
+ Width of the mask.
96
+ """
97
+ try:
98
+ from pycocotools import mask as coco_mask
99
+ except ImportError:
100
+ raise ImportError("Pycocotools is not installed in your environment.")
101
+
102
+ masks = []
103
+ for polygons in segmentations:
104
+ rles = coco_mask.frPyObjects(polygons, height, width)
105
+ mask = coco_mask.decode(rles)
106
+ if len(mask.shape) < 3:
107
+ mask = mask[..., None]
108
+ mask = torch.as_tensor(mask, dtype=torch.uint8, device=device)
109
+ mask = torch.any(mask, axis=2)
110
+ masks.append(mask)
111
+ if masks:
112
+ masks = torch.stack(masks, axis=0)
113
+ else:
114
+ masks = torch.zeros((0, height, width), dtype=torch.uint8, device=device)
115
+
116
+ return masks
117
+
118
+
119
+ # inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50
120
+ def prepare_coco_detection_annotation(
121
+ image,
122
+ target,
123
+ return_segmentation_masks: bool = False,
124
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
125
+ ):
126
+ """
127
+ Convert the target in COCO format into the format expected by DETR.
128
+ """
129
+ image_height, image_width = image.size()[-2:]
130
+
131
+ image_id = target["image_id"]
132
+ image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
133
+
134
+ # Get all COCO annotations for the given image.
135
+ annotations = target["annotations"]
136
+ classes = []
137
+ area = []
138
+ boxes = []
139
+ keypoints = []
140
+ for obj in annotations:
141
+ if "iscrowd" not in obj or obj["iscrowd"] == 0:
142
+ classes.append(obj["category_id"])
143
+ area.append(obj["area"])
144
+ boxes.append(obj["bbox"])
145
+ if "keypoints" in obj:
146
+ keypoints.append(obj["keypoints"])
147
+
148
+ classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
149
+ area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
150
+ iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
151
+ # guard against no boxes via resizing
152
+ boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
153
+ boxes[:, 2:] += boxes[:, :2]
154
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
155
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
156
+
157
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
158
+
159
+ new_target = {
160
+ "image_id": image_id,
161
+ "class_labels": classes[keep],
162
+ "boxes": boxes[keep],
163
+ "area": area[keep],
164
+ "iscrowd": iscrowd[keep],
165
+ "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
166
+ }
167
+
168
+ if keypoints:
169
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
170
+ # Apply the keep mask here to filter the relevant annotations
171
+ keypoints = keypoints[keep]
172
+ num_keypoints = keypoints.shape[0]
173
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
174
+ new_target["keypoints"] = keypoints
175
+
176
+ if return_segmentation_masks:
177
+ segmentation_masks = [obj["segmentation"] for obj in annotations]
178
+ masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width, device=image.device)
179
+ new_target["masks"] = masks[keep]
180
+
181
+ return new_target
182
+
183
+
184
+ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
185
+ """
186
+ Compute the bounding boxes around the provided panoptic segmentation masks.
187
+
188
+ Args:
189
+ masks: masks in format `[number_masks, height, width]` where N is the number of masks
190
+
191
+ Returns:
192
+ boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
193
+ """
194
+ if masks.numel() == 0:
195
+ return torch.zeros((0, 4), device=masks.device)
196
+
197
+ h, w = masks.shape[-2:]
198
+ y = torch.arange(0, h, dtype=torch.float32, device=masks.device)
199
+ x = torch.arange(0, w, dtype=torch.float32, device=masks.device)
200
+ # see https://github.com/pytorch/pytorch/issues/50276
201
+ y, x = torch.meshgrid(y, x, indexing="ij")
202
+
203
+ x_mask = masks * torch.unsqueeze(x, 0)
204
+ x_max = x_mask.view(x_mask.shape[0], -1).max(-1)[0]
205
+ x_min = (
206
+ torch.where(masks, x.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0]
207
+ )
208
+
209
+ y_mask = masks * torch.unsqueeze(y, 0)
210
+ y_max = y_mask.view(y_mask.shape[0], -1).max(-1)[0]
211
+ y_min = (
212
+ torch.where(masks, y.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0]
213
+ )
214
+
215
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
216
+
217
+
218
+ # 2 functions below adapted from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
219
+ # Copyright (c) 2018, Alexander Kirillov
220
+ # All rights reserved.
221
+ def rgb_to_id(color):
222
+ """
223
+ Converts RGB color to unique ID.
224
+ """
225
+ if isinstance(color, torch.Tensor) and len(color.shape) == 3:
226
+ if color.dtype == torch.uint8:
227
+ color = color.to(torch.int32)
228
+ return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
229
+ return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
230
+
231
+
232
+ def prepare_coco_panoptic_annotation(
233
+ image: torch.Tensor,
234
+ target: dict,
235
+ masks_path: Union[str, pathlib.Path],
236
+ return_masks: bool = True,
237
+ input_data_format: Union[ChannelDimension, str] = None,
238
+ ) -> dict:
239
+ """
240
+ Prepare a coco panoptic annotation for DETR.
241
+ """
242
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
243
+ annotation_path = pathlib.Path(masks_path) / target["file_name"]
244
+
245
+ new_target = {}
246
+ new_target["image_id"] = torch.as_tensor(
247
+ [target["image_id"] if "image_id" in target else target["id"]], dtype=torch.int64, device=image.device
248
+ )
249
+ new_target["size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
250
+ new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
251
+
252
+ if "segments_info" in target:
253
+ masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device)
254
+ masks = rgb_to_id(masks)
255
+
256
+ ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device)
257
+ masks = masks == ids[:, None, None]
258
+ masks = masks.to(torch.bool)
259
+ if return_masks:
260
+ new_target["masks"] = masks
261
+ new_target["boxes"] = masks_to_boxes(masks)
262
+ new_target["class_labels"] = torch.as_tensor(
263
+ [segment_info["category_id"] for segment_info in target["segments_info"]],
264
+ dtype=torch.int64,
265
+ device=image.device,
266
+ )
267
+ new_target["iscrowd"] = torch.as_tensor(
268
+ [segment_info["iscrowd"] for segment_info in target["segments_info"]],
269
+ dtype=torch.int64,
270
+ device=image.device,
271
+ )
272
+ new_target["area"] = torch.as_tensor(
273
+ [segment_info["area"] for segment_info in target["segments_info"]],
274
+ dtype=torch.float32,
275
+ device=image.device,
276
+ )
277
+
278
+ return new_target
279
+
280
+
281
+ class DetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
282
+ r"""
283
+ format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
284
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
285
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
286
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
287
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
288
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
289
+ do_pad (`bool`, *optional*, defaults to `True`):
290
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
291
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
292
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
293
+ Otherwise, the image will be padded to the maximum height and width of the batch.
294
+ pad_size (`dict[str, int]`, *optional*):
295
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
296
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
297
+ height and width in the batch.
298
+ return_segmentation_masks (`bool`, *optional*, defaults to `False`):
299
+ Whether to return segmentation masks.
300
+ """
301
+
302
+ format: Optional[Union[str, AnnotationFormat]]
303
+ do_convert_annotations: Optional[bool]
304
+ do_pad: Optional[bool]
305
+ pad_size: Optional[dict[str, int]]
306
+ return_segmentation_masks: Optional[bool]
307
+
308
+
309
+ @auto_docstring
310
+ @requires(backends=("torchvision", "torch"))
311
+ class DetrImageProcessorFast(BaseImageProcessorFast):
312
+ resample = PILImageResampling.BILINEAR
313
+ image_mean = IMAGENET_DEFAULT_MEAN
314
+ image_std = IMAGENET_DEFAULT_STD
315
+ format = AnnotationFormat.COCO_DETECTION
316
+ do_resize = True
317
+ do_rescale = True
318
+ do_normalize = True
319
+ do_pad = True
320
+ size = {"shortest_edge": 800, "longest_edge": 1333}
321
+ default_to_square = False
322
+ model_input_names = ["pixel_values", "pixel_mask"]
323
+ valid_kwargs = DetrFastImageProcessorKwargs
324
+
325
+ def __init__(self, **kwargs: Unpack[DetrFastImageProcessorKwargs]) -> None:
326
+ if "pad_and_return_pixel_mask" in kwargs:
327
+ kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
328
+
329
+ size = kwargs.pop("size", None)
330
+ if "max_size" in kwargs:
331
+ logger.warning_once(
332
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
333
+ "Please specify in `size['longest_edge'] instead`.",
334
+ )
335
+ max_size = kwargs.pop("max_size")
336
+ else:
337
+ max_size = None if size is None else 1333
338
+
339
+ size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
340
+ self.size = get_size_dict(size, max_size=max_size, default_to_square=False)
341
+
342
+ # Backwards compatibility
343
+ do_convert_annotations = kwargs.get("do_convert_annotations")
344
+ do_normalize = kwargs.get("do_normalize")
345
+ if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
346
+ self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
347
+
348
+ super().__init__(**kwargs)
349
+
350
+ @classmethod
351
+ def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
352
+ """
353
+ Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
354
+ created using from_dict and kwargs e.g. `DetrImageProcessorFast.from_pretrained(checkpoint, size=600,
355
+ max_size=800)`
356
+ """
357
+ image_processor_dict = image_processor_dict.copy()
358
+ if "max_size" in kwargs:
359
+ image_processor_dict["max_size"] = kwargs.pop("max_size")
360
+ if "pad_and_return_pixel_mask" in kwargs:
361
+ image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
362
+ return super().from_dict(image_processor_dict, **kwargs)
363
+
364
+ def prepare_annotation(
365
+ self,
366
+ image: torch.Tensor,
367
+ target: dict,
368
+ format: Optional[AnnotationFormat] = None,
369
+ return_segmentation_masks: Optional[bool] = None,
370
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
371
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
372
+ ) -> dict:
373
+ """
374
+ Prepare an annotation for feeding into DETR model.
375
+ """
376
+ format = format if format is not None else self.format
377
+
378
+ if format == AnnotationFormat.COCO_DETECTION:
379
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
380
+ target = prepare_coco_detection_annotation(
381
+ image, target, return_segmentation_masks, input_data_format=input_data_format
382
+ )
383
+ elif format == AnnotationFormat.COCO_PANOPTIC:
384
+ return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
385
+ target = prepare_coco_panoptic_annotation(
386
+ image,
387
+ target,
388
+ masks_path=masks_path,
389
+ return_masks=return_segmentation_masks,
390
+ input_data_format=input_data_format,
391
+ )
392
+ else:
393
+ raise ValueError(f"Format {format} is not supported.")
394
+ return target
395
+
396
+ def resize(
397
+ self,
398
+ image: torch.Tensor,
399
+ size: SizeDict,
400
+ interpolation: "F.InterpolationMode" = None,
401
+ **kwargs,
402
+ ) -> torch.Tensor:
403
+ """
404
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
405
+ int, smaller edge of the image will be matched to this number.
406
+
407
+ Args:
408
+ image (`torch.Tensor`):
409
+ Image to resize.
410
+ size (`SizeDict`):
411
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
412
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
413
+ Do NOT keep the aspect ratio.
414
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
415
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
416
+ less or equal to `longest_edge`.
417
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
418
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
419
+ `max_width`.
420
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
421
+ Resampling filter to use if resizing the image.
422
+ """
423
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
424
+ if size.shortest_edge and size.longest_edge:
425
+ # Resize the image so that the shortest edge or the longest edge is of the given size
426
+ # while maintaining the aspect ratio of the original image.
427
+ new_size = get_size_with_aspect_ratio(
428
+ image.size()[-2:],
429
+ size["shortest_edge"],
430
+ size["longest_edge"],
431
+ )
432
+ elif size.max_height and size.max_width:
433
+ new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"])
434
+ elif size.height and size.width:
435
+ new_size = (size["height"], size["width"])
436
+ else:
437
+ raise ValueError(
438
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
439
+ f" {size.keys()}."
440
+ )
441
+
442
+ image = F.resize(
443
+ image,
444
+ size=new_size,
445
+ interpolation=interpolation,
446
+ **kwargs,
447
+ )
448
+ return image
449
+
450
+ def resize_annotation(
451
+ self,
452
+ annotation: dict[str, Any],
453
+ orig_size: tuple[int, int],
454
+ target_size: tuple[int, int],
455
+ threshold: float = 0.5,
456
+ interpolation: "F.InterpolationMode" = None,
457
+ ):
458
+ """
459
+ Resizes an annotation to a target size.
460
+
461
+ Args:
462
+ annotation (`dict[str, Any]`):
463
+ The annotation dictionary.
464
+ orig_size (`tuple[int, int]`):
465
+ The original size of the input image.
466
+ target_size (`tuple[int, int]`):
467
+ The target size of the image, as returned by the preprocessing `resize` step.
468
+ threshold (`float`, *optional*, defaults to 0.5):
469
+ The threshold used to binarize the segmentation masks.
470
+ resample (`InterpolationMode`, defaults to `F.InterpolationMode.NEAREST_EXACT`):
471
+ The resampling filter to use when resizing the masks.
472
+ """
473
+ interpolation = (
474
+ interpolation
475
+ if interpolation is not None
476
+ else F.InterpolationMode.NEAREST_EXACT
477
+ if is_torchvision_v2_available()
478
+ else F.InterpolationMode.NEAREST
479
+ )
480
+ ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)]
481
+
482
+ new_annotation = {}
483
+ new_annotation["size"] = target_size
484
+
485
+ for key, value in annotation.items():
486
+ if key == "boxes":
487
+ boxes = value
488
+ scaled_boxes = boxes * torch.as_tensor(
489
+ [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device
490
+ )
491
+ new_annotation["boxes"] = scaled_boxes
492
+ elif key == "area":
493
+ area = value
494
+ scaled_area = area * (ratio_width * ratio_height)
495
+ new_annotation["area"] = scaled_area
496
+ elif key == "masks":
497
+ masks = value[:, None]
498
+ masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks]
499
+ masks = torch.stack(masks).to(torch.float32)
500
+ masks = masks[:, 0] > threshold
501
+ new_annotation["masks"] = masks
502
+ elif key == "size":
503
+ new_annotation["size"] = target_size
504
+ else:
505
+ new_annotation[key] = value
506
+
507
+ return new_annotation
508
+
509
+ def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
510
+ image_height, image_width = image_size
511
+ norm_annotation = {}
512
+ for key, value in annotation.items():
513
+ if key == "boxes":
514
+ boxes = value
515
+ boxes = corners_to_center_format(boxes)
516
+ boxes /= torch.as_tensor(
517
+ [image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device
518
+ )
519
+ norm_annotation[key] = boxes
520
+ else:
521
+ norm_annotation[key] = value
522
+ return norm_annotation
523
+
524
+ def _update_annotation_for_padded_image(
525
+ self,
526
+ annotation: dict,
527
+ input_image_size: tuple[int, int],
528
+ output_image_size: tuple[int, int],
529
+ padding,
530
+ update_bboxes,
531
+ ) -> dict:
532
+ """
533
+ Update the annotation for a padded image.
534
+ """
535
+ new_annotation = {}
536
+ new_annotation["size"] = output_image_size
537
+ ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size))
538
+
539
+ for key, value in annotation.items():
540
+ if key == "masks":
541
+ masks = value
542
+ masks = F.pad(
543
+ masks,
544
+ padding,
545
+ fill=0,
546
+ )
547
+ masks = safe_squeeze(masks, 1)
548
+ new_annotation["masks"] = masks
549
+ elif key == "boxes" and update_bboxes:
550
+ boxes = value
551
+ boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device)
552
+ new_annotation["boxes"] = boxes
553
+ elif key == "size":
554
+ new_annotation["size"] = output_image_size
555
+ else:
556
+ new_annotation[key] = value
557
+ return new_annotation
558
+
559
+ def pad(
560
+ self,
561
+ image: torch.Tensor,
562
+ padded_size: tuple[int, int],
563
+ annotation: Optional[dict[str, Any]] = None,
564
+ update_bboxes: bool = True,
565
+ fill: int = 0,
566
+ ):
567
+ original_size = image.size()[-2:]
568
+ padding_bottom = padded_size[0] - original_size[0]
569
+ padding_right = padded_size[1] - original_size[1]
570
+ if padding_bottom < 0 or padding_right < 0:
571
+ raise ValueError(
572
+ f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
573
+ f"original size. Got padded size: {padded_size}, original size: {original_size}."
574
+ )
575
+ if original_size != padded_size:
576
+ padding = [0, 0, padding_right, padding_bottom]
577
+ image = F.pad(image, padding, fill=fill)
578
+ if annotation is not None:
579
+ annotation = self._update_annotation_for_padded_image(
580
+ annotation, original_size, padded_size, padding, update_bboxes
581
+ )
582
+
583
+ # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
584
+ pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
585
+ pixel_mask[: original_size[0], : original_size[1]] = 1
586
+
587
+ return image, pixel_mask, annotation
588
+
589
+ @auto_docstring
590
+ def preprocess(
591
+ self,
592
+ images: ImageInput,
593
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
594
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
595
+ **kwargs: Unpack[DetrFastImageProcessorKwargs],
596
+ ) -> BatchFeature:
597
+ r"""
598
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
599
+ List of annotations associated with the image or batch of images. If annotation is for object
600
+ detection, the annotations should be a dictionary with the following keys:
601
+ - "image_id" (`int`): The image id.
602
+ - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
603
+ dictionary. An image can have no annotations, in which case the list should be empty.
604
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
605
+ - "image_id" (`int`): The image id.
606
+ - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
607
+ An image can have no segments, in which case the list should be empty.
608
+ - "file_name" (`str`): The file name of the image.
609
+ masks_path (`str` or `pathlib.Path`, *optional*):
610
+ Path to the directory containing the segmentation masks.
611
+ """
612
+ if "pad_and_return_pixel_mask" in kwargs:
613
+ kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
614
+ logger.warning_once(
615
+ "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
616
+ "use `do_pad` instead."
617
+ )
618
+
619
+ if "max_size" in kwargs:
620
+ logger.warning_once(
621
+ "The `max_size` argument is deprecated and will be removed in a future version, use"
622
+ " `size['longest_edge']` instead."
623
+ )
624
+ kwargs["size"] = kwargs.pop("max_size")
625
+
626
+ return super().preprocess(images, annotations, masks_path, **kwargs)
627
+
628
+ def _preprocess(
629
+ self,
630
+ images: list["torch.Tensor"],
631
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]],
632
+ masks_path: Optional[Union[str, pathlib.Path]],
633
+ return_segmentation_masks: bool,
634
+ do_resize: bool,
635
+ size: SizeDict,
636
+ interpolation: Optional["F.InterpolationMode"],
637
+ do_rescale: bool,
638
+ rescale_factor: float,
639
+ do_normalize: bool,
640
+ do_convert_annotations: bool,
641
+ image_mean: Optional[Union[float, list[float]]],
642
+ image_std: Optional[Union[float, list[float]]],
643
+ do_pad: bool,
644
+ pad_size: Optional[dict[str, int]],
645
+ format: Optional[Union[str, AnnotationFormat]],
646
+ return_tensors: Optional[Union[str, TensorType]],
647
+ **kwargs,
648
+ ) -> BatchFeature:
649
+ """
650
+ Preprocess an image or a batch of images so that it can be used by the model.
651
+ """
652
+ if annotations is not None and isinstance(annotations, dict):
653
+ annotations = [annotations]
654
+
655
+ if annotations is not None and len(images) != len(annotations):
656
+ raise ValueError(
657
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
658
+ )
659
+
660
+ format = AnnotationFormat(format)
661
+ if annotations is not None:
662
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
663
+
664
+ if (
665
+ masks_path is not None
666
+ and format == AnnotationFormat.COCO_PANOPTIC
667
+ and not isinstance(masks_path, (pathlib.Path, str))
668
+ ):
669
+ raise ValueError(
670
+ "The path to the directory containing the mask PNG files should be provided as a"
671
+ f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
672
+ )
673
+
674
+ data = {}
675
+
676
+ processed_images = []
677
+ processed_annotations = []
678
+ pixel_masks = [] # Initialize pixel_masks here
679
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
680
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
681
+ if annotations is not None:
682
+ annotation = self.prepare_annotation(
683
+ image,
684
+ annotation,
685
+ format,
686
+ return_segmentation_masks=return_segmentation_masks,
687
+ masks_path=masks_path,
688
+ input_data_format=ChannelDimension.FIRST,
689
+ )
690
+
691
+ if do_resize:
692
+ resized_image = self.resize(image, size=size, interpolation=interpolation)
693
+ if annotations is not None:
694
+ annotation = self.resize_annotation(
695
+ annotation,
696
+ orig_size=image.size()[-2:],
697
+ target_size=resized_image.size()[-2:],
698
+ )
699
+ image = resized_image
700
+ # Fused rescale and normalize
701
+ image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
702
+ if do_convert_annotations and annotations is not None:
703
+ annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
704
+
705
+ processed_images.append(image)
706
+ processed_annotations.append(annotation)
707
+ images = processed_images
708
+ annotations = processed_annotations if annotations is not None else None
709
+
710
+ if do_pad:
711
+ # depends on all resized image shapes so we need another loop
712
+ if pad_size is not None:
713
+ padded_size = (pad_size["height"], pad_size["width"])
714
+ else:
715
+ padded_size = get_max_height_width(images)
716
+
717
+ padded_images = []
718
+ padded_annotations = []
719
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
720
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
721
+ if padded_size == image.size()[-2:]:
722
+ padded_images.append(image)
723
+ pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
724
+ padded_annotations.append(annotation)
725
+ continue
726
+ image, pixel_mask, annotation = self.pad(
727
+ image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
728
+ )
729
+ padded_images.append(image)
730
+ padded_annotations.append(annotation)
731
+ pixel_masks.append(pixel_mask)
732
+ images = padded_images
733
+ annotations = padded_annotations if annotations is not None else None
734
+ data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
735
+
736
+ data.update({"pixel_values": torch.stack(images, dim=0)})
737
+ encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
738
+ if annotations is not None:
739
+ encoded_inputs["labels"] = [
740
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
741
+ ]
742
+ return encoded_inputs
743
+
744
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process
745
+ def post_process(self, outputs, target_sizes):
746
+ """
747
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
748
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
749
+
750
+ Args:
751
+ outputs ([`DetrObjectDetectionOutput`]):
752
+ Raw outputs of the model.
753
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
754
+ Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
755
+ original image size (before any data augmentation). For visualization, this should be the image size
756
+ after data augment, but before padding.
757
+ Returns:
758
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
759
+ in the batch as predicted by the model.
760
+ """
761
+ logger.warning_once(
762
+ "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
763
+ " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
764
+ )
765
+
766
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
767
+
768
+ if len(out_logits) != len(target_sizes):
769
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
770
+ if target_sizes.shape[1] != 2:
771
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
772
+
773
+ prob = nn.functional.softmax(out_logits, -1)
774
+ scores, labels = prob[..., :-1].max(-1)
775
+
776
+ # convert to [x0, y0, x1, y1] format
777
+ boxes = center_to_corners_format(out_bbox)
778
+ # and from relative [0, 1] to absolute [0, height] coordinates
779
+ img_h, img_w = target_sizes.unbind(1)
780
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
781
+ boxes = boxes * scale_fct[:, None, :]
782
+
783
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
784
+ return results
785
+
786
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_segmentation
787
+ def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
788
+ """
789
+ Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch.
790
+
791
+ Args:
792
+ outputs ([`DetrSegmentationOutput`]):
793
+ Raw outputs of the model.
794
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`):
795
+ Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
796
+ threshold (`float`, *optional*, defaults to 0.9):
797
+ Threshold to use to filter out queries.
798
+ mask_threshold (`float`, *optional*, defaults to 0.5):
799
+ Threshold to use when turning the predicted masks into binary values.
800
+ Returns:
801
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
802
+ in the batch as predicted by the model.
803
+ """
804
+ logger.warning_once(
805
+ "`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
806
+ " `post_process_semantic_segmentation`.",
807
+ )
808
+ out_logits, raw_masks = outputs.logits, outputs.pred_masks
809
+ empty_label = out_logits.shape[-1] - 1
810
+ preds = []
811
+
812
+ def to_tuple(tup):
813
+ if isinstance(tup, tuple):
814
+ return tup
815
+ return tuple(tup.tolist())
816
+
817
+ for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
818
+ # we filter empty queries and detection below threshold
819
+ cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
820
+ keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
821
+ cur_scores = cur_scores[keep]
822
+ cur_labels = cur_labels[keep]
823
+ cur_masks = cur_masks[keep]
824
+ cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
825
+ cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
826
+
827
+ predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
828
+ preds.append(predictions)
829
+ return preds
830
+
831
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance
832
+ def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
833
+ """
834
+ Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports
835
+ PyTorch.
836
+
837
+ Args:
838
+ results (`list[Dict]`):
839
+ Results list obtained by [`~DetrImageProcessor.post_process`], to which "masks" results will be added.
840
+ outputs ([`DetrSegmentationOutput`]):
841
+ Raw outputs of the model.
842
+ orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
843
+ Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
844
+ image size (before any data augmentation).
845
+ max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
846
+ Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the
847
+ original image size (before any data augmentation).
848
+ threshold (`float`, *optional*, defaults to 0.5):
849
+ Threshold to use when turning the predicted masks into binary values.
850
+ Returns:
851
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
852
+ image in the batch as predicted by the model.
853
+ """
854
+ logger.warning_once(
855
+ "`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
856
+ " `post_process_instance_segmentation`.",
857
+ )
858
+
859
+ if len(orig_target_sizes) != len(max_target_sizes):
860
+ raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
861
+ max_h, max_w = max_target_sizes.max(0)[0].tolist()
862
+ outputs_masks = outputs.pred_masks.squeeze(2)
863
+ outputs_masks = nn.functional.interpolate(
864
+ outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
865
+ )
866
+ outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
867
+
868
+ for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
869
+ img_h, img_w = t[0], t[1]
870
+ results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
871
+ results[i]["masks"] = nn.functional.interpolate(
872
+ results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
873
+ ).byte()
874
+
875
+ return results
876
+
877
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic
878
+ def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85):
879
+ """
880
+ Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch.
881
+
882
+ Args:
883
+ outputs ([`DetrSegmentationOutput`]):
884
+ Raw outputs of the model.
885
+ processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`):
886
+ Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data
887
+ augmentation but before batching.
888
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`, *optional*):
889
+ Torch Tensor (or list) corresponding to the requested final size `(height, width)` of each prediction.
890
+ If left to None, it will default to the `processed_sizes`.
891
+ is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*):
892
+ Dictionary mapping class indices to either True or False, depending on whether or not they are a thing.
893
+ If not set, defaults to the `is_thing_map` of COCO panoptic.
894
+ threshold (`float`, *optional*, defaults to 0.85):
895
+ Threshold to use to filter out queries.
896
+ Returns:
897
+ `list[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
898
+ an image in the batch as predicted by the model.
899
+ """
900
+ logger.warning_once(
901
+ "`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
902
+ " `post_process_panoptic_segmentation`.",
903
+ )
904
+ if target_sizes is None:
905
+ target_sizes = processed_sizes
906
+ if len(processed_sizes) != len(target_sizes):
907
+ raise ValueError("Make sure to pass in as many processed_sizes as target_sizes")
908
+
909
+ if is_thing_map is None:
910
+ # default to is_thing_map of COCO panoptic
911
+ is_thing_map = {i: i <= 90 for i in range(201)}
912
+
913
+ out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes
914
+ if not len(out_logits) == len(raw_masks) == len(target_sizes):
915
+ raise ValueError(
916
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
917
+ )
918
+ empty_label = out_logits.shape[-1] - 1
919
+ preds = []
920
+
921
+ def to_tuple(tup):
922
+ if isinstance(tup, tuple):
923
+ return tup
924
+ return tuple(tup.tolist())
925
+
926
+ for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
927
+ out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
928
+ ):
929
+ # we filter empty queries and detection below threshold
930
+ cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
931
+ keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
932
+ cur_scores = cur_scores[keep]
933
+ cur_labels = cur_labels[keep]
934
+ cur_masks = cur_masks[keep]
935
+ cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
936
+ cur_boxes = center_to_corners_format(cur_boxes[keep])
937
+
938
+ h, w = cur_masks.shape[-2:]
939
+ if len(cur_boxes) != len(cur_labels):
940
+ raise ValueError("Not as many boxes as there are classes")
941
+
942
+ # It may be that we have several predicted masks for the same stuff class.
943
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
944
+ cur_masks = cur_masks.flatten(1)
945
+ stuff_equiv_classes = defaultdict(lambda: [])
946
+ for k, label in enumerate(cur_labels):
947
+ if not is_thing_map[label.item()]:
948
+ stuff_equiv_classes[label.item()].append(k)
949
+
950
+ def get_ids_area(masks, scores, dedup=False):
951
+ # This helper function creates the final panoptic segmentation image
952
+ # It also returns the area of the masks that appears on the image
953
+
954
+ m_id = masks.transpose(0, 1).softmax(-1)
955
+
956
+ if m_id.shape[-1] == 0:
957
+ # We didn't detect any mask :(
958
+ m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
959
+ else:
960
+ m_id = m_id.argmax(-1).view(h, w)
961
+
962
+ if dedup:
963
+ # Merge the masks corresponding to the same stuff class
964
+ for equiv in stuff_equiv_classes.values():
965
+ if len(equiv) > 1:
966
+ for eq_id in equiv:
967
+ m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
968
+
969
+ final_h, final_w = to_tuple(target_size)
970
+
971
+ seg_img = PIL.Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy()))
972
+ seg_img = seg_img.resize(size=(final_w, final_h), resample=PILImageResampling.NEAREST)
973
+
974
+ np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
975
+ np_seg_img = np_seg_img.view(final_h, final_w, 3)
976
+ np_seg_img = np_seg_img.numpy()
977
+
978
+ m_id = torch.from_numpy(rgb_to_id(np_seg_img))
979
+
980
+ area = []
981
+ for i in range(len(scores)):
982
+ area.append(m_id.eq(i).sum().item())
983
+ return area, seg_img
984
+
985
+ area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
986
+ if cur_labels.numel() > 0:
987
+ # We know filter empty masks as long as we find some
988
+ while True:
989
+ filtered_small = torch.as_tensor(
990
+ [area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
991
+ )
992
+ if filtered_small.any().item():
993
+ cur_scores = cur_scores[~filtered_small]
994
+ cur_labels = cur_labels[~filtered_small]
995
+ cur_masks = cur_masks[~filtered_small]
996
+ area, seg_img = get_ids_area(cur_masks, cur_scores)
997
+ else:
998
+ break
999
+
1000
+ else:
1001
+ cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
1002
+
1003
+ segments_info = []
1004
+ for i, a in enumerate(area):
1005
+ cat = cur_labels[i].item()
1006
+ segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
1007
+ del cur_labels
1008
+
1009
+ with io.BytesIO() as out:
1010
+ seg_img.save(out, format="PNG")
1011
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
1012
+ preds.append(predictions)
1013
+ return preds
1014
+
1015
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_object_detection
1016
+ def post_process_object_detection(
1017
+ self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None
1018
+ ):
1019
+ """
1020
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
1021
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
1022
+
1023
+ Args:
1024
+ outputs ([`DetrObjectDetectionOutput`]):
1025
+ Raw outputs of the model.
1026
+ threshold (`float`, *optional*):
1027
+ Score threshold to keep object detection predictions.
1028
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
1029
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
1030
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
1031
+ Returns:
1032
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
1033
+ in the batch as predicted by the model.
1034
+ """
1035
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
1036
+
1037
+ if target_sizes is not None:
1038
+ if len(out_logits) != len(target_sizes):
1039
+ raise ValueError(
1040
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1041
+ )
1042
+
1043
+ prob = nn.functional.softmax(out_logits, -1)
1044
+ scores, labels = prob[..., :-1].max(-1)
1045
+
1046
+ # Convert to [x0, y0, x1, y1] format
1047
+ boxes = center_to_corners_format(out_bbox)
1048
+
1049
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
1050
+ if target_sizes is not None:
1051
+ if isinstance(target_sizes, list):
1052
+ img_h = torch.Tensor([i[0] for i in target_sizes])
1053
+ img_w = torch.Tensor([i[1] for i in target_sizes])
1054
+ else:
1055
+ img_h, img_w = target_sizes.unbind(1)
1056
+
1057
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
1058
+ boxes = boxes * scale_fct[:, None, :]
1059
+
1060
+ results = []
1061
+ for s, l, b in zip(scores, labels, boxes):
1062
+ score = s[s > threshold]
1063
+ label = l[s > threshold]
1064
+ box = b[s > threshold]
1065
+ results.append({"scores": score, "labels": label, "boxes": box})
1066
+
1067
+ return results
1068
+
1069
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_semantic_segmentation
1070
+ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None):
1071
+ """
1072
+ Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch.
1073
+
1074
+ Args:
1075
+ outputs ([`DetrForSegmentation`]):
1076
+ Raw outputs of the model.
1077
+ target_sizes (`list[tuple[int, int]]`, *optional*):
1078
+ A list of tuples (`tuple[int, int]`) containing the target size (height, width) of each image in the
1079
+ batch. If unset, predictions will not be resized.
1080
+ Returns:
1081
+ `list[torch.Tensor]`:
1082
+ A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
1083
+ corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
1084
+ `torch.Tensor` correspond to a semantic class id.
1085
+ """
1086
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
1087
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
1088
+
1089
+ # Remove the null class `[..., :-1]`
1090
+ masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
1091
+ masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
1092
+
1093
+ # Semantic segmentation logits of shape (batch_size, num_classes, height, width)
1094
+ segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
1095
+ batch_size = class_queries_logits.shape[0]
1096
+
1097
+ # Resize logits and compute semantic segmentation maps
1098
+ if target_sizes is not None:
1099
+ if batch_size != len(target_sizes):
1100
+ raise ValueError(
1101
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1102
+ )
1103
+
1104
+ semantic_segmentation = []
1105
+ for idx in range(batch_size):
1106
+ resized_logits = nn.functional.interpolate(
1107
+ segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
1108
+ )
1109
+ semantic_map = resized_logits[0].argmax(dim=0)
1110
+ semantic_segmentation.append(semantic_map)
1111
+ else:
1112
+ semantic_segmentation = segmentation.argmax(dim=1)
1113
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
1114
+
1115
+ return semantic_segmentation
1116
+
1117
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance_segmentation
1118
+ def post_process_instance_segmentation(
1119
+ self,
1120
+ outputs,
1121
+ threshold: float = 0.5,
1122
+ mask_threshold: float = 0.5,
1123
+ overlap_mask_area_threshold: float = 0.8,
1124
+ target_sizes: Optional[list[tuple[int, int]]] = None,
1125
+ return_coco_annotation: Optional[bool] = False,
1126
+ ) -> list[dict]:
1127
+ """
1128
+ Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
1129
+
1130
+ Args:
1131
+ outputs ([`DetrForSegmentation`]):
1132
+ Raw outputs of the model.
1133
+ threshold (`float`, *optional*, defaults to 0.5):
1134
+ The probability score threshold to keep predicted instance masks.
1135
+ mask_threshold (`float`, *optional*, defaults to 0.5):
1136
+ Threshold to use when turning the predicted masks into binary values.
1137
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
1138
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
1139
+ instance mask.
1140
+ target_sizes (`list[Tuple]`, *optional*):
1141
+ List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
1142
+ final size (height, width) of each prediction. If unset, predictions will not be resized.
1143
+ return_coco_annotation (`bool`, *optional*):
1144
+ Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
1145
+ format.
1146
+ Returns:
1147
+ `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
1148
+ - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
1149
+ `list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
1150
+ `True`. Set to `None` if no mask if found above `threshold`.
1151
+ - **segments_info** -- A dictionary that contains additional information on each segment.
1152
+ - **id** -- An integer representing the `segment_id`.
1153
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
1154
+ - **score** -- Prediction score of segment with `segment_id`.
1155
+ """
1156
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
1157
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
1158
+
1159
+ batch_size = class_queries_logits.shape[0]
1160
+ num_labels = class_queries_logits.shape[-1] - 1
1161
+
1162
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
1163
+
1164
+ # Predicted label and score of each query (batch_size, num_queries)
1165
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
1166
+
1167
+ # Loop over items in batch size
1168
+ results: list[dict[str, TensorType]] = []
1169
+
1170
+ for i in range(batch_size):
1171
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
1172
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
1173
+ )
1174
+
1175
+ # No mask found
1176
+ if mask_probs_item.shape[0] <= 0:
1177
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
1178
+ segmentation = torch.zeros((height, width)) - 1
1179
+ results.append({"segmentation": segmentation, "segments_info": []})
1180
+ continue
1181
+
1182
+ # Get segmentation map and segment information of batch item
1183
+ target_size = target_sizes[i] if target_sizes is not None else None
1184
+ segmentation, segments = compute_segments(
1185
+ mask_probs=mask_probs_item,
1186
+ pred_scores=pred_scores_item,
1187
+ pred_labels=pred_labels_item,
1188
+ mask_threshold=mask_threshold,
1189
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
1190
+ label_ids_to_fuse=[],
1191
+ target_size=target_size,
1192
+ )
1193
+
1194
+ # Return segmentation map in run-length encoding (RLE) format
1195
+ if return_coco_annotation:
1196
+ segmentation = convert_segmentation_to_rle(segmentation)
1197
+
1198
+ results.append({"segmentation": segmentation, "segments_info": segments})
1199
+ return results
1200
+
1201
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic_segmentation
1202
+ def post_process_panoptic_segmentation(
1203
+ self,
1204
+ outputs,
1205
+ threshold: float = 0.5,
1206
+ mask_threshold: float = 0.5,
1207
+ overlap_mask_area_threshold: float = 0.8,
1208
+ label_ids_to_fuse: Optional[set[int]] = None,
1209
+ target_sizes: Optional[list[tuple[int, int]]] = None,
1210
+ ) -> list[dict]:
1211
+ """
1212
+ Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports
1213
+ PyTorch.
1214
+
1215
+ Args:
1216
+ outputs ([`DetrForSegmentation`]):
1217
+ The outputs from [`DetrForSegmentation`].
1218
+ threshold (`float`, *optional*, defaults to 0.5):
1219
+ The probability score threshold to keep predicted instance masks.
1220
+ mask_threshold (`float`, *optional*, defaults to 0.5):
1221
+ Threshold to use when turning the predicted masks into binary values.
1222
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
1223
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
1224
+ instance mask.
1225
+ label_ids_to_fuse (`Set[int]`, *optional*):
1226
+ The labels in this state will have all their instances be fused together. For instance we could say
1227
+ there can only be one sky in an image, but several persons, so the label ID for sky would be in that
1228
+ set, but not the one for person.
1229
+ target_sizes (`list[Tuple]`, *optional*):
1230
+ List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
1231
+ final size (height, width) of each prediction in batch. If unset, predictions will not be resized.
1232
+ Returns:
1233
+ `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
1234
+ - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or
1235
+ `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to
1236
+ the corresponding `target_sizes` entry.
1237
+ - **segments_info** -- A dictionary that contains additional information on each segment.
1238
+ - **id** -- an integer representing the `segment_id`.
1239
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
1240
+ - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
1241
+ Multiple instances of the same class / label were fused and assigned a single `segment_id`.
1242
+ - **score** -- Prediction score of segment with `segment_id`.
1243
+ """
1244
+
1245
+ if label_ids_to_fuse is None:
1246
+ logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.")
1247
+ label_ids_to_fuse = set()
1248
+
1249
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
1250
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
1251
+
1252
+ batch_size = class_queries_logits.shape[0]
1253
+ num_labels = class_queries_logits.shape[-1] - 1
1254
+
1255
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
1256
+
1257
+ # Predicted label and score of each query (batch_size, num_queries)
1258
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
1259
+
1260
+ # Loop over items in batch size
1261
+ results: list[dict[str, TensorType]] = []
1262
+
1263
+ for i in range(batch_size):
1264
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
1265
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
1266
+ )
1267
+
1268
+ # No mask found
1269
+ if mask_probs_item.shape[0] <= 0:
1270
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
1271
+ segmentation = torch.zeros((height, width)) - 1
1272
+ results.append({"segmentation": segmentation, "segments_info": []})
1273
+ continue
1274
+
1275
+ # Get segmentation map and segment information of batch item
1276
+ target_size = target_sizes[i] if target_sizes is not None else None
1277
+ segmentation, segments = compute_segments(
1278
+ mask_probs=mask_probs_item,
1279
+ pred_scores=pred_scores_item,
1280
+ pred_labels=pred_labels_item,
1281
+ mask_threshold=mask_threshold,
1282
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
1283
+ label_ids_to_fuse=label_ids_to_fuse,
1284
+ target_size=target_size,
1285
+ )
1286
+
1287
+ results.append({"segmentation": segmentation, "segments_info": segments})
1288
+ return results
1289
+
1290
+
1291
+ __all__ = ["DetrImageProcessorFast"]
phivenv/Lib/site-packages/transformers/models/detr/modeling_detr.py ADDED
@@ -0,0 +1,1693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DETR model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Union
20
+
21
+ import torch
22
+ from torch import Tensor, nn
23
+
24
+ from ...activations import ACT2FN
25
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
26
+ from ...modeling_layers import GradientCheckpointingLayer
27
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
28
+ from ...modeling_utils import PreTrainedModel
29
+ from ...utils import (
30
+ ModelOutput,
31
+ auto_docstring,
32
+ is_timm_available,
33
+ logging,
34
+ requires_backends,
35
+ )
36
+ from ...utils.backbone_utils import load_backbone
37
+ from .configuration_detr import DetrConfig
38
+
39
+
40
+ if is_timm_available():
41
+ from timm import create_model
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ @dataclass
48
+ @auto_docstring(
49
+ custom_intro="""
50
+ Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
51
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
52
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
53
+ """
54
+ )
55
+ class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
56
+ r"""
57
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
58
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
59
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
60
+ used to compute the weighted average in the cross-attention heads.
61
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
62
+ Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
63
+ layernorm.
64
+ """
65
+
66
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
67
+
68
+
69
+ @dataclass
70
+ @auto_docstring(
71
+ custom_intro="""
72
+ Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
73
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
74
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
75
+ """
76
+ )
77
+ class DetrModelOutput(Seq2SeqModelOutput):
78
+ r"""
79
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
80
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
81
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
82
+ Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
83
+ layernorm.
84
+ """
85
+
86
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
87
+
88
+
89
+ @dataclass
90
+ @auto_docstring(
91
+ custom_intro="""
92
+ Output type of [`DetrForObjectDetection`].
93
+ """
94
+ )
95
+ class DetrObjectDetectionOutput(ModelOutput):
96
+ r"""
97
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
98
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
99
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
100
+ scale-invariant IoU loss.
101
+ loss_dict (`Dict`, *optional*):
102
+ A dictionary containing the individual losses. Useful for logging.
103
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
104
+ Classification logits (including no-object) for all queries.
105
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
106
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
107
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
108
+ possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
109
+ unnormalized bounding boxes.
110
+ auxiliary_outputs (`list[Dict]`, *optional*):
111
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
112
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
113
+ `pred_boxes`) for each decoder layer.
114
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
115
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
116
+ """
117
+
118
+ loss: Optional[torch.FloatTensor] = None
119
+ loss_dict: Optional[dict] = None
120
+ logits: Optional[torch.FloatTensor] = None
121
+ pred_boxes: Optional[torch.FloatTensor] = None
122
+ auxiliary_outputs: Optional[list[dict]] = None
123
+ last_hidden_state: Optional[torch.FloatTensor] = None
124
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
125
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
126
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
127
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
128
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
129
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
130
+
131
+
132
+ @dataclass
133
+ @auto_docstring(
134
+ custom_intro="""
135
+ Output type of [`DetrForSegmentation`].
136
+ """
137
+ )
138
+ class DetrSegmentationOutput(ModelOutput):
139
+ r"""
140
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
141
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
142
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
143
+ scale-invariant IoU loss.
144
+ loss_dict (`Dict`, *optional*):
145
+ A dictionary containing the individual losses. Useful for logging.
146
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
147
+ Classification logits (including no-object) for all queries.
148
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
149
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
150
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
151
+ possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
152
+ unnormalized bounding boxes.
153
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
154
+ Segmentation masks logits for all queries. See also
155
+ [`~DetrImageProcessor.post_process_semantic_segmentation`] or
156
+ [`~DetrImageProcessor.post_process_instance_segmentation`]
157
+ [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
158
+ segmentation masks respectively.
159
+ auxiliary_outputs (`list[Dict]`, *optional*):
160
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
161
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
162
+ `pred_boxes`) for each decoder layer.
163
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
164
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
165
+ """
166
+
167
+ loss: Optional[torch.FloatTensor] = None
168
+ loss_dict: Optional[dict] = None
169
+ logits: Optional[torch.FloatTensor] = None
170
+ pred_boxes: Optional[torch.FloatTensor] = None
171
+ pred_masks: Optional[torch.FloatTensor] = None
172
+ auxiliary_outputs: Optional[list[dict]] = None
173
+ last_hidden_state: Optional[torch.FloatTensor] = None
174
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
175
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
176
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
177
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
178
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
179
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
180
+
181
+
182
+ # BELOW: utilities copied from
183
+ # https://github.com/facebookresearch/detr/blob/master/backbone.py
184
+ class DetrFrozenBatchNorm2d(nn.Module):
185
+ """
186
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
187
+
188
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
189
+ torchvision.models.resnet[18,34,50,101] produce nans.
190
+ """
191
+
192
+ def __init__(self, n):
193
+ super().__init__()
194
+ self.register_buffer("weight", torch.ones(n))
195
+ self.register_buffer("bias", torch.zeros(n))
196
+ self.register_buffer("running_mean", torch.zeros(n))
197
+ self.register_buffer("running_var", torch.ones(n))
198
+
199
+ def _load_from_state_dict(
200
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
201
+ ):
202
+ num_batches_tracked_key = prefix + "num_batches_tracked"
203
+ if num_batches_tracked_key in state_dict:
204
+ del state_dict[num_batches_tracked_key]
205
+
206
+ super()._load_from_state_dict(
207
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
208
+ )
209
+
210
+ def forward(self, x):
211
+ # move reshapes to the beginning
212
+ # to make it user-friendly
213
+ weight = self.weight.reshape(1, -1, 1, 1)
214
+ bias = self.bias.reshape(1, -1, 1, 1)
215
+ running_var = self.running_var.reshape(1, -1, 1, 1)
216
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
217
+ epsilon = 1e-5
218
+ scale = weight * (running_var + epsilon).rsqrt()
219
+ bias = bias - running_mean * scale
220
+ return x * scale + bias
221
+
222
+
223
+ def replace_batch_norm(model):
224
+ r"""
225
+ Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
226
+
227
+ Args:
228
+ model (torch.nn.Module):
229
+ input model
230
+ """
231
+ for name, module in model.named_children():
232
+ if isinstance(module, nn.BatchNorm2d):
233
+ new_module = DetrFrozenBatchNorm2d(module.num_features)
234
+
235
+ if module.weight.device != torch.device("meta"):
236
+ new_module.weight.data.copy_(module.weight)
237
+ new_module.bias.data.copy_(module.bias)
238
+ new_module.running_mean.data.copy_(module.running_mean)
239
+ new_module.running_var.data.copy_(module.running_var)
240
+
241
+ model._modules[name] = new_module
242
+
243
+ if len(list(module.children())) > 0:
244
+ replace_batch_norm(module)
245
+
246
+
247
+ class DetrConvEncoder(nn.Module):
248
+ """
249
+ Convolutional backbone, using either the AutoBackbone API or one from the timm library.
250
+
251
+ nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
252
+
253
+ """
254
+
255
+ def __init__(self, config):
256
+ super().__init__()
257
+
258
+ self.config = config
259
+
260
+ # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
261
+ if config.use_timm_backbone:
262
+ # We default to values which were previously hard-coded. This enables configurability from the config
263
+ # using backbone arguments, while keeping the default behavior the same.
264
+ requires_backends(self, ["timm"])
265
+ kwargs = getattr(config, "backbone_kwargs", {})
266
+ kwargs = {} if kwargs is None else kwargs.copy()
267
+ out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
268
+ num_channels = kwargs.pop("in_chans", config.num_channels)
269
+ if config.dilation:
270
+ kwargs["output_stride"] = kwargs.get("output_stride", 16)
271
+ backbone = create_model(
272
+ config.backbone,
273
+ pretrained=config.use_pretrained_backbone,
274
+ features_only=True,
275
+ out_indices=out_indices,
276
+ in_chans=num_channels,
277
+ **kwargs,
278
+ )
279
+ else:
280
+ backbone = load_backbone(config)
281
+
282
+ # replace batch norm by frozen batch norm
283
+ with torch.no_grad():
284
+ replace_batch_norm(backbone)
285
+ self.model = backbone
286
+ self.intermediate_channel_sizes = (
287
+ self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
288
+ )
289
+
290
+ backbone_model_type = None
291
+ if config.backbone is not None:
292
+ backbone_model_type = config.backbone
293
+ elif config.backbone_config is not None:
294
+ backbone_model_type = config.backbone_config.model_type
295
+ else:
296
+ raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
297
+
298
+ if "resnet" in backbone_model_type:
299
+ for name, parameter in self.model.named_parameters():
300
+ if config.use_timm_backbone:
301
+ if "layer2" not in name and "layer3" not in name and "layer4" not in name:
302
+ parameter.requires_grad_(False)
303
+ else:
304
+ if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
305
+ parameter.requires_grad_(False)
306
+
307
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
308
+ # send pixel_values through the model to get list of feature maps
309
+ features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
310
+
311
+ out = []
312
+ for feature_map in features:
313
+ # downsample pixel_mask to match shape of corresponding feature_map
314
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
315
+ out.append((feature_map, mask))
316
+ return out
317
+
318
+
319
+ class DetrConvModel(nn.Module):
320
+ """
321
+ This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
322
+ """
323
+
324
+ def __init__(self, conv_encoder, position_embedding):
325
+ super().__init__()
326
+ self.conv_encoder = conv_encoder
327
+ self.position_embedding = position_embedding
328
+
329
+ def forward(self, pixel_values, pixel_mask):
330
+ # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
331
+ out = self.conv_encoder(pixel_values, pixel_mask)
332
+ pos = []
333
+ for feature_map, mask in out:
334
+ # position encoding
335
+ pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
336
+
337
+ return out, pos
338
+
339
+
340
+ class DetrSinePositionEmbedding(nn.Module):
341
+ """
342
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
343
+ need paper, generalized to work on images.
344
+ """
345
+
346
+ def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
347
+ super().__init__()
348
+ self.embedding_dim = embedding_dim
349
+ self.temperature = temperature
350
+ self.normalize = normalize
351
+ if scale is not None and normalize is False:
352
+ raise ValueError("normalize should be True if scale is passed")
353
+ if scale is None:
354
+ scale = 2 * math.pi
355
+ self.scale = scale
356
+
357
+ def forward(self, pixel_values, pixel_mask):
358
+ if pixel_mask is None:
359
+ raise ValueError("No pixel mask provided")
360
+ y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
361
+ x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
362
+ if self.normalize:
363
+ y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
364
+ x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
365
+
366
+ dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
367
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
368
+
369
+ pos_x = x_embed[:, :, :, None] / dim_t
370
+ pos_y = y_embed[:, :, :, None] / dim_t
371
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
372
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
373
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
374
+ return pos
375
+
376
+
377
+ class DetrLearnedPositionEmbedding(nn.Module):
378
+ """
379
+ This module learns positional embeddings up to a fixed maximum size.
380
+ """
381
+
382
+ def __init__(self, embedding_dim=256):
383
+ super().__init__()
384
+ self.row_embeddings = nn.Embedding(50, embedding_dim)
385
+ self.column_embeddings = nn.Embedding(50, embedding_dim)
386
+
387
+ def forward(self, pixel_values, pixel_mask=None):
388
+ height, width = pixel_values.shape[-2:]
389
+ width_values = torch.arange(width, device=pixel_values.device)
390
+ height_values = torch.arange(height, device=pixel_values.device)
391
+ x_emb = self.column_embeddings(width_values)
392
+ y_emb = self.row_embeddings(height_values)
393
+ pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
394
+ pos = pos.permute(2, 0, 1)
395
+ pos = pos.unsqueeze(0)
396
+ pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
397
+ return pos
398
+
399
+
400
+ def build_position_encoding(config):
401
+ n_steps = config.d_model // 2
402
+ if config.position_embedding_type == "sine":
403
+ # TODO find a better way of exposing other arguments
404
+ position_embedding = DetrSinePositionEmbedding(n_steps, normalize=True)
405
+ elif config.position_embedding_type == "learned":
406
+ position_embedding = DetrLearnedPositionEmbedding(n_steps)
407
+ else:
408
+ raise ValueError(f"Not supported {config.position_embedding_type}")
409
+
410
+ return position_embedding
411
+
412
+
413
+ class DetrAttention(nn.Module):
414
+ """
415
+ Multi-headed attention from 'Attention Is All You Need' paper.
416
+
417
+ Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
418
+ """
419
+
420
+ def __init__(
421
+ self,
422
+ embed_dim: int,
423
+ num_heads: int,
424
+ dropout: float = 0.0,
425
+ bias: bool = True,
426
+ ):
427
+ super().__init__()
428
+ self.embed_dim = embed_dim
429
+ self.num_heads = num_heads
430
+ self.dropout = dropout
431
+ self.head_dim = embed_dim // num_heads
432
+ if self.head_dim * num_heads != self.embed_dim:
433
+ raise ValueError(
434
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
435
+ f" {num_heads})."
436
+ )
437
+ self.scaling = self.head_dim**-0.5
438
+
439
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
440
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
441
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
442
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
443
+
444
+ def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
445
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
446
+
447
+ def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
448
+ return tensor if object_queries is None else tensor + object_queries
449
+
450
+ def forward(
451
+ self,
452
+ hidden_states: torch.Tensor,
453
+ attention_mask: Optional[torch.Tensor] = None,
454
+ object_queries: Optional[torch.Tensor] = None,
455
+ key_value_states: Optional[torch.Tensor] = None,
456
+ spatial_position_embeddings: Optional[torch.Tensor] = None,
457
+ output_attentions: bool = False,
458
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
459
+ """Input shape: Batch x Time x Channel"""
460
+ # if key_value_states are provided this layer is used as a cross-attention layer
461
+ # for the decoder
462
+ is_cross_attention = key_value_states is not None
463
+ batch_size, target_len, embed_dim = hidden_states.size()
464
+
465
+ # add position embeddings to the hidden states before projecting to queries and keys
466
+ if object_queries is not None:
467
+ hidden_states_original = hidden_states
468
+ hidden_states = self.with_pos_embed(hidden_states, object_queries)
469
+
470
+ # add key-value position embeddings to the key value states
471
+ if spatial_position_embeddings is not None:
472
+ key_value_states_original = key_value_states
473
+ key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
474
+
475
+ # get query proj
476
+ query_states = self.q_proj(hidden_states) * self.scaling
477
+ # get key, value proj
478
+ if is_cross_attention:
479
+ # cross_attentions
480
+ key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
481
+ value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
482
+ else:
483
+ # self_attention
484
+ key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
485
+ value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
486
+
487
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
488
+ query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
489
+ key_states = key_states.view(*proj_shape)
490
+ value_states = value_states.view(*proj_shape)
491
+
492
+ source_len = key_states.size(1)
493
+
494
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
495
+
496
+ if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
497
+ raise ValueError(
498
+ f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
499
+ f" {attn_weights.size()}"
500
+ )
501
+
502
+ if attention_mask is not None:
503
+ if attention_mask.size() != (batch_size, 1, target_len, source_len):
504
+ raise ValueError(
505
+ f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
506
+ f" {attention_mask.size()}"
507
+ )
508
+ if attention_mask.dtype == torch.bool:
509
+ attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
510
+ attention_mask, -torch.inf
511
+ )
512
+ attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
513
+ attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
514
+
515
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
516
+
517
+ if output_attentions:
518
+ # this operation is a bit awkward, but it's required to
519
+ # make sure that attn_weights keeps its gradient.
520
+ # In order to do so, attn_weights have to reshaped
521
+ # twice and have to be reused in the following
522
+ attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
523
+ attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
524
+ else:
525
+ attn_weights_reshaped = None
526
+
527
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
528
+
529
+ attn_output = torch.bmm(attn_probs, value_states)
530
+
531
+ if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
532
+ raise ValueError(
533
+ f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
534
+ f" {attn_output.size()}"
535
+ )
536
+
537
+ attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
538
+ attn_output = attn_output.transpose(1, 2)
539
+ attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
540
+
541
+ attn_output = self.out_proj(attn_output)
542
+
543
+ return attn_output, attn_weights_reshaped
544
+
545
+
546
+ class DetrEncoderLayer(nn.Module):
547
+ def __init__(self, config: DetrConfig):
548
+ super().__init__()
549
+ self.embed_dim = config.d_model
550
+ self.self_attn = DetrAttention(
551
+ embed_dim=self.embed_dim,
552
+ num_heads=config.encoder_attention_heads,
553
+ dropout=config.attention_dropout,
554
+ )
555
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
556
+ self.dropout = config.dropout
557
+ self.activation_fn = ACT2FN[config.activation_function]
558
+ self.activation_dropout = config.activation_dropout
559
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
560
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
561
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
562
+
563
+ def forward(
564
+ self,
565
+ hidden_states: torch.Tensor,
566
+ attention_mask: torch.Tensor,
567
+ object_queries: Optional[torch.Tensor] = None,
568
+ output_attentions: bool = False,
569
+ ):
570
+ """
571
+ Args:
572
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
573
+ attention_mask (`torch.FloatTensor`): attention mask of size
574
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
575
+ values.
576
+ object_queries (`torch.FloatTensor`, *optional*):
577
+ Object queries (also called content embeddings), to be added to the hidden states.
578
+ output_attentions (`bool`, *optional*):
579
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
580
+ returned tensors for more detail.
581
+ """
582
+ residual = hidden_states
583
+ hidden_states, attn_weights = self.self_attn(
584
+ hidden_states=hidden_states,
585
+ attention_mask=attention_mask,
586
+ object_queries=object_queries,
587
+ output_attentions=output_attentions,
588
+ )
589
+
590
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
591
+ hidden_states = residual + hidden_states
592
+ hidden_states = self.self_attn_layer_norm(hidden_states)
593
+
594
+ residual = hidden_states
595
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
596
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
597
+
598
+ hidden_states = self.fc2(hidden_states)
599
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
600
+
601
+ hidden_states = residual + hidden_states
602
+ hidden_states = self.final_layer_norm(hidden_states)
603
+
604
+ if self.training:
605
+ if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
606
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
607
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
608
+
609
+ outputs = (hidden_states,)
610
+
611
+ if output_attentions:
612
+ outputs += (attn_weights,)
613
+
614
+ return outputs
615
+
616
+
617
+ class DetrDecoderLayer(GradientCheckpointingLayer):
618
+ def __init__(self, config: DetrConfig):
619
+ super().__init__()
620
+ self.embed_dim = config.d_model
621
+
622
+ self.self_attn = DetrAttention(
623
+ embed_dim=self.embed_dim,
624
+ num_heads=config.decoder_attention_heads,
625
+ dropout=config.attention_dropout,
626
+ )
627
+ self.dropout = config.dropout
628
+ self.activation_fn = ACT2FN[config.activation_function]
629
+ self.activation_dropout = config.activation_dropout
630
+
631
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
632
+ self.encoder_attn = DetrAttention(
633
+ self.embed_dim,
634
+ config.decoder_attention_heads,
635
+ dropout=config.attention_dropout,
636
+ )
637
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
638
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
639
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
640
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
641
+
642
+ def forward(
643
+ self,
644
+ hidden_states: torch.Tensor,
645
+ attention_mask: Optional[torch.Tensor] = None,
646
+ object_queries: Optional[torch.Tensor] = None,
647
+ query_position_embeddings: Optional[torch.Tensor] = None,
648
+ encoder_hidden_states: Optional[torch.Tensor] = None,
649
+ encoder_attention_mask: Optional[torch.Tensor] = None,
650
+ output_attentions: Optional[bool] = False,
651
+ ):
652
+ """
653
+ Args:
654
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
655
+ attention_mask (`torch.FloatTensor`): attention mask of size
656
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
657
+ values.
658
+ object_queries (`torch.FloatTensor`, *optional*):
659
+ object_queries that are added to the hidden states
660
+ in the cross-attention layer.
661
+ query_position_embeddings (`torch.FloatTensor`, *optional*):
662
+ position embeddings that are added to the queries and keys
663
+ in the self-attention layer.
664
+ encoder_hidden_states (`torch.FloatTensor`):
665
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
666
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
667
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
668
+ values.
669
+ output_attentions (`bool`, *optional*):
670
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
671
+ returned tensors for more detail.
672
+ """
673
+ residual = hidden_states
674
+
675
+ # Self Attention
676
+ hidden_states, self_attn_weights = self.self_attn(
677
+ hidden_states=hidden_states,
678
+ object_queries=query_position_embeddings,
679
+ attention_mask=attention_mask,
680
+ output_attentions=output_attentions,
681
+ )
682
+
683
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
684
+ hidden_states = residual + hidden_states
685
+ hidden_states = self.self_attn_layer_norm(hidden_states)
686
+
687
+ # Cross-Attention Block
688
+ cross_attn_weights = None
689
+ if encoder_hidden_states is not None:
690
+ residual = hidden_states
691
+
692
+ hidden_states, cross_attn_weights = self.encoder_attn(
693
+ hidden_states=hidden_states,
694
+ object_queries=query_position_embeddings,
695
+ key_value_states=encoder_hidden_states,
696
+ attention_mask=encoder_attention_mask,
697
+ spatial_position_embeddings=object_queries,
698
+ output_attentions=output_attentions,
699
+ )
700
+
701
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
702
+ hidden_states = residual + hidden_states
703
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
704
+
705
+ # Fully Connected
706
+ residual = hidden_states
707
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
708
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
709
+ hidden_states = self.fc2(hidden_states)
710
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
711
+ hidden_states = residual + hidden_states
712
+ hidden_states = self.final_layer_norm(hidden_states)
713
+
714
+ outputs = (hidden_states,)
715
+
716
+ if output_attentions:
717
+ outputs += (self_attn_weights, cross_attn_weights)
718
+
719
+ return outputs
720
+
721
+
722
+ @auto_docstring
723
+ class DetrPreTrainedModel(PreTrainedModel):
724
+ config: DetrConfig
725
+ base_model_prefix = "model"
726
+ main_input_name = "pixel_values"
727
+ _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
728
+
729
+ def _init_weights(self, module):
730
+ std = self.config.init_std
731
+ xavier_std = self.config.init_xavier_std
732
+
733
+ if isinstance(module, DetrMHAttentionMap):
734
+ nn.init.zeros_(module.k_linear.bias)
735
+ nn.init.zeros_(module.q_linear.bias)
736
+ nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
737
+ nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
738
+ elif isinstance(module, DetrLearnedPositionEmbedding):
739
+ nn.init.uniform_(module.row_embeddings.weight)
740
+ nn.init.uniform_(module.column_embeddings.weight)
741
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
742
+ # Slightly different from the TF version which uses truncated_normal for initialization
743
+ # cf https://github.com/pytorch/pytorch/pull/5617
744
+ module.weight.data.normal_(mean=0.0, std=std)
745
+ if module.bias is not None:
746
+ module.bias.data.zero_()
747
+ elif isinstance(module, nn.Embedding):
748
+ module.weight.data.normal_(mean=0.0, std=std)
749
+ if module.padding_idx is not None:
750
+ module.weight.data[module.padding_idx].zero_()
751
+
752
+
753
+ class DetrEncoder(DetrPreTrainedModel):
754
+ """
755
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
756
+ [`DetrEncoderLayer`].
757
+
758
+ The encoder updates the flattened feature map through multiple self-attention layers.
759
+
760
+ Small tweak for DETR:
761
+
762
+ - object_queries are added to the forward pass.
763
+
764
+ Args:
765
+ config: DetrConfig
766
+ """
767
+
768
+ def __init__(self, config: DetrConfig):
769
+ super().__init__(config)
770
+
771
+ self.dropout = config.dropout
772
+ self.layerdrop = config.encoder_layerdrop
773
+
774
+ self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
775
+
776
+ # in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
777
+
778
+ # Initialize weights and apply final processing
779
+ self.post_init()
780
+
781
+ def forward(
782
+ self,
783
+ inputs_embeds=None,
784
+ attention_mask=None,
785
+ object_queries=None,
786
+ output_attentions=None,
787
+ output_hidden_states=None,
788
+ return_dict=None,
789
+ ):
790
+ r"""
791
+ Args:
792
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
793
+ Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
794
+
795
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
796
+ Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
797
+
798
+ - 1 for pixel features that are real (i.e. **not masked**),
799
+ - 0 for pixel features that are padding (i.e. **masked**).
800
+
801
+ [What are attention masks?](../glossary#attention-mask)
802
+
803
+ object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
804
+ Object queries that are added to the queries in each self-attention layer.
805
+
806
+ output_attentions (`bool`, *optional*):
807
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
808
+ returned tensors for more detail.
809
+ output_hidden_states (`bool`, *optional*):
810
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
811
+ for more detail.
812
+ return_dict (`bool`, *optional*):
813
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
814
+ """
815
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
816
+ output_hidden_states = (
817
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
818
+ )
819
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
820
+
821
+ hidden_states = inputs_embeds
822
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
823
+
824
+ # expand attention_mask
825
+ if attention_mask is not None:
826
+ # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
827
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
828
+
829
+ encoder_states = () if output_hidden_states else None
830
+ all_attentions = () if output_attentions else None
831
+ for i, encoder_layer in enumerate(self.layers):
832
+ if output_hidden_states:
833
+ encoder_states = encoder_states + (hidden_states,)
834
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
835
+ to_drop = False
836
+ if self.training:
837
+ dropout_probability = torch.rand([])
838
+ if dropout_probability < self.layerdrop: # skip the layer
839
+ to_drop = True
840
+
841
+ if to_drop:
842
+ layer_outputs = (None, None)
843
+ else:
844
+ # we add object_queries as extra input to the encoder_layer
845
+ layer_outputs = encoder_layer(
846
+ hidden_states,
847
+ attention_mask,
848
+ object_queries=object_queries,
849
+ output_attentions=output_attentions,
850
+ )
851
+
852
+ hidden_states = layer_outputs[0]
853
+
854
+ if output_attentions:
855
+ all_attentions = all_attentions + (layer_outputs[1],)
856
+
857
+ if output_hidden_states:
858
+ encoder_states = encoder_states + (hidden_states,)
859
+
860
+ if not return_dict:
861
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
862
+ return BaseModelOutput(
863
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
864
+ )
865
+
866
+
867
+ class DetrDecoder(DetrPreTrainedModel):
868
+ """
869
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
870
+
871
+ The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
872
+
873
+ Some small tweaks for DETR:
874
+
875
+ - object_queries and query_position_embeddings are added to the forward pass.
876
+ - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
877
+
878
+ Args:
879
+ config: DetrConfig
880
+ """
881
+
882
+ def __init__(self, config: DetrConfig):
883
+ super().__init__(config)
884
+ self.dropout = config.dropout
885
+ self.layerdrop = config.decoder_layerdrop
886
+
887
+ self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
888
+ # in DETR, the decoder uses layernorm after the last decoder layer output
889
+ self.layernorm = nn.LayerNorm(config.d_model)
890
+
891
+ self.gradient_checkpointing = False
892
+ # Initialize weights and apply final processing
893
+ self.post_init()
894
+
895
+ def forward(
896
+ self,
897
+ inputs_embeds=None,
898
+ attention_mask=None,
899
+ encoder_hidden_states=None,
900
+ encoder_attention_mask=None,
901
+ object_queries=None,
902
+ query_position_embeddings=None,
903
+ output_attentions=None,
904
+ output_hidden_states=None,
905
+ return_dict=None,
906
+ ):
907
+ r"""
908
+ Args:
909
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
910
+ The query embeddings that are passed into the decoder.
911
+
912
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
913
+ Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
914
+
915
+ - 1 for queries that are **not masked**,
916
+ - 0 for queries that are **masked**.
917
+
918
+ [What are attention masks?](../glossary#attention-mask)
919
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
920
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
921
+ of the decoder.
922
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
923
+ Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
924
+ in `[0, 1]`:
925
+
926
+ - 1 for pixels that are real (i.e. **not masked**),
927
+ - 0 for pixels that are padding (i.e. **masked**).
928
+
929
+ object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
930
+ Object queries that are added to the queries and keys in each cross-attention layer.
931
+ query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
932
+ , *optional*): Position embeddings that are added to the values and keys in each self-attention layer.
933
+
934
+ output_attentions (`bool`, *optional*):
935
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
936
+ returned tensors for more detail.
937
+ output_hidden_states (`bool`, *optional*):
938
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
939
+ for more detail.
940
+ return_dict (`bool`, *optional*):
941
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
942
+ """
943
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
944
+ output_hidden_states = (
945
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
946
+ )
947
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
948
+
949
+ if inputs_embeds is not None:
950
+ hidden_states = inputs_embeds
951
+ input_shape = inputs_embeds.size()[:-1]
952
+
953
+ combined_attention_mask = None
954
+
955
+ if attention_mask is not None and combined_attention_mask is not None:
956
+ # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
957
+ combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask(
958
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
959
+ )
960
+
961
+ # expand encoder attention mask
962
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
963
+ # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
964
+ encoder_attention_mask = _prepare_4d_attention_mask(
965
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
966
+ )
967
+
968
+ # optional intermediate hidden states
969
+ intermediate = () if self.config.auxiliary_loss else None
970
+
971
+ # decoder layers
972
+ all_hidden_states = () if output_hidden_states else None
973
+ all_self_attns = () if output_attentions else None
974
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
975
+
976
+ for idx, decoder_layer in enumerate(self.layers):
977
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
978
+ if output_hidden_states:
979
+ all_hidden_states += (hidden_states,)
980
+ if self.training:
981
+ dropout_probability = torch.rand([])
982
+ if dropout_probability < self.layerdrop:
983
+ continue
984
+
985
+ layer_outputs = decoder_layer(
986
+ hidden_states,
987
+ combined_attention_mask,
988
+ object_queries,
989
+ query_position_embeddings,
990
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
991
+ encoder_attention_mask=encoder_attention_mask,
992
+ output_attentions=output_attentions,
993
+ )
994
+
995
+ hidden_states = layer_outputs[0]
996
+
997
+ if self.config.auxiliary_loss:
998
+ hidden_states = self.layernorm(hidden_states)
999
+ intermediate += (hidden_states,)
1000
+
1001
+ if output_attentions:
1002
+ all_self_attns += (layer_outputs[1],)
1003
+
1004
+ if encoder_hidden_states is not None:
1005
+ all_cross_attentions += (layer_outputs[2],)
1006
+
1007
+ # finally, apply layernorm
1008
+ hidden_states = self.layernorm(hidden_states)
1009
+
1010
+ # add hidden states from the last decoder layer
1011
+ if output_hidden_states:
1012
+ all_hidden_states += (hidden_states,)
1013
+
1014
+ # stack intermediate decoder activations
1015
+ if self.config.auxiliary_loss:
1016
+ intermediate = torch.stack(intermediate)
1017
+
1018
+ if not return_dict:
1019
+ return tuple(
1020
+ v
1021
+ for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
1022
+ if v is not None
1023
+ )
1024
+ return DetrDecoderOutput(
1025
+ last_hidden_state=hidden_states,
1026
+ hidden_states=all_hidden_states,
1027
+ attentions=all_self_attns,
1028
+ cross_attentions=all_cross_attentions,
1029
+ intermediate_hidden_states=intermediate,
1030
+ )
1031
+
1032
+
1033
+ @auto_docstring(
1034
+ custom_intro="""
1035
+ The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
1036
+ any specific head on top.
1037
+ """
1038
+ )
1039
+ class DetrModel(DetrPreTrainedModel):
1040
+ def __init__(self, config: DetrConfig):
1041
+ super().__init__(config)
1042
+
1043
+ # Create backbone + positional encoding
1044
+ backbone = DetrConvEncoder(config)
1045
+ object_queries = build_position_encoding(config)
1046
+ self.backbone = DetrConvModel(backbone, object_queries)
1047
+
1048
+ # Create projection layer
1049
+ self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
1050
+
1051
+ self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
1052
+
1053
+ self.encoder = DetrEncoder(config)
1054
+ self.decoder = DetrDecoder(config)
1055
+
1056
+ # Initialize weights and apply final processing
1057
+ self.post_init()
1058
+
1059
+ def get_encoder(self):
1060
+ return self.encoder
1061
+
1062
+ def freeze_backbone(self):
1063
+ for name, param in self.backbone.conv_encoder.model.named_parameters():
1064
+ param.requires_grad_(False)
1065
+
1066
+ def unfreeze_backbone(self):
1067
+ for name, param in self.backbone.conv_encoder.model.named_parameters():
1068
+ param.requires_grad_(True)
1069
+
1070
+ @auto_docstring
1071
+ def forward(
1072
+ self,
1073
+ pixel_values: torch.FloatTensor,
1074
+ pixel_mask: Optional[torch.LongTensor] = None,
1075
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
1076
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1077
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1078
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1079
+ output_attentions: Optional[bool] = None,
1080
+ output_hidden_states: Optional[bool] = None,
1081
+ return_dict: Optional[bool] = None,
1082
+ ) -> Union[tuple[torch.FloatTensor], DetrModelOutput]:
1083
+ r"""
1084
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1085
+ Not used by default. Can be used to mask object queries.
1086
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1087
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1088
+ can choose to directly pass a flattened representation of an image.
1089
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1090
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1091
+ embedded representation.
1092
+
1093
+ Examples:
1094
+
1095
+ ```python
1096
+ >>> from transformers import AutoImageProcessor, DetrModel
1097
+ >>> from PIL import Image
1098
+ >>> import requests
1099
+
1100
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1101
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1102
+
1103
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
1104
+ >>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
1105
+
1106
+ >>> # prepare image for the model
1107
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1108
+
1109
+ >>> # forward pass
1110
+ >>> outputs = model(**inputs)
1111
+
1112
+ >>> # the last hidden states are the final query embeddings of the Transformer decoder
1113
+ >>> # these are of shape (batch_size, num_queries, hidden_size)
1114
+ >>> last_hidden_states = outputs.last_hidden_state
1115
+ >>> list(last_hidden_states.shape)
1116
+ [1, 100, 256]
1117
+ ```"""
1118
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1119
+ output_hidden_states = (
1120
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1121
+ )
1122
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1123
+
1124
+ batch_size, num_channels, height, width = pixel_values.shape
1125
+ device = pixel_values.device
1126
+
1127
+ if pixel_mask is None:
1128
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1129
+
1130
+ # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1131
+ # pixel_values should be of shape (batch_size, num_channels, height, width)
1132
+ # pixel_mask should be of shape (batch_size, height, width)
1133
+ features, object_queries_list = self.backbone(pixel_values, pixel_mask)
1134
+
1135
+ # get final feature map and downsampled mask
1136
+ feature_map, mask = features[-1]
1137
+
1138
+ if mask is None:
1139
+ raise ValueError("Backbone does not return downsampled pixel mask")
1140
+
1141
+ # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1142
+ projected_feature_map = self.input_projection(feature_map)
1143
+
1144
+ # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1145
+ # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1146
+ flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1147
+ object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1148
+
1149
+ flattened_mask = mask.flatten(1)
1150
+
1151
+ # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
1152
+ # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
1153
+ # flattened_mask is a Tensor of shape (batch_size, height*width)
1154
+ if encoder_outputs is None:
1155
+ encoder_outputs = self.encoder(
1156
+ inputs_embeds=flattened_features,
1157
+ attention_mask=flattened_mask,
1158
+ object_queries=object_queries,
1159
+ output_attentions=output_attentions,
1160
+ output_hidden_states=output_hidden_states,
1161
+ return_dict=return_dict,
1162
+ )
1163
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1164
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1165
+ encoder_outputs = BaseModelOutput(
1166
+ last_hidden_state=encoder_outputs[0],
1167
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1168
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1169
+ )
1170
+
1171
+ # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
1172
+ query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
1173
+ queries = torch.zeros_like(query_position_embeddings)
1174
+
1175
+ # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1176
+ decoder_outputs = self.decoder(
1177
+ inputs_embeds=queries,
1178
+ attention_mask=None,
1179
+ object_queries=object_queries,
1180
+ query_position_embeddings=query_position_embeddings,
1181
+ encoder_hidden_states=encoder_outputs[0],
1182
+ encoder_attention_mask=flattened_mask,
1183
+ output_attentions=output_attentions,
1184
+ output_hidden_states=output_hidden_states,
1185
+ return_dict=return_dict,
1186
+ )
1187
+
1188
+ if not return_dict:
1189
+ return decoder_outputs + encoder_outputs
1190
+
1191
+ return DetrModelOutput(
1192
+ last_hidden_state=decoder_outputs.last_hidden_state,
1193
+ decoder_hidden_states=decoder_outputs.hidden_states,
1194
+ decoder_attentions=decoder_outputs.attentions,
1195
+ cross_attentions=decoder_outputs.cross_attentions,
1196
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1197
+ encoder_hidden_states=encoder_outputs.hidden_states,
1198
+ encoder_attentions=encoder_outputs.attentions,
1199
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
1200
+ )
1201
+
1202
+
1203
+ # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1204
+ class DetrMLPPredictionHead(nn.Module):
1205
+ """
1206
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1207
+ height and width of a bounding box w.r.t. an image.
1208
+
1209
+ Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1210
+
1211
+ """
1212
+
1213
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
1214
+ super().__init__()
1215
+ self.num_layers = num_layers
1216
+ h = [hidden_dim] * (num_layers - 1)
1217
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1218
+
1219
+ def forward(self, x):
1220
+ for i, layer in enumerate(self.layers):
1221
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1222
+ return x
1223
+
1224
+
1225
+ @auto_docstring(
1226
+ custom_intro="""
1227
+ DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
1228
+ such as COCO detection.
1229
+ """
1230
+ )
1231
+ class DetrForObjectDetection(DetrPreTrainedModel):
1232
+ def __init__(self, config: DetrConfig):
1233
+ super().__init__(config)
1234
+
1235
+ # DETR encoder-decoder model
1236
+ self.model = DetrModel(config)
1237
+
1238
+ # Object detection heads
1239
+ self.class_labels_classifier = nn.Linear(
1240
+ config.d_model, config.num_labels + 1
1241
+ ) # We add one for the "no object" class
1242
+ self.bbox_predictor = DetrMLPPredictionHead(
1243
+ input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
1244
+ )
1245
+
1246
+ # Initialize weights and apply final processing
1247
+ self.post_init()
1248
+
1249
+ @auto_docstring
1250
+ def forward(
1251
+ self,
1252
+ pixel_values: torch.FloatTensor,
1253
+ pixel_mask: Optional[torch.LongTensor] = None,
1254
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
1255
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1256
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1257
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1258
+ labels: Optional[list[dict]] = None,
1259
+ output_attentions: Optional[bool] = None,
1260
+ output_hidden_states: Optional[bool] = None,
1261
+ return_dict: Optional[bool] = None,
1262
+ ) -> Union[tuple[torch.FloatTensor], DetrObjectDetectionOutput]:
1263
+ r"""
1264
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1265
+ Not used by default. Can be used to mask object queries.
1266
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1267
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1268
+ can choose to directly pass a flattened representation of an image.
1269
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1270
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1271
+ embedded representation.
1272
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1273
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1274
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1275
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1276
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1277
+
1278
+ Examples:
1279
+
1280
+ ```python
1281
+ >>> from transformers import AutoImageProcessor, DetrForObjectDetection
1282
+ >>> import torch
1283
+ >>> from PIL import Image
1284
+ >>> import requests
1285
+
1286
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1287
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1288
+
1289
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
1290
+ >>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
1291
+
1292
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1293
+ >>> outputs = model(**inputs)
1294
+
1295
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
1296
+ >>> target_sizes = torch.tensor([image.size[::-1]])
1297
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
1298
+ ... 0
1299
+ ... ]
1300
+
1301
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
1302
+ ... box = [round(i, 2) for i in box.tolist()]
1303
+ ... print(
1304
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
1305
+ ... f"{round(score.item(), 3)} at location {box}"
1306
+ ... )
1307
+ Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
1308
+ Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
1309
+ Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
1310
+ Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
1311
+ Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
1312
+ ```"""
1313
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1314
+
1315
+ # First, sent images through DETR base model to obtain encoder + decoder outputs
1316
+ outputs = self.model(
1317
+ pixel_values,
1318
+ pixel_mask=pixel_mask,
1319
+ decoder_attention_mask=decoder_attention_mask,
1320
+ encoder_outputs=encoder_outputs,
1321
+ inputs_embeds=inputs_embeds,
1322
+ decoder_inputs_embeds=decoder_inputs_embeds,
1323
+ output_attentions=output_attentions,
1324
+ output_hidden_states=output_hidden_states,
1325
+ return_dict=return_dict,
1326
+ )
1327
+
1328
+ sequence_output = outputs[0]
1329
+
1330
+ # class logits + predicted bounding boxes
1331
+ logits = self.class_labels_classifier(sequence_output)
1332
+ pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
1333
+
1334
+ loss, loss_dict, auxiliary_outputs = None, None, None
1335
+ if labels is not None:
1336
+ outputs_class, outputs_coord = None, None
1337
+ if self.config.auxiliary_loss:
1338
+ intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
1339
+ outputs_class = self.class_labels_classifier(intermediate)
1340
+ outputs_coord = self.bbox_predictor(intermediate).sigmoid()
1341
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
1342
+ logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
1343
+ )
1344
+
1345
+ if not return_dict:
1346
+ if auxiliary_outputs is not None:
1347
+ output = (logits, pred_boxes) + auxiliary_outputs + outputs
1348
+ else:
1349
+ output = (logits, pred_boxes) + outputs
1350
+ return ((loss, loss_dict) + output) if loss is not None else output
1351
+
1352
+ return DetrObjectDetectionOutput(
1353
+ loss=loss,
1354
+ loss_dict=loss_dict,
1355
+ logits=logits,
1356
+ pred_boxes=pred_boxes,
1357
+ auxiliary_outputs=auxiliary_outputs,
1358
+ last_hidden_state=outputs.last_hidden_state,
1359
+ decoder_hidden_states=outputs.decoder_hidden_states,
1360
+ decoder_attentions=outputs.decoder_attentions,
1361
+ cross_attentions=outputs.cross_attentions,
1362
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1363
+ encoder_hidden_states=outputs.encoder_hidden_states,
1364
+ encoder_attentions=outputs.encoder_attentions,
1365
+ )
1366
+
1367
+
1368
+ @auto_docstring(
1369
+ custom_intro="""
1370
+ DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
1371
+ such as COCO panoptic.
1372
+ """
1373
+ )
1374
+ class DetrForSegmentation(DetrPreTrainedModel):
1375
+ def __init__(self, config: DetrConfig):
1376
+ super().__init__(config)
1377
+
1378
+ # object detection model
1379
+ self.detr = DetrForObjectDetection(config)
1380
+
1381
+ # segmentation head
1382
+ hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
1383
+ intermediate_channel_sizes = self.detr.model.backbone.conv_encoder.intermediate_channel_sizes
1384
+
1385
+ self.mask_head = DetrMaskHeadSmallConv(
1386
+ hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size
1387
+ )
1388
+
1389
+ self.bbox_attention = DetrMHAttentionMap(
1390
+ hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
1391
+ )
1392
+ # Initialize weights and apply final processing
1393
+ self.post_init()
1394
+
1395
+ @auto_docstring
1396
+ def forward(
1397
+ self,
1398
+ pixel_values: torch.FloatTensor,
1399
+ pixel_mask: Optional[torch.LongTensor] = None,
1400
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
1401
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1402
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1403
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1404
+ labels: Optional[list[dict]] = None,
1405
+ output_attentions: Optional[bool] = None,
1406
+ output_hidden_states: Optional[bool] = None,
1407
+ return_dict: Optional[bool] = None,
1408
+ ) -> Union[tuple[torch.FloatTensor], DetrSegmentationOutput]:
1409
+ r"""
1410
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1411
+ Not used by default. Can be used to mask object queries.
1412
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1413
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1414
+ can choose to directly pass a flattened representation of an image.
1415
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1416
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1417
+ embedded representation.
1418
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1419
+ Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
1420
+ dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
1421
+ bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves
1422
+ should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a
1423
+ `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a
1424
+ `torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`.
1425
+
1426
+ Examples:
1427
+
1428
+ ```python
1429
+ >>> import io
1430
+ >>> import requests
1431
+ >>> from PIL import Image
1432
+ >>> import torch
1433
+ >>> import numpy
1434
+
1435
+ >>> from transformers import AutoImageProcessor, DetrForSegmentation
1436
+ >>> from transformers.image_transforms import rgb_to_id
1437
+
1438
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1439
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1440
+
1441
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
1442
+ >>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
1443
+
1444
+ >>> # prepare image for the model
1445
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1446
+
1447
+ >>> # forward pass
1448
+ >>> outputs = model(**inputs)
1449
+
1450
+ >>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
1451
+ >>> # Segmentation results are returned as a list of dictionaries
1452
+ >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
1453
+
1454
+ >>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
1455
+ >>> panoptic_seg = result[0]["segmentation"]
1456
+ >>> # Get prediction score and segment_id to class_id mapping of each segment
1457
+ >>> panoptic_segments_info = result[0]["segments_info"]
1458
+ ```"""
1459
+
1460
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1461
+
1462
+ batch_size, num_channels, height, width = pixel_values.shape
1463
+ device = pixel_values.device
1464
+
1465
+ if pixel_mask is None:
1466
+ pixel_mask = torch.ones((batch_size, height, width), device=device)
1467
+
1468
+ # First, get list of feature maps and position embeddings
1469
+ features, object_queries_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
1470
+
1471
+ # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1472
+ feature_map, mask = features[-1]
1473
+ batch_size, num_channels, height, width = feature_map.shape
1474
+ projected_feature_map = self.detr.model.input_projection(feature_map)
1475
+
1476
+ # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1477
+ # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1478
+ flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1479
+ object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1480
+
1481
+ flattened_mask = mask.flatten(1)
1482
+
1483
+ # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
1484
+ # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
1485
+ # flattened_mask is a Tensor of shape (batch_size, height*width)
1486
+ if encoder_outputs is None:
1487
+ encoder_outputs = self.detr.model.encoder(
1488
+ inputs_embeds=flattened_features,
1489
+ attention_mask=flattened_mask,
1490
+ object_queries=object_queries,
1491
+ output_attentions=output_attentions,
1492
+ output_hidden_states=output_hidden_states,
1493
+ return_dict=return_dict,
1494
+ )
1495
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1496
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1497
+ encoder_outputs = BaseModelOutput(
1498
+ last_hidden_state=encoder_outputs[0],
1499
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1500
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1501
+ )
1502
+
1503
+ # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
1504
+ query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
1505
+ batch_size, 1, 1
1506
+ )
1507
+ queries = torch.zeros_like(query_position_embeddings)
1508
+
1509
+ # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1510
+ decoder_outputs = self.detr.model.decoder(
1511
+ inputs_embeds=queries,
1512
+ attention_mask=None,
1513
+ object_queries=object_queries,
1514
+ query_position_embeddings=query_position_embeddings,
1515
+ encoder_hidden_states=encoder_outputs[0],
1516
+ encoder_attention_mask=flattened_mask,
1517
+ output_attentions=output_attentions,
1518
+ output_hidden_states=output_hidden_states,
1519
+ return_dict=return_dict,
1520
+ )
1521
+
1522
+ sequence_output = decoder_outputs[0]
1523
+
1524
+ # Sixth, compute logits, pred_boxes and pred_masks
1525
+ logits = self.detr.class_labels_classifier(sequence_output)
1526
+ pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
1527
+
1528
+ memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width)
1529
+ mask = flattened_mask.view(batch_size, height, width)
1530
+
1531
+ # FIXME h_boxes takes the last one computed, keep this in mind
1532
+ # important: we need to reverse the mask, since in the original implementation the mask works reversed
1533
+ # bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
1534
+ bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)
1535
+
1536
+ seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]])
1537
+
1538
+ pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
1539
+
1540
+ loss, loss_dict, auxiliary_outputs = None, None, None
1541
+ if labels is not None:
1542
+ outputs_class, outputs_coord = None, None
1543
+ if self.config.auxiliary_loss:
1544
+ intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
1545
+ outputs_class = self.detr.class_labels_classifier(intermediate)
1546
+ outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
1547
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
1548
+ logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
1549
+ )
1550
+
1551
+ if not return_dict:
1552
+ if auxiliary_outputs is not None:
1553
+ output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
1554
+ else:
1555
+ output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
1556
+ return ((loss, loss_dict) + output) if loss is not None else output
1557
+
1558
+ return DetrSegmentationOutput(
1559
+ loss=loss,
1560
+ loss_dict=loss_dict,
1561
+ logits=logits,
1562
+ pred_boxes=pred_boxes,
1563
+ pred_masks=pred_masks,
1564
+ auxiliary_outputs=auxiliary_outputs,
1565
+ last_hidden_state=decoder_outputs.last_hidden_state,
1566
+ decoder_hidden_states=decoder_outputs.hidden_states,
1567
+ decoder_attentions=decoder_outputs.attentions,
1568
+ cross_attentions=decoder_outputs.cross_attentions,
1569
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1570
+ encoder_hidden_states=encoder_outputs.hidden_states,
1571
+ encoder_attentions=encoder_outputs.attentions,
1572
+ )
1573
+
1574
+
1575
+ def _expand(tensor, length: int):
1576
+ return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
1577
+
1578
+
1579
+ # taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
1580
+ class DetrMaskHeadSmallConv(nn.Module):
1581
+ """
1582
+ Simple convolutional head, using group norm. Upsampling is done using a FPN approach
1583
+ """
1584
+
1585
+ def __init__(self, dim, fpn_dims, context_dim):
1586
+ super().__init__()
1587
+
1588
+ if dim % 8 != 0:
1589
+ raise ValueError(
1590
+ "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
1591
+ " GroupNorm is set to 8"
1592
+ )
1593
+
1594
+ inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
1595
+
1596
+ self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
1597
+ self.gn1 = nn.GroupNorm(8, dim)
1598
+ self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
1599
+ self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
1600
+ self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
1601
+ self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
1602
+ self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
1603
+ self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
1604
+ self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
1605
+ self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
1606
+ self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
1607
+
1608
+ self.dim = dim
1609
+
1610
+ self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
1611
+ self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
1612
+ self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
1613
+
1614
+ for m in self.modules():
1615
+ if isinstance(m, nn.Conv2d):
1616
+ nn.init.kaiming_uniform_(m.weight, a=1)
1617
+ nn.init.constant_(m.bias, 0)
1618
+
1619
+ def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]):
1620
+ # here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with
1621
+ # the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
1622
+ # We expand the projected feature map to match the number of heads.
1623
+ x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
1624
+
1625
+ x = self.lay1(x)
1626
+ x = self.gn1(x)
1627
+ x = nn.functional.relu(x)
1628
+ x = self.lay2(x)
1629
+ x = self.gn2(x)
1630
+ x = nn.functional.relu(x)
1631
+
1632
+ cur_fpn = self.adapter1(fpns[0])
1633
+ if cur_fpn.size(0) != x.size(0):
1634
+ cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1635
+ x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1636
+ x = self.lay3(x)
1637
+ x = self.gn3(x)
1638
+ x = nn.functional.relu(x)
1639
+
1640
+ cur_fpn = self.adapter2(fpns[1])
1641
+ if cur_fpn.size(0) != x.size(0):
1642
+ cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1643
+ x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1644
+ x = self.lay4(x)
1645
+ x = self.gn4(x)
1646
+ x = nn.functional.relu(x)
1647
+
1648
+ cur_fpn = self.adapter3(fpns[2])
1649
+ if cur_fpn.size(0) != x.size(0):
1650
+ cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1651
+ x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1652
+ x = self.lay5(x)
1653
+ x = self.gn5(x)
1654
+ x = nn.functional.relu(x)
1655
+
1656
+ x = self.out_lay(x)
1657
+ return x
1658
+
1659
+
1660
+ class DetrMHAttentionMap(nn.Module):
1661
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
1662
+
1663
+ def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
1664
+ super().__init__()
1665
+ self.num_heads = num_heads
1666
+ self.hidden_dim = hidden_dim
1667
+ self.dropout = nn.Dropout(dropout)
1668
+
1669
+ self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1670
+ self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1671
+
1672
+ self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
1673
+
1674
+ def forward(self, q, k, mask: Optional[Tensor] = None):
1675
+ q = self.q_linear(q)
1676
+ k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
1677
+ queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
1678
+ keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
1679
+ weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
1680
+
1681
+ if mask is not None:
1682
+ weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
1683
+ weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
1684
+ weights = self.dropout(weights)
1685
+ return weights
1686
+
1687
+
1688
+ __all__ = [
1689
+ "DetrForObjectDetection",
1690
+ "DetrForSegmentation",
1691
+ "DetrModel",
1692
+ "DetrPreTrainedModel",
1693
+ ]
phivenv/Lib/site-packages/transformers/models/dia/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_dia import *
22
+ from .feature_extraction_dia import *
23
+ from .generation_dia import *
24
+ from .modeling_dia import *
25
+ from .processing_dia import *
26
+ from .tokenization_dia import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (633 Bytes). View file
 
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/configuration_dia.cpython-39.pyc ADDED
Binary file (18.3 kB). View file
 
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/feature_extraction_dia.cpython-39.pyc ADDED
Binary file (6.62 kB). View file
 
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/generation_dia.cpython-39.pyc ADDED
Binary file (9.73 kB). View file
 
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/modeling_dia.cpython-39.pyc ADDED
Binary file (28 kB). View file
 
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/modular_dia.cpython-39.pyc ADDED
Binary file (21.4 kB). View file
 
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/processing_dia.cpython-39.pyc ADDED
Binary file (12.7 kB). View file
 
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/tokenization_dia.cpython-39.pyc ADDED
Binary file (4.4 kB). View file
 
phivenv/Lib/site-packages/transformers/models/dia/configuration_dia.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Dia model configuration"""
16
+
17
+ from typing import Optional
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...modeling_rope_utils import rope_config_validation
21
+ from ...utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class DiaEncoderConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`DiaEncoder`]. It is used to instantiate a Dia
30
+ encoder according to the specified arguments, defining the encoder architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
37
+ The maximum sequence length that this model might ever be used with.
38
+ num_hidden_layers (`int`, *optional*, defaults to 12):
39
+ Number of hidden layers in the Transformer encoder.
40
+ hidden_size (`int`, *optional*, defaults to 1024):
41
+ Dimensionality of the encoder layers and the pooler layer.
42
+ num_attention_heads (`int`, *optional*, defaults to 16):
43
+ Number of attention heads for each attention layer in the Transformer encoder.
44
+ num_key_value_heads (`int`, *optional*, defaults to 16):
45
+ Number of key and value heads for each attention layer in the Transformer encoder.
46
+ head_dim (`int`, *optional*, defaults to 128):
47
+ Dimensionality of the attention head.
48
+ intermediate_size (`int`, *optional*, defaults to 4096):
49
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
50
+ norm_eps (`float`, *optional*, defaults to 1e-05):
51
+ The epsilon used by the normalization layers.
52
+ vocab_size (`int`, *optional*, defaults to 256):
53
+ Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
54
+ `inputs_ids` passed when calling [`DiaModel`].
55
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
56
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
57
+ `"relu"`, `"swish"` and `"gelu_new"` are supported.
58
+ rope_theta (`float`, *optional*, defaults to 10000.0):
59
+ The base period of the RoPE embeddings.
60
+ rope_scaling (`dict`, *optional*):
61
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
62
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
63
+ accordingly.
64
+ Expected contents:
65
+ `rope_type` (`str`):
66
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
67
+ 'llama3'], with 'default' being the original RoPE implementation.
68
+ `factor` (`float`, *optional*):
69
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
70
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
71
+ original maximum pre-trained length.
72
+ `original_max_position_embeddings` (`int`, *optional*):
73
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
74
+ pretraining.
75
+ `attention_factor` (`float`, *optional*):
76
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
77
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
78
+ `factor` field to infer the suggested value.
79
+ `beta_fast` (`float`, *optional*):
80
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
81
+ ramp function. If unspecified, it defaults to 32.
82
+ `beta_slow` (`float`, *optional*):
83
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
84
+ ramp function. If unspecified, it defaults to 1.
85
+ `short_factor` (`List[float]`, *optional*):
86
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
87
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
88
+ size divided by the number of attention heads divided by 2
89
+ `long_factor` (`List[float]`, *optional*):
90
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
91
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
92
+ size divided by the number of attention heads divided by 2
93
+ `low_freq_factor` (`float`, *optional*):
94
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
95
+ `high_freq_factor` (`float`, *optional*):
96
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
97
+ initializer_range (`float`, *optional*, defaults to 0.02):
98
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
99
+ """
100
+
101
+ model_type = "dia_encoder"
102
+
103
+ def __init__(
104
+ self,
105
+ max_position_embeddings: int = 1024,
106
+ num_hidden_layers: int = 12,
107
+ hidden_size: int = 1024,
108
+ num_attention_heads: int = 16,
109
+ num_key_value_heads: int = 16,
110
+ head_dim: int = 128,
111
+ intermediate_size: int = 4096,
112
+ norm_eps: float = 1e-5,
113
+ vocab_size: int = 256,
114
+ hidden_act: str = "silu",
115
+ rope_theta: float = 10000.0,
116
+ rope_scaling: Optional[dict] = None,
117
+ initializer_range: float = 0.02,
118
+ **kwargs,
119
+ ):
120
+ self.max_position_embeddings = max_position_embeddings
121
+ self.num_hidden_layers = num_hidden_layers
122
+ self.hidden_size = hidden_size
123
+ self.intermediate_size = intermediate_size
124
+ self.num_attention_heads = num_attention_heads
125
+ self.head_dim = head_dim
126
+ self.norm_eps = norm_eps
127
+ self.vocab_size = vocab_size
128
+ self.num_key_value_heads = num_key_value_heads
129
+ self.hidden_act = hidden_act
130
+ self.rope_theta = rope_theta
131
+ self.rope_scaling = rope_scaling
132
+ # Validate the correctness of rotary position embeddings parameters
133
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
134
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
135
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
136
+ rope_config_validation(self)
137
+ self.initializer_range = initializer_range
138
+ super().__init__(**kwargs)
139
+
140
+
141
+ class DiaDecoderConfig(PretrainedConfig):
142
+ r"""
143
+ This is the configuration class to store the configuration of a [`DiaDecoder`]. It is used to instantiate a Dia
144
+ decoder according to the specified arguments, defining the decoder architecture.
145
+
146
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
147
+ documentation from [`PretrainedConfig`] for more information.
148
+
149
+ Args:
150
+ max_position_embeddings (`int`, *optional*, defaults to 3072):
151
+ The maximum sequence length that this model might ever be used with.
152
+ num_hidden_layers (`int`, *optional*, defaults to 18):
153
+ Number of hidden layers in the Transformer decoder.
154
+ hidden_size (`int`, *optional*, defaults to 2048):
155
+ Dimensionality of the decoder layers and the pooler layer.
156
+ intermediate_size (`int`, *optional*, defaults to 8192):
157
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder.
158
+ num_attention_heads (`int`, *optional*, defaults to 16):
159
+ Number of attention heads for each attention layer in the Transformer decoder.
160
+ num_key_value_heads (`int`, *optional*, defaults to 4):
161
+ Number of key and value heads for each attention layer in the Transformer decoder.
162
+ head_dim (`int`, *optional*, defaults to 128):
163
+ Dimensionality of the attention head.
164
+ cross_num_attention_heads (`int`, *optional*, defaults to 16):
165
+ Number of attention heads for each cross-attention layer in the Transformer decoder.
166
+ cross_head_dim (`int`, *optional*, defaults to 128):
167
+ Dimensionality of the cross-attention head.
168
+ cross_num_key_value_heads (`int`, *optional*, defaults to 16):
169
+ Number of key and value heads for each cross-attention layer in the Transformer decoder.
170
+ cross_hidden_size (`int`, *optional*, defaults to 1024):
171
+ Dimensionality of the cross-attention layers.
172
+ norm_eps (`float`, *optional*, defaults to 1e-05):
173
+ The epsilon used by the normalization layers.
174
+ vocab_size (`int`, *optional*, defaults to 1028):
175
+ Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
176
+ `inputs_ids` passed when calling [`DiaModel`].
177
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
178
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, `"relu"`,
179
+ `"swish"` and `"gelu_new"` are supported.
180
+ num_channels (`int`, *optional*, defaults to 9):
181
+ Number of channels for the Dia decoder.
182
+ rope_theta (`float`, *optional*, defaults to 10000.0):
183
+ The base period of the RoPE embeddings.
184
+ rope_scaling (`dict`, *optional*):
185
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
186
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
187
+ accordingly.
188
+ Expected contents:
189
+ `rope_type` (`str`):
190
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
191
+ 'llama3'], with 'default' being the original RoPE implementation.
192
+ `factor` (`float`, *optional*):
193
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
194
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
195
+ original maximum pre-trained length.
196
+ `original_max_position_embeddings` (`int`, *optional*):
197
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
198
+ pretraining.
199
+ `attention_factor` (`float`, *optional*):
200
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
201
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
202
+ `factor` field to infer the suggested value.
203
+ `beta_fast` (`float`, *optional*):
204
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
205
+ ramp function. If unspecified, it defaults to 32.
206
+ `beta_slow` (`float`, *optional*):
207
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
208
+ ramp function. If unspecified, it defaults to 1.
209
+ `short_factor` (`List[float]`, *optional*):
210
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
211
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
212
+ size divided by the number of attention heads divided by 2
213
+ `long_factor` (`List[float]`, *optional*):
214
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
215
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
216
+ size divided by the number of attention heads divided by 2
217
+ `low_freq_factor` (`float`, *optional*):
218
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
219
+ `high_freq_factor` (`float`, *optional*):
220
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
221
+ initializer_range (`float`, *optional*, defaults to 0.02):
222
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
223
+ use_cache (`bool`, *optional*, defaults to `True`):
224
+ Whether or not the model should return the last key/values attentions (not used by all models).
225
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
226
+ Indicating that this model is part of an encoder-decoder architecture.
227
+ """
228
+
229
+ model_type = "dia_decoder"
230
+
231
+ def __init__(
232
+ self,
233
+ max_position_embeddings: int = 3072,
234
+ num_hidden_layers: int = 18,
235
+ hidden_size: int = 2048,
236
+ intermediate_size: int = 8192,
237
+ num_attention_heads: int = 16,
238
+ num_key_value_heads: int = 4,
239
+ head_dim: int = 128,
240
+ cross_num_attention_heads: int = 16,
241
+ cross_head_dim: int = 128,
242
+ cross_num_key_value_heads: int = 16,
243
+ cross_hidden_size: int = 1024,
244
+ norm_eps: float = 1e-5,
245
+ vocab_size: int = 1028,
246
+ hidden_act: str = "silu",
247
+ num_channels: int = 9,
248
+ rope_theta: float = 10000.0,
249
+ rope_scaling: Optional[dict] = None,
250
+ initializer_range: float = 0.02,
251
+ use_cache: bool = True,
252
+ is_encoder_decoder: bool = True,
253
+ **kwargs,
254
+ ):
255
+ self.max_position_embeddings = max_position_embeddings
256
+ self.num_hidden_layers = num_hidden_layers
257
+ self.hidden_size = hidden_size
258
+ self.intermediate_size = intermediate_size
259
+ self.num_attention_heads = num_attention_heads
260
+ self.num_key_value_heads = num_key_value_heads
261
+ self.head_dim = head_dim
262
+ self.cross_num_key_value_heads = cross_num_key_value_heads
263
+ self.cross_num_attention_heads = cross_num_attention_heads
264
+ self.cross_head_dim = cross_head_dim
265
+ self.cross_hidden_size = cross_hidden_size
266
+ self.norm_eps = norm_eps
267
+ self.vocab_size = vocab_size
268
+ self.hidden_act = hidden_act
269
+ self.num_channels = num_channels
270
+ self.rope_theta = rope_theta
271
+ self.rope_scaling = rope_scaling
272
+ # Validate the correctness of rotary position embeddings parameters
273
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
274
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
275
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
276
+ rope_config_validation(self)
277
+ self.initializer_range = initializer_range
278
+ self.use_cache = use_cache
279
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
280
+
281
+
282
+ class DiaConfig(PretrainedConfig):
283
+ r"""
284
+ This is the configuration class to store the configuration of a [`DiaModel`]. It is used to instantiate a
285
+ Dia model according to the specified arguments, defining the model architecture. Instantiating a configuration
286
+ with the defaults will yield a similar configuration to that of the
287
+ [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) architecture.
288
+
289
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
290
+ documentation from [`PretrainedConfig`] for more information.
291
+
292
+ Args:
293
+ encoder_config (`DiaEncoderConfig`, *optional*):
294
+ Configuration for the encoder part of the model. If not provided, a default `DiaEncoderConfig` will be used.
295
+ decoder_config (`DiaDecoderConfig`, *optional*):
296
+ Configuration for the decoder part of the model. If not provided, a default `DiaDecoderConfig` will be used.
297
+ norm_eps (`float`, *optional*, defaults to 1e-05):
298
+ The epsilon used by the normalization layers.
299
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
300
+ Indicating that this model uses an encoder-decoder architecture.
301
+ pad_token_id (`int`, *optional*, defaults to 1025):
302
+ Padding token id.
303
+ eos_token_id (`int`, *optional*, defaults to 1024):
304
+ End of stream token id.
305
+ bos_token_id (`int`, *optional*, defaults to 1026):
306
+ Beginning of stream token id.
307
+ delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`):
308
+ The delay pattern for the decoder. The length of this list must match `decoder_config.num_channels`.
309
+ initializer_range (`float`, *optional*, defaults to 0.02):
310
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
311
+ use_cache (`bool`, *optional*, defaults to `True`):
312
+ Whether or not the model should return the last key/values attentions (not used by all models).
313
+
314
+ Example:
315
+
316
+ ```python
317
+ >>> from transformers import DiaConfig, DiaModel
318
+
319
+ >>> # Initializing a DiaConfig with default values
320
+ >>> configuration = DiaConfig()
321
+
322
+ >>> # Initializing a DiaModel (with random weights) from the configuration
323
+ >>> model = DiaModel(configuration)
324
+
325
+ >>> # Accessing the model configuration
326
+ >>> configuration = model.config
327
+ ```
328
+ """
329
+
330
+ model_type = "dia"
331
+ keys_to_ignore_at_inference = ["past_key_values"]
332
+ sub_configs = {"encoder_config": DiaEncoderConfig, "decoder_config": DiaDecoderConfig}
333
+
334
+ def __init__(
335
+ self,
336
+ encoder_config: Optional[DiaEncoderConfig] = None,
337
+ decoder_config: Optional[DiaDecoderConfig] = None,
338
+ norm_eps: float = 1e-5,
339
+ is_encoder_decoder: bool = True,
340
+ pad_token_id: int = 1025,
341
+ eos_token_id: int = 1024,
342
+ bos_token_id: int = 1026,
343
+ delay_pattern: Optional[list[int]] = None,
344
+ initializer_range: float = 0.02,
345
+ use_cache: bool = True,
346
+ **kwargs,
347
+ ):
348
+ if isinstance(encoder_config, dict):
349
+ encoder_config = DiaEncoderConfig(**encoder_config)
350
+ if isinstance(decoder_config, dict):
351
+ decoder_config = DiaDecoderConfig(**decoder_config)
352
+ self.encoder_config = encoder_config if encoder_config is not None else DiaEncoderConfig()
353
+ self.decoder_config = decoder_config if decoder_config is not None else DiaDecoderConfig()
354
+ self.norm_eps = norm_eps
355
+ self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 8, 9, 10, 11, 12, 13, 14, 15]
356
+ self.initializer_range = initializer_range
357
+ self.use_cache = use_cache
358
+
359
+ assert self.decoder_config.num_channels == len(self.delay_pattern), (
360
+ "Number of channels must match delay pattern length."
361
+ )
362
+
363
+ super().__init__(
364
+ pad_token_id=pad_token_id,
365
+ eos_token_id=eos_token_id,
366
+ bos_token_id=bos_token_id,
367
+ is_encoder_decoder=is_encoder_decoder,
368
+ **kwargs,
369
+ )
370
+
371
+ def get_text_config(self, *args, **kwargs):
372
+ """Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
373
+ return self.decoder_config
374
+
375
+
376
+ __all__ = ["DiaConfig", "DiaEncoderConfig", "DiaDecoderConfig"]
phivenv/Lib/site-packages/transformers/models/dia/feature_extraction_dia.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for Dia"""
16
+
17
+ from typing import Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
22
+ from ...feature_extraction_utils import BatchFeature
23
+ from ...utils import PaddingStrategy, TensorType, logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class DiaFeatureExtractor(SequenceFeatureExtractor):
30
+ r"""
31
+ Constructs an Dia feature extractor.
32
+
33
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
34
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
35
+
36
+ Args:
37
+ feature_size (`int`, *optional*, defaults to 1):
38
+ The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
39
+ sampling_rate (`int`, *optional*, defaults to 16000):
40
+ The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
41
+ padding_value (`float`, *optional*, defaults to 0.0):
42
+ The value that is used for padding.
43
+ hop_length (`int`, *optional*, defaults to 512):
44
+ Overlap length between successive windows.
45
+ """
46
+
47
+ model_input_names = ["input_values", "n_quantizers"]
48
+
49
+ def __init__(
50
+ self,
51
+ feature_size: int = 1,
52
+ sampling_rate: int = 16000,
53
+ padding_value: float = 0.0,
54
+ hop_length: int = 512,
55
+ **kwargs,
56
+ ):
57
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
58
+ self.hop_length = hop_length
59
+
60
+ def __call__(
61
+ self,
62
+ raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
63
+ padding: Optional[Union[bool, str, PaddingStrategy]] = None,
64
+ truncation: Optional[bool] = False,
65
+ max_length: Optional[int] = None,
66
+ return_tensors: Optional[Union[str, TensorType]] = None,
67
+ sampling_rate: Optional[int] = None,
68
+ ) -> BatchFeature:
69
+ """
70
+ Main method to featurize and prepare for the model one or several sequence(s).
71
+
72
+ Args:
73
+ raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
74
+ The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
75
+ values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
76
+ `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
77
+ (`feature_size = 2`).
78
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
79
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
80
+ index) among:
81
+
82
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
83
+ sequence if provided).
84
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
85
+ acceptable input length for the model if that argument is not provided.
86
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
87
+ lengths).
88
+ truncation (`bool`, *optional*, defaults to `False`):
89
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
90
+ max_length (`int`, *optional*):
91
+ Maximum length of the returned list and optionally padding length (see above).
92
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
93
+ If set, will return tensors instead of list of python integers. Acceptable values are:
94
+
95
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
96
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
97
+ - `'np'`: Return Numpy `np.ndarray` objects.
98
+ sampling_rate (`int`, *optional*):
99
+ The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
100
+ `sampling_rate` at the forward call to prevent silent errors.
101
+ """
102
+ if sampling_rate is not None:
103
+ if sampling_rate != self.sampling_rate:
104
+ raise ValueError(
105
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
106
+ f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
107
+ f" {self.sampling_rate} and not {sampling_rate}."
108
+ )
109
+ else:
110
+ logger.warning(
111
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
112
+ "Failing to do so can result in silent errors that might be hard to debug."
113
+ )
114
+
115
+ if padding and truncation:
116
+ raise ValueError("Both padding and truncation were set. Make sure you only set one.")
117
+ elif padding is None:
118
+ # by default let's pad the inputs
119
+ padding = True
120
+
121
+ is_batched = bool(
122
+ isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
123
+ )
124
+
125
+ if is_batched:
126
+ raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
127
+ elif not is_batched and not isinstance(raw_audio, np.ndarray):
128
+ raw_audio = np.asarray(raw_audio, dtype=np.float32)
129
+ elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
130
+ raw_audio = raw_audio.astype(np.float32)
131
+
132
+ # always return batch
133
+ if not is_batched:
134
+ raw_audio = [np.asarray(raw_audio).T]
135
+
136
+ # convert stereo to mono if necessary, unique to Dia
137
+ for idx, example in enumerate(raw_audio):
138
+ if self.feature_size == 2 and example.ndim == 2:
139
+ raw_audio[idx] = np.mean(example, -1)
140
+
141
+ # verify inputs are valid
142
+ for idx, example in enumerate(raw_audio):
143
+ if example.ndim > 2:
144
+ raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
145
+ if self.feature_size == 1 and example.ndim != 1:
146
+ raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
147
+ if self.feature_size == 2 and example.ndim != 1: # note the conversion before
148
+ raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
149
+
150
+ input_values = BatchFeature({"input_values": raw_audio})
151
+
152
+ # temporarily treat it as if we were mono as we also convert stereo to mono
153
+ origingal_feature_size = self.feature_size
154
+ self.feature_size = 1
155
+
156
+ # normal padding on batch
157
+ padded_inputs = self.pad(
158
+ input_values,
159
+ max_length=max_length,
160
+ truncation=truncation,
161
+ padding=padding,
162
+ return_attention_mask=True,
163
+ pad_to_multiple_of=self.hop_length,
164
+ )
165
+ padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
166
+
167
+ input_values = []
168
+ for example in padded_inputs.pop("input_values"):
169
+ if self.feature_size == 1:
170
+ example = example[..., None]
171
+ input_values.append(example.T)
172
+
173
+ padded_inputs["input_values"] = input_values
174
+ if return_tensors is not None:
175
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
176
+
177
+ # rewrite back to original feature size
178
+ self.feature_size = origingal_feature_size
179
+
180
+ return padded_inputs
181
+
182
+
183
+ __all__ = ["DiaFeatureExtractor"]
phivenv/Lib/site-packages/transformers/models/dia/generation_dia.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Callable, Optional, Union
17
+
18
+ import torch
19
+ import torch.distributed as dist
20
+
21
+ from ...generation.logits_process import (
22
+ DiaClassifierFreeGuidanceLogitsProcessor,
23
+ DiaEOSChannelFilterLogitsProcessor,
24
+ DiaEOSDelayPatternLogitsProcessor,
25
+ LogitsProcessorList,
26
+ TemperatureLogitsWarper,
27
+ )
28
+ from ...generation.stopping_criteria import StoppingCriteriaList
29
+ from ...generation.streamers import BaseStreamer
30
+ from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode
31
+ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
32
+ from ...integrations.fsdp import is_fsdp_managed_module
33
+ from ...modeling_utils import PreTrainedModel
34
+ from ...utils import logging
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ class DiaGenerationMixin(GenerationMixin):
41
+ # Indicates CFG which needs preparation to be properly handled by repeats
42
+ _uses_cfg = None
43
+
44
+ def _get_logits_processor(
45
+ self,
46
+ generation_config: GenerationConfig,
47
+ input_ids_seq_length: Optional[int] = None,
48
+ encoder_input_ids: torch.LongTensor = None,
49
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
50
+ logits_processor: Optional[LogitsProcessorList] = None,
51
+ device: Optional[str] = None,
52
+ model_kwargs: Optional[dict[str, Any]] = None,
53
+ negative_prompt_ids: Optional[torch.Tensor] = None,
54
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
55
+ ) -> LogitsProcessorList:
56
+ # Need either custom order or custom processor instead
57
+ # (Temporarily disabling those for the super function)
58
+ original_guidance_scale = generation_config.guidance_scale
59
+ original_temperature = generation_config.temperature
60
+ generation_config.guidance_scale = None
61
+ generation_config.temperature = None
62
+
63
+ # Get base processors and those we can integrate easily
64
+ custom_processors = LogitsProcessorList()
65
+
66
+ if original_temperature is not None and original_temperature != 1.0:
67
+ custom_processors.append(TemperatureLogitsWarper(original_temperature))
68
+
69
+ custom_processors.append(
70
+ DiaEOSChannelFilterLogitsProcessor(
71
+ num_channels=len(self.config.delay_pattern),
72
+ eos_token_id=self.config.eos_token_id,
73
+ )
74
+ )
75
+
76
+ merged_processors = super()._get_logits_processor(
77
+ generation_config=generation_config,
78
+ input_ids_seq_length=input_ids_seq_length,
79
+ encoder_input_ids=encoder_input_ids,
80
+ prefix_allowed_tokens_fn=None,
81
+ logits_processor=custom_processors,
82
+ device=device,
83
+ model_kwargs=model_kwargs,
84
+ negative_prompt_ids=negative_prompt_ids,
85
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
86
+ )
87
+
88
+ # Custom processors we need at specific positions
89
+ if original_guidance_scale is not None and original_guidance_scale != 1:
90
+ cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor(
91
+ guidance_scale=original_guidance_scale,
92
+ guidance_top_k=generation_config.top_k,
93
+ )
94
+ merged_processors.insert(0, cfg_processor)
95
+
96
+ merged_processors.append(
97
+ DiaEOSDelayPatternLogitsProcessor(
98
+ delay_pattern=self.config.delay_pattern,
99
+ eos_token_id=self.config.eos_token_id,
100
+ max_generation_len=generation_config.max_length,
101
+ device=device,
102
+ )
103
+ )
104
+
105
+ # Enable temporarily disabled values back
106
+ generation_config.guidance_scale = original_guidance_scale
107
+ generation_config.temperature = original_temperature
108
+
109
+ return merged_processors
110
+
111
+ def _prepare_generation_config(
112
+ self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict
113
+ ) -> tuple[GenerationConfig, dict]:
114
+ generation_config, model_kwargs = super()._prepare_generation_config(
115
+ generation_config, use_model_defaults, **kwargs
116
+ )
117
+
118
+ # We allow generation up to max length + max delay pattern
119
+ # (will revert back to max length after generation)
120
+ generation_config.max_length += max(self.config.delay_pattern)
121
+
122
+ # Internal flag to indicate CFG that needs to prepare unconditioned input
123
+ self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1
124
+
125
+ return generation_config, model_kwargs
126
+
127
+ def _prepare_model_inputs(
128
+ self,
129
+ inputs: Optional[torch.Tensor] = None,
130
+ bos_token_id: Optional[torch.Tensor] = None,
131
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
132
+ ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
133
+ inputs, input_name, model_kwargs = super()._prepare_model_inputs(
134
+ inputs=inputs,
135
+ bos_token_id=bos_token_id,
136
+ model_kwargs=model_kwargs,
137
+ )
138
+
139
+ # If CFG is requested we fill in the unconditioned parts
140
+ if self._uses_cfg:
141
+ unconditioned_inputs = torch.zeros_like(inputs)
142
+ inputs = torch.cat([inputs, unconditioned_inputs], dim=0)
143
+
144
+ if model_kwargs.get("attention_mask", None) is not None:
145
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1)
146
+
147
+ return inputs, input_name, model_kwargs
148
+
149
+ def _prepare_decoder_input_ids_for_generation(
150
+ self,
151
+ batch_size: int,
152
+ model_input_name: str,
153
+ model_kwargs: dict[str, torch.Tensor],
154
+ decoder_start_token_id: torch.Tensor,
155
+ device: Optional[torch.device] = None,
156
+ ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
157
+ """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
158
+ # 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out
159
+ decoder_input_ids = decoder_attention_mask = None
160
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
161
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
162
+ if model_kwargs is not None and "decoder_attention_mask" in model_kwargs:
163
+ decoder_attention_mask = model_kwargs.pop("decoder_attention_mask")
164
+
165
+ # We allow generating without preparation (no proper delay) but discourage it
166
+ if decoder_input_ids is None or decoder_attention_mask is None:
167
+ logger.warning_once(
168
+ "In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:"
169
+ f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}."
170
+ f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation."
171
+ )
172
+
173
+ num_channels = self.config.decoder_config.num_channels
174
+ real_batch_size = batch_size // 2 if self._uses_cfg else batch_size
175
+
176
+ if decoder_input_ids is None:
177
+ decoder_input_ids = torch.full(
178
+ (real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device
179
+ )
180
+
181
+ decoder_attention_mask = torch.ones(
182
+ size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device
183
+ )
184
+
185
+ # 2. Determine the valid input and what works as mask within the input
186
+ delay_mask = decoder_input_ids.long()
187
+ valid_input_size = (
188
+ decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max()
189
+ )
190
+ decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long()
191
+ decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long()
192
+
193
+ # 3. Overwrite into model kwargs
194
+ model_kwargs["decoder_attention_mask"] = decoder_attention_mask
195
+ model_kwargs["decoder_delay_mask"] = delay_mask
196
+
197
+ return decoder_input_ids, model_kwargs
198
+
199
+ def prepare_inputs_for_generation(
200
+ self,
201
+ input_ids,
202
+ encoder_outputs=None, # Using this to easily get the batch size
203
+ decoder_delay_mask=None,
204
+ **kwargs,
205
+ ):
206
+ # Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape
207
+ batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0]
208
+ input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2)
209
+
210
+ # Base method handles most things except CFG and the delay pattern mask
211
+ model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs)
212
+
213
+ # Post processing for CFG and overwriting via delay pattern mask
214
+ # 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask)
215
+ model_inputs["decoder_input_ids"] = self.apply_delay_mask(
216
+ input_ids, self.config.pad_token_id, decoder_delay_mask
217
+ )
218
+
219
+ # Depending on cache usage we need to pass all or just one
220
+ if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0:
221
+ model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :]
222
+
223
+ # Be compile friendly
224
+ model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous()
225
+
226
+ # 2. Apply CFG duplication if needed
227
+ if self._uses_cfg:
228
+ for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]:
229
+ if model_inputs.get(key, None) is not None:
230
+ # double first dimension and keep everything else the same
231
+ repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1))
232
+ model_inputs[key] = model_inputs[key].repeat(*repeat_pattern)
233
+
234
+ return model_inputs
235
+
236
+ @staticmethod
237
+ def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor:
238
+ if delay_mask is None:
239
+ return input_ids
240
+
241
+ mask_len = min(input_ids.shape[1], delay_mask.shape[1])
242
+ valid_mask = delay_mask[:, :mask_len, :]
243
+ valid_input = input_ids[:, :mask_len, :]
244
+
245
+ # Overwrite the respective parts of the input
246
+ input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask)
247
+
248
+ return input_ids
249
+
250
+ def _main_generate_loop(
251
+ self,
252
+ inputs: Optional[torch.Tensor] = None,
253
+ generation_config: Optional[GenerationConfig] = None,
254
+ logits_processor: Optional[LogitsProcessorList] = None,
255
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
256
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
257
+ synced_gpus: Optional[bool] = None,
258
+ assistant_model: Optional["PreTrainedModel"] = None,
259
+ streamer: Optional["BaseStreamer"] = None,
260
+ negative_prompt_ids: Optional[torch.Tensor] = None,
261
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
262
+ use_model_defaults: Optional[bool] = None,
263
+ custom_generate: Optional[str] = None,
264
+ **kwargs,
265
+ ):
266
+ # ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
267
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
268
+ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
269
+ assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
270
+
271
+ generation_config, model_kwargs = self._prepare_generation_config(
272
+ generation_config, use_model_defaults, **kwargs
273
+ )
274
+ self._validate_model_kwargs(model_kwargs.copy())
275
+ self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
276
+
277
+ # 2. Set generation parameters if not already defined
278
+ if synced_gpus is None:
279
+ synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
280
+
281
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
282
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
283
+
284
+ # 3. Define model inputs
285
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
286
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
287
+ inputs, generation_config.bos_token_id, model_kwargs
288
+ )
289
+ batch_size = inputs_tensor.shape[0]
290
+
291
+ device = inputs_tensor.device
292
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
293
+
294
+ # 4. Define other model kwargs
295
+ if "encoder_outputs" not in model_kwargs:
296
+ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
297
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
298
+ inputs_tensor, model_kwargs, model_input_name, generation_config
299
+ )
300
+
301
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
302
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
303
+ batch_size=batch_size,
304
+ model_input_name=model_input_name,
305
+ model_kwargs=model_kwargs,
306
+ decoder_start_token_id=generation_config._decoder_start_token_tensor,
307
+ device=inputs_tensor.device,
308
+ )
309
+
310
+ if generation_config.token_healing:
311
+ input_ids = self.heal_tokens(input_ids, tokenizer)
312
+
313
+ if streamer is not None:
314
+ streamer.put(input_ids.cpu())
315
+
316
+ # 6. Prepare `max_length` depending on other stopping criteria.
317
+ # NOTE: incorrect `input_ids.shape[1]` previously
318
+ input_ids_length = input_ids.shape[-1]
319
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
320
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
321
+ generation_config = self._prepare_generated_length(
322
+ generation_config=generation_config,
323
+ has_default_max_length=has_default_max_length,
324
+ has_default_min_length=has_default_min_length,
325
+ model_input_name=model_input_name,
326
+ inputs_tensor=inputs_tensor,
327
+ input_ids_length=input_ids_length,
328
+ )
329
+
330
+ # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
331
+ # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
332
+ # dynamically overrides this value as it can need more than the last token logits
333
+ if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
334
+ model_kwargs["logits_to_keep"] = 1
335
+
336
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
337
+
338
+ # 7. Prepare the cache.
339
+ # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
340
+ # - different models have a different cache name expected by the model (default = "past_key_values")
341
+ # - `max_length`, prepared above, is used to determine the maximum cache length
342
+ max_cache_length = generation_config.max_length - 1
343
+ if (
344
+ inputs_tensor.shape[1] != input_ids_length
345
+ and model_input_name == "inputs_embeds"
346
+ and not self.config.is_encoder_decoder
347
+ ):
348
+ max_cache_length += inputs_tensor.shape[1]
349
+ self._prepare_cache_for_generation(
350
+ generation_config, model_kwargs, assistant_model, batch_size, max_cache_length
351
+ )
352
+
353
+ # 8. determine generation mode
354
+ generation_mode = generation_config.get_generation_mode(assistant_model)
355
+
356
+ if streamer is not None and (generation_config.num_beams > 1):
357
+ raise ValueError(
358
+ "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
359
+ )
360
+
361
+ # 9. prepare logits processors and stopping criteria
362
+ prepared_logits_processor = self._get_logits_processor(
363
+ generation_config=generation_config,
364
+ input_ids_seq_length=input_ids_length,
365
+ encoder_input_ids=inputs_tensor,
366
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
367
+ logits_processor=logits_processor,
368
+ device=inputs_tensor.device,
369
+ model_kwargs=model_kwargs,
370
+ negative_prompt_ids=negative_prompt_ids,
371
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
372
+ )
373
+ prepared_stopping_criteria = self._get_stopping_criteria(
374
+ generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
375
+ )
376
+
377
+ # Set model_kwargs `use_cache` so we can use it later in forward runs
378
+ model_kwargs["use_cache"] = generation_config.use_cache
379
+ # ******************* taken from main generate function up to calling the different methods *******************
380
+
381
+ # Prepare inner 2D logic in generation loop
382
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1])
383
+
384
+ # 10. go into different generation modes
385
+ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
386
+ # 11. expand input_ids with `num_return_sequences` additional sequences per batch
387
+ if generation_config.num_return_sequences > 1:
388
+ raise ValueError("`num_return_sequences>1` is incompatible with Dia.")
389
+
390
+ # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
391
+ return self._sample(
392
+ input_ids,
393
+ logits_processor=prepared_logits_processor,
394
+ stopping_criteria=prepared_stopping_criteria,
395
+ generation_config=generation_config,
396
+ synced_gpus=synced_gpus,
397
+ streamer=streamer,
398
+ **model_kwargs,
399
+ )
400
+ else:
401
+ raise ValueError(
402
+ "Got incompatible mode for generation, should be one of greedy or sampling. "
403
+ "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
404
+ )
405
+
406
+ @torch.no_grad()
407
+ def generate(
408
+ self,
409
+ inputs: Optional[torch.Tensor] = None,
410
+ generation_config: Optional[GenerationConfig] = None,
411
+ logits_processor: Optional[LogitsProcessorList] = None,
412
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
413
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
414
+ synced_gpus: Optional[bool] = None,
415
+ assistant_model: Optional["PreTrainedModel"] = None,
416
+ streamer: Optional["BaseStreamer"] = None,
417
+ negative_prompt_ids: Optional[torch.Tensor] = None,
418
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
419
+ use_model_defaults: Optional[bool] = None,
420
+ custom_generate: Optional[str] = None,
421
+ **kwargs,
422
+ ) -> Union[GenerateOutput, torch.LongTensor]:
423
+ # We expect the initial input ids to be the complete mask (delayed input)
424
+ delay_mask = kwargs.get("decoder_input_ids")
425
+ if delay_mask is not None:
426
+ delay_mask = delay_mask.clone()
427
+
428
+ output = self._main_generate_loop(
429
+ inputs=inputs,
430
+ generation_config=generation_config,
431
+ logits_processor=logits_processor,
432
+ stopping_criteria=stopping_criteria,
433
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
434
+ synced_gpus=synced_gpus,
435
+ assistant_model=assistant_model,
436
+ streamer=streamer,
437
+ negative_prompt_ids=negative_prompt_ids,
438
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
439
+ use_model_defaults=use_model_defaults,
440
+ custom_generate=custom_generate,
441
+ **kwargs,
442
+ )
443
+
444
+ return_dict_in_generate = not isinstance(output, torch.Tensor)
445
+
446
+ if return_dict_in_generate:
447
+ output_sequences = output.sequences
448
+ else:
449
+ output_sequences = output
450
+
451
+ # Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels)
452
+ num_channels = self.config.decoder_config.num_channels
453
+ bsz = output_sequences.shape[0] // num_channels
454
+ output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2)
455
+
456
+ # Apply delay mask
457
+ output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask)
458
+
459
+ if return_dict_in_generate:
460
+ output.sequences = output_sequences
461
+ else:
462
+ output = output_sequences
463
+
464
+ return output
phivenv/Lib/site-packages/transformers/models/dia/modeling_dia.py ADDED
@@ -0,0 +1,958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/dia/modular_dia.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_dia.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Callable, Optional, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ from ...activations import ACT2FN
28
+ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
29
+ from ...integrations import use_kernel_forward_from_hub
30
+ from ...masking_utils import create_causal_mask
31
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
32
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from ...modeling_layers import GradientCheckpointingLayer
34
+ from ...modeling_outputs import (
35
+ BaseModelOutput,
36
+ BaseModelOutputWithPastAndCrossAttentions,
37
+ Seq2SeqLMOutput,
38
+ Seq2SeqModelOutput,
39
+ )
40
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
+ from ...processing_utils import Unpack
43
+ from ...utils import (
44
+ TransformersKwargs,
45
+ auto_docstring,
46
+ can_return_tuple,
47
+ is_torch_flex_attn_available,
48
+ is_torchdynamo_compiling,
49
+ logging,
50
+ )
51
+ from ...utils.deprecation import deprecate_kwarg
52
+ from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
53
+ from .generation_dia import DiaGenerationMixin
54
+
55
+
56
+ if is_torch_flex_attn_available():
57
+ from ...integrations.flex_attention import make_flex_block_causal_mask
58
+
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+
63
+ @auto_docstring
64
+ class DiaPreTrainedModel(PreTrainedModel):
65
+ config: DiaConfig
66
+ base_model_prefix = "model"
67
+ supports_gradient_checkpointing = True
68
+ _supports_flash_attn = True
69
+ _supports_sdpa = True
70
+ _supports_flex_attn = True
71
+ _can_compile_fullgraph = True
72
+ main_input_name = "input_ids"
73
+ _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
74
+
75
+
76
+ class DiaMultiChannelEmbedding(nn.Module):
77
+ """In order to efficiently compute the audio embedding from the 9 different channels,
78
+ we vectorize the embedding process by using a single embedding layer and an offset.
79
+ Example:
80
+ - num_embeds = 4
81
+ - vocab_size = 8
82
+ - num_channels = 3
83
+ We would have offsets = [0, 8, 16]
84
+ If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
85
+ then tokens = audio_codes + offsets
86
+ = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
87
+ This allows us to use a single embedding layer for all channels.
88
+ """
89
+
90
+ def __init__(self, config: DiaDecoderConfig):
91
+ super().__init__()
92
+ self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
93
+ self.hidden_size = config.hidden_size
94
+ self.num_channels = config.num_channels
95
+ offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
96
+ self.register_buffer("offsets", offsets, persistent=False)
97
+
98
+ def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
99
+ tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
100
+ embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
101
+ return embeds.sum(dim=2)
102
+
103
+
104
+ class DiaMLP(nn.Module):
105
+ def __init__(self, config):
106
+ super().__init__()
107
+
108
+ self.config = config
109
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
110
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
111
+ self.activation_fn = ACT2FN[config.hidden_act]
112
+
113
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
114
+ up_states = self.gate_up_proj(hidden_states)
115
+
116
+ gate, up_states = up_states.chunk(2, dim=-1)
117
+ up_states = up_states * self.activation_fn(gate)
118
+
119
+ return self.down_proj(up_states)
120
+
121
+
122
+ @use_kernel_forward_from_hub("RMSNorm")
123
+ class DiaRMSNorm(nn.Module):
124
+ def __init__(self, hidden_size, eps=1e-6):
125
+ """
126
+ DiaRMSNorm is equivalent to T5LayerNorm
127
+ """
128
+ super().__init__()
129
+ self.weight = nn.Parameter(torch.ones(hidden_size))
130
+ self.variance_epsilon = eps
131
+
132
+ def forward(self, hidden_states):
133
+ input_dtype = hidden_states.dtype
134
+ hidden_states = hidden_states.to(torch.float32)
135
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
136
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
137
+ return self.weight * hidden_states.to(input_dtype)
138
+
139
+ def extra_repr(self):
140
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
141
+
142
+
143
+ class DiaRotaryEmbedding(nn.Module):
144
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
145
+
146
+ def __init__(self, config: DiaConfig, device=None):
147
+ super().__init__()
148
+ # BC: "rope_type" was originally "type"
149
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
150
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
151
+ else:
152
+ self.rope_type = "default"
153
+ self.max_seq_len_cached = config.max_position_embeddings
154
+ self.original_max_seq_len = config.max_position_embeddings
155
+
156
+ self.config = config
157
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
158
+
159
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
160
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
161
+ self.original_inv_freq = self.inv_freq
162
+
163
+ @torch.no_grad()
164
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
165
+ def forward(self, x, position_ids):
166
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
167
+ position_ids_expanded = position_ids[:, None, :].float()
168
+
169
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
170
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
171
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
172
+ emb = torch.cat((freqs, freqs), dim=-1)
173
+ cos = emb.cos() * self.attention_scaling
174
+ sin = emb.sin() * self.attention_scaling
175
+
176
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
177
+
178
+
179
+ def rotate_half(x):
180
+ """Rotates half the hidden dims of the input."""
181
+ x1 = x[..., : x.shape[-1] // 2]
182
+ x2 = x[..., x.shape[-1] // 2 :]
183
+ return torch.cat((-x2, x1), dim=-1)
184
+
185
+
186
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
187
+ """Applies Rotary Position Embedding to the query and key tensors.
188
+
189
+ Args:
190
+ q (`torch.Tensor`): The query tensor.
191
+ k (`torch.Tensor`): The key tensor.
192
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
193
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
194
+ position_ids (`torch.Tensor`, *optional*):
195
+ Deprecated and unused.
196
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
197
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
198
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
199
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
200
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
201
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
202
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
203
+ Returns:
204
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
205
+ """
206
+ cos = cos.unsqueeze(unsqueeze_dim)
207
+ sin = sin.unsqueeze(unsqueeze_dim)
208
+ q_embed = (q * cos) + (rotate_half(q) * sin)
209
+ k_embed = (k * cos) + (rotate_half(k) * sin)
210
+ return q_embed, k_embed
211
+
212
+
213
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
214
+ """
215
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
216
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
217
+ """
218
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
219
+ if n_rep == 1:
220
+ return hidden_states
221
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
222
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
223
+
224
+
225
+ def eager_attention_forward(
226
+ module: nn.Module,
227
+ query: torch.Tensor,
228
+ key: torch.Tensor,
229
+ value: torch.Tensor,
230
+ attention_mask: Optional[torch.Tensor],
231
+ scaling: float,
232
+ dropout: float = 0.0,
233
+ **kwargs: Unpack[TransformersKwargs],
234
+ ):
235
+ key_states = repeat_kv(key, module.num_key_value_groups)
236
+ value_states = repeat_kv(value, module.num_key_value_groups)
237
+
238
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
239
+ if attention_mask is not None:
240
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
241
+ attn_weights = attn_weights + causal_mask
242
+
243
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
244
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
245
+ attn_output = torch.matmul(attn_weights, value_states)
246
+ attn_output = attn_output.transpose(1, 2).contiguous()
247
+
248
+ return attn_output, attn_weights
249
+
250
+
251
+ class DiaSelfAttention(nn.Module):
252
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
253
+
254
+ def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
255
+ super().__init__()
256
+ self.config = config
257
+ self.layer_idx = layer_idx
258
+ self.hidden_size = config.hidden_size
259
+ self.num_heads = self.config.num_attention_heads
260
+ self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
261
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
262
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
263
+ self.scaling = 1
264
+ self.attention_dropout = 0.0
265
+ self.is_causal = is_causal
266
+
267
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
268
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
269
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
270
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
271
+
272
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
273
+ def forward(
274
+ self,
275
+ hidden_states: torch.Tensor,
276
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
277
+ attention_mask: Optional[torch.Tensor],
278
+ past_key_values: Optional[Cache] = None,
279
+ cache_position: Optional[torch.LongTensor] = None,
280
+ **kwargs: Unpack[TransformersKwargs],
281
+ ) -> tuple[torch.Tensor, torch.Tensor]:
282
+ input_shape = hidden_states.shape[:-1]
283
+ hidden_shape = (*input_shape, -1, self.head_dim)
284
+
285
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
286
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
287
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
288
+
289
+ cos, sin = position_embeddings
290
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
291
+
292
+ if past_key_values is not None:
293
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
294
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
295
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
296
+
297
+ attention_interface: Callable = eager_attention_forward
298
+ if self.config._attn_implementation != "eager":
299
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
300
+
301
+ attn_output, attn_weights = attention_interface(
302
+ self,
303
+ query_states,
304
+ key_states,
305
+ value_states,
306
+ attention_mask,
307
+ dropout=0.0 if not self.training else self.attention_dropout,
308
+ scaling=self.scaling,
309
+ **kwargs,
310
+ )
311
+
312
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
313
+ attn_output = self.o_proj(attn_output)
314
+ return attn_output, attn_weights
315
+
316
+
317
+ class DiaCrossAttention(nn.Module):
318
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
319
+
320
+ def __init__(self, config: DiaDecoderConfig, layer_idx: int):
321
+ super().__init__()
322
+ self.config = config
323
+ self.layer_idx = layer_idx
324
+ self.hidden_size = config.hidden_size
325
+ self.cross_hidden_size = config.cross_hidden_size
326
+ self.num_heads = self.config.cross_num_attention_heads
327
+ self.num_key_value_heads = self.config.cross_num_key_value_heads
328
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
329
+ self.head_dim = config.cross_head_dim
330
+ self.scaling = 1
331
+ self.attention_dropout = 0.0
332
+ self.is_causal = False
333
+
334
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
335
+ self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
336
+ self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
337
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: torch.Tensor,
342
+ cross_attention_states: torch.Tensor,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ past_key_values: Optional[EncoderDecoderCache] = None,
345
+ **kwargs: Unpack[FlashAttentionKwargs],
346
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
347
+ input_shape = hidden_states.shape[:-1]
348
+ hidden_shape = (*input_shape, -1, self.head_dim)
349
+ cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
350
+
351
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
352
+
353
+ is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
354
+ if past_key_values is not None and is_updated:
355
+ # reuse k,v, cross_attentions
356
+ key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
357
+ value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
358
+ else:
359
+ key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
360
+ value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
361
+
362
+ if past_key_values is not None:
363
+ # save all states to the cache
364
+ key_states, value_states = past_key_values.cross_attention_cache.update(
365
+ key_states,
366
+ value_states,
367
+ self.layer_idx,
368
+ )
369
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
370
+ past_key_values.is_updated[self.layer_idx] = True
371
+
372
+ attention_interface: Callable = eager_attention_forward
373
+ if self.config._attn_implementation != "eager":
374
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
375
+
376
+ attn_output, attn_weights = attention_interface(
377
+ self,
378
+ query_states,
379
+ key_states,
380
+ value_states,
381
+ attention_mask,
382
+ scaling=self.scaling,
383
+ **kwargs,
384
+ )
385
+
386
+ attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
387
+ attn_output = self.o_proj(attn_output)
388
+ return attn_output, attn_weights
389
+
390
+
391
+ class DiaEncoderLayer(GradientCheckpointingLayer):
392
+ def __init__(self, config: DiaEncoderConfig, layer_idx: int):
393
+ super().__init__()
394
+ self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
395
+ self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
396
+ self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
397
+ self.mlp = DiaMLP(config)
398
+
399
+ def forward(
400
+ self,
401
+ hidden_states: torch.Tensor,
402
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
403
+ attention_mask: Optional[torch.Tensor] = None,
404
+ **kwargs: Unpack[FlashAttentionKwargs],
405
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
406
+ residual = hidden_states
407
+ normed_states = self.pre_sa_norm(hidden_states)
408
+ self_attn_output, self_attn_weights = self.self_attention(
409
+ normed_states,
410
+ position_embeddings=position_embeddings,
411
+ attention_mask=attention_mask,
412
+ **kwargs,
413
+ )
414
+ hidden_states = residual + self_attn_output
415
+
416
+ residual = hidden_states
417
+ normed_states = self.post_sa_norm(hidden_states)
418
+ mlp_out = self.mlp(normed_states)
419
+ hidden_states = residual + mlp_out
420
+
421
+ return hidden_states, self_attn_weights
422
+
423
+
424
+ class DiaEncoder(DiaPreTrainedModel):
425
+ def __init__(self, config: DiaEncoderConfig):
426
+ super().__init__(config)
427
+ self.config = config
428
+
429
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
430
+ self.layers = nn.ModuleList(
431
+ [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
432
+ )
433
+ self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
434
+ self.rotary_embeddings = DiaRotaryEmbedding(config)
435
+
436
+ @auto_docstring
437
+ @can_return_tuple
438
+ def forward(
439
+ self,
440
+ input_ids: torch.Tensor,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ output_attentions: Optional[bool] = False,
443
+ output_hidden_states: Optional[bool] = False,
444
+ **kwargs: Unpack[FlashAttentionKwargs],
445
+ ) -> Union[BaseModelOutput, tuple]:
446
+ hidden_states = self.embedding(input_ids)
447
+
448
+ # RoPE
449
+ # Note: We expect right padding and hence always generate
450
+ # the position ids on the fly to reduce preparation overhead
451
+ position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
452
+ position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
453
+
454
+ attention_mask = self._update_full_mask(
455
+ attention_mask,
456
+ hidden_states,
457
+ )
458
+
459
+ encoder_states = () if output_hidden_states else None
460
+ all_attentions = () if output_attentions else None
461
+
462
+ for encoder_layer in self.layers:
463
+ if output_hidden_states:
464
+ encoder_states = encoder_states + (hidden_states,)
465
+
466
+ layer_outputs = encoder_layer(
467
+ hidden_states,
468
+ position_embeddings=position_embeddings,
469
+ attention_mask=attention_mask,
470
+ **kwargs,
471
+ )
472
+ hidden_states = layer_outputs[0]
473
+
474
+ if output_attentions:
475
+ all_attentions = all_attentions + (layer_outputs[1],)
476
+
477
+ hidden_states = self.norm(hidden_states)
478
+
479
+ if output_hidden_states:
480
+ encoder_states += (hidden_states,)
481
+
482
+ return BaseModelOutput(
483
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
484
+ )
485
+
486
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
487
+ def _update_full_mask(
488
+ self,
489
+ attention_mask: Union[torch.Tensor, None],
490
+ inputs_embeds: torch.Tensor,
491
+ ):
492
+ if attention_mask is not None:
493
+ if self.config._attn_implementation == "flash_attention_2":
494
+ attention_mask = attention_mask if 0 in attention_mask else None
495
+ elif self.config._attn_implementation == "sdpa":
496
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
497
+ # the manual implementation that requires a 4D causal mask in all cases.
498
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
499
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
500
+ elif self.config._attn_implementation == "flex_attention":
501
+ if isinstance(attention_mask, torch.Tensor):
502
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
503
+ else:
504
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
505
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
506
+
507
+ return attention_mask
508
+
509
+
510
+ class DiaDecoderLayer(GradientCheckpointingLayer):
511
+ def __init__(self, config: DiaDecoderConfig, layer_idx: int):
512
+ super().__init__()
513
+ self.embed_dim = config.hidden_size
514
+ self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
515
+ self.cross_attention = DiaCrossAttention(config, layer_idx)
516
+ self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
517
+ self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
518
+ self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
519
+ self.mlp = DiaMLP(config)
520
+
521
+ def forward(
522
+ self,
523
+ hidden_states: torch.Tensor,
524
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
525
+ attention_mask: Optional[torch.Tensor] = None,
526
+ encoder_hidden_states: Optional[torch.Tensor] = None,
527
+ encoder_attention_mask: Optional[torch.Tensor] = None,
528
+ past_key_values: Optional[EncoderDecoderCache] = None,
529
+ cache_position: Optional[torch.LongTensor] = None,
530
+ **kwargs,
531
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
532
+ self_attn_cache = past_key_values
533
+ if isinstance(self_attn_cache, EncoderDecoderCache):
534
+ self_attn_cache = self_attn_cache.self_attention_cache
535
+
536
+ residual = hidden_states
537
+ normed_states = self.pre_sa_norm(hidden_states)
538
+ self_attn_output, self_attn_weights = self.self_attention(
539
+ normed_states,
540
+ position_embeddings,
541
+ attention_mask,
542
+ # Needs to be an arg in order to function properly
543
+ # on inplace operations to be carried (e.g. compile)
544
+ self_attn_cache,
545
+ cache_position=cache_position,
546
+ **kwargs,
547
+ )
548
+ hidden_states = residual + self_attn_output
549
+
550
+ residual = hidden_states
551
+ normed_states = self.pre_ca_norm(hidden_states)
552
+ cross_states, cross_attn_weights = self.cross_attention(
553
+ normed_states,
554
+ encoder_hidden_states,
555
+ attention_mask=encoder_attention_mask,
556
+ past_key_values=past_key_values,
557
+ **kwargs,
558
+ )
559
+ hidden_states = residual + cross_states
560
+
561
+ residual = hidden_states
562
+ normed_states = self.pre_mlp_norm(hidden_states)
563
+ mlp_out = self.mlp(normed_states)
564
+ hidden_states = residual + mlp_out
565
+
566
+ return hidden_states, self_attn_weights, cross_attn_weights
567
+
568
+
569
+ class DiaDecoder(DiaPreTrainedModel):
570
+ """Transformer Decoder Stack using DenseGeneral."""
571
+
572
+ def __init__(self, config: DiaDecoderConfig):
573
+ super().__init__(config)
574
+ self.num_channels = config.num_channels
575
+ self.vocab_size = config.vocab_size
576
+ self.embeddings = DiaMultiChannelEmbedding(config)
577
+ self.rotary_embeddings = DiaRotaryEmbedding(config)
578
+ self.layers = nn.ModuleList(
579
+ [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
580
+ )
581
+ self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
582
+
583
+ @auto_docstring
584
+ @can_return_tuple
585
+ def forward(
586
+ self,
587
+ input_ids: torch.Tensor,
588
+ position_ids: Optional[torch.LongTensor] = None,
589
+ attention_mask: Optional[torch.Tensor] = None,
590
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
591
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
592
+ past_key_values: Optional[EncoderDecoderCache] = None,
593
+ output_attentions: Optional[bool] = False,
594
+ output_hidden_states: Optional[bool] = False,
595
+ cache_position: Optional[torch.LongTensor] = None,
596
+ **kwargs,
597
+ ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
598
+ r"""
599
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
600
+ The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
601
+
602
+ [What are input IDs?](../glossary#input-ids)
603
+ """
604
+
605
+ batch_size, seq_length = input_ids.size()[:-1]
606
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
607
+ if cache_position is None:
608
+ cache_position = torch.arange(
609
+ past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
610
+ )
611
+ if position_ids is None:
612
+ position_ids = cache_position[None, :]
613
+
614
+ # RoPE
615
+ hidden_states = self.embeddings(input_ids)
616
+ position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
617
+
618
+ if attention_mask is None and not is_torchdynamo_compiling():
619
+ # required mask seq length can be calculated via length of past cache
620
+ mask_seq_length = past_key_values_length + seq_length
621
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
622
+
623
+ attention_mask = create_causal_mask(
624
+ config=self.config,
625
+ input_embeds=hidden_states,
626
+ attention_mask=attention_mask,
627
+ cache_position=cache_position,
628
+ past_key_values=past_key_values,
629
+ position_ids=position_ids,
630
+ )
631
+ encoder_attention_mask = self._update_cross_attn_mask(
632
+ encoder_hidden_states,
633
+ encoder_attention_mask,
634
+ hidden_states.shape[:2],
635
+ hidden_states,
636
+ )
637
+
638
+ all_hidden_states = () if output_hidden_states else None
639
+ all_self_attns = () if output_attentions else None
640
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
641
+
642
+ for layer in self.layers:
643
+ if output_hidden_states:
644
+ all_hidden_states += (hidden_states,)
645
+
646
+ layer_outputs = layer(
647
+ hidden_states,
648
+ position_embeddings,
649
+ attention_mask,
650
+ encoder_hidden_states,
651
+ encoder_attention_mask=encoder_attention_mask,
652
+ past_key_values=past_key_values,
653
+ cache_position=cache_position,
654
+ **kwargs,
655
+ )
656
+ hidden_states = layer_outputs[0]
657
+
658
+ if output_attentions:
659
+ all_self_attns = all_self_attns + (layer_outputs[1],)
660
+
661
+ if encoder_hidden_states is not None:
662
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
663
+
664
+ hidden_states = self.norm(hidden_states)
665
+
666
+ if output_hidden_states:
667
+ all_hidden_states += (hidden_states,)
668
+
669
+ return BaseModelOutputWithPastAndCrossAttentions(
670
+ last_hidden_state=hidden_states,
671
+ past_key_values=past_key_values,
672
+ hidden_states=all_hidden_states,
673
+ attentions=all_self_attns,
674
+ cross_attentions=all_cross_attentions,
675
+ )
676
+
677
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
678
+ def _update_cross_attn_mask(
679
+ self,
680
+ encoder_hidden_states: Union[torch.Tensor, None],
681
+ encoder_attention_mask: Union[torch.Tensor, None],
682
+ input_shape: torch.Size,
683
+ inputs_embeds: torch.Tensor,
684
+ ):
685
+ # expand encoder attention mask
686
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
687
+ if self.config._attn_implementation == "flash_attention_2":
688
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
689
+ elif self.config._attn_implementation == "sdpa":
690
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
691
+ # the manual implementation that requires a 4D causal mask in all cases.
692
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
693
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
694
+ encoder_attention_mask,
695
+ inputs_embeds.dtype,
696
+ tgt_len=input_shape[-1],
697
+ )
698
+ elif self.config._attn_implementation == "flex_attention":
699
+ if isinstance(encoder_attention_mask, torch.Tensor):
700
+ encoder_attention_mask = make_flex_block_causal_mask(
701
+ encoder_attention_mask,
702
+ query_length=input_shape[-1],
703
+ is_causal=False,
704
+ )
705
+ else:
706
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
707
+ encoder_attention_mask = _prepare_4d_attention_mask(
708
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
709
+ )
710
+
711
+ return encoder_attention_mask
712
+
713
+
714
+ @auto_docstring(
715
+ custom_intro="""
716
+ The bare Dia model outputting raw hidden-states without any specific head on top.
717
+ """
718
+ )
719
+ class DiaModel(DiaPreTrainedModel):
720
+ def __init__(self, config: DiaConfig):
721
+ super().__init__(config)
722
+ self.config = config
723
+ self.encoder = DiaEncoder(config.encoder_config)
724
+ self.decoder = DiaDecoder(config.decoder_config)
725
+ self.post_init()
726
+
727
+ def get_encoder(self):
728
+ return self.encoder
729
+
730
+ @auto_docstring
731
+ @can_return_tuple
732
+ def forward(
733
+ self,
734
+ input_ids: Optional[torch.LongTensor] = None,
735
+ attention_mask: Optional[torch.LongTensor] = None,
736
+ decoder_input_ids: Optional[torch.LongTensor] = None,
737
+ decoder_position_ids: Optional[torch.LongTensor] = None,
738
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
739
+ encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
740
+ past_key_values: Optional[EncoderDecoderCache] = None,
741
+ use_cache: Optional[bool] = None,
742
+ output_attentions: Optional[bool] = None,
743
+ output_hidden_states: Optional[bool] = None,
744
+ cache_position: Optional[torch.LongTensor] = None,
745
+ **kwargs,
746
+ ) -> Union[tuple, Seq2SeqModelOutput]:
747
+ r"""
748
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
749
+ or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
750
+ 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
751
+ the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
752
+ tened audio logits which are used to calculate the loss.
753
+
754
+ 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
755
+ Dia to calculate embeddings and subsequent steps more efficiently.
756
+
757
+ If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
758
+ `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
759
+ [`DiaProcessor.__call__`] for more details.
760
+
761
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
762
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
763
+ Indices of positions of each input sequence tokens in the position embeddings.
764
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
765
+
766
+ [What are position IDs?](../glossary#position-ids)
767
+ """
768
+
769
+ if input_ids is None and encoder_outputs is None:
770
+ raise ValueError(
771
+ "You should either provide text ids or the cached text encodings. Neither has been found."
772
+ )
773
+
774
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
775
+ output_hidden_states = (
776
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
777
+ )
778
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
779
+
780
+ if self.is_gradient_checkpointing and self.training:
781
+ if use_cache:
782
+ logger.warning_once(
783
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
784
+ )
785
+ use_cache = False
786
+
787
+ if use_cache and past_key_values is None:
788
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
789
+
790
+ if encoder_outputs is None:
791
+ encoder_outputs = self.encoder(
792
+ input_ids=input_ids,
793
+ attention_mask=attention_mask,
794
+ output_attentions=output_attentions,
795
+ output_hidden_states=output_hidden_states,
796
+ **kwargs,
797
+ )
798
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
799
+ elif not isinstance(encoder_outputs, BaseModelOutput):
800
+ encoder_outputs = BaseModelOutput(
801
+ last_hidden_state=encoder_outputs[0],
802
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
803
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
804
+ )
805
+
806
+ # On default we initialize the decoder with bos tokens if nothing has been provided
807
+ bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
808
+ if decoder_input_ids is None:
809
+ decoder_input_ids = torch.full(
810
+ size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
811
+ )
812
+ # Ensure 3D
813
+ if decoder_input_ids.ndim == 2:
814
+ decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
815
+
816
+ decoder_outputs = self.decoder(
817
+ input_ids=decoder_input_ids,
818
+ position_ids=decoder_position_ids,
819
+ attention_mask=decoder_attention_mask,
820
+ encoder_hidden_states=encoder_outputs[0],
821
+ encoder_attention_mask=attention_mask,
822
+ past_key_values=past_key_values,
823
+ output_attentions=output_attentions,
824
+ output_hidden_states=output_hidden_states,
825
+ use_cache=use_cache,
826
+ cache_position=cache_position,
827
+ **kwargs,
828
+ )
829
+
830
+ return Seq2SeqModelOutput(
831
+ last_hidden_state=decoder_outputs.last_hidden_state,
832
+ past_key_values=decoder_outputs.past_key_values,
833
+ decoder_hidden_states=decoder_outputs.hidden_states,
834
+ decoder_attentions=decoder_outputs.attentions,
835
+ cross_attentions=decoder_outputs.cross_attentions,
836
+ encoder_last_hidden_state=encoder_outputs[0],
837
+ encoder_hidden_states=encoder_outputs.hidden_states,
838
+ encoder_attentions=encoder_outputs.attentions,
839
+ )
840
+
841
+
842
+ @auto_docstring(
843
+ custom_intro="""
844
+ The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
845
+ """
846
+ )
847
+ class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
848
+ base_model_prefix = "model"
849
+
850
+ def __init__(self, config: DiaConfig):
851
+ super().__init__(config)
852
+ self.config = config
853
+ self.model = DiaModel(config)
854
+
855
+ self.num_channels = config.decoder_config.num_channels
856
+ self.vocab_size = config.decoder_config.vocab_size
857
+ self.logits_dense = nn.Linear(
858
+ config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
859
+ )
860
+ self.loss_type = "ForMaskedLM"
861
+
862
+ # Initialize weights and apply final processing
863
+ self.post_init()
864
+
865
+ def get_encoder(self):
866
+ return self.model.get_encoder()
867
+
868
+ def get_decoder(self):
869
+ return self.model.get_decoder()
870
+
871
+ @auto_docstring
872
+ @can_return_tuple
873
+ def forward(
874
+ self,
875
+ input_ids: Optional[torch.LongTensor] = None,
876
+ attention_mask: Optional[torch.LongTensor] = None,
877
+ decoder_input_ids: Optional[torch.LongTensor] = None,
878
+ decoder_position_ids: Optional[torch.LongTensor] = None,
879
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
880
+ encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
881
+ past_key_values: Optional[EncoderDecoderCache] = None,
882
+ use_cache: Optional[bool] = None,
883
+ output_attentions: Optional[bool] = None,
884
+ output_hidden_states: Optional[bool] = None,
885
+ labels: Optional[torch.LongTensor] = None,
886
+ cache_position: Optional[torch.LongTensor] = None,
887
+ **kwargs,
888
+ ) -> Union[tuple, Seq2SeqLMOutput]:
889
+ r"""
890
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
891
+ or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
892
+ 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
893
+ the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
894
+ tened audio logits which are used to calculate the loss.
895
+
896
+ 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
897
+ Dia to calculate embeddings and subsequent steps more efficiently.
898
+
899
+ If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
900
+ `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
901
+ [`DiaProcessor.__call__`] for more details.
902
+
903
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
904
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
905
+ Indices of positions of each input sequence tokens in the position embeddings.
906
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
907
+
908
+ [What are position IDs?](../glossary#position-ids)
909
+ labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
910
+ Labels for computing the masked language modeling loss. Indices should either be in
911
+ `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
912
+ are ignored (masked).
913
+ """
914
+
915
+ outputs = self.model(
916
+ input_ids=input_ids,
917
+ attention_mask=attention_mask,
918
+ decoder_input_ids=decoder_input_ids,
919
+ decoder_position_ids=decoder_position_ids,
920
+ decoder_attention_mask=decoder_attention_mask,
921
+ encoder_outputs=encoder_outputs,
922
+ past_key_values=past_key_values,
923
+ use_cache=use_cache,
924
+ output_attentions=output_attentions,
925
+ output_hidden_states=output_hidden_states,
926
+ cache_position=cache_position,
927
+ **kwargs,
928
+ )
929
+
930
+ last_hidden_state = outputs[0]
931
+ batch_size = last_hidden_state.shape[0]
932
+ # 3D <-> 2D makes it necessary to prioritize channel dim
933
+ audio_logits = (
934
+ self.logits_dense(last_hidden_state)
935
+ .view((batch_size, -1, self.num_channels, self.vocab_size))
936
+ .transpose(1, 2)
937
+ .contiguous()
938
+ .view(batch_size * self.num_channels, -1, self.vocab_size)
939
+ )
940
+
941
+ loss = None
942
+ if labels is not None:
943
+ loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
944
+
945
+ return Seq2SeqLMOutput(
946
+ loss=loss,
947
+ logits=audio_logits,
948
+ past_key_values=outputs.past_key_values,
949
+ decoder_hidden_states=outputs.decoder_hidden_states,
950
+ decoder_attentions=outputs.decoder_attentions,
951
+ cross_attentions=outputs.cross_attentions,
952
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
953
+ encoder_hidden_states=outputs.encoder_hidden_states,
954
+ encoder_attentions=outputs.encoder_attentions,
955
+ )
956
+
957
+
958
+ __all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]