Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- phivenv/Lib/site-packages/torch/lib/XNNPACK.lib +3 -0
- phivenv/Lib/site-packages/torch/lib/torch_cpu.lib +3 -0
- phivenv/Lib/site-packages/torch/lib/torch_python.dll +3 -0
- phivenv/Lib/site-packages/torch/lib/torch_python.lib +3 -0
- phivenv/Lib/site-packages/torch/lib/uv.dll +3 -0
- phivenv/Lib/site-packages/torch/linalg/__pycache__/__init__.cpython-39.pyc +3 -0
- phivenv/Lib/site-packages/transformers/models/d_fine/__init__.py +29 -0
- phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/configuration_d_fine.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/modeling_d_fine.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/modular_d_fine.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/d_fine/configuration_d_fine.py +433 -0
- phivenv/Lib/site-packages/transformers/models/d_fine/modeling_d_fine.py +0 -0
- phivenv/Lib/site-packages/transformers/models/d_fine/modular_d_fine.py +1221 -0
- phivenv/Lib/site-packages/transformers/models/depth_pro/__init__.py +29 -0
- phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/configuration_depth_pro.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/image_processing_depth_pro.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/image_processing_depth_pro_fast.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/modeling_depth_pro.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/depth_pro/configuration_depth_pro.py +205 -0
- phivenv/Lib/site-packages/transformers/models/depth_pro/image_processing_depth_pro.py +389 -0
- phivenv/Lib/site-packages/transformers/models/depth_pro/image_processing_depth_pro_fast.py +177 -0
- phivenv/Lib/site-packages/transformers/models/depth_pro/modeling_depth_pro.py +1132 -0
- phivenv/Lib/site-packages/transformers/models/detr/__init__.py +31 -0
- phivenv/Lib/site-packages/transformers/models/detr/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/detr/__pycache__/configuration_detr.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/detr/__pycache__/feature_extraction_detr.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/detr/__pycache__/image_processing_detr.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/detr/__pycache__/image_processing_detr_fast.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/detr/__pycache__/modeling_detr.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/detr/configuration_detr.py +297 -0
- phivenv/Lib/site-packages/transformers/models/detr/feature_extraction_detr.py +48 -0
- phivenv/Lib/site-packages/transformers/models/detr/image_processing_detr.py +2049 -0
- phivenv/Lib/site-packages/transformers/models/detr/image_processing_detr_fast.py +1291 -0
- phivenv/Lib/site-packages/transformers/models/detr/modeling_detr.py +1693 -0
- phivenv/Lib/site-packages/transformers/models/dia/__init__.py +31 -0
- phivenv/Lib/site-packages/transformers/models/dia/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/dia/__pycache__/configuration_dia.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/dia/__pycache__/feature_extraction_dia.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/dia/__pycache__/generation_dia.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/dia/__pycache__/modeling_dia.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/dia/__pycache__/modular_dia.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/dia/__pycache__/processing_dia.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/dia/__pycache__/tokenization_dia.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/transformers/models/dia/configuration_dia.py +376 -0
- phivenv/Lib/site-packages/transformers/models/dia/feature_extraction_dia.py +183 -0
- phivenv/Lib/site-packages/transformers/models/dia/generation_dia.py +464 -0
- phivenv/Lib/site-packages/transformers/models/dia/modeling_dia.py +958 -0
.gitattributes
CHANGED
|
@@ -122,3 +122,9 @@ phivenv/Lib/site-packages/torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs
|
|
| 122 |
phivenv/Lib/site-packages/torch/lib/pthreadpool.lib filter=lfs diff=lfs merge=lfs -text
|
| 123 |
phivenv/Lib/site-packages/torch/lib/microkernels-prod.lib filter=lfs diff=lfs merge=lfs -text
|
| 124 |
phivenv/Lib/site-packages/torch/lib/sleef.lib filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
phivenv/Lib/site-packages/torch/lib/pthreadpool.lib filter=lfs diff=lfs merge=lfs -text
|
| 123 |
phivenv/Lib/site-packages/torch/lib/microkernels-prod.lib filter=lfs diff=lfs merge=lfs -text
|
| 124 |
phivenv/Lib/site-packages/torch/lib/sleef.lib filter=lfs diff=lfs merge=lfs -text
|
| 125 |
+
phivenv/Lib/site-packages/torch/lib/torch_cpu.lib filter=lfs diff=lfs merge=lfs -text
|
| 126 |
+
phivenv/Lib/site-packages/torch/lib/torch_python.dll filter=lfs diff=lfs merge=lfs -text
|
| 127 |
+
phivenv/Lib/site-packages/torch/lib/torch_python.lib filter=lfs diff=lfs merge=lfs -text
|
| 128 |
+
phivenv/Lib/site-packages/torch/lib/uv.dll filter=lfs diff=lfs merge=lfs -text
|
| 129 |
+
phivenv/Lib/site-packages/torch/lib/XNNPACK.lib filter=lfs diff=lfs merge=lfs -text
|
| 130 |
+
phivenv/Lib/site-packages/torch/linalg/__pycache__/__init__.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
phivenv/Lib/site-packages/torch/lib/XNNPACK.lib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3bf5c98f694f4587f5a191739ea8dd565a0696828448828e7491b9c8ca5d6fe2
|
| 3 |
+
size 14049460
|
phivenv/Lib/site-packages/torch/lib/torch_cpu.lib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:08b81393191ac47ebf63e92aad8b65ece890d86dd51eb1e7294f1be3e496f3d7
|
| 3 |
+
size 29046564
|
phivenv/Lib/site-packages/torch/lib/torch_python.dll
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2cd2897d163f1029341f3ae67fcccd5f4f4fe7b4d62ecae8ca767128e3140f73
|
| 3 |
+
size 16310272
|
phivenv/Lib/site-packages/torch/lib/torch_python.lib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8910178ebb14175b5c8b9ccd27b38b2360297399273b6dd3312bcc733b779529
|
| 3 |
+
size 287836
|
phivenv/Lib/site-packages/torch/lib/uv.dll
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa569e682fc5fb7a8eb94c6829af9f30a569748dbbc6bce39735d48bc960bcf8
|
| 3 |
+
size 195072
|
phivenv/Lib/site-packages/torch/linalg/__pycache__/__init__.cpython-39.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9aa23740f58a645a003b0df576478da22f00c55c5001053fba842208036eb483
|
| 3 |
+
size 113386
|
phivenv/Lib/site-packages/transformers/models/d_fine/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import TYPE_CHECKING
|
| 17 |
+
|
| 18 |
+
from ...utils import _LazyModule
|
| 19 |
+
from ...utils.import_utils import define_import_structure
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from .configuration_d_fine import *
|
| 24 |
+
from .modeling_d_fine import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (528 Bytes). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/configuration_d_fine.cpython-39.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/modeling_d_fine.cpython-39.pyc
ADDED
|
Binary file (72.6 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/d_fine/__pycache__/modular_d_fine.cpython-39.pyc
ADDED
|
Binary file (43.9 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/d_fine/configuration_d_fine.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/d_fine/modular_d_fine.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_d_fine.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
from ...configuration_utils import PretrainedConfig
|
| 22 |
+
from ...utils import logging
|
| 23 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 24 |
+
from ..auto import CONFIG_MAPPING
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# TODO: Attribute map assignment logic should be fixed in modular
|
| 31 |
+
# as well as super() call parsing because otherwise we cannot re-write args after initialization
|
| 32 |
+
class DFineConfig(PretrainedConfig):
|
| 33 |
+
"""
|
| 34 |
+
This is the configuration class to store the configuration of a [`DFineModel`]. It is used to instantiate a D-FINE
|
| 35 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 36 |
+
defaults will yield a similar configuration to that of D-FINE-X-COCO "[ustc-community/dfine-xlarge-coco"](https://huggingface.co/ustc-community/dfine-xlarge-coco").
|
| 37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 38 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
initializer_range (`float`, *optional*, defaults to 0.01):
|
| 42 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 43 |
+
initializer_bias_prior_prob (`float`, *optional*):
|
| 44 |
+
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
|
| 45 |
+
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
|
| 46 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 47 |
+
The epsilon used by the layer normalization layers.
|
| 48 |
+
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 49 |
+
The epsilon used by the batch normalization layers.
|
| 50 |
+
backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
|
| 51 |
+
The configuration of the backbone model.
|
| 52 |
+
backbone (`str`, *optional*):
|
| 53 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 54 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 55 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 56 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 57 |
+
Whether to use pretrained weights for the backbone.
|
| 58 |
+
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
| 59 |
+
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
| 60 |
+
library.
|
| 61 |
+
freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Whether to freeze the batch normalization layers in the backbone.
|
| 63 |
+
backbone_kwargs (`dict`, *optional*):
|
| 64 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 65 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 66 |
+
encoder_hidden_dim (`int`, *optional*, defaults to 256):
|
| 67 |
+
Dimension of the layers in hybrid encoder.
|
| 68 |
+
encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
|
| 69 |
+
Multi level features input for encoder.
|
| 70 |
+
feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
|
| 71 |
+
Strides used in each feature map.
|
| 72 |
+
encoder_layers (`int`, *optional*, defaults to 1):
|
| 73 |
+
Total of layers to be used by the encoder.
|
| 74 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 75 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 76 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 77 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 78 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 79 |
+
The ratio for all dropout layers.
|
| 80 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 81 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 82 |
+
encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
|
| 83 |
+
Indexes of the projected layers to be used in the encoder.
|
| 84 |
+
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
|
| 85 |
+
The temperature parameter used to create the positional encodings.
|
| 86 |
+
encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
|
| 87 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 88 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 89 |
+
activation_function (`str`, *optional*, defaults to `"silu"`):
|
| 90 |
+
The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
|
| 91 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 92 |
+
eval_size (`tuple[int, int]`, *optional*):
|
| 93 |
+
Height and width used to computes the effective height and width of the position embeddings after taking
|
| 94 |
+
into account the stride.
|
| 95 |
+
normalize_before (`bool`, *optional*, defaults to `False`):
|
| 96 |
+
Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
|
| 97 |
+
feed-forward modules.
|
| 98 |
+
hidden_expansion (`float`, *optional*, defaults to 1.0):
|
| 99 |
+
Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
|
| 100 |
+
d_model (`int`, *optional*, defaults to 256):
|
| 101 |
+
Dimension of the layers exclude hybrid encoder.
|
| 102 |
+
num_queries (`int`, *optional*, defaults to 300):
|
| 103 |
+
Number of object queries.
|
| 104 |
+
decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
|
| 105 |
+
Multi level features dimension for decoder
|
| 106 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 107 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 108 |
+
num_feature_levels (`int`, *optional*, defaults to 3):
|
| 109 |
+
The number of input feature levels.
|
| 110 |
+
decoder_n_points (`int`, *optional*, defaults to 4):
|
| 111 |
+
The number of sampled keys in each feature level for each attention head in the decoder.
|
| 112 |
+
decoder_layers (`int`, *optional*, defaults to 6):
|
| 113 |
+
Number of decoder layers.
|
| 114 |
+
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 115 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 116 |
+
decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
|
| 117 |
+
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
|
| 118 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 119 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 120 |
+
The dropout ratio for the attention probabilities.
|
| 121 |
+
num_denoising (`int`, *optional*, defaults to 100):
|
| 122 |
+
The total number of denoising tasks or queries to be used for contrastive denoising.
|
| 123 |
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
| 124 |
+
The fraction of denoising labels to which random noise should be added.
|
| 125 |
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
| 126 |
+
Scale or magnitude of noise to be added to the bounding boxes.
|
| 127 |
+
learn_initial_query (`bool`, *optional*, defaults to `False`):
|
| 128 |
+
Indicates whether the initial query embeddings for the decoder should be learned during training
|
| 129 |
+
anchor_image_size (`tuple[int, int]`, *optional*):
|
| 130 |
+
Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
|
| 131 |
+
with_box_refine (`bool`, *optional*, defaults to `True`):
|
| 132 |
+
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
|
| 133 |
+
based on the predictions from the previous layer.
|
| 134 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 135 |
+
Whether the architecture has an encoder decoder structure.
|
| 136 |
+
matcher_alpha (`float`, *optional*, defaults to 0.25):
|
| 137 |
+
Parameter alpha used by the Hungarian Matcher.
|
| 138 |
+
matcher_gamma (`float`, *optional*, defaults to 2.0):
|
| 139 |
+
Parameter gamma used by the Hungarian Matcher.
|
| 140 |
+
matcher_class_cost (`float`, *optional*, defaults to 2.0):
|
| 141 |
+
The relative weight of the class loss used by the Hungarian Matcher.
|
| 142 |
+
matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
|
| 143 |
+
The relative weight of the bounding box loss used by the Hungarian Matcher.
|
| 144 |
+
matcher_giou_cost (`float`, *optional*, defaults to 2.0):
|
| 145 |
+
The relative weight of the giou loss of used by the Hungarian Matcher.
|
| 146 |
+
use_focal_loss (`bool`, *optional*, defaults to `True`):
|
| 147 |
+
Parameter informing if focal focal should be used.
|
| 148 |
+
auxiliary_loss (`bool`, *optional*, defaults to `True`):
|
| 149 |
+
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 150 |
+
focal_loss_alpha (`float`, *optional*, defaults to 0.75):
|
| 151 |
+
Parameter alpha used to compute the focal loss.
|
| 152 |
+
focal_loss_gamma (`float`, *optional*, defaults to 2.0):
|
| 153 |
+
Parameter gamma used to compute the focal loss.
|
| 154 |
+
weight_loss_vfl (`float`, *optional*, defaults to 1.0):
|
| 155 |
+
Relative weight of the varifocal loss in the object detection loss.
|
| 156 |
+
weight_loss_bbox (`float`, *optional*, defaults to 5.0):
|
| 157 |
+
Relative weight of the L1 bounding box loss in the object detection loss.
|
| 158 |
+
weight_loss_giou (`float`, *optional*, defaults to 2.0):
|
| 159 |
+
Relative weight of the generalized IoU loss in the object detection loss.
|
| 160 |
+
weight_loss_fgl (`float`, *optional*, defaults to 0.15):
|
| 161 |
+
Relative weight of the fine-grained localization loss in the object detection loss.
|
| 162 |
+
weight_loss_ddf (`float`, *optional*, defaults to 1.5):
|
| 163 |
+
Relative weight of the decoupled distillation focal loss in the object detection loss.
|
| 164 |
+
eos_coefficient (`float`, *optional*, defaults to 0.0001):
|
| 165 |
+
Relative classification weight of the 'no-object' class in the object detection loss.
|
| 166 |
+
eval_idx (`int`, *optional*, defaults to -1):
|
| 167 |
+
Index of the decoder layer to use for evaluation. If negative, counts from the end
|
| 168 |
+
(e.g., -1 means use the last layer). This allows for early prediction in the decoder
|
| 169 |
+
stack while still training later layers.
|
| 170 |
+
layer_scale (`float`, *optional*, defaults to `1.0`):
|
| 171 |
+
Scaling factor for the hidden dimension in later decoder layers. Used to adjust the
|
| 172 |
+
model capacity after the evaluation layer.
|
| 173 |
+
max_num_bins (`int`, *optional*, defaults to 32):
|
| 174 |
+
Maximum number of bins for the distribution-guided bounding box refinement.
|
| 175 |
+
Higher values allow for more fine-grained localization but increase computation.
|
| 176 |
+
reg_scale (`float`, *optional*, defaults to 4.0):
|
| 177 |
+
Scale factor for the regression distribution. Controls the range and granularity
|
| 178 |
+
of the bounding box refinement process.
|
| 179 |
+
depth_mult (`float`, *optional*, defaults to 1.0):
|
| 180 |
+
Multiplier for the number of blocks in RepNCSPELAN4 layers. Used to scale the model's
|
| 181 |
+
depth while maintaining its architecture.
|
| 182 |
+
top_prob_values (`int`, *optional*, defaults to 4):
|
| 183 |
+
Number of top probability values to consider from each corner's distribution.
|
| 184 |
+
lqe_hidden_dim (`int`, *optional*, defaults to 64):
|
| 185 |
+
Hidden dimension size for the Location Quality Estimator (LQE) network.
|
| 186 |
+
lqe_layers (`int`, *optional*, defaults to 2):
|
| 187 |
+
Number of layers in the Location Quality Estimator MLP.
|
| 188 |
+
decoder_offset_scale (`float`, *optional*, defaults to 0.5):
|
| 189 |
+
Offset scale used in deformable attention.
|
| 190 |
+
decoder_method (`str`, *optional*, defaults to `"default"`):
|
| 191 |
+
The method to use for the decoder: `"default"` or `"discrete"`.
|
| 192 |
+
up (`float`, *optional*, defaults to 0.5):
|
| 193 |
+
Controls the upper bounds of the Weighting Function.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
model_type = "d_fine"
|
| 197 |
+
layer_types = ["basic", "bottleneck"]
|
| 198 |
+
attribute_map = {
|
| 199 |
+
"hidden_size": "d_model",
|
| 200 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
initializer_range=0.01,
|
| 206 |
+
initializer_bias_prior_prob=None,
|
| 207 |
+
layer_norm_eps=1e-5,
|
| 208 |
+
batch_norm_eps=1e-5,
|
| 209 |
+
# backbone
|
| 210 |
+
backbone_config=None,
|
| 211 |
+
backbone=None,
|
| 212 |
+
use_pretrained_backbone=False,
|
| 213 |
+
use_timm_backbone=False,
|
| 214 |
+
freeze_backbone_batch_norms=True,
|
| 215 |
+
backbone_kwargs=None,
|
| 216 |
+
# encoder HybridEncoder
|
| 217 |
+
encoder_hidden_dim=256,
|
| 218 |
+
encoder_in_channels=[512, 1024, 2048],
|
| 219 |
+
feat_strides=[8, 16, 32],
|
| 220 |
+
encoder_layers=1,
|
| 221 |
+
encoder_ffn_dim=1024,
|
| 222 |
+
encoder_attention_heads=8,
|
| 223 |
+
dropout=0.0,
|
| 224 |
+
activation_dropout=0.0,
|
| 225 |
+
encode_proj_layers=[2],
|
| 226 |
+
positional_encoding_temperature=10000,
|
| 227 |
+
encoder_activation_function="gelu",
|
| 228 |
+
activation_function="silu",
|
| 229 |
+
eval_size=None,
|
| 230 |
+
normalize_before=False,
|
| 231 |
+
hidden_expansion=1.0,
|
| 232 |
+
# decoder DFineTransformer
|
| 233 |
+
d_model=256,
|
| 234 |
+
num_queries=300,
|
| 235 |
+
decoder_in_channels=[256, 256, 256],
|
| 236 |
+
decoder_ffn_dim=1024,
|
| 237 |
+
num_feature_levels=3,
|
| 238 |
+
decoder_n_points=4,
|
| 239 |
+
decoder_layers=6,
|
| 240 |
+
decoder_attention_heads=8,
|
| 241 |
+
decoder_activation_function="relu",
|
| 242 |
+
attention_dropout=0.0,
|
| 243 |
+
num_denoising=100,
|
| 244 |
+
label_noise_ratio=0.5,
|
| 245 |
+
box_noise_scale=1.0,
|
| 246 |
+
learn_initial_query=False,
|
| 247 |
+
anchor_image_size=None,
|
| 248 |
+
with_box_refine=True,
|
| 249 |
+
is_encoder_decoder=True,
|
| 250 |
+
# Loss
|
| 251 |
+
matcher_alpha=0.25,
|
| 252 |
+
matcher_gamma=2.0,
|
| 253 |
+
matcher_class_cost=2.0,
|
| 254 |
+
matcher_bbox_cost=5.0,
|
| 255 |
+
matcher_giou_cost=2.0,
|
| 256 |
+
use_focal_loss=True,
|
| 257 |
+
auxiliary_loss=True,
|
| 258 |
+
focal_loss_alpha=0.75,
|
| 259 |
+
focal_loss_gamma=2.0,
|
| 260 |
+
weight_loss_vfl=1.0,
|
| 261 |
+
weight_loss_bbox=5.0,
|
| 262 |
+
weight_loss_giou=2.0,
|
| 263 |
+
weight_loss_fgl=0.15,
|
| 264 |
+
weight_loss_ddf=1.5,
|
| 265 |
+
eos_coefficient=1e-4,
|
| 266 |
+
eval_idx=-1,
|
| 267 |
+
layer_scale=1,
|
| 268 |
+
max_num_bins=32,
|
| 269 |
+
reg_scale=4.0,
|
| 270 |
+
depth_mult=1.0,
|
| 271 |
+
top_prob_values=4,
|
| 272 |
+
lqe_hidden_dim=64,
|
| 273 |
+
lqe_layers=2,
|
| 274 |
+
decoder_offset_scale=0.5,
|
| 275 |
+
decoder_method="default",
|
| 276 |
+
up=0.5,
|
| 277 |
+
**kwargs,
|
| 278 |
+
):
|
| 279 |
+
self.initializer_range = initializer_range
|
| 280 |
+
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
| 281 |
+
self.layer_norm_eps = layer_norm_eps
|
| 282 |
+
self.batch_norm_eps = batch_norm_eps
|
| 283 |
+
# backbone
|
| 284 |
+
if backbone_config is None and backbone is None:
|
| 285 |
+
logger.info(
|
| 286 |
+
"`backbone_config` and `backbone` are `None`. Initializing the config with the default `HGNet-V2` backbone."
|
| 287 |
+
)
|
| 288 |
+
backbone_model_type = "hgnet_v2"
|
| 289 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 290 |
+
# this will map it to RTDetrResNetConfig
|
| 291 |
+
# note: we can instead create HGNetV2Config
|
| 292 |
+
# and we would need to create HGNetV2Backbone
|
| 293 |
+
backbone_config = config_class(
|
| 294 |
+
num_channels=3,
|
| 295 |
+
embedding_size=64,
|
| 296 |
+
hidden_sizes=[256, 512, 1024, 2048],
|
| 297 |
+
depths=[3, 4, 6, 3],
|
| 298 |
+
layer_type="bottleneck",
|
| 299 |
+
hidden_act="relu",
|
| 300 |
+
downsample_in_first_stage=False,
|
| 301 |
+
downsample_in_bottleneck=False,
|
| 302 |
+
out_features=None,
|
| 303 |
+
out_indices=[2, 3, 4],
|
| 304 |
+
)
|
| 305 |
+
elif isinstance(backbone_config, dict):
|
| 306 |
+
backbone_model_type = backbone_config.pop("model_type")
|
| 307 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 308 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 309 |
+
|
| 310 |
+
verify_backbone_config_arguments(
|
| 311 |
+
use_timm_backbone=use_timm_backbone,
|
| 312 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 313 |
+
backbone=backbone,
|
| 314 |
+
backbone_config=backbone_config,
|
| 315 |
+
backbone_kwargs=backbone_kwargs,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
self.backbone_config = backbone_config
|
| 319 |
+
self.backbone = backbone
|
| 320 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 321 |
+
self.use_timm_backbone = use_timm_backbone
|
| 322 |
+
self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
|
| 323 |
+
self.backbone_kwargs = backbone_kwargs
|
| 324 |
+
# encoder
|
| 325 |
+
self.encoder_hidden_dim = encoder_hidden_dim
|
| 326 |
+
self.encoder_in_channels = encoder_in_channels
|
| 327 |
+
self.feat_strides = feat_strides
|
| 328 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 329 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 330 |
+
self.dropout = dropout
|
| 331 |
+
self.activation_dropout = activation_dropout
|
| 332 |
+
self.encode_proj_layers = encode_proj_layers
|
| 333 |
+
self.encoder_layers = encoder_layers
|
| 334 |
+
self.positional_encoding_temperature = positional_encoding_temperature
|
| 335 |
+
self.eval_size = eval_size
|
| 336 |
+
self.normalize_before = normalize_before
|
| 337 |
+
self.encoder_activation_function = encoder_activation_function
|
| 338 |
+
self.activation_function = activation_function
|
| 339 |
+
self.hidden_expansion = hidden_expansion
|
| 340 |
+
# decoder
|
| 341 |
+
self.d_model = d_model
|
| 342 |
+
self.num_queries = num_queries
|
| 343 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 344 |
+
self.decoder_in_channels = decoder_in_channels
|
| 345 |
+
self.num_feature_levels = num_feature_levels
|
| 346 |
+
self.decoder_n_points = decoder_n_points
|
| 347 |
+
self.decoder_layers = decoder_layers
|
| 348 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 349 |
+
self.decoder_activation_function = decoder_activation_function
|
| 350 |
+
self.attention_dropout = attention_dropout
|
| 351 |
+
self.num_denoising = num_denoising
|
| 352 |
+
self.label_noise_ratio = label_noise_ratio
|
| 353 |
+
self.box_noise_scale = box_noise_scale
|
| 354 |
+
self.learn_initial_query = learn_initial_query
|
| 355 |
+
self.anchor_image_size = anchor_image_size
|
| 356 |
+
self.auxiliary_loss = auxiliary_loss
|
| 357 |
+
self.with_box_refine = with_box_refine
|
| 358 |
+
# Loss
|
| 359 |
+
self.matcher_alpha = matcher_alpha
|
| 360 |
+
self.matcher_gamma = matcher_gamma
|
| 361 |
+
self.matcher_class_cost = matcher_class_cost
|
| 362 |
+
self.matcher_bbox_cost = matcher_bbox_cost
|
| 363 |
+
self.matcher_giou_cost = matcher_giou_cost
|
| 364 |
+
self.use_focal_loss = use_focal_loss
|
| 365 |
+
self.focal_loss_alpha = focal_loss_alpha
|
| 366 |
+
self.focal_loss_gamma = focal_loss_gamma
|
| 367 |
+
self.weight_loss_vfl = weight_loss_vfl
|
| 368 |
+
self.weight_loss_bbox = weight_loss_bbox
|
| 369 |
+
self.weight_loss_giou = weight_loss_giou
|
| 370 |
+
self.weight_loss_fgl = weight_loss_fgl
|
| 371 |
+
self.weight_loss_ddf = weight_loss_ddf
|
| 372 |
+
self.eos_coefficient = eos_coefficient
|
| 373 |
+
# add the new attributes with the given values or defaults
|
| 374 |
+
self.eval_idx = eval_idx
|
| 375 |
+
self.layer_scale = layer_scale
|
| 376 |
+
self.max_num_bins = max_num_bins
|
| 377 |
+
self.reg_scale = reg_scale
|
| 378 |
+
self.depth_mult = depth_mult
|
| 379 |
+
self.decoder_offset_scale = decoder_offset_scale
|
| 380 |
+
self.decoder_method = decoder_method
|
| 381 |
+
self.top_prob_values = top_prob_values
|
| 382 |
+
self.lqe_hidden_dim = lqe_hidden_dim
|
| 383 |
+
self.lqe_layers = lqe_layers
|
| 384 |
+
self.up = up
|
| 385 |
+
|
| 386 |
+
if isinstance(self.decoder_n_points, list):
|
| 387 |
+
if len(self.decoder_n_points) != self.num_feature_levels:
|
| 388 |
+
raise ValueError(
|
| 389 |
+
f"Length of decoder_n_points list ({len(self.decoder_n_points)}) must match num_feature_levels ({self.num_feature_levels})."
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
head_dim = self.d_model // self.decoder_attention_heads
|
| 393 |
+
if head_dim * self.decoder_attention_heads != self.d_model:
|
| 394 |
+
raise ValueError(
|
| 395 |
+
f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
|
| 396 |
+
)
|
| 397 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 398 |
+
|
| 399 |
+
@property
|
| 400 |
+
def num_attention_heads(self) -> int:
|
| 401 |
+
return self.encoder_attention_heads
|
| 402 |
+
|
| 403 |
+
@property
|
| 404 |
+
def hidden_size(self) -> int:
|
| 405 |
+
return self.d_model
|
| 406 |
+
|
| 407 |
+
@property
|
| 408 |
+
def sub_configs(self):
|
| 409 |
+
return (
|
| 410 |
+
{"backbone_config": type(self.backbone_config)}
|
| 411 |
+
if getattr(self, "backbone_config", None) is not None
|
| 412 |
+
else {}
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
@classmethod
|
| 416 |
+
def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
|
| 417 |
+
"""Instantiate a [`DFineConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
| 418 |
+
configuration.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
backbone_config ([`PretrainedConfig`]):
|
| 422 |
+
The backbone configuration.
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
[`DFineConfig`]: An instance of a configuration object
|
| 426 |
+
"""
|
| 427 |
+
return cls(
|
| 428 |
+
backbone_config=backbone_config,
|
| 429 |
+
**kwargs,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
__all__ = ["DFineConfig"]
|
phivenv/Lib/site-packages/transformers/models/d_fine/modeling_d_fine.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
phivenv/Lib/site-packages/transformers/models/d_fine/modular_d_fine.py
ADDED
|
@@ -0,0 +1,1221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import math
|
| 16 |
+
from typing import Any, Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import torch.nn.init as init
|
| 21 |
+
from torch import nn
|
| 22 |
+
|
| 23 |
+
from ...activations import ACT2CLS
|
| 24 |
+
from ...configuration_utils import PretrainedConfig
|
| 25 |
+
from ...image_transforms import corners_to_center_format
|
| 26 |
+
from ...utils import is_torchdynamo_compiling, logging
|
| 27 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 28 |
+
from ..auto import CONFIG_MAPPING
|
| 29 |
+
from ..rt_detr.modeling_rt_detr import (
|
| 30 |
+
RTDetrConvNormLayer,
|
| 31 |
+
RTDetrDecoder,
|
| 32 |
+
RTDetrDecoderLayer,
|
| 33 |
+
RTDetrDecoderOutput,
|
| 34 |
+
RTDetrEncoder,
|
| 35 |
+
RTDetrForObjectDetection,
|
| 36 |
+
RTDetrHybridEncoder,
|
| 37 |
+
RTDetrMLPPredictionHead,
|
| 38 |
+
RTDetrModel,
|
| 39 |
+
RTDetrPreTrainedModel,
|
| 40 |
+
RTDetrRepVggBlock,
|
| 41 |
+
inverse_sigmoid,
|
| 42 |
+
)
|
| 43 |
+
from ..rt_detr_v2.modeling_rt_detr_v2 import multi_scale_deformable_attention_v2
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# TODO: Attribute map assignment logic should be fixed in modular
|
| 50 |
+
# as well as super() call parsing because otherwise we cannot re-write args after initialization
|
| 51 |
+
class DFineConfig(PretrainedConfig):
|
| 52 |
+
"""
|
| 53 |
+
This is the configuration class to store the configuration of a [`DFineModel`]. It is used to instantiate a D-FINE
|
| 54 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 55 |
+
defaults will yield a similar configuration to that of D-FINE-X-COCO "[ustc-community/dfine-xlarge-coco"](https://huggingface.co/ustc-community/dfine-xlarge-coco").
|
| 56 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 57 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
initializer_range (`float`, *optional*, defaults to 0.01):
|
| 61 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 62 |
+
initializer_bias_prior_prob (`float`, *optional*):
|
| 63 |
+
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
|
| 64 |
+
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
|
| 65 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 66 |
+
The epsilon used by the layer normalization layers.
|
| 67 |
+
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 68 |
+
The epsilon used by the batch normalization layers.
|
| 69 |
+
backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
|
| 70 |
+
The configuration of the backbone model.
|
| 71 |
+
backbone (`str`, *optional*):
|
| 72 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 73 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 74 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 75 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 76 |
+
Whether to use pretrained weights for the backbone.
|
| 77 |
+
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
| 78 |
+
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
| 79 |
+
library.
|
| 80 |
+
freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
|
| 81 |
+
Whether to freeze the batch normalization layers in the backbone.
|
| 82 |
+
backbone_kwargs (`dict`, *optional*):
|
| 83 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 84 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 85 |
+
encoder_hidden_dim (`int`, *optional*, defaults to 256):
|
| 86 |
+
Dimension of the layers in hybrid encoder.
|
| 87 |
+
encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
|
| 88 |
+
Multi level features input for encoder.
|
| 89 |
+
feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
|
| 90 |
+
Strides used in each feature map.
|
| 91 |
+
encoder_layers (`int`, *optional*, defaults to 1):
|
| 92 |
+
Total of layers to be used by the encoder.
|
| 93 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 94 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 95 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 96 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 97 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 98 |
+
The ratio for all dropout layers.
|
| 99 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 100 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 101 |
+
encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
|
| 102 |
+
Indexes of the projected layers to be used in the encoder.
|
| 103 |
+
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
|
| 104 |
+
The temperature parameter used to create the positional encodings.
|
| 105 |
+
encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
|
| 106 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 107 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 108 |
+
activation_function (`str`, *optional*, defaults to `"silu"`):
|
| 109 |
+
The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
|
| 110 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 111 |
+
eval_size (`tuple[int, int]`, *optional*):
|
| 112 |
+
Height and width used to computes the effective height and width of the position embeddings after taking
|
| 113 |
+
into account the stride.
|
| 114 |
+
normalize_before (`bool`, *optional*, defaults to `False`):
|
| 115 |
+
Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
|
| 116 |
+
feed-forward modules.
|
| 117 |
+
hidden_expansion (`float`, *optional*, defaults to 1.0):
|
| 118 |
+
Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
|
| 119 |
+
d_model (`int`, *optional*, defaults to 256):
|
| 120 |
+
Dimension of the layers exclude hybrid encoder.
|
| 121 |
+
num_queries (`int`, *optional*, defaults to 300):
|
| 122 |
+
Number of object queries.
|
| 123 |
+
decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
|
| 124 |
+
Multi level features dimension for decoder
|
| 125 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
| 126 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 127 |
+
num_feature_levels (`int`, *optional*, defaults to 3):
|
| 128 |
+
The number of input feature levels.
|
| 129 |
+
decoder_n_points (`int`, *optional*, defaults to 4):
|
| 130 |
+
The number of sampled keys in each feature level for each attention head in the decoder.
|
| 131 |
+
decoder_layers (`int`, *optional*, defaults to 6):
|
| 132 |
+
Number of decoder layers.
|
| 133 |
+
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 134 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 135 |
+
decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
|
| 136 |
+
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
|
| 137 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 138 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 139 |
+
The dropout ratio for the attention probabilities.
|
| 140 |
+
num_denoising (`int`, *optional*, defaults to 100):
|
| 141 |
+
The total number of denoising tasks or queries to be used for contrastive denoising.
|
| 142 |
+
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
| 143 |
+
The fraction of denoising labels to which random noise should be added.
|
| 144 |
+
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
| 145 |
+
Scale or magnitude of noise to be added to the bounding boxes.
|
| 146 |
+
learn_initial_query (`bool`, *optional*, defaults to `False`):
|
| 147 |
+
Indicates whether the initial query embeddings for the decoder should be learned during training
|
| 148 |
+
anchor_image_size (`tuple[int, int]`, *optional*):
|
| 149 |
+
Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
|
| 150 |
+
with_box_refine (`bool`, *optional*, defaults to `True`):
|
| 151 |
+
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
|
| 152 |
+
based on the predictions from the previous layer.
|
| 153 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 154 |
+
Whether the architecture has an encoder decoder structure.
|
| 155 |
+
matcher_alpha (`float`, *optional*, defaults to 0.25):
|
| 156 |
+
Parameter alpha used by the Hungarian Matcher.
|
| 157 |
+
matcher_gamma (`float`, *optional*, defaults to 2.0):
|
| 158 |
+
Parameter gamma used by the Hungarian Matcher.
|
| 159 |
+
matcher_class_cost (`float`, *optional*, defaults to 2.0):
|
| 160 |
+
The relative weight of the class loss used by the Hungarian Matcher.
|
| 161 |
+
matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
|
| 162 |
+
The relative weight of the bounding box loss used by the Hungarian Matcher.
|
| 163 |
+
matcher_giou_cost (`float`, *optional*, defaults to 2.0):
|
| 164 |
+
The relative weight of the giou loss of used by the Hungarian Matcher.
|
| 165 |
+
use_focal_loss (`bool`, *optional*, defaults to `True`):
|
| 166 |
+
Parameter informing if focal focal should be used.
|
| 167 |
+
auxiliary_loss (`bool`, *optional*, defaults to `True`):
|
| 168 |
+
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 169 |
+
focal_loss_alpha (`float`, *optional*, defaults to 0.75):
|
| 170 |
+
Parameter alpha used to compute the focal loss.
|
| 171 |
+
focal_loss_gamma (`float`, *optional*, defaults to 2.0):
|
| 172 |
+
Parameter gamma used to compute the focal loss.
|
| 173 |
+
weight_loss_vfl (`float`, *optional*, defaults to 1.0):
|
| 174 |
+
Relative weight of the varifocal loss in the object detection loss.
|
| 175 |
+
weight_loss_bbox (`float`, *optional*, defaults to 5.0):
|
| 176 |
+
Relative weight of the L1 bounding box loss in the object detection loss.
|
| 177 |
+
weight_loss_giou (`float`, *optional*, defaults to 2.0):
|
| 178 |
+
Relative weight of the generalized IoU loss in the object detection loss.
|
| 179 |
+
weight_loss_fgl (`float`, *optional*, defaults to 0.15):
|
| 180 |
+
Relative weight of the fine-grained localization loss in the object detection loss.
|
| 181 |
+
weight_loss_ddf (`float`, *optional*, defaults to 1.5):
|
| 182 |
+
Relative weight of the decoupled distillation focal loss in the object detection loss.
|
| 183 |
+
eos_coefficient (`float`, *optional*, defaults to 0.0001):
|
| 184 |
+
Relative classification weight of the 'no-object' class in the object detection loss.
|
| 185 |
+
eval_idx (`int`, *optional*, defaults to -1):
|
| 186 |
+
Index of the decoder layer to use for evaluation. If negative, counts from the end
|
| 187 |
+
(e.g., -1 means use the last layer). This allows for early prediction in the decoder
|
| 188 |
+
stack while still training later layers.
|
| 189 |
+
layer_scale (`float`, *optional*, defaults to `1.0`):
|
| 190 |
+
Scaling factor for the hidden dimension in later decoder layers. Used to adjust the
|
| 191 |
+
model capacity after the evaluation layer.
|
| 192 |
+
max_num_bins (`int`, *optional*, defaults to 32):
|
| 193 |
+
Maximum number of bins for the distribution-guided bounding box refinement.
|
| 194 |
+
Higher values allow for more fine-grained localization but increase computation.
|
| 195 |
+
reg_scale (`float`, *optional*, defaults to 4.0):
|
| 196 |
+
Scale factor for the regression distribution. Controls the range and granularity
|
| 197 |
+
of the bounding box refinement process.
|
| 198 |
+
depth_mult (`float`, *optional*, defaults to 1.0):
|
| 199 |
+
Multiplier for the number of blocks in RepNCSPELAN4 layers. Used to scale the model's
|
| 200 |
+
depth while maintaining its architecture.
|
| 201 |
+
top_prob_values (`int`, *optional*, defaults to 4):
|
| 202 |
+
Number of top probability values to consider from each corner's distribution.
|
| 203 |
+
lqe_hidden_dim (`int`, *optional*, defaults to 64):
|
| 204 |
+
Hidden dimension size for the Location Quality Estimator (LQE) network.
|
| 205 |
+
lqe_layers (`int`, *optional*, defaults to 2):
|
| 206 |
+
Number of layers in the Location Quality Estimator MLP.
|
| 207 |
+
decoder_offset_scale (`float`, *optional*, defaults to 0.5):
|
| 208 |
+
Offset scale used in deformable attention.
|
| 209 |
+
decoder_method (`str`, *optional*, defaults to `"default"`):
|
| 210 |
+
The method to use for the decoder: `"default"` or `"discrete"`.
|
| 211 |
+
up (`float`, *optional*, defaults to 0.5):
|
| 212 |
+
Controls the upper bounds of the Weighting Function.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
model_type = "d_fine"
|
| 216 |
+
layer_types = ["basic", "bottleneck"]
|
| 217 |
+
attribute_map = {
|
| 218 |
+
"hidden_size": "d_model",
|
| 219 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
def __init__(
|
| 223 |
+
self,
|
| 224 |
+
initializer_range=0.01,
|
| 225 |
+
initializer_bias_prior_prob=None,
|
| 226 |
+
layer_norm_eps=1e-5,
|
| 227 |
+
batch_norm_eps=1e-5,
|
| 228 |
+
# backbone
|
| 229 |
+
backbone_config=None,
|
| 230 |
+
backbone=None,
|
| 231 |
+
use_pretrained_backbone=False,
|
| 232 |
+
use_timm_backbone=False,
|
| 233 |
+
freeze_backbone_batch_norms=True,
|
| 234 |
+
backbone_kwargs=None,
|
| 235 |
+
# encoder HybridEncoder
|
| 236 |
+
encoder_hidden_dim=256,
|
| 237 |
+
encoder_in_channels=[512, 1024, 2048],
|
| 238 |
+
feat_strides=[8, 16, 32],
|
| 239 |
+
encoder_layers=1,
|
| 240 |
+
encoder_ffn_dim=1024,
|
| 241 |
+
encoder_attention_heads=8,
|
| 242 |
+
dropout=0.0,
|
| 243 |
+
activation_dropout=0.0,
|
| 244 |
+
encode_proj_layers=[2],
|
| 245 |
+
positional_encoding_temperature=10000,
|
| 246 |
+
encoder_activation_function="gelu",
|
| 247 |
+
activation_function="silu",
|
| 248 |
+
eval_size=None,
|
| 249 |
+
normalize_before=False,
|
| 250 |
+
hidden_expansion=1.0,
|
| 251 |
+
# decoder DFineTransformer
|
| 252 |
+
d_model=256,
|
| 253 |
+
num_queries=300,
|
| 254 |
+
decoder_in_channels=[256, 256, 256],
|
| 255 |
+
decoder_ffn_dim=1024,
|
| 256 |
+
num_feature_levels=3,
|
| 257 |
+
decoder_n_points=4,
|
| 258 |
+
decoder_layers=6,
|
| 259 |
+
decoder_attention_heads=8,
|
| 260 |
+
decoder_activation_function="relu",
|
| 261 |
+
attention_dropout=0.0,
|
| 262 |
+
num_denoising=100,
|
| 263 |
+
label_noise_ratio=0.5,
|
| 264 |
+
box_noise_scale=1.0,
|
| 265 |
+
learn_initial_query=False,
|
| 266 |
+
anchor_image_size=None,
|
| 267 |
+
with_box_refine=True,
|
| 268 |
+
is_encoder_decoder=True,
|
| 269 |
+
# Loss
|
| 270 |
+
matcher_alpha=0.25,
|
| 271 |
+
matcher_gamma=2.0,
|
| 272 |
+
matcher_class_cost=2.0,
|
| 273 |
+
matcher_bbox_cost=5.0,
|
| 274 |
+
matcher_giou_cost=2.0,
|
| 275 |
+
use_focal_loss=True,
|
| 276 |
+
auxiliary_loss=True,
|
| 277 |
+
focal_loss_alpha=0.75,
|
| 278 |
+
focal_loss_gamma=2.0,
|
| 279 |
+
weight_loss_vfl=1.0,
|
| 280 |
+
weight_loss_bbox=5.0,
|
| 281 |
+
weight_loss_giou=2.0,
|
| 282 |
+
weight_loss_fgl=0.15,
|
| 283 |
+
weight_loss_ddf=1.5,
|
| 284 |
+
eos_coefficient=1e-4,
|
| 285 |
+
eval_idx=-1,
|
| 286 |
+
layer_scale=1,
|
| 287 |
+
max_num_bins=32,
|
| 288 |
+
reg_scale=4.0,
|
| 289 |
+
depth_mult=1.0,
|
| 290 |
+
top_prob_values=4,
|
| 291 |
+
lqe_hidden_dim=64,
|
| 292 |
+
lqe_layers=2,
|
| 293 |
+
decoder_offset_scale=0.5,
|
| 294 |
+
decoder_method="default",
|
| 295 |
+
up=0.5,
|
| 296 |
+
**kwargs,
|
| 297 |
+
):
|
| 298 |
+
self.initializer_range = initializer_range
|
| 299 |
+
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
| 300 |
+
self.layer_norm_eps = layer_norm_eps
|
| 301 |
+
self.batch_norm_eps = batch_norm_eps
|
| 302 |
+
# backbone
|
| 303 |
+
if backbone_config is None and backbone is None:
|
| 304 |
+
logger.info(
|
| 305 |
+
"`backbone_config` and `backbone` are `None`. Initializing the config with the default `HGNet-V2` backbone."
|
| 306 |
+
)
|
| 307 |
+
backbone_model_type = "hgnet_v2"
|
| 308 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 309 |
+
# this will map it to RTDetrResNetConfig
|
| 310 |
+
# note: we can instead create HGNetV2Config
|
| 311 |
+
# and we would need to create HGNetV2Backbone
|
| 312 |
+
backbone_config = config_class(
|
| 313 |
+
num_channels=3,
|
| 314 |
+
embedding_size=64,
|
| 315 |
+
hidden_sizes=[256, 512, 1024, 2048],
|
| 316 |
+
depths=[3, 4, 6, 3],
|
| 317 |
+
layer_type="bottleneck",
|
| 318 |
+
hidden_act="relu",
|
| 319 |
+
downsample_in_first_stage=False,
|
| 320 |
+
downsample_in_bottleneck=False,
|
| 321 |
+
out_features=None,
|
| 322 |
+
out_indices=[2, 3, 4],
|
| 323 |
+
)
|
| 324 |
+
elif isinstance(backbone_config, dict):
|
| 325 |
+
backbone_model_type = backbone_config.pop("model_type")
|
| 326 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 327 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 328 |
+
|
| 329 |
+
verify_backbone_config_arguments(
|
| 330 |
+
use_timm_backbone=use_timm_backbone,
|
| 331 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 332 |
+
backbone=backbone,
|
| 333 |
+
backbone_config=backbone_config,
|
| 334 |
+
backbone_kwargs=backbone_kwargs,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
self.backbone_config = backbone_config
|
| 338 |
+
self.backbone = backbone
|
| 339 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 340 |
+
self.use_timm_backbone = use_timm_backbone
|
| 341 |
+
self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
|
| 342 |
+
self.backbone_kwargs = backbone_kwargs
|
| 343 |
+
# encoder
|
| 344 |
+
self.encoder_hidden_dim = encoder_hidden_dim
|
| 345 |
+
self.encoder_in_channels = encoder_in_channels
|
| 346 |
+
self.feat_strides = feat_strides
|
| 347 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 348 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 349 |
+
self.dropout = dropout
|
| 350 |
+
self.activation_dropout = activation_dropout
|
| 351 |
+
self.encode_proj_layers = encode_proj_layers
|
| 352 |
+
self.encoder_layers = encoder_layers
|
| 353 |
+
self.positional_encoding_temperature = positional_encoding_temperature
|
| 354 |
+
self.eval_size = eval_size
|
| 355 |
+
self.normalize_before = normalize_before
|
| 356 |
+
self.encoder_activation_function = encoder_activation_function
|
| 357 |
+
self.activation_function = activation_function
|
| 358 |
+
self.hidden_expansion = hidden_expansion
|
| 359 |
+
# decoder
|
| 360 |
+
self.d_model = d_model
|
| 361 |
+
self.num_queries = num_queries
|
| 362 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 363 |
+
self.decoder_in_channels = decoder_in_channels
|
| 364 |
+
self.num_feature_levels = num_feature_levels
|
| 365 |
+
self.decoder_n_points = decoder_n_points
|
| 366 |
+
self.decoder_layers = decoder_layers
|
| 367 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 368 |
+
self.decoder_activation_function = decoder_activation_function
|
| 369 |
+
self.attention_dropout = attention_dropout
|
| 370 |
+
self.num_denoising = num_denoising
|
| 371 |
+
self.label_noise_ratio = label_noise_ratio
|
| 372 |
+
self.box_noise_scale = box_noise_scale
|
| 373 |
+
self.learn_initial_query = learn_initial_query
|
| 374 |
+
self.anchor_image_size = anchor_image_size
|
| 375 |
+
self.auxiliary_loss = auxiliary_loss
|
| 376 |
+
self.with_box_refine = with_box_refine
|
| 377 |
+
# Loss
|
| 378 |
+
self.matcher_alpha = matcher_alpha
|
| 379 |
+
self.matcher_gamma = matcher_gamma
|
| 380 |
+
self.matcher_class_cost = matcher_class_cost
|
| 381 |
+
self.matcher_bbox_cost = matcher_bbox_cost
|
| 382 |
+
self.matcher_giou_cost = matcher_giou_cost
|
| 383 |
+
self.use_focal_loss = use_focal_loss
|
| 384 |
+
self.focal_loss_alpha = focal_loss_alpha
|
| 385 |
+
self.focal_loss_gamma = focal_loss_gamma
|
| 386 |
+
self.weight_loss_vfl = weight_loss_vfl
|
| 387 |
+
self.weight_loss_bbox = weight_loss_bbox
|
| 388 |
+
self.weight_loss_giou = weight_loss_giou
|
| 389 |
+
self.weight_loss_fgl = weight_loss_fgl
|
| 390 |
+
self.weight_loss_ddf = weight_loss_ddf
|
| 391 |
+
self.eos_coefficient = eos_coefficient
|
| 392 |
+
# add the new attributes with the given values or defaults
|
| 393 |
+
self.eval_idx = eval_idx
|
| 394 |
+
self.layer_scale = layer_scale
|
| 395 |
+
self.max_num_bins = max_num_bins
|
| 396 |
+
self.reg_scale = reg_scale
|
| 397 |
+
self.depth_mult = depth_mult
|
| 398 |
+
self.decoder_offset_scale = decoder_offset_scale
|
| 399 |
+
self.decoder_method = decoder_method
|
| 400 |
+
self.top_prob_values = top_prob_values
|
| 401 |
+
self.lqe_hidden_dim = lqe_hidden_dim
|
| 402 |
+
self.lqe_layers = lqe_layers
|
| 403 |
+
self.up = up
|
| 404 |
+
|
| 405 |
+
if isinstance(self.decoder_n_points, list):
|
| 406 |
+
if len(self.decoder_n_points) != self.num_feature_levels:
|
| 407 |
+
raise ValueError(
|
| 408 |
+
f"Length of decoder_n_points list ({len(self.decoder_n_points)}) must match num_feature_levels ({self.num_feature_levels})."
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
head_dim = self.d_model // self.decoder_attention_heads
|
| 412 |
+
if head_dim * self.decoder_attention_heads != self.d_model:
|
| 413 |
+
raise ValueError(
|
| 414 |
+
f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
|
| 415 |
+
)
|
| 416 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 417 |
+
|
| 418 |
+
@property
|
| 419 |
+
def num_attention_heads(self) -> int:
|
| 420 |
+
return self.encoder_attention_heads
|
| 421 |
+
|
| 422 |
+
@property
|
| 423 |
+
def hidden_size(self) -> int:
|
| 424 |
+
return self.d_model
|
| 425 |
+
|
| 426 |
+
@property
|
| 427 |
+
def sub_configs(self):
|
| 428 |
+
return (
|
| 429 |
+
{"backbone_config": type(self.backbone_config)}
|
| 430 |
+
if getattr(self, "backbone_config", None) is not None
|
| 431 |
+
else {}
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
@classmethod
|
| 435 |
+
def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
|
| 436 |
+
"""Instantiate a [`DFineConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
| 437 |
+
configuration.
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
backbone_config ([`PretrainedConfig`]):
|
| 441 |
+
The backbone configuration.
|
| 442 |
+
|
| 443 |
+
Returns:
|
| 444 |
+
[`DFineConfig`]: An instance of a configuration object
|
| 445 |
+
"""
|
| 446 |
+
return cls(
|
| 447 |
+
backbone_config=backbone_config,
|
| 448 |
+
**kwargs,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
class DFineMultiscaleDeformableAttention(nn.Module):
|
| 453 |
+
def __init__(self, config: DFineConfig):
|
| 454 |
+
"""
|
| 455 |
+
D-Fine version of multiscale deformable attention
|
| 456 |
+
"""
|
| 457 |
+
super().__init__()
|
| 458 |
+
self.d_model = config.d_model
|
| 459 |
+
self.n_heads = config.decoder_attention_heads
|
| 460 |
+
self.n_levels = config.num_feature_levels
|
| 461 |
+
self.offset_scale = config.decoder_offset_scale
|
| 462 |
+
self.decoder_method = config.decoder_method
|
| 463 |
+
self.n_points = config.decoder_n_points
|
| 464 |
+
|
| 465 |
+
if isinstance(self.n_points, list):
|
| 466 |
+
num_points_list = self.n_points
|
| 467 |
+
else:
|
| 468 |
+
num_points_list = [self.n_points for _ in range(self.n_levels)]
|
| 469 |
+
|
| 470 |
+
self.num_points_list = num_points_list
|
| 471 |
+
num_points_scale = [1 / n for n in self.num_points_list for _ in range(n)]
|
| 472 |
+
self.register_buffer("num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32))
|
| 473 |
+
|
| 474 |
+
self.total_points = self.n_heads * sum(self.num_points_list)
|
| 475 |
+
|
| 476 |
+
self.sampling_offsets = nn.Linear(self.d_model, self.total_points * 2)
|
| 477 |
+
self.attention_weights = nn.Linear(self.d_model, self.total_points)
|
| 478 |
+
|
| 479 |
+
self.ms_deformable_attn_core = multi_scale_deformable_attention_v2
|
| 480 |
+
|
| 481 |
+
def forward(
|
| 482 |
+
self,
|
| 483 |
+
hidden_states: torch.Tensor,
|
| 484 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 485 |
+
reference_points=None,
|
| 486 |
+
encoder_hidden_states=None,
|
| 487 |
+
spatial_shapes=None,
|
| 488 |
+
spatial_shapes_list=None,
|
| 489 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 490 |
+
batch_size, num_queries, _ = hidden_states.shape
|
| 491 |
+
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
| 492 |
+
|
| 493 |
+
if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
| 494 |
+
raise ValueError(
|
| 495 |
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# Reshape for multi-head attention
|
| 499 |
+
value = encoder_hidden_states.reshape(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
|
| 500 |
+
if attention_mask is not None:
|
| 501 |
+
value = value.masked_fill(~attention_mask[..., None], float(0))
|
| 502 |
+
|
| 503 |
+
sampling_offsets: torch.Tensor = self.sampling_offsets(hidden_states)
|
| 504 |
+
sampling_offsets = sampling_offsets.reshape(
|
| 505 |
+
batch_size, num_queries, self.n_heads, sum(self.num_points_list), 2
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
attention_weights = self.attention_weights(hidden_states).reshape(
|
| 509 |
+
batch_size, num_queries, self.n_heads, sum(self.num_points_list)
|
| 510 |
+
)
|
| 511 |
+
attention_weights = F.softmax(attention_weights, dim=-1)
|
| 512 |
+
|
| 513 |
+
if reference_points.shape[-1] == 2:
|
| 514 |
+
offset_normalizer = torch.tensor(spatial_shapes)
|
| 515 |
+
offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.n_levels, 1, 2)
|
| 516 |
+
sampling_locations = (
|
| 517 |
+
reference_points.reshape(batch_size, sequence_length, 1, self.n_levels, 1, 2)
|
| 518 |
+
+ sampling_offsets / offset_normalizer
|
| 519 |
+
)
|
| 520 |
+
elif reference_points.shape[-1] == 4:
|
| 521 |
+
# reference_points [8, 480, None, 1, 4]
|
| 522 |
+
# sampling_offsets [8, 480, 8, 12, 2]
|
| 523 |
+
num_points_scale = self.num_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1)
|
| 524 |
+
offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
|
| 525 |
+
sampling_locations = reference_points[:, :, None, :, :2] + offset
|
| 526 |
+
else:
|
| 527 |
+
raise ValueError(
|
| 528 |
+
f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead."
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
output = self.ms_deformable_attn_core(
|
| 532 |
+
value,
|
| 533 |
+
spatial_shapes_list,
|
| 534 |
+
sampling_locations,
|
| 535 |
+
attention_weights,
|
| 536 |
+
self.num_points_list,
|
| 537 |
+
self.decoder_method,
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
return output, attention_weights
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class DFineGate(nn.Module):
|
| 544 |
+
def __init__(self, d_model: int):
|
| 545 |
+
super().__init__()
|
| 546 |
+
self.gate = nn.Linear(2 * d_model, 2 * d_model)
|
| 547 |
+
self.norm = nn.LayerNorm(d_model)
|
| 548 |
+
|
| 549 |
+
def forward(self, second_residual: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 550 |
+
gate_input = torch.cat([second_residual, hidden_states], dim=-1)
|
| 551 |
+
gates = torch.sigmoid(self.gate(gate_input))
|
| 552 |
+
gate1, gate2 = gates.chunk(2, dim=-1)
|
| 553 |
+
hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states)
|
| 554 |
+
return hidden_states
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class DFineDecoderLayer(RTDetrDecoderLayer):
|
| 558 |
+
def __init__(self, config: DFineConfig):
|
| 559 |
+
super().__init__(config)
|
| 560 |
+
|
| 561 |
+
# override the encoder attention module with d-fine version
|
| 562 |
+
self.encoder_attn = DFineMultiscaleDeformableAttention(config=config)
|
| 563 |
+
# gate
|
| 564 |
+
self.gateway = DFineGate(config.d_model)
|
| 565 |
+
|
| 566 |
+
del self.encoder_attn_layer_norm
|
| 567 |
+
|
| 568 |
+
def forward(
|
| 569 |
+
self,
|
| 570 |
+
hidden_states: torch.Tensor,
|
| 571 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 572 |
+
reference_points=None,
|
| 573 |
+
spatial_shapes=None,
|
| 574 |
+
spatial_shapes_list=None,
|
| 575 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 576 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 577 |
+
output_attentions: Optional[bool] = False,
|
| 578 |
+
) -> tuple[torch.Tensor, Any, Any]:
|
| 579 |
+
# Self Attention
|
| 580 |
+
hidden_states_2, self_attn_weights = self.self_attn(
|
| 581 |
+
hidden_states=hidden_states,
|
| 582 |
+
attention_mask=encoder_attention_mask,
|
| 583 |
+
position_embeddings=position_embeddings,
|
| 584 |
+
output_attentions=output_attentions,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
|
| 588 |
+
hidden_states = hidden_states + hidden_states_2
|
| 589 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 590 |
+
residual = hidden_states
|
| 591 |
+
|
| 592 |
+
# Cross-Attention
|
| 593 |
+
cross_attn_weights = None
|
| 594 |
+
hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
|
| 595 |
+
hidden_states_2, cross_attn_weights = self.encoder_attn(
|
| 596 |
+
hidden_states=hidden_states,
|
| 597 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 598 |
+
reference_points=reference_points,
|
| 599 |
+
spatial_shapes=spatial_shapes,
|
| 600 |
+
spatial_shapes_list=spatial_shapes_list,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
|
| 604 |
+
hidden_states = self.gateway(residual, hidden_states_2)
|
| 605 |
+
|
| 606 |
+
# Fully Connected
|
| 607 |
+
hidden_states_2 = self.activation_fn(self.fc1(hidden_states))
|
| 608 |
+
hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.activation_dropout, training=self.training)
|
| 609 |
+
hidden_states_2 = self.fc2(hidden_states_2)
|
| 610 |
+
hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
|
| 611 |
+
hidden_states = hidden_states + hidden_states_2
|
| 612 |
+
hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504))
|
| 613 |
+
|
| 614 |
+
outputs = (hidden_states,)
|
| 615 |
+
|
| 616 |
+
if output_attentions:
|
| 617 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
| 618 |
+
|
| 619 |
+
return outputs
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
class DFinePreTrainedModel(RTDetrPreTrainedModel):
|
| 623 |
+
def _init_weights(self, module):
|
| 624 |
+
# initialize linear layer bias value according to a given probability value.
|
| 625 |
+
if isinstance(module, (DFineForObjectDetection, DFineDecoder)):
|
| 626 |
+
if module.class_embed is not None:
|
| 627 |
+
for layer in module.class_embed:
|
| 628 |
+
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
| 629 |
+
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
| 630 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 631 |
+
nn.init.constant_(layer.bias, bias)
|
| 632 |
+
|
| 633 |
+
if module.bbox_embed is not None:
|
| 634 |
+
for layer in module.bbox_embed:
|
| 635 |
+
nn.init.constant_(layer.layers[-1].weight, 0)
|
| 636 |
+
nn.init.constant_(layer.layers[-1].bias, 0)
|
| 637 |
+
|
| 638 |
+
if isinstance(module, DFineMultiscaleDeformableAttention):
|
| 639 |
+
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
|
| 640 |
+
default_dtype = torch.get_default_dtype()
|
| 641 |
+
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
|
| 642 |
+
2.0 * math.pi / module.n_heads
|
| 643 |
+
)
|
| 644 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
| 645 |
+
grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
|
| 646 |
+
grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1])
|
| 647 |
+
scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1)
|
| 648 |
+
grid_init *= scaling
|
| 649 |
+
with torch.no_grad():
|
| 650 |
+
module.sampling_offsets.bias.data[...] = grid_init.flatten()
|
| 651 |
+
|
| 652 |
+
nn.init.constant_(module.attention_weights.weight.data, 0.0)
|
| 653 |
+
nn.init.constant_(module.attention_weights.bias.data, 0.0)
|
| 654 |
+
|
| 655 |
+
if isinstance(module, DFineModel):
|
| 656 |
+
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
| 657 |
+
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
| 658 |
+
nn.init.xavier_uniform_(module.enc_score_head.weight)
|
| 659 |
+
nn.init.constant_(module.enc_score_head.bias, bias)
|
| 660 |
+
|
| 661 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
| 662 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 663 |
+
if module.bias is not None:
|
| 664 |
+
module.bias.data.zero_()
|
| 665 |
+
|
| 666 |
+
if isinstance(module, DFineGate):
|
| 667 |
+
bias = float(-math.log((1 - 0.5) / 0.5))
|
| 668 |
+
init.constant_(module.gate.bias, bias)
|
| 669 |
+
init.constant_(module.gate.weight, 0)
|
| 670 |
+
|
| 671 |
+
if isinstance(module, DFineLQE):
|
| 672 |
+
init.constant_(module.reg_conf.layers[-1].bias, 0)
|
| 673 |
+
init.constant_(module.reg_conf.layers[-1].weight, 0)
|
| 674 |
+
|
| 675 |
+
if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
|
| 676 |
+
nn.init.xavier_uniform_(module.weight_embedding.weight)
|
| 677 |
+
if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
|
| 678 |
+
nn.init.xavier_uniform_(module.denoising_class_embed.weight)
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
class DFineIntegral(nn.Module):
|
| 682 |
+
"""
|
| 683 |
+
A static layer that calculates integral results from a distribution.
|
| 684 |
+
|
| 685 |
+
This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
|
| 686 |
+
where Pr(n) is the softmax probability vector representing the discrete
|
| 687 |
+
distribution, and W(n) is the non-uniform Weighting Function.
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
max_num_bins (int): Max number of the discrete bins. Default is 32.
|
| 691 |
+
It can be adjusted based on the dataset or task requirements.
|
| 692 |
+
"""
|
| 693 |
+
|
| 694 |
+
def __init__(self, config: DFineConfig):
|
| 695 |
+
super().__init__()
|
| 696 |
+
self.max_num_bins = config.max_num_bins
|
| 697 |
+
|
| 698 |
+
def forward(self, pred_corners: torch.Tensor, project: torch.Tensor) -> torch.Tensor:
|
| 699 |
+
batch_size, num_queries, _ = pred_corners.shape
|
| 700 |
+
pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1)
|
| 701 |
+
pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4)
|
| 702 |
+
pred_corners = pred_corners.reshape(batch_size, num_queries, -1)
|
| 703 |
+
return pred_corners
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
class DFineDecoderOutput(RTDetrDecoderOutput):
|
| 707 |
+
pass
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
class DFineDecoder(RTDetrDecoder):
|
| 711 |
+
"""
|
| 712 |
+
D-FINE Decoder implementing Fine-grained Distribution Refinement (FDR).
|
| 713 |
+
|
| 714 |
+
This decoder refines object detection predictions through iterative updates across multiple layers,
|
| 715 |
+
utilizing attention mechanisms, location quality estimators, and distribution refinement techniques
|
| 716 |
+
to improve bounding box accuracy and robustness.
|
| 717 |
+
"""
|
| 718 |
+
|
| 719 |
+
def __init__(self, config: DFineConfig):
|
| 720 |
+
self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx
|
| 721 |
+
super().__init__(config=config)
|
| 722 |
+
self.reg_scale = nn.Parameter(torch.tensor([config.reg_scale]), requires_grad=False)
|
| 723 |
+
self.max_num_bins = config.max_num_bins
|
| 724 |
+
self.d_model = config.d_model
|
| 725 |
+
self.layer_scale = config.layer_scale
|
| 726 |
+
self.pre_bbox_head = DFineMLP(config.hidden_size, config.hidden_size, 4, 3)
|
| 727 |
+
self.integral = DFineIntegral(config)
|
| 728 |
+
self.num_head = config.decoder_attention_heads
|
| 729 |
+
self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False)
|
| 730 |
+
self.lqe_layers = nn.ModuleList([DFineLQE(config) for _ in range(config.decoder_layers)])
|
| 731 |
+
self.layers = nn.ModuleList(
|
| 732 |
+
[DFineDecoderLayer(config) for _ in range(config.decoder_layers)]
|
| 733 |
+
+ [DFineDecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)]
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
def forward(
|
| 737 |
+
self,
|
| 738 |
+
encoder_hidden_states: torch.Tensor,
|
| 739 |
+
reference_points: torch.Tensor,
|
| 740 |
+
inputs_embeds: torch.Tensor,
|
| 741 |
+
spatial_shapes,
|
| 742 |
+
level_start_index=None,
|
| 743 |
+
spatial_shapes_list=None,
|
| 744 |
+
output_hidden_states=None,
|
| 745 |
+
encoder_attention_mask=None,
|
| 746 |
+
memory_mask=None,
|
| 747 |
+
output_attentions=None,
|
| 748 |
+
return_dict=None,
|
| 749 |
+
) -> DFineDecoderOutput:
|
| 750 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 751 |
+
output_hidden_states = (
|
| 752 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 753 |
+
)
|
| 754 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 755 |
+
|
| 756 |
+
if inputs_embeds is not None:
|
| 757 |
+
hidden_states = inputs_embeds
|
| 758 |
+
|
| 759 |
+
# decoder layers
|
| 760 |
+
all_hidden_states = () if output_hidden_states else None
|
| 761 |
+
all_self_attns = () if output_attentions else None
|
| 762 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
| 763 |
+
intermediate = ()
|
| 764 |
+
intermediate_reference_points = ()
|
| 765 |
+
intermediate_logits = ()
|
| 766 |
+
intermediate_predicted_corners = ()
|
| 767 |
+
initial_reference_points = ()
|
| 768 |
+
|
| 769 |
+
output_detach = pred_corners_undetach = 0
|
| 770 |
+
|
| 771 |
+
project = weighting_function(self.max_num_bins, self.up, self.reg_scale)
|
| 772 |
+
ref_points_detach = F.sigmoid(reference_points)
|
| 773 |
+
|
| 774 |
+
for i, decoder_layer in enumerate(self.layers):
|
| 775 |
+
ref_points_input = ref_points_detach.unsqueeze(2)
|
| 776 |
+
query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10)
|
| 777 |
+
|
| 778 |
+
if output_hidden_states:
|
| 779 |
+
all_hidden_states += (hidden_states,)
|
| 780 |
+
|
| 781 |
+
output = decoder_layer(
|
| 782 |
+
hidden_states=hidden_states,
|
| 783 |
+
position_embeddings=query_pos_embed,
|
| 784 |
+
reference_points=ref_points_input,
|
| 785 |
+
spatial_shapes=spatial_shapes,
|
| 786 |
+
spatial_shapes_list=spatial_shapes_list,
|
| 787 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 788 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 789 |
+
output_attentions=output_attentions,
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
hidden_states = output[0]
|
| 793 |
+
|
| 794 |
+
if i == 0:
|
| 795 |
+
# Initial bounding box predictions with inverse sigmoid refinement
|
| 796 |
+
new_reference_points = F.sigmoid(self.pre_bbox_head(output[0]) + inverse_sigmoid(ref_points_detach))
|
| 797 |
+
ref_points_initial = new_reference_points.detach()
|
| 798 |
+
|
| 799 |
+
# Refine bounding box corners using FDR, integrating previous layer's corrections
|
| 800 |
+
if self.bbox_embed is not None:
|
| 801 |
+
pred_corners = self.bbox_embed[i](hidden_states + output_detach) + pred_corners_undetach
|
| 802 |
+
inter_ref_bbox = distance2bbox(
|
| 803 |
+
ref_points_initial, self.integral(pred_corners, project), self.reg_scale
|
| 804 |
+
)
|
| 805 |
+
pred_corners_undetach = pred_corners
|
| 806 |
+
ref_points_detach = inter_ref_bbox.detach()
|
| 807 |
+
|
| 808 |
+
output_detach = hidden_states.detach()
|
| 809 |
+
|
| 810 |
+
intermediate += (hidden_states,)
|
| 811 |
+
|
| 812 |
+
if self.class_embed is not None and (self.training or i == self.eval_idx):
|
| 813 |
+
scores = self.class_embed[i](hidden_states)
|
| 814 |
+
# Add initial logits and reference points with pre-bbox head
|
| 815 |
+
if i == 0:
|
| 816 |
+
intermediate_logits += (scores,)
|
| 817 |
+
intermediate_reference_points += (new_reference_points,)
|
| 818 |
+
# Lqe does not affect the performance here.
|
| 819 |
+
scores = self.lqe_layers[i](scores, pred_corners)
|
| 820 |
+
intermediate_logits += (scores,)
|
| 821 |
+
intermediate_reference_points += (inter_ref_bbox,)
|
| 822 |
+
initial_reference_points += (ref_points_initial,)
|
| 823 |
+
intermediate_predicted_corners += (pred_corners,)
|
| 824 |
+
|
| 825 |
+
if output_attentions:
|
| 826 |
+
all_self_attns += (output[1],)
|
| 827 |
+
|
| 828 |
+
if encoder_hidden_states is not None:
|
| 829 |
+
all_cross_attentions += (output[2],)
|
| 830 |
+
|
| 831 |
+
# Keep batch_size as first dimension
|
| 832 |
+
intermediate = torch.stack(intermediate)
|
| 833 |
+
if self.class_embed is not None and self.bbox_embed is not None:
|
| 834 |
+
intermediate_logits = torch.stack(intermediate_logits, dim=1)
|
| 835 |
+
intermediate_predicted_corners = torch.stack(intermediate_predicted_corners, dim=1)
|
| 836 |
+
initial_reference_points = torch.stack(initial_reference_points, dim=1)
|
| 837 |
+
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
| 838 |
+
|
| 839 |
+
# add hidden states from the last decoder layer
|
| 840 |
+
if output_hidden_states:
|
| 841 |
+
all_hidden_states += (hidden_states,)
|
| 842 |
+
|
| 843 |
+
if not return_dict:
|
| 844 |
+
return tuple(
|
| 845 |
+
v
|
| 846 |
+
for v in [
|
| 847 |
+
hidden_states,
|
| 848 |
+
intermediate,
|
| 849 |
+
intermediate_logits,
|
| 850 |
+
intermediate_reference_points,
|
| 851 |
+
intermediate_predicted_corners,
|
| 852 |
+
initial_reference_points,
|
| 853 |
+
all_hidden_states,
|
| 854 |
+
all_self_attns,
|
| 855 |
+
all_cross_attentions,
|
| 856 |
+
]
|
| 857 |
+
if v is not None
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
return DFineDecoderOutput(
|
| 861 |
+
last_hidden_state=hidden_states,
|
| 862 |
+
intermediate_hidden_states=intermediate,
|
| 863 |
+
intermediate_logits=intermediate_logits,
|
| 864 |
+
intermediate_reference_points=intermediate_reference_points,
|
| 865 |
+
intermediate_predicted_corners=intermediate_predicted_corners,
|
| 866 |
+
initial_reference_points=initial_reference_points,
|
| 867 |
+
hidden_states=all_hidden_states,
|
| 868 |
+
attentions=all_self_attns,
|
| 869 |
+
cross_attentions=all_cross_attentions,
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
class DFineModel(RTDetrModel):
|
| 874 |
+
def __init__(self, config: DFineConfig):
|
| 875 |
+
super().__init__(config)
|
| 876 |
+
del self.decoder_input_proj
|
| 877 |
+
self.encoder = DFineHybridEncoder(config=config)
|
| 878 |
+
num_backbone_outs = len(config.decoder_in_channels)
|
| 879 |
+
decoder_input_proj = []
|
| 880 |
+
in_channels = config.decoder_in_channels[-1]
|
| 881 |
+
for _ in range(num_backbone_outs):
|
| 882 |
+
if config.hidden_size == config.decoder_in_channels[-1]:
|
| 883 |
+
decoder_input_proj.append(nn.Identity())
|
| 884 |
+
else:
|
| 885 |
+
conv = nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False)
|
| 886 |
+
batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps)
|
| 887 |
+
decoder_input_proj.append(nn.Sequential(conv, batchnorm))
|
| 888 |
+
for _ in range(config.num_feature_levels - num_backbone_outs):
|
| 889 |
+
if config.hidden_size == config.decoder_in_channels[-1]:
|
| 890 |
+
decoder_input_proj.append(nn.Identity())
|
| 891 |
+
else:
|
| 892 |
+
conv = nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False)
|
| 893 |
+
batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps)
|
| 894 |
+
decoder_input_proj.append(nn.Sequential(conv, batchnorm))
|
| 895 |
+
self.decoder_input_proj = nn.ModuleList(decoder_input_proj)
|
| 896 |
+
self.decoder = DFineDecoder(config)
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
class DFineForObjectDetection(RTDetrForObjectDetection, DFinePreTrainedModel):
|
| 900 |
+
def __init__(self, config: DFineConfig):
|
| 901 |
+
DFinePreTrainedModel.__init__(self, config)
|
| 902 |
+
|
| 903 |
+
# D-FINE encoder-decoder model
|
| 904 |
+
self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx
|
| 905 |
+
self.model = DFineModel(config)
|
| 906 |
+
scaled_dim = round(config.layer_scale * config.hidden_size)
|
| 907 |
+
num_pred = config.decoder_layers
|
| 908 |
+
self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)])
|
| 909 |
+
self.bbox_embed = nn.ModuleList(
|
| 910 |
+
[
|
| 911 |
+
DFineMLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3)
|
| 912 |
+
for _ in range(self.eval_idx + 1)
|
| 913 |
+
]
|
| 914 |
+
+ [
|
| 915 |
+
DFineMLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3)
|
| 916 |
+
for _ in range(config.decoder_layers - self.eval_idx - 1)
|
| 917 |
+
]
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
# here self.model.decoder.bbox_embed is null, but not self.bbox_embed
|
| 921 |
+
self.model.decoder.class_embed = self.class_embed
|
| 922 |
+
self.model.decoder.bbox_embed = self.bbox_embed
|
| 923 |
+
|
| 924 |
+
# Initialize weights and apply final processing
|
| 925 |
+
self.post_init()
|
| 926 |
+
|
| 927 |
+
def forward(**super_kwargs):
|
| 928 |
+
r"""
|
| 929 |
+
Example:
|
| 930 |
+
|
| 931 |
+
```python
|
| 932 |
+
>>> import torch
|
| 933 |
+
>>> from transformers.image_utils import load_image
|
| 934 |
+
>>> from transformers import AutoImageProcessor, DFineForObjectDetection
|
| 935 |
+
|
| 936 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 937 |
+
>>> image = load_image(url)
|
| 938 |
+
|
| 939 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-xlarge-coco")
|
| 940 |
+
>>> model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-xlarge-coco")
|
| 941 |
+
|
| 942 |
+
>>> # prepare image for the model
|
| 943 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 944 |
+
|
| 945 |
+
>>> # forward pass
|
| 946 |
+
>>> outputs = model(**inputs)
|
| 947 |
+
|
| 948 |
+
>>> logits = outputs.logits
|
| 949 |
+
>>> list(logits.shape)
|
| 950 |
+
[1, 300, 80]
|
| 951 |
+
|
| 952 |
+
>>> boxes = outputs.pred_boxes
|
| 953 |
+
>>> list(boxes.shape)
|
| 954 |
+
[1, 300, 4]
|
| 955 |
+
|
| 956 |
+
>>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
|
| 957 |
+
>>> target_sizes = torch.tensor([image.size[::-1]])
|
| 958 |
+
>>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)
|
| 959 |
+
>>> result = results[0] # first image in batch
|
| 960 |
+
|
| 961 |
+
>>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
|
| 962 |
+
... box = [round(i, 2) for i in box.tolist()]
|
| 963 |
+
... print(
|
| 964 |
+
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
| 965 |
+
... f"{round(score.item(), 3)} at location {box}"
|
| 966 |
+
... )
|
| 967 |
+
Detected cat with confidence 0.958 at location [344.49, 23.4, 639.84, 374.27]
|
| 968 |
+
Detected cat with confidence 0.956 at location [11.71, 53.52, 316.64, 472.33]
|
| 969 |
+
Detected remote with confidence 0.947 at location [40.46, 73.7, 175.62, 117.57]
|
| 970 |
+
Detected sofa with confidence 0.918 at location [0.59, 1.88, 640.25, 474.74]
|
| 971 |
+
```
|
| 972 |
+
"""
|
| 973 |
+
super().forward(**super_kwargs)
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
def weighting_function(max_num_bins: int, up: torch.Tensor, reg_scale: int) -> torch.Tensor:
|
| 977 |
+
"""
|
| 978 |
+
Generates the non-uniform Weighting Function W(n) for bounding box regression.
|
| 979 |
+
|
| 980 |
+
Args:
|
| 981 |
+
max_num_bins (int): Max number of the discrete bins.
|
| 982 |
+
up (Tensor): Controls upper bounds of the sequence,
|
| 983 |
+
where maximum offset is ±up * H / W.
|
| 984 |
+
reg_scale (float): Controls the curvature of the Weighting Function.
|
| 985 |
+
Larger values result in flatter weights near the central axis W(max_num_bins/2)=0
|
| 986 |
+
and steeper weights at both ends.
|
| 987 |
+
Returns:
|
| 988 |
+
Tensor: Sequence of Weighting Function.
|
| 989 |
+
"""
|
| 990 |
+
upper_bound1 = abs(up[0]) * abs(reg_scale)
|
| 991 |
+
upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
|
| 992 |
+
step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2))
|
| 993 |
+
left_values = [-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)]
|
| 994 |
+
right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)]
|
| 995 |
+
values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2]
|
| 996 |
+
values = torch.cat(values, 0)
|
| 997 |
+
return values
|
| 998 |
+
|
| 999 |
+
|
| 1000 |
+
class DFineMLPPredictionHead(RTDetrMLPPredictionHead):
|
| 1001 |
+
pass
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
def distance2bbox(points, distance: torch.Tensor, reg_scale: float) -> torch.Tensor:
|
| 1005 |
+
"""
|
| 1006 |
+
Decodes edge-distances into bounding box coordinates.
|
| 1007 |
+
|
| 1008 |
+
Args:
|
| 1009 |
+
points (`torch.Tensor`):
|
| 1010 |
+
(batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height]
|
| 1011 |
+
distance (`torch.Tensor`):
|
| 1012 |
+
(batch_size, num_boxes, 4) or (num_boxes, 4), representing distances from the point to the left, top, right, and bottom boundaries.
|
| 1013 |
+
reg_scale (`float`):
|
| 1014 |
+
Controls the curvature of the Weighting Function.
|
| 1015 |
+
Returns:
|
| 1016 |
+
`torch.Tensor`: Bounding boxes in (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height]
|
| 1017 |
+
"""
|
| 1018 |
+
reg_scale = abs(reg_scale)
|
| 1019 |
+
top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale)
|
| 1020 |
+
top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale)
|
| 1021 |
+
bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale)
|
| 1022 |
+
bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale)
|
| 1023 |
+
|
| 1024 |
+
bboxes = torch.stack([top_left_x, top_left_y, bottom_right_x, bottom_right_y], -1)
|
| 1025 |
+
|
| 1026 |
+
return corners_to_center_format(bboxes)
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
class DFineMLP(nn.Module):
|
| 1030 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"):
|
| 1031 |
+
super().__init__()
|
| 1032 |
+
self.num_layers = num_layers
|
| 1033 |
+
hidden_dims = [hidden_dim] * (num_layers - 1)
|
| 1034 |
+
input_dims = [input_dim] + hidden_dims
|
| 1035 |
+
output_dims = hidden_dims + [output_dim]
|
| 1036 |
+
self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims))
|
| 1037 |
+
self.act = ACT2CLS[act]()
|
| 1038 |
+
|
| 1039 |
+
def forward(self, stat_features: torch.Tensor) -> torch.Tensor:
|
| 1040 |
+
for i, layer in enumerate(self.layers):
|
| 1041 |
+
stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features)
|
| 1042 |
+
return stat_features
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
class DFineLQE(nn.Module):
|
| 1046 |
+
def __init__(self, config: DFineConfig):
|
| 1047 |
+
super().__init__()
|
| 1048 |
+
self.top_prob_values = config.top_prob_values
|
| 1049 |
+
self.max_num_bins = config.max_num_bins
|
| 1050 |
+
self.reg_conf = DFineMLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers)
|
| 1051 |
+
|
| 1052 |
+
def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor:
|
| 1053 |
+
batch_size, length, _ = pred_corners.size()
|
| 1054 |
+
prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1)
|
| 1055 |
+
prob_topk, _ = prob.topk(self.top_prob_values, dim=-1)
|
| 1056 |
+
stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
|
| 1057 |
+
quality_score = self.reg_conf(stat.reshape(batch_size, length, -1))
|
| 1058 |
+
scores = scores + quality_score
|
| 1059 |
+
return scores
|
| 1060 |
+
|
| 1061 |
+
|
| 1062 |
+
class DFineConvNormLayer(RTDetrConvNormLayer):
|
| 1063 |
+
def __init__(
|
| 1064 |
+
self,
|
| 1065 |
+
config: DFineConfig,
|
| 1066 |
+
in_channels: int,
|
| 1067 |
+
out_channels: int,
|
| 1068 |
+
kernel_size: int,
|
| 1069 |
+
stride: int,
|
| 1070 |
+
groups: int = 1,
|
| 1071 |
+
padding: Optional[int] = None,
|
| 1072 |
+
activation: Optional[str] = None,
|
| 1073 |
+
):
|
| 1074 |
+
super().__init__(config, in_channels, out_channels, kernel_size, stride, padding=None, activation=activation)
|
| 1075 |
+
self.conv = nn.Conv2d(
|
| 1076 |
+
in_channels,
|
| 1077 |
+
out_channels,
|
| 1078 |
+
kernel_size,
|
| 1079 |
+
stride,
|
| 1080 |
+
groups=groups,
|
| 1081 |
+
padding=(kernel_size - 1) // 2 if padding is None else padding,
|
| 1082 |
+
bias=False,
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
class DFineRepVggBlock(RTDetrRepVggBlock):
|
| 1087 |
+
def __init__(self, config: DFineConfig, in_channels: int, out_channels: int):
|
| 1088 |
+
super().__init__(config)
|
| 1089 |
+
hidden_channels = in_channels
|
| 1090 |
+
self.conv1 = DFineConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1)
|
| 1091 |
+
self.conv2 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0)
|
| 1092 |
+
|
| 1093 |
+
|
| 1094 |
+
class DFineCSPRepLayer(nn.Module):
|
| 1095 |
+
"""
|
| 1096 |
+
Cross Stage Partial (CSP) network layer with RepVGG blocks.
|
| 1097 |
+
"""
|
| 1098 |
+
|
| 1099 |
+
def __init__(
|
| 1100 |
+
self, config: DFineConfig, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0
|
| 1101 |
+
):
|
| 1102 |
+
super().__init__()
|
| 1103 |
+
in_channels = in_channels
|
| 1104 |
+
out_channels = out_channels
|
| 1105 |
+
activation = config.activation_function
|
| 1106 |
+
|
| 1107 |
+
hidden_channels = int(out_channels * expansion)
|
| 1108 |
+
self.conv1 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
| 1109 |
+
self.conv2 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
| 1110 |
+
self.bottlenecks = nn.ModuleList(
|
| 1111 |
+
[DFineRepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)]
|
| 1112 |
+
)
|
| 1113 |
+
if hidden_channels != out_channels:
|
| 1114 |
+
self.conv3 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
|
| 1115 |
+
else:
|
| 1116 |
+
self.conv3 = nn.Identity()
|
| 1117 |
+
|
| 1118 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 1119 |
+
hidden_state_1 = self.conv1(hidden_state)
|
| 1120 |
+
for bottleneck in self.bottlenecks:
|
| 1121 |
+
hidden_state_1 = bottleneck(hidden_state_1)
|
| 1122 |
+
hidden_state_2 = self.conv2(hidden_state)
|
| 1123 |
+
hidden_state_3 = self.conv3(hidden_state_1 + hidden_state_2)
|
| 1124 |
+
return hidden_state_3
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
class DFineRepNCSPELAN4(nn.Module):
|
| 1128 |
+
def __init__(self, config: DFineConfig, act: str = "silu", numb_blocks: int = 3):
|
| 1129 |
+
super().__init__()
|
| 1130 |
+
conv1_dim = config.encoder_hidden_dim * 2
|
| 1131 |
+
conv2_dim = config.encoder_hidden_dim
|
| 1132 |
+
conv3_dim = config.encoder_hidden_dim * 2
|
| 1133 |
+
conv4_dim = round(config.hidden_expansion * config.encoder_hidden_dim // 2)
|
| 1134 |
+
self.conv_dim = conv3_dim // 2
|
| 1135 |
+
self.conv1 = DFineConvNormLayer(config, conv1_dim, conv3_dim, 1, 1, activation=act)
|
| 1136 |
+
self.csp_rep1 = DFineCSPRepLayer(config, conv3_dim // 2, conv4_dim, num_blocks=numb_blocks)
|
| 1137 |
+
self.conv2 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
|
| 1138 |
+
self.csp_rep2 = DFineCSPRepLayer(config, conv4_dim, conv4_dim, num_blocks=numb_blocks)
|
| 1139 |
+
self.conv3 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
|
| 1140 |
+
self.conv4 = DFineConvNormLayer(config, conv3_dim + (2 * conv4_dim), conv2_dim, 1, 1, activation=act)
|
| 1141 |
+
|
| 1142 |
+
def forward(self, input_features: torch.Tensor) -> torch.Tensor:
|
| 1143 |
+
# Split initial features into two branches after first convolution
|
| 1144 |
+
split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1))
|
| 1145 |
+
|
| 1146 |
+
# Process branches sequentially
|
| 1147 |
+
branch1 = self.csp_rep1(split_features[-1])
|
| 1148 |
+
branch1 = self.conv2(branch1)
|
| 1149 |
+
branch2 = self.csp_rep2(branch1)
|
| 1150 |
+
branch2 = self.conv3(branch2)
|
| 1151 |
+
|
| 1152 |
+
split_features.extend([branch1, branch2])
|
| 1153 |
+
merged_features = torch.cat(split_features, 1)
|
| 1154 |
+
merged_features = self.conv4(merged_features)
|
| 1155 |
+
return merged_features
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
class DFineSCDown(nn.Module):
|
| 1159 |
+
def __init__(self, config: DFineConfig, kernel_size: int, stride: int):
|
| 1160 |
+
super().__init__()
|
| 1161 |
+
self.conv1 = DFineConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1)
|
| 1162 |
+
self.conv2 = DFineConvNormLayer(
|
| 1163 |
+
config,
|
| 1164 |
+
config.encoder_hidden_dim,
|
| 1165 |
+
config.encoder_hidden_dim,
|
| 1166 |
+
kernel_size,
|
| 1167 |
+
stride,
|
| 1168 |
+
config.encoder_hidden_dim,
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
def forward(self, input_features: torch.Tensor) -> torch.Tensor:
|
| 1172 |
+
input_features = self.conv1(input_features)
|
| 1173 |
+
input_features = self.conv2(input_features)
|
| 1174 |
+
return input_features
|
| 1175 |
+
|
| 1176 |
+
|
| 1177 |
+
class DFineEncoder(RTDetrEncoder):
|
| 1178 |
+
pass
|
| 1179 |
+
|
| 1180 |
+
|
| 1181 |
+
class DFineHybridEncoder(RTDetrHybridEncoder):
|
| 1182 |
+
def __init__(self, config: DFineConfig):
|
| 1183 |
+
nn.Module.__init__(self)
|
| 1184 |
+
self.config = config
|
| 1185 |
+
self.in_channels = config.encoder_in_channels
|
| 1186 |
+
self.num_fpn_stages = len(self.in_channels) - 1
|
| 1187 |
+
self.feat_strides = config.feat_strides
|
| 1188 |
+
self.encoder_hidden_dim = config.encoder_hidden_dim
|
| 1189 |
+
self.encode_proj_layers = config.encode_proj_layers
|
| 1190 |
+
self.positional_encoding_temperature = config.positional_encoding_temperature
|
| 1191 |
+
self.eval_size = config.eval_size
|
| 1192 |
+
self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
|
| 1193 |
+
self.out_strides = self.feat_strides
|
| 1194 |
+
|
| 1195 |
+
# encoder transformer
|
| 1196 |
+
self.encoder = nn.ModuleList([DFineEncoder(config) for _ in range(len(self.encode_proj_layers))])
|
| 1197 |
+
# top-down fpn
|
| 1198 |
+
self.lateral_convs = nn.ModuleList()
|
| 1199 |
+
self.fpn_blocks = nn.ModuleList()
|
| 1200 |
+
for _ in range(len(self.in_channels) - 1, 0, -1):
|
| 1201 |
+
lateral_layer = DFineConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1)
|
| 1202 |
+
self.lateral_convs.append(lateral_layer)
|
| 1203 |
+
num_blocks = round(3 * config.depth_mult)
|
| 1204 |
+
fpn_layer = DFineRepNCSPELAN4(config, numb_blocks=num_blocks)
|
| 1205 |
+
self.fpn_blocks.append(fpn_layer)
|
| 1206 |
+
|
| 1207 |
+
# bottom-up pan
|
| 1208 |
+
self.downsample_convs = nn.ModuleList()
|
| 1209 |
+
self.pan_blocks = nn.ModuleList()
|
| 1210 |
+
for _ in range(len(self.in_channels) - 1):
|
| 1211 |
+
self.downsample_convs.append(DFineSCDown(config, 3, 2))
|
| 1212 |
+
num_blocks = round(3 * config.depth_mult)
|
| 1213 |
+
self.pan_blocks.append(DFineRepNCSPELAN4(config, numb_blocks=num_blocks))
|
| 1214 |
+
|
| 1215 |
+
|
| 1216 |
+
__all__ = [
|
| 1217 |
+
"DFineConfig",
|
| 1218 |
+
"DFineModel",
|
| 1219 |
+
"DFinePreTrainedModel",
|
| 1220 |
+
"DFineForObjectDetection",
|
| 1221 |
+
]
|
phivenv/Lib/site-packages/transformers/models/depth_pro/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_depth_pro import *
|
| 22 |
+
from .image_processing_depth_pro import *
|
| 23 |
+
from .image_processing_depth_pro_fast import *
|
| 24 |
+
from .modeling_depth_pro import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (618 Bytes). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/configuration_depth_pro.cpython-39.pyc
ADDED
|
Binary file (7.61 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/image_processing_depth_pro.cpython-39.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/image_processing_depth_pro_fast.cpython-39.pyc
ADDED
|
Binary file (4.79 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/depth_pro/__pycache__/modeling_depth_pro.cpython-39.pyc
ADDED
|
Binary file (27.7 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/depth_pro/configuration_depth_pro.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""DepthPro model configuration"""
|
| 16 |
+
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from ..auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DepthProConfig(PretrainedConfig):
|
| 28 |
+
r"""
|
| 29 |
+
This is the configuration class to store the configuration of a [`DepthProModel`]. It is used to instantiate a
|
| 30 |
+
DepthPro model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 31 |
+
with the defaults will yield a similar configuration to that of the DepthPro
|
| 32 |
+
[apple/DepthPro](https://huggingface.co/apple/DepthPro) architecture.
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
fusion_hidden_size (`int`, *optional*, defaults to 256):
|
| 39 |
+
The number of channels before fusion.
|
| 40 |
+
patch_size (`int`, *optional*, defaults to 384):
|
| 41 |
+
The size (resolution) of each patch. This is also the image_size for backbone model.
|
| 42 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 43 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 44 |
+
intermediate_hook_ids (`list[int]`, *optional*, defaults to `[11, 5]`):
|
| 45 |
+
Indices of the intermediate hidden states from the patch encoder to use for fusion.
|
| 46 |
+
intermediate_feature_dims (`list[int]`, *optional*, defaults to `[256, 256]`):
|
| 47 |
+
Hidden state dimensions during upsampling for each intermediate hidden state in `intermediate_hook_ids`.
|
| 48 |
+
scaled_images_ratios (`list[float]`, *optional*, defaults to `[0.25, 0.5, 1]`):
|
| 49 |
+
Ratios of scaled images to be used by the patch encoder.
|
| 50 |
+
scaled_images_overlap_ratios (`list[float]`, *optional*, defaults to `[0.0, 0.5, 0.25]`):
|
| 51 |
+
Overlap ratios between patches for each scaled image in `scaled_images_ratios`.
|
| 52 |
+
scaled_images_feature_dims (`list[int]`, *optional*, defaults to `[1024, 1024, 512]`):
|
| 53 |
+
Hidden state dimensions during upsampling for each scaled image in `scaled_images_ratios`.
|
| 54 |
+
merge_padding_value (`int`, *optional*, defaults to 3):
|
| 55 |
+
When merging smaller patches back to the image size, overlapping sections of this size are removed.
|
| 56 |
+
use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
|
| 57 |
+
Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
|
| 58 |
+
use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether to use bias in the pre-activate residual units of the fusion blocks.
|
| 60 |
+
use_fov_model (`bool`, *optional*, defaults to `False`):
|
| 61 |
+
Whether to use `DepthProFovModel` to generate the field of view.
|
| 62 |
+
num_fov_head_layers (`int`, *optional*, defaults to 2):
|
| 63 |
+
Number of convolution layers in the head of `DepthProFovModel`.
|
| 64 |
+
image_model_config (`Union[dict[str, Any], PretrainedConfig]`, *optional*):
|
| 65 |
+
The configuration of the image encoder model, which is loaded using the [`AutoModel`] API.
|
| 66 |
+
By default, Dinov2 model is used as backbone.
|
| 67 |
+
patch_model_config (`Union[dict[str, Any], PretrainedConfig]`, *optional*):
|
| 68 |
+
The configuration of the patch encoder model, which is loaded using the [`AutoModel`] API.
|
| 69 |
+
By default, Dinov2 model is used as backbone.
|
| 70 |
+
fov_model_config (`Union[dict[str, Any], PretrainedConfig]`, *optional*):
|
| 71 |
+
The configuration of the fov encoder model, which is loaded using the [`AutoModel`] API.
|
| 72 |
+
By default, Dinov2 model is used as backbone.
|
| 73 |
+
|
| 74 |
+
Example:
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
>>> from transformers import DepthProConfig, DepthProModel
|
| 78 |
+
|
| 79 |
+
>>> # Initializing a DepthPro apple/DepthPro style configuration
|
| 80 |
+
>>> configuration = DepthProConfig()
|
| 81 |
+
|
| 82 |
+
>>> # Initializing a model (with random weights) from the apple/DepthPro style configuration
|
| 83 |
+
>>> model = DepthProModel(configuration)
|
| 84 |
+
|
| 85 |
+
>>> # Accessing the model configuration
|
| 86 |
+
>>> configuration = model.config
|
| 87 |
+
```"""
|
| 88 |
+
|
| 89 |
+
model_type = "depth_pro"
|
| 90 |
+
sub_configs = {"image_model_config": AutoConfig, "patch_model_config": AutoConfig, "fov_model_config": AutoConfig}
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
fusion_hidden_size=256,
|
| 95 |
+
patch_size=384,
|
| 96 |
+
initializer_range=0.02,
|
| 97 |
+
intermediate_hook_ids=[11, 5],
|
| 98 |
+
intermediate_feature_dims=[256, 256],
|
| 99 |
+
scaled_images_ratios=[0.25, 0.5, 1],
|
| 100 |
+
scaled_images_overlap_ratios=[0.0, 0.5, 0.25],
|
| 101 |
+
scaled_images_feature_dims=[1024, 1024, 512],
|
| 102 |
+
merge_padding_value=3,
|
| 103 |
+
use_batch_norm_in_fusion_residual=False,
|
| 104 |
+
use_bias_in_fusion_residual=True,
|
| 105 |
+
use_fov_model=False,
|
| 106 |
+
num_fov_head_layers=2,
|
| 107 |
+
image_model_config=None,
|
| 108 |
+
patch_model_config=None,
|
| 109 |
+
fov_model_config=None,
|
| 110 |
+
**kwargs,
|
| 111 |
+
):
|
| 112 |
+
super().__init__(**kwargs)
|
| 113 |
+
|
| 114 |
+
# scaled_images_ratios is sorted
|
| 115 |
+
if scaled_images_ratios != sorted(scaled_images_ratios):
|
| 116 |
+
raise ValueError(
|
| 117 |
+
f"Values in scaled_images_ratios={scaled_images_ratios} should be sorted from low to high"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# scaled_images_ratios, scaled_images_overlap_ratios, scaled_images_feature_dims should be consistent
|
| 121 |
+
if not (len(scaled_images_ratios) == len(scaled_images_overlap_ratios) == len(scaled_images_feature_dims)):
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"len(scaled_images_ratios)={len(scaled_images_ratios)} and "
|
| 124 |
+
f"len(scaled_images_overlap_ratios)={len(scaled_images_overlap_ratios)} and "
|
| 125 |
+
f"len(scaled_images_feature_dims)={len(scaled_images_feature_dims)}, "
|
| 126 |
+
f"should match in config."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# intermediate_hook_ids, intermediate_feature_dims should be consistent
|
| 130 |
+
if not (len(intermediate_hook_ids) == len(intermediate_feature_dims)):
|
| 131 |
+
raise ValueError(
|
| 132 |
+
f"len(intermediate_hook_ids)={len(intermediate_hook_ids)} and "
|
| 133 |
+
f"len(intermediate_feature_dims)={len(intermediate_feature_dims)}, "
|
| 134 |
+
f"should match in config."
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# fusion_hidden_size should be consistent with num_fov_head_layers
|
| 138 |
+
if fusion_hidden_size // 2**num_fov_head_layers == 0:
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"fusion_hidden_size={fusion_hidden_size} should be consistent with num_fov_head_layers={num_fov_head_layers} "
|
| 141 |
+
"i.e fusion_hidden_size // 2**num_fov_head_layers > 0"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.fusion_hidden_size = fusion_hidden_size
|
| 145 |
+
self.patch_size = patch_size
|
| 146 |
+
self.initializer_range = initializer_range
|
| 147 |
+
self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
|
| 148 |
+
self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
|
| 149 |
+
self.use_fov_model = use_fov_model
|
| 150 |
+
self.num_fov_head_layers = num_fov_head_layers
|
| 151 |
+
self.intermediate_hook_ids = intermediate_hook_ids
|
| 152 |
+
self.intermediate_feature_dims = intermediate_feature_dims
|
| 153 |
+
self.scaled_images_ratios = scaled_images_ratios
|
| 154 |
+
self.scaled_images_overlap_ratios = scaled_images_overlap_ratios
|
| 155 |
+
self.scaled_images_feature_dims = scaled_images_feature_dims
|
| 156 |
+
self.merge_padding_value = merge_padding_value
|
| 157 |
+
self.image_model_config = image_model_config
|
| 158 |
+
self.patch_model_config = patch_model_config
|
| 159 |
+
self.fov_model_config = fov_model_config
|
| 160 |
+
|
| 161 |
+
for sub_config_key in self.sub_configs:
|
| 162 |
+
sub_config = getattr(self, sub_config_key)
|
| 163 |
+
|
| 164 |
+
if sub_config is None:
|
| 165 |
+
sub_config = CONFIG_MAPPING["dinov2"](image_size=patch_size)
|
| 166 |
+
logger.info(
|
| 167 |
+
f"`{sub_config_key}` is `None`. Initializing `{sub_config_key}` with the `Dinov2Config` "
|
| 168 |
+
f"with default values except `{sub_config_key}.image_size` is set to `config.patch_size`."
|
| 169 |
+
)
|
| 170 |
+
elif isinstance(sub_config, dict):
|
| 171 |
+
sub_config = deepcopy(sub_config)
|
| 172 |
+
if "model_type" not in sub_config:
|
| 173 |
+
raise KeyError(
|
| 174 |
+
f"The `model_type` key is missing in the `{sub_config_key}` dictionary. Please provide the model type."
|
| 175 |
+
)
|
| 176 |
+
elif sub_config["model_type"] not in CONFIG_MAPPING:
|
| 177 |
+
raise ValueError(
|
| 178 |
+
f"The model type `{sub_config['model_type']}` in `{sub_config_key}` is not supported. Please provide a valid model type."
|
| 179 |
+
)
|
| 180 |
+
image_size = sub_config.get("image_size")
|
| 181 |
+
if image_size != patch_size:
|
| 182 |
+
logger.info(
|
| 183 |
+
f"The `image_size` in `{sub_config_key}` is set to `{image_size}`, "
|
| 184 |
+
f"but it does not match the required `patch_size` of `{patch_size}`. "
|
| 185 |
+
f"Updating `image_size` to `{patch_size}` for consistency. "
|
| 186 |
+
f"Ensure that `image_size` aligns with `patch_size` in the configuration."
|
| 187 |
+
)
|
| 188 |
+
sub_config.update({"image_size": patch_size})
|
| 189 |
+
sub_config = CONFIG_MAPPING[sub_config["model_type"]](**sub_config)
|
| 190 |
+
elif isinstance(sub_config, PretrainedConfig):
|
| 191 |
+
sub_config = sub_config
|
| 192 |
+
image_size = getattr(sub_config, "image_size", None)
|
| 193 |
+
if image_size != patch_size:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f"`config.{sub_config_key}.image_size={image_size}` should match `config.patch_size={patch_size}`."
|
| 196 |
+
)
|
| 197 |
+
else:
|
| 198 |
+
raise TypeError(
|
| 199 |
+
f"Invalid type for `sub_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(sub_config)}."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
setattr(self, sub_config_key, sub_config)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
__all__ = ["DepthProConfig"]
|
phivenv/Lib/site-packages/transformers/models/depth_pro/image_processing_depth_pro.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for DepthPro."""
|
| 16 |
+
|
| 17 |
+
from typing import TYPE_CHECKING, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...utils.import_utils import requires
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from .modeling_depth_pro import DepthProDepthEstimatorOutput
|
| 26 |
+
|
| 27 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 28 |
+
from ...image_transforms import to_channel_dimension_format
|
| 29 |
+
from ...image_utils import (
|
| 30 |
+
IMAGENET_STANDARD_MEAN,
|
| 31 |
+
IMAGENET_STANDARD_STD,
|
| 32 |
+
ChannelDimension,
|
| 33 |
+
ImageInput,
|
| 34 |
+
PILImageResampling,
|
| 35 |
+
infer_channel_dimension_format,
|
| 36 |
+
is_scaled_image,
|
| 37 |
+
is_torch_available,
|
| 38 |
+
make_list_of_images,
|
| 39 |
+
to_numpy_array,
|
| 40 |
+
valid_images,
|
| 41 |
+
)
|
| 42 |
+
from ...utils import TensorType, filter_out_non_signature_kwargs, is_torchvision_available, logging, requires_backends
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if is_torch_available():
|
| 46 |
+
import torch
|
| 47 |
+
|
| 48 |
+
if is_torchvision_available():
|
| 49 |
+
from ...image_utils import pil_torch_interpolation_mapping
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
logger = logging.get_logger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@requires(backends=("torchvision", "torch"))
|
| 56 |
+
class DepthProImageProcessor(BaseImageProcessor):
|
| 57 |
+
r"""
|
| 58 |
+
Constructs a DepthPro image processor.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
|
| 63 |
+
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
|
| 64 |
+
size (`dict`, *optional*, defaults to `{"height": 1536, "width": 1536}`):
|
| 65 |
+
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
| 66 |
+
method.
|
| 67 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 68 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 69 |
+
`preprocess` method.
|
| 70 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 71 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 72 |
+
parameter in the `preprocess` method.
|
| 73 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 74 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 75 |
+
`preprocess` method.
|
| 76 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 77 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 78 |
+
method.
|
| 79 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 80 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 81 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 82 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 83 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 84 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
model_input_names = ["pixel_values"]
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
do_resize: bool = True,
|
| 92 |
+
size: Optional[dict[str, int]] = None,
|
| 93 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 94 |
+
do_rescale: bool = True,
|
| 95 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 96 |
+
do_normalize: bool = True,
|
| 97 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 98 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 99 |
+
**kwargs,
|
| 100 |
+
):
|
| 101 |
+
super().__init__(**kwargs)
|
| 102 |
+
size = size if size is not None else {"height": 1536, "width": 1536}
|
| 103 |
+
size = get_size_dict(size)
|
| 104 |
+
self.do_resize = do_resize
|
| 105 |
+
self.do_rescale = do_rescale
|
| 106 |
+
self.do_normalize = do_normalize
|
| 107 |
+
self.size = size
|
| 108 |
+
self.resample = resample
|
| 109 |
+
self.rescale_factor = rescale_factor
|
| 110 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 111 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 112 |
+
|
| 113 |
+
def resize(
|
| 114 |
+
self,
|
| 115 |
+
image: np.ndarray,
|
| 116 |
+
size: dict[str, int],
|
| 117 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 118 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 119 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 120 |
+
**kwargs,
|
| 121 |
+
) -> np.ndarray:
|
| 122 |
+
"""
|
| 123 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
image (`np.ndarray`):
|
| 127 |
+
Image to resize.
|
| 128 |
+
size (`dict[str, int]`):
|
| 129 |
+
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
| 130 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 131 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 132 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 133 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 134 |
+
image is used. Can be one of:
|
| 135 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 136 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 137 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 138 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 139 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 140 |
+
from the input image. Can be one of:
|
| 141 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 142 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 143 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
`np.ndarray`: The resized images.
|
| 147 |
+
"""
|
| 148 |
+
requires_backends(self, "torch")
|
| 149 |
+
|
| 150 |
+
size = get_size_dict(size)
|
| 151 |
+
if "height" not in size or "width" not in size:
|
| 152 |
+
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
| 153 |
+
output_size = (size["height"], size["width"])
|
| 154 |
+
|
| 155 |
+
# we use torch interpolation instead of image.resize because DepthProImageProcessor
|
| 156 |
+
# rescales, then normalizes, which may cause some values to become negative, before resizing the image.
|
| 157 |
+
# image.resize expects all values to be in range [0, 1] or [0, 255] and throws an exception otherwise,
|
| 158 |
+
# however pytorch interpolation works with negative values.
|
| 159 |
+
# relevant issue here: https://github.com/huggingface/transformers/issues/34920
|
| 160 |
+
# input should be (B, C, H, W)
|
| 161 |
+
image_tensor = torch.from_numpy(image).unsqueeze(0)
|
| 162 |
+
resized_image = torch.nn.functional.interpolate(
|
| 163 |
+
input=image_tensor,
|
| 164 |
+
size=output_size,
|
| 165 |
+
mode=pil_torch_interpolation_mapping[resample].value,
|
| 166 |
+
)
|
| 167 |
+
resized_image = resized_image.squeeze(0).numpy()
|
| 168 |
+
return resized_image
|
| 169 |
+
|
| 170 |
+
def _validate_input_arguments(
|
| 171 |
+
self,
|
| 172 |
+
do_resize: bool,
|
| 173 |
+
size: dict[str, int],
|
| 174 |
+
resample: PILImageResampling,
|
| 175 |
+
do_rescale: bool,
|
| 176 |
+
rescale_factor: float,
|
| 177 |
+
do_normalize: bool,
|
| 178 |
+
image_mean: Union[float, list[float]],
|
| 179 |
+
image_std: Union[float, list[float]],
|
| 180 |
+
data_format: Union[str, ChannelDimension],
|
| 181 |
+
):
|
| 182 |
+
if do_resize and None in (size, resample):
|
| 183 |
+
raise ValueError("Size and resample must be specified if do_resize is True.")
|
| 184 |
+
|
| 185 |
+
if do_rescale and rescale_factor is None:
|
| 186 |
+
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
| 187 |
+
|
| 188 |
+
if do_normalize and None in (image_mean, image_std):
|
| 189 |
+
raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
|
| 190 |
+
|
| 191 |
+
@filter_out_non_signature_kwargs()
|
| 192 |
+
def preprocess(
|
| 193 |
+
self,
|
| 194 |
+
images: ImageInput,
|
| 195 |
+
do_resize: Optional[bool] = None,
|
| 196 |
+
size: Optional[dict[str, int]] = None,
|
| 197 |
+
resample: Optional[PILImageResampling] = None,
|
| 198 |
+
do_rescale: Optional[bool] = None,
|
| 199 |
+
rescale_factor: Optional[float] = None,
|
| 200 |
+
do_normalize: Optional[bool] = None,
|
| 201 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 202 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 203 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 204 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 205 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 206 |
+
):
|
| 207 |
+
"""
|
| 208 |
+
Preprocess an image or batch of images.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
images (`ImageInput`):
|
| 212 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 213 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 214 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 215 |
+
Whether to resize the image.
|
| 216 |
+
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
| 217 |
+
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
|
| 218 |
+
resizing.
|
| 219 |
+
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
|
| 220 |
+
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
|
| 221 |
+
an effect if `do_resize` is set to `True`.
|
| 222 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 223 |
+
Whether to rescale the image values between [0 - 1].
|
| 224 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 225 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 226 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 227 |
+
Whether to normalize the image.
|
| 228 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
| 229 |
+
Image mean to use if `do_normalize` is set to `True`.
|
| 230 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
| 231 |
+
Image standard deviation to use if `do_normalize` is set to `True`.
|
| 232 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 233 |
+
The type of tensors to return. Can be one of:
|
| 234 |
+
- Unset: Return a list of `np.ndarray`.
|
| 235 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 236 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 237 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 238 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 239 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 240 |
+
The channel dimension format for the output image. Can be one of:
|
| 241 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 242 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 243 |
+
- Unset: Use the channel dimension format of the input image.
|
| 244 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 245 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 246 |
+
from the input image. Can be one of:
|
| 247 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 248 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 249 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 250 |
+
"""
|
| 251 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 252 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 253 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 254 |
+
resample = resample if resample is not None else self.resample
|
| 255 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 256 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 257 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 258 |
+
|
| 259 |
+
size = size if size is not None else self.size
|
| 260 |
+
|
| 261 |
+
images = make_list_of_images(images)
|
| 262 |
+
|
| 263 |
+
if not valid_images(images):
|
| 264 |
+
raise ValueError(
|
| 265 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 266 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 267 |
+
)
|
| 268 |
+
self._validate_input_arguments(
|
| 269 |
+
do_resize=do_resize,
|
| 270 |
+
size=size,
|
| 271 |
+
resample=resample,
|
| 272 |
+
do_rescale=do_rescale,
|
| 273 |
+
rescale_factor=rescale_factor,
|
| 274 |
+
do_normalize=do_normalize,
|
| 275 |
+
image_mean=image_mean,
|
| 276 |
+
image_std=image_std,
|
| 277 |
+
data_format=data_format,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# All transformations expect numpy arrays.
|
| 281 |
+
images = [to_numpy_array(image) for image in images]
|
| 282 |
+
|
| 283 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 284 |
+
logger.warning_once(
|
| 285 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 286 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if input_data_format is None:
|
| 290 |
+
# We assume that all images have the same channel dimension format.
|
| 291 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 292 |
+
|
| 293 |
+
all_images = []
|
| 294 |
+
for image in images:
|
| 295 |
+
if do_rescale:
|
| 296 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 297 |
+
|
| 298 |
+
if do_normalize:
|
| 299 |
+
image = self.normalize(
|
| 300 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# depth-pro rescales and normalizes the image before resizing it
|
| 304 |
+
# uses torch interpolation which requires ChannelDimension.FIRST
|
| 305 |
+
if do_resize:
|
| 306 |
+
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
|
| 307 |
+
image = self.resize(image=image, size=size, resample=resample)
|
| 308 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=ChannelDimension.FIRST)
|
| 309 |
+
else:
|
| 310 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 311 |
+
|
| 312 |
+
all_images.append(image)
|
| 313 |
+
|
| 314 |
+
data = {"pixel_values": all_images}
|
| 315 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 316 |
+
|
| 317 |
+
def post_process_depth_estimation(
|
| 318 |
+
self,
|
| 319 |
+
outputs: "DepthProDepthEstimatorOutput",
|
| 320 |
+
target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None,
|
| 321 |
+
) -> list[dict[str, TensorType]]:
|
| 322 |
+
"""
|
| 323 |
+
Post-processes the raw depth predictions from the model to generate
|
| 324 |
+
final depth predictions which is caliberated using the field of view if provided
|
| 325 |
+
and resized to specified target sizes if provided.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
outputs ([`DepthProDepthEstimatorOutput`]):
|
| 329 |
+
Raw outputs of the model.
|
| 330 |
+
target_sizes (`Optional[Union[TensorType, list[tuple[int, int]], None]]`, *optional*, defaults to `None`):
|
| 331 |
+
Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
|
| 332 |
+
or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
|
| 333 |
+
is performed.
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
`list[dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
| 337 |
+
predictions, and field of view (degrees) and focal length (pixels) if `field_of_view` is given in `outputs`.
|
| 338 |
+
|
| 339 |
+
Raises:
|
| 340 |
+
`ValueError`:
|
| 341 |
+
If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
|
| 342 |
+
"""
|
| 343 |
+
requires_backends(self, "torch")
|
| 344 |
+
|
| 345 |
+
predicted_depth = outputs.predicted_depth
|
| 346 |
+
fov = outputs.field_of_view
|
| 347 |
+
|
| 348 |
+
batch_size = len(predicted_depth)
|
| 349 |
+
|
| 350 |
+
if target_sizes is not None and batch_size != len(target_sizes):
|
| 351 |
+
raise ValueError(
|
| 352 |
+
"Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
results = []
|
| 356 |
+
fov = [None] * batch_size if fov is None else fov
|
| 357 |
+
target_sizes = [None] * batch_size if target_sizes is None else target_sizes
|
| 358 |
+
for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
|
| 359 |
+
focal_length = None
|
| 360 |
+
if target_size is not None:
|
| 361 |
+
# scale image w.r.t fov
|
| 362 |
+
if fov_value is not None:
|
| 363 |
+
width = target_size[1]
|
| 364 |
+
focal_length = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
|
| 365 |
+
depth = depth * width / focal_length
|
| 366 |
+
|
| 367 |
+
# interpolate
|
| 368 |
+
depth = torch.nn.functional.interpolate(
|
| 369 |
+
# input should be (B, C, H, W)
|
| 370 |
+
input=depth.unsqueeze(0).unsqueeze(1),
|
| 371 |
+
size=target_size,
|
| 372 |
+
mode=pil_torch_interpolation_mapping[self.resample].value,
|
| 373 |
+
).squeeze()
|
| 374 |
+
|
| 375 |
+
# inverse the depth
|
| 376 |
+
depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
|
| 377 |
+
|
| 378 |
+
results.append(
|
| 379 |
+
{
|
| 380 |
+
"predicted_depth": depth,
|
| 381 |
+
"field_of_view": fov_value,
|
| 382 |
+
"focal_length": focal_length,
|
| 383 |
+
}
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
return results
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
__all__ = ["DepthProImageProcessor"]
|
phivenv/Lib/site-packages/transformers/models/depth_pro/image_processing_depth_pro_fast.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Image processor class for DepthPro."""
|
| 16 |
+
|
| 17 |
+
from typing import TYPE_CHECKING, Optional, Union
|
| 18 |
+
|
| 19 |
+
from ...image_processing_base import BatchFeature
|
| 20 |
+
from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
|
| 21 |
+
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling, SizeDict
|
| 22 |
+
from ...utils import (
|
| 23 |
+
TensorType,
|
| 24 |
+
auto_docstring,
|
| 25 |
+
is_torch_available,
|
| 26 |
+
is_torchvision_available,
|
| 27 |
+
is_torchvision_v2_available,
|
| 28 |
+
logging,
|
| 29 |
+
requires_backends,
|
| 30 |
+
)
|
| 31 |
+
from ...utils.import_utils import requires
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if TYPE_CHECKING:
|
| 35 |
+
from .modeling_depth_pro import DepthProDepthEstimatorOutput
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if is_torch_available():
|
| 41 |
+
import torch
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if is_torchvision_available():
|
| 45 |
+
from ...image_utils import pil_torch_interpolation_mapping
|
| 46 |
+
|
| 47 |
+
if is_torchvision_v2_available():
|
| 48 |
+
from torchvision.transforms.v2 import functional as F
|
| 49 |
+
else:
|
| 50 |
+
from torchvision.transforms import functional as F
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@auto_docstring
|
| 54 |
+
@requires(backends=("torchvision", "torch"))
|
| 55 |
+
class DepthProImageProcessorFast(BaseImageProcessorFast):
|
| 56 |
+
resample = PILImageResampling.BILINEAR
|
| 57 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
| 58 |
+
image_std = IMAGENET_STANDARD_STD
|
| 59 |
+
size = {"height": 1536, "width": 1536}
|
| 60 |
+
do_resize = True
|
| 61 |
+
do_rescale = True
|
| 62 |
+
do_normalize = True
|
| 63 |
+
|
| 64 |
+
# DepthPro resizes image after rescaling and normalizing,
|
| 65 |
+
# which makes it different from BaseImageProcessorFast._preprocess
|
| 66 |
+
def _preprocess(
|
| 67 |
+
self,
|
| 68 |
+
images: list["torch.Tensor"],
|
| 69 |
+
do_resize: bool,
|
| 70 |
+
size: SizeDict,
|
| 71 |
+
interpolation: Optional["F.InterpolationMode"],
|
| 72 |
+
do_center_crop: bool,
|
| 73 |
+
crop_size: SizeDict,
|
| 74 |
+
do_rescale: bool,
|
| 75 |
+
rescale_factor: float,
|
| 76 |
+
do_normalize: bool,
|
| 77 |
+
image_mean: Optional[Union[float, list[float]]],
|
| 78 |
+
image_std: Optional[Union[float, list[float]]],
|
| 79 |
+
disable_grouping: Optional[bool],
|
| 80 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 81 |
+
) -> BatchFeature:
|
| 82 |
+
# Group images by size for batched scaling
|
| 83 |
+
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
| 84 |
+
processed_images_grouped = {}
|
| 85 |
+
for shape, stacked_images in grouped_images.items():
|
| 86 |
+
# Fused rescale and normalize
|
| 87 |
+
stacked_images = self.rescale_and_normalize(
|
| 88 |
+
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
| 89 |
+
)
|
| 90 |
+
if do_resize:
|
| 91 |
+
stacked_images = self.resize(
|
| 92 |
+
image=stacked_images,
|
| 93 |
+
size=size,
|
| 94 |
+
interpolation=interpolation,
|
| 95 |
+
antialias=False,
|
| 96 |
+
)
|
| 97 |
+
processed_images_grouped[shape] = stacked_images
|
| 98 |
+
|
| 99 |
+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
| 100 |
+
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
| 101 |
+
|
| 102 |
+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
| 103 |
+
|
| 104 |
+
# Copied from transformers.models.depth_pro.image_processing_depth_pro.DepthProImageProcessor.post_process_depth_estimation
|
| 105 |
+
def post_process_depth_estimation(
|
| 106 |
+
self,
|
| 107 |
+
outputs: "DepthProDepthEstimatorOutput",
|
| 108 |
+
target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None,
|
| 109 |
+
) -> list[dict[str, TensorType]]:
|
| 110 |
+
"""
|
| 111 |
+
Post-processes the raw depth predictions from the model to generate
|
| 112 |
+
final depth predictions which is caliberated using the field of view if provided
|
| 113 |
+
and resized to specified target sizes if provided.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
outputs ([`DepthProDepthEstimatorOutput`]):
|
| 117 |
+
Raw outputs of the model.
|
| 118 |
+
target_sizes (`Optional[Union[TensorType, list[tuple[int, int]], None]]`, *optional*, defaults to `None`):
|
| 119 |
+
Target sizes to resize the depth predictions. Can be a tensor of shape `(batch_size, 2)`
|
| 120 |
+
or a list of tuples `(height, width)` for each image in the batch. If `None`, no resizing
|
| 121 |
+
is performed.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
`list[dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
| 125 |
+
predictions, and field of view (degrees) and focal length (pixels) if `field_of_view` is given in `outputs`.
|
| 126 |
+
|
| 127 |
+
Raises:
|
| 128 |
+
`ValueError`:
|
| 129 |
+
If the lengths of `predicted_depths`, `fovs`, or `target_sizes` are mismatched.
|
| 130 |
+
"""
|
| 131 |
+
requires_backends(self, "torch")
|
| 132 |
+
|
| 133 |
+
predicted_depth = outputs.predicted_depth
|
| 134 |
+
fov = outputs.field_of_view
|
| 135 |
+
|
| 136 |
+
batch_size = len(predicted_depth)
|
| 137 |
+
|
| 138 |
+
if target_sizes is not None and batch_size != len(target_sizes):
|
| 139 |
+
raise ValueError(
|
| 140 |
+
"Make sure that you pass in as many fov values as the batch dimension of the predicted depth"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
results = []
|
| 144 |
+
fov = [None] * batch_size if fov is None else fov
|
| 145 |
+
target_sizes = [None] * batch_size if target_sizes is None else target_sizes
|
| 146 |
+
for depth, fov_value, target_size in zip(predicted_depth, fov, target_sizes):
|
| 147 |
+
focal_length = None
|
| 148 |
+
if target_size is not None:
|
| 149 |
+
# scale image w.r.t fov
|
| 150 |
+
if fov_value is not None:
|
| 151 |
+
width = target_size[1]
|
| 152 |
+
focal_length = 0.5 * width / torch.tan(0.5 * torch.deg2rad(fov_value))
|
| 153 |
+
depth = depth * width / focal_length
|
| 154 |
+
|
| 155 |
+
# interpolate
|
| 156 |
+
depth = torch.nn.functional.interpolate(
|
| 157 |
+
# input should be (B, C, H, W)
|
| 158 |
+
input=depth.unsqueeze(0).unsqueeze(1),
|
| 159 |
+
size=target_size,
|
| 160 |
+
mode=pil_torch_interpolation_mapping[self.resample].value,
|
| 161 |
+
).squeeze()
|
| 162 |
+
|
| 163 |
+
# inverse the depth
|
| 164 |
+
depth = 1.0 / torch.clamp(depth, min=1e-4, max=1e4)
|
| 165 |
+
|
| 166 |
+
results.append(
|
| 167 |
+
{
|
| 168 |
+
"predicted_depth": depth,
|
| 169 |
+
"field_of_view": fov_value,
|
| 170 |
+
"focal_length": focal_length,
|
| 171 |
+
}
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return results
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
__all__ = ["DepthProImageProcessorFast"]
|
phivenv/Lib/site-packages/transformers/models/depth_pro/modeling_depth_pro.py
ADDED
|
@@ -0,0 +1,1132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Apple Research Team Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch DepthPro model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torch import nn
|
| 24 |
+
|
| 25 |
+
from ...modeling_utils import PreTrainedModel
|
| 26 |
+
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
| 27 |
+
from ..auto import AutoModel
|
| 28 |
+
from .configuration_depth_pro import DepthProConfig
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
@auto_docstring(
|
| 36 |
+
custom_intro="""
|
| 37 |
+
Base class for DepthPro's outputs.
|
| 38 |
+
"""
|
| 39 |
+
)
|
| 40 |
+
class DepthProOutput(ModelOutput):
|
| 41 |
+
r"""
|
| 42 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, n_patches_per_batch, sequence_length, hidden_size)`):
|
| 43 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 44 |
+
features (`Union[torch.FloatTensor, List[torch.FloatTensor]]`, *optional*):
|
| 45 |
+
Features from encoders. Can be a single feature or a list of features.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 49 |
+
features: Union[torch.FloatTensor, list[torch.FloatTensor]] = None
|
| 50 |
+
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 51 |
+
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
@auto_docstring(
|
| 56 |
+
custom_intro="""
|
| 57 |
+
Base class for DepthProForDepthEstimation's output.
|
| 58 |
+
"""
|
| 59 |
+
)
|
| 60 |
+
class DepthProDepthEstimatorOutput(ModelOutput):
|
| 61 |
+
r"""
|
| 62 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 63 |
+
Classification (or regression if config.num_labels==1) loss.
|
| 64 |
+
field_of_view (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned when `use_fov_model` is provided):
|
| 65 |
+
Field of View Scaler.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
loss: Optional[torch.FloatTensor] = None
|
| 69 |
+
predicted_depth: Optional[torch.FloatTensor] = None
|
| 70 |
+
field_of_view: Optional[torch.FloatTensor] = None
|
| 71 |
+
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 72 |
+
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def split_to_patches(pixel_values: torch.Tensor, patch_size: int, overlap_ratio: float) -> torch.Tensor:
|
| 76 |
+
"""Creates Patches from Batch."""
|
| 77 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 78 |
+
|
| 79 |
+
if height == width == patch_size:
|
| 80 |
+
# create patches only if scaled image is not already equal to patch size
|
| 81 |
+
return pixel_values
|
| 82 |
+
|
| 83 |
+
stride = torch_int(patch_size * (1 - overlap_ratio))
|
| 84 |
+
|
| 85 |
+
patches = F.unfold(pixel_values, kernel_size=(patch_size, patch_size), stride=(stride, stride))
|
| 86 |
+
patches = patches.permute(2, 0, 1)
|
| 87 |
+
patches = patches.reshape(-1, num_channels, patch_size, patch_size)
|
| 88 |
+
|
| 89 |
+
return patches
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def reshape_features(hidden_states: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
"""Discard class token and reshape 1D feature map to a 2D grid."""
|
| 94 |
+
n_samples, seq_len, hidden_size = hidden_states.shape
|
| 95 |
+
size = torch_int(seq_len**0.5)
|
| 96 |
+
|
| 97 |
+
hidden_states = hidden_states[:, -(size**2) :, :] # remove special tokens if there are any
|
| 98 |
+
hidden_states = hidden_states.reshape(n_samples, size, size, hidden_size)
|
| 99 |
+
hidden_states = hidden_states.permute(0, 3, 1, 2)
|
| 100 |
+
|
| 101 |
+
return hidden_states
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def merge_patches(patches: torch.Tensor, batch_size: int, padding: int) -> torch.Tensor:
|
| 105 |
+
"""Merges smaller patches into image-like feature map."""
|
| 106 |
+
n_patches, hidden_size, out_size, out_size = patches.shape
|
| 107 |
+
n_patches_per_batch = n_patches // batch_size
|
| 108 |
+
sqrt_n_patches_per_batch = torch_int(n_patches_per_batch**0.5)
|
| 109 |
+
new_out_size = sqrt_n_patches_per_batch * out_size
|
| 110 |
+
|
| 111 |
+
if n_patches == batch_size:
|
| 112 |
+
# merge only if the patches were created from scaled image
|
| 113 |
+
# patches are not created when scaled image size is equal to patch size
|
| 114 |
+
return patches
|
| 115 |
+
|
| 116 |
+
if n_patches_per_batch < 4:
|
| 117 |
+
# for each batch, at least 4 small patches are required to
|
| 118 |
+
# recreate a large square patch from merging them and later padding is applied
|
| 119 |
+
# 3 x (8x8) patches becomes 1 x ( 8x8 ) patch (extra patch ignored, no padding)
|
| 120 |
+
# 4 x (8x8) patches becomes 1 x (16x16) patch (padding later)
|
| 121 |
+
# 5 x (8x8) patches becomes 1 x (16x16) patch (extra patch ignored, padding later)
|
| 122 |
+
# 9 x (8x8) patches becomes 1 x (24x24) patch (padding later)
|
| 123 |
+
# thus the following code only rearranges the patches and removes extra ones
|
| 124 |
+
padding = 0
|
| 125 |
+
|
| 126 |
+
# make sure padding is not large enough to remove more than half of the patch
|
| 127 |
+
padding = min(out_size // 4, padding)
|
| 128 |
+
|
| 129 |
+
if padding == 0:
|
| 130 |
+
# faster when no padding is required
|
| 131 |
+
merged = patches.reshape(n_patches_per_batch, batch_size, hidden_size, out_size, out_size)
|
| 132 |
+
merged = merged.permute(1, 2, 0, 3, 4)
|
| 133 |
+
merged = merged[:, :, : sqrt_n_patches_per_batch**2, :, :]
|
| 134 |
+
merged = merged.reshape(
|
| 135 |
+
batch_size, hidden_size, sqrt_n_patches_per_batch, sqrt_n_patches_per_batch, out_size, out_size
|
| 136 |
+
)
|
| 137 |
+
merged = merged.permute(0, 1, 2, 4, 3, 5)
|
| 138 |
+
merged = merged.reshape(batch_size, hidden_size, new_out_size, new_out_size)
|
| 139 |
+
else:
|
| 140 |
+
# padding example:
|
| 141 |
+
# let out_size = 8, new_out_size = 32, padding = 2
|
| 142 |
+
# each patch is separated by "|"
|
| 143 |
+
# and padding is applied to the merging edges of each patch
|
| 144 |
+
# 00 01 02 03 04 05 06 07 | 08 09 10 11 12 13 14 15 | 16 17 18 19 20 21 22 23 | 24 25 26 27 28 29 30 31
|
| 145 |
+
# 00 01 02 03 04 05 -- -- | -- -- 10 11 12 13 -- -- | -- -- 18 19 20 21 -- -- | -- -- 26 27 28 29 30 31
|
| 146 |
+
i = 0
|
| 147 |
+
boxes = []
|
| 148 |
+
for h in range(sqrt_n_patches_per_batch):
|
| 149 |
+
boxes_in_row = []
|
| 150 |
+
for w in range(sqrt_n_patches_per_batch):
|
| 151 |
+
box = patches[batch_size * i : batch_size * (i + 1)]
|
| 152 |
+
|
| 153 |
+
# collect paddings
|
| 154 |
+
paddings = [0, 0, 0, 0]
|
| 155 |
+
if h != 0:
|
| 156 |
+
# remove pad from height if box is not at top border
|
| 157 |
+
paddings[0] = padding
|
| 158 |
+
if w != 0:
|
| 159 |
+
# remove pad from width if box is not at left border
|
| 160 |
+
paddings[2] = padding
|
| 161 |
+
if h != sqrt_n_patches_per_batch - 1:
|
| 162 |
+
# remove pad from height if box is not at bottom border
|
| 163 |
+
paddings[1] = padding
|
| 164 |
+
if w != sqrt_n_patches_per_batch - 1:
|
| 165 |
+
# remove pad from width if box is not at right border
|
| 166 |
+
paddings[3] = padding
|
| 167 |
+
|
| 168 |
+
# remove paddings
|
| 169 |
+
_, _, box_h, box_w = box.shape
|
| 170 |
+
pad_top, pad_bottom, pad_left, pad_right = paddings
|
| 171 |
+
box = box[:, :, pad_top : box_h - pad_bottom, pad_left : box_w - pad_right]
|
| 172 |
+
|
| 173 |
+
boxes_in_row.append(box)
|
| 174 |
+
i += 1
|
| 175 |
+
boxes_in_row = torch.cat(boxes_in_row, dim=-1)
|
| 176 |
+
boxes.append(boxes_in_row)
|
| 177 |
+
merged = torch.cat(boxes, dim=-2)
|
| 178 |
+
|
| 179 |
+
return merged
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def reconstruct_feature_maps(
|
| 183 |
+
hidden_state: torch.Tensor, batch_size: int, padding: int, output_size: tuple[float, float]
|
| 184 |
+
) -> torch.Tensor:
|
| 185 |
+
"""
|
| 186 |
+
Reconstructs feature maps from the hidden state produced by any of the encoder. Converts the hidden state of shape
|
| 187 |
+
`(n_patches_per_batch * batch_size, seq_len, hidden_size)` to feature maps of shape
|
| 188 |
+
`(batch_size, hidden_size, output_size[0], output_size[1])`.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
hidden_state (torch.Tensor): Input tensor of shape `(n_patches_per_batch * batch_size, seq_len, hidden_size)`
|
| 192 |
+
representing the encoded patches.
|
| 193 |
+
batch_size (int): The number of samples in a batch.
|
| 194 |
+
padding (int): The amount of padding to be removed when merging patches.
|
| 195 |
+
output_size (tuple[float, float]): The desired output size for the feature maps, specified as `(height, width)`.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
torch.Tensor: Reconstructed feature maps of shape `(batch_size, hidden_size, output_size[0], output_size[1])`.
|
| 199 |
+
"""
|
| 200 |
+
# reshape back to image like
|
| 201 |
+
features = reshape_features(hidden_state)
|
| 202 |
+
|
| 203 |
+
# merge all patches in a batch to create one large patch per batch
|
| 204 |
+
features = merge_patches(
|
| 205 |
+
features,
|
| 206 |
+
batch_size=batch_size,
|
| 207 |
+
padding=padding,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# interpolate patches to base size
|
| 211 |
+
features = F.interpolate(
|
| 212 |
+
features,
|
| 213 |
+
size=output_size,
|
| 214 |
+
mode="bilinear",
|
| 215 |
+
align_corners=False,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
return features
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class DepthProPatchEncoder(nn.Module):
|
| 222 |
+
def __init__(self, config: DepthProConfig):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.config = config
|
| 225 |
+
|
| 226 |
+
self.intermediate_hook_ids = config.intermediate_hook_ids
|
| 227 |
+
self.intermediate_feature_dims = config.intermediate_feature_dims
|
| 228 |
+
self.scaled_images_ratios = config.scaled_images_ratios
|
| 229 |
+
self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
|
| 230 |
+
self.scaled_images_feature_dims = config.scaled_images_feature_dims
|
| 231 |
+
self.merge_padding_value = config.merge_padding_value
|
| 232 |
+
|
| 233 |
+
self.n_scaled_images = len(config.scaled_images_ratios)
|
| 234 |
+
self.n_intermediate_hooks = len(config.intermediate_hook_ids)
|
| 235 |
+
self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
|
| 236 |
+
|
| 237 |
+
self.model = AutoModel.from_config(config.patch_model_config)
|
| 238 |
+
|
| 239 |
+
def forward(
|
| 240 |
+
self,
|
| 241 |
+
pixel_values: torch.Tensor,
|
| 242 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 243 |
+
) -> list[torch.Tensor]:
|
| 244 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 245 |
+
|
| 246 |
+
if min(self.scaled_images_ratios) * min(height, width) < self.config.patch_size:
|
| 247 |
+
raise ValueError(
|
| 248 |
+
f"Image size {height}x{width} is too small to be scaled "
|
| 249 |
+
f"with scaled_images_ratios={self.scaled_images_ratios} "
|
| 250 |
+
f"when patch_size={self.config.patch_size}."
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# STEP 1: create 3-level image
|
| 254 |
+
|
| 255 |
+
scaled_images = []
|
| 256 |
+
for ratio in self.scaled_images_ratios:
|
| 257 |
+
scaled_images.append(
|
| 258 |
+
F.interpolate(
|
| 259 |
+
pixel_values,
|
| 260 |
+
scale_factor=ratio,
|
| 261 |
+
mode="bilinear",
|
| 262 |
+
align_corners=False,
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# STEP 2: create patches
|
| 267 |
+
|
| 268 |
+
for i in range(self.n_scaled_images):
|
| 269 |
+
scaled_images[i] = split_to_patches(
|
| 270 |
+
scaled_images[i],
|
| 271 |
+
patch_size=self.config.patch_size,
|
| 272 |
+
overlap_ratio=self.scaled_images_overlap_ratios[i],
|
| 273 |
+
)
|
| 274 |
+
n_patches_per_scaled_image = [len(i) for i in scaled_images]
|
| 275 |
+
patches = torch.cat(scaled_images[::-1], dim=0) # -1 as patch encoder expects high res patches first
|
| 276 |
+
|
| 277 |
+
# STEP 3: apply patch encoder
|
| 278 |
+
|
| 279 |
+
encodings = self.model(
|
| 280 |
+
# each patch is processed as a separate batch
|
| 281 |
+
patches,
|
| 282 |
+
head_mask=head_mask,
|
| 283 |
+
# required for intermediate features
|
| 284 |
+
output_hidden_states=self.n_intermediate_hooks > 0,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
scaled_images_last_hidden_state = torch.split_with_sizes(encodings[0], n_patches_per_scaled_image[::-1])
|
| 288 |
+
# -1 (reverse list) as patch encoder returns high res patches first, we need low res first
|
| 289 |
+
scaled_images_last_hidden_state = scaled_images_last_hidden_state[::-1]
|
| 290 |
+
|
| 291 |
+
# calculate base height and width
|
| 292 |
+
# base height and width are the dimensions of the lowest resolution features
|
| 293 |
+
exponent_value = torch_int(math.log2(width / self.out_size))
|
| 294 |
+
base_height = height // 2**exponent_value
|
| 295 |
+
base_width = width // 2**exponent_value
|
| 296 |
+
|
| 297 |
+
# STEP 4: get patch features (high_res, med_res, low_res) - (3-5) in diagram
|
| 298 |
+
|
| 299 |
+
scaled_images_features = []
|
| 300 |
+
for i in range(self.n_scaled_images):
|
| 301 |
+
hidden_state = scaled_images_last_hidden_state[i]
|
| 302 |
+
batch_size = batch_size
|
| 303 |
+
padding = torch_int(self.merge_padding_value * (1 / self.scaled_images_ratios[i]))
|
| 304 |
+
output_height = base_height * 2**i
|
| 305 |
+
output_width = base_width * 2**i
|
| 306 |
+
features = reconstruct_feature_maps(
|
| 307 |
+
hidden_state,
|
| 308 |
+
batch_size=batch_size,
|
| 309 |
+
padding=padding,
|
| 310 |
+
output_size=(output_height, output_width),
|
| 311 |
+
)
|
| 312 |
+
scaled_images_features.append(features)
|
| 313 |
+
|
| 314 |
+
# STEP 5: get intermediate features - (1-2) in diagram
|
| 315 |
+
|
| 316 |
+
intermediate_features = []
|
| 317 |
+
for i in range(self.n_intermediate_hooks):
|
| 318 |
+
# +1 to correct index position as hidden_states contain embedding output as well
|
| 319 |
+
hidden_state = encodings[2][self.intermediate_hook_ids[i] + 1]
|
| 320 |
+
padding = torch_int(self.merge_padding_value * (1 / self.scaled_images_ratios[-1]))
|
| 321 |
+
output_height = base_height * 2 ** (self.n_scaled_images - 1)
|
| 322 |
+
output_width = base_width * 2 ** (self.n_scaled_images - 1)
|
| 323 |
+
features = reconstruct_feature_maps(
|
| 324 |
+
hidden_state,
|
| 325 |
+
batch_size=batch_size,
|
| 326 |
+
padding=padding,
|
| 327 |
+
output_size=(output_height, output_width),
|
| 328 |
+
)
|
| 329 |
+
intermediate_features.append(features)
|
| 330 |
+
|
| 331 |
+
# STEP 7: combine all features
|
| 332 |
+
features = [*scaled_images_features, *intermediate_features]
|
| 333 |
+
|
| 334 |
+
return features
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class DepthProImageEncoder(nn.Module):
|
| 338 |
+
def __init__(self, config: DepthProConfig):
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.config = config
|
| 341 |
+
self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
|
| 342 |
+
|
| 343 |
+
self.model = AutoModel.from_config(config.image_model_config)
|
| 344 |
+
|
| 345 |
+
def forward(
|
| 346 |
+
self,
|
| 347 |
+
pixel_values: torch.Tensor,
|
| 348 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 349 |
+
output_attentions: bool = False,
|
| 350 |
+
output_hidden_states: bool = False,
|
| 351 |
+
return_dict: bool = True,
|
| 352 |
+
) -> Union[tuple, DepthProOutput]:
|
| 353 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 354 |
+
|
| 355 |
+
# scale the image for image_encoder
|
| 356 |
+
size = self.config.image_model_config.image_size
|
| 357 |
+
pixel_values = F.interpolate(
|
| 358 |
+
pixel_values,
|
| 359 |
+
size=(size, size),
|
| 360 |
+
mode="bilinear",
|
| 361 |
+
align_corners=False,
|
| 362 |
+
)
|
| 363 |
+
encodings = self.model(
|
| 364 |
+
pixel_values=pixel_values,
|
| 365 |
+
head_mask=head_mask,
|
| 366 |
+
output_attentions=output_attentions,
|
| 367 |
+
output_hidden_states=output_hidden_states,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# calculate base height and width
|
| 371 |
+
# base height and width are the dimensions of the lowest resolution features
|
| 372 |
+
exponent_value = torch_int(math.log2(width / self.out_size))
|
| 373 |
+
base_height = height // 2**exponent_value
|
| 374 |
+
base_width = width // 2**exponent_value
|
| 375 |
+
|
| 376 |
+
features = reconstruct_feature_maps(
|
| 377 |
+
encodings[0],
|
| 378 |
+
batch_size=batch_size,
|
| 379 |
+
padding=0,
|
| 380 |
+
output_size=(base_height, base_width),
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if not return_dict:
|
| 384 |
+
return (encodings[0], features) + encodings[2:] # ignore last_hidden_state and poooler output
|
| 385 |
+
|
| 386 |
+
return DepthProOutput(
|
| 387 |
+
last_hidden_state=encodings.last_hidden_state,
|
| 388 |
+
features=features,
|
| 389 |
+
hidden_states=encodings.hidden_states,
|
| 390 |
+
attentions=encodings.attentions,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class DepthProEncoder(nn.Module):
|
| 395 |
+
def __init__(self, config: DepthProConfig):
|
| 396 |
+
super().__init__()
|
| 397 |
+
self.config = config
|
| 398 |
+
self.intermediate_hook_ids = config.intermediate_hook_ids
|
| 399 |
+
self.intermediate_feature_dims = config.intermediate_feature_dims
|
| 400 |
+
self.scaled_images_ratios = config.scaled_images_ratios
|
| 401 |
+
self.scaled_images_overlap_ratios = config.scaled_images_overlap_ratios
|
| 402 |
+
self.scaled_images_feature_dims = config.scaled_images_feature_dims
|
| 403 |
+
self.merge_padding_value = config.merge_padding_value
|
| 404 |
+
|
| 405 |
+
self.n_scaled_images = len(self.scaled_images_ratios)
|
| 406 |
+
self.n_intermediate_hooks = len(self.intermediate_hook_ids)
|
| 407 |
+
|
| 408 |
+
self.patch_encoder = DepthProPatchEncoder(config)
|
| 409 |
+
self.image_encoder = DepthProImageEncoder(config)
|
| 410 |
+
|
| 411 |
+
def forward(
|
| 412 |
+
self,
|
| 413 |
+
pixel_values: torch.Tensor,
|
| 414 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 415 |
+
output_attentions: bool = False,
|
| 416 |
+
output_hidden_states: bool = False,
|
| 417 |
+
return_dict: bool = True,
|
| 418 |
+
) -> Union[tuple, DepthProOutput]:
|
| 419 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 420 |
+
|
| 421 |
+
patch_features = self.patch_encoder(
|
| 422 |
+
pixel_values,
|
| 423 |
+
head_mask=head_mask,
|
| 424 |
+
)
|
| 425 |
+
image_encodings = self.image_encoder(
|
| 426 |
+
pixel_values,
|
| 427 |
+
head_mask=head_mask,
|
| 428 |
+
output_attentions=output_attentions,
|
| 429 |
+
output_hidden_states=output_hidden_states,
|
| 430 |
+
return_dict=return_dict,
|
| 431 |
+
)
|
| 432 |
+
image_features = image_encodings[1] # index 1 contains features
|
| 433 |
+
|
| 434 |
+
features = [image_features, *patch_features]
|
| 435 |
+
|
| 436 |
+
if not return_dict:
|
| 437 |
+
return (image_encodings[0], features) + image_encodings[2:]
|
| 438 |
+
|
| 439 |
+
return DepthProOutput(
|
| 440 |
+
last_hidden_state=image_encodings.last_hidden_state,
|
| 441 |
+
features=features,
|
| 442 |
+
hidden_states=image_encodings.hidden_states,
|
| 443 |
+
attentions=image_encodings.attentions,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class DepthProFeatureUpsampleBlock(nn.Module):
|
| 448 |
+
def __init__(
|
| 449 |
+
self,
|
| 450 |
+
config: DepthProConfig,
|
| 451 |
+
input_dims: int,
|
| 452 |
+
intermediate_dims: int,
|
| 453 |
+
output_dims: int,
|
| 454 |
+
n_upsample_layers: int,
|
| 455 |
+
use_proj: bool = True,
|
| 456 |
+
bias: bool = False,
|
| 457 |
+
):
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.config = config
|
| 460 |
+
self.layers = nn.ModuleList()
|
| 461 |
+
|
| 462 |
+
# create first projection layer
|
| 463 |
+
if use_proj:
|
| 464 |
+
proj = nn.Conv2d(
|
| 465 |
+
in_channels=input_dims,
|
| 466 |
+
out_channels=intermediate_dims,
|
| 467 |
+
kernel_size=1,
|
| 468 |
+
stride=1,
|
| 469 |
+
padding=0,
|
| 470 |
+
bias=bias,
|
| 471 |
+
)
|
| 472 |
+
self.layers.append(proj)
|
| 473 |
+
|
| 474 |
+
# create following upsample layers
|
| 475 |
+
for i in range(n_upsample_layers):
|
| 476 |
+
in_channels = intermediate_dims if i == 0 else output_dims
|
| 477 |
+
layer = nn.ConvTranspose2d(
|
| 478 |
+
in_channels=in_channels,
|
| 479 |
+
out_channels=output_dims,
|
| 480 |
+
kernel_size=2,
|
| 481 |
+
stride=2,
|
| 482 |
+
padding=0,
|
| 483 |
+
bias=bias,
|
| 484 |
+
)
|
| 485 |
+
self.layers.append(layer)
|
| 486 |
+
|
| 487 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 488 |
+
for layer in self.layers:
|
| 489 |
+
features = layer(features)
|
| 490 |
+
return features
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class DepthProFeatureUpsample(nn.Module):
|
| 494 |
+
def __init__(self, config: DepthProConfig):
|
| 495 |
+
super().__init__()
|
| 496 |
+
self.config = config
|
| 497 |
+
self.n_scaled_images = len(self.config.scaled_images_ratios)
|
| 498 |
+
self.n_intermediate_hooks = len(self.config.intermediate_hook_ids)
|
| 499 |
+
|
| 500 |
+
# for image_features
|
| 501 |
+
self.image_block = DepthProFeatureUpsampleBlock(
|
| 502 |
+
config=config,
|
| 503 |
+
input_dims=config.image_model_config.hidden_size,
|
| 504 |
+
intermediate_dims=config.image_model_config.hidden_size,
|
| 505 |
+
output_dims=config.scaled_images_feature_dims[0],
|
| 506 |
+
n_upsample_layers=1,
|
| 507 |
+
use_proj=False,
|
| 508 |
+
bias=True,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# for scaled_images_features
|
| 512 |
+
self.scaled_images = nn.ModuleList()
|
| 513 |
+
for i, feature_dims in enumerate(config.scaled_images_feature_dims):
|
| 514 |
+
block = DepthProFeatureUpsampleBlock(
|
| 515 |
+
config=config,
|
| 516 |
+
input_dims=config.patch_model_config.hidden_size,
|
| 517 |
+
intermediate_dims=feature_dims,
|
| 518 |
+
output_dims=feature_dims,
|
| 519 |
+
n_upsample_layers=1,
|
| 520 |
+
)
|
| 521 |
+
self.scaled_images.append(block)
|
| 522 |
+
|
| 523 |
+
# for intermediate_features
|
| 524 |
+
self.intermediate = nn.ModuleList()
|
| 525 |
+
for i, feature_dims in enumerate(config.intermediate_feature_dims):
|
| 526 |
+
intermediate_dims = config.fusion_hidden_size if i == 0 else feature_dims
|
| 527 |
+
block = DepthProFeatureUpsampleBlock(
|
| 528 |
+
config=config,
|
| 529 |
+
input_dims=config.patch_model_config.hidden_size,
|
| 530 |
+
intermediate_dims=intermediate_dims,
|
| 531 |
+
output_dims=feature_dims,
|
| 532 |
+
n_upsample_layers=2 + i,
|
| 533 |
+
)
|
| 534 |
+
self.intermediate.append(block)
|
| 535 |
+
|
| 536 |
+
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
|
| 537 |
+
features[0] = self.image_block(features[0])
|
| 538 |
+
|
| 539 |
+
for i in range(self.n_scaled_images):
|
| 540 |
+
features[i + 1] = self.scaled_images[i](features[i + 1])
|
| 541 |
+
|
| 542 |
+
for i in range(self.n_intermediate_hooks):
|
| 543 |
+
features[self.n_scaled_images + i + 1] = self.intermediate[i](features[self.n_scaled_images + i + 1])
|
| 544 |
+
|
| 545 |
+
return features
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class DepthProFeatureProjection(nn.Module):
|
| 549 |
+
def __init__(self, config: DepthProConfig):
|
| 550 |
+
super().__init__()
|
| 551 |
+
self.config = config
|
| 552 |
+
|
| 553 |
+
combined_feature_dims = config.scaled_images_feature_dims + config.intermediate_feature_dims
|
| 554 |
+
self.projections = nn.ModuleList()
|
| 555 |
+
for i, in_channels in enumerate(combined_feature_dims):
|
| 556 |
+
if i == len(combined_feature_dims) - 1 and in_channels == config.fusion_hidden_size:
|
| 557 |
+
# projection for last layer can be ignored if input and output channels already match
|
| 558 |
+
self.projections.append(nn.Identity())
|
| 559 |
+
else:
|
| 560 |
+
self.projections.append(
|
| 561 |
+
nn.Conv2d(
|
| 562 |
+
in_channels=in_channels,
|
| 563 |
+
out_channels=config.fusion_hidden_size,
|
| 564 |
+
kernel_size=3,
|
| 565 |
+
stride=1,
|
| 566 |
+
padding=1,
|
| 567 |
+
bias=False,
|
| 568 |
+
)
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
|
| 572 |
+
projected_features = []
|
| 573 |
+
for i, projection in enumerate(self.projections):
|
| 574 |
+
upsampled_feature = projection(features[i])
|
| 575 |
+
projected_features.append(upsampled_feature)
|
| 576 |
+
return projected_features
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
class DepthProNeck(nn.Module):
|
| 580 |
+
def __init__(self, config: DepthProConfig):
|
| 581 |
+
super().__init__()
|
| 582 |
+
self.config = config
|
| 583 |
+
|
| 584 |
+
self.feature_upsample = DepthProFeatureUpsample(config)
|
| 585 |
+
self.fuse_image_with_low_res = nn.Conv2d(
|
| 586 |
+
in_channels=config.scaled_images_feature_dims[0] * 2,
|
| 587 |
+
out_channels=config.scaled_images_feature_dims[0],
|
| 588 |
+
kernel_size=1,
|
| 589 |
+
stride=1,
|
| 590 |
+
padding=0,
|
| 591 |
+
bias=True,
|
| 592 |
+
)
|
| 593 |
+
self.feature_projection = DepthProFeatureProjection(config)
|
| 594 |
+
|
| 595 |
+
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
|
| 596 |
+
features = self.feature_upsample(features)
|
| 597 |
+
# global features = low res features + image features
|
| 598 |
+
global_features = torch.cat((features[1], features[0]), dim=1)
|
| 599 |
+
global_features = self.fuse_image_with_low_res(global_features)
|
| 600 |
+
features = [global_features, *features[2:]]
|
| 601 |
+
features = self.feature_projection(features)
|
| 602 |
+
return features
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
# General docstring
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
@auto_docstring
|
| 609 |
+
class DepthProPreTrainedModel(PreTrainedModel):
|
| 610 |
+
config: DepthProConfig
|
| 611 |
+
base_model_prefix = "depth_pro"
|
| 612 |
+
main_input_name = "pixel_values"
|
| 613 |
+
supports_gradient_checkpointing = True
|
| 614 |
+
_supports_sdpa = True
|
| 615 |
+
_no_split_modules = ["DepthProPreActResidualLayer"]
|
| 616 |
+
_keys_to_ignore_on_load_unexpected = ["fov_model.*"]
|
| 617 |
+
|
| 618 |
+
def _init_weights(self, module):
|
| 619 |
+
"""Initialize the weights"""
|
| 620 |
+
if isinstance(module, nn.Linear):
|
| 621 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 622 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 623 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 624 |
+
if module.bias is not None:
|
| 625 |
+
module.bias.data.zero_()
|
| 626 |
+
elif isinstance(module, nn.LayerNorm):
|
| 627 |
+
module.bias.data.zero_()
|
| 628 |
+
module.weight.data.fill_(1.0)
|
| 629 |
+
elif isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
| 630 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
| 631 |
+
if module.bias is not None:
|
| 632 |
+
module.bias.data.zero_()
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
@auto_docstring
|
| 636 |
+
class DepthProModel(DepthProPreTrainedModel):
|
| 637 |
+
def __init__(self, config):
|
| 638 |
+
super().__init__(config)
|
| 639 |
+
self.config = config
|
| 640 |
+
self.encoder = DepthProEncoder(config)
|
| 641 |
+
self.neck = DepthProNeck(config)
|
| 642 |
+
# Initialize weights and apply final processing
|
| 643 |
+
self.post_init()
|
| 644 |
+
|
| 645 |
+
def get_input_embeddings(self):
|
| 646 |
+
return self.encoder.image_encoder.model.get_input_embeddings()
|
| 647 |
+
|
| 648 |
+
@auto_docstring
|
| 649 |
+
def forward(
|
| 650 |
+
self,
|
| 651 |
+
pixel_values: torch.FloatTensor,
|
| 652 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 653 |
+
output_attentions: Optional[bool] = None,
|
| 654 |
+
output_hidden_states: Optional[bool] = None,
|
| 655 |
+
return_dict: Optional[bool] = None,
|
| 656 |
+
) -> Union[tuple, DepthProOutput]:
|
| 657 |
+
r"""
|
| 658 |
+
Examples:
|
| 659 |
+
|
| 660 |
+
```python
|
| 661 |
+
>>> import torch
|
| 662 |
+
>>> from PIL import Image
|
| 663 |
+
>>> import requests
|
| 664 |
+
>>> from transformers import AutoProcessor, DepthProModel
|
| 665 |
+
|
| 666 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
| 667 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 668 |
+
|
| 669 |
+
>>> checkpoint = "apple/DepthPro-hf"
|
| 670 |
+
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
| 671 |
+
>>> model = DepthProModel.from_pretrained(checkpoint)
|
| 672 |
+
|
| 673 |
+
>>> # prepare image for the model
|
| 674 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
| 675 |
+
|
| 676 |
+
>>> with torch.no_grad():
|
| 677 |
+
... output = model(**inputs)
|
| 678 |
+
|
| 679 |
+
>>> output.last_hidden_state.shape
|
| 680 |
+
torch.Size([1, 35, 577, 1024])
|
| 681 |
+
```"""
|
| 682 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 683 |
+
output_hidden_states = (
|
| 684 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 685 |
+
)
|
| 686 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 687 |
+
|
| 688 |
+
encodings = self.encoder(
|
| 689 |
+
pixel_values,
|
| 690 |
+
head_mask=head_mask,
|
| 691 |
+
output_attentions=output_attentions,
|
| 692 |
+
output_hidden_states=output_hidden_states,
|
| 693 |
+
return_dict=return_dict,
|
| 694 |
+
)
|
| 695 |
+
features = encodings[1] # index 1 contains features
|
| 696 |
+
features = self.neck(features)
|
| 697 |
+
|
| 698 |
+
if not return_dict:
|
| 699 |
+
return (encodings[0], features) + encodings[2:]
|
| 700 |
+
|
| 701 |
+
return DepthProOutput(
|
| 702 |
+
last_hidden_state=encodings.last_hidden_state,
|
| 703 |
+
features=features,
|
| 704 |
+
hidden_states=encodings.hidden_states,
|
| 705 |
+
attentions=encodings.attentions,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
# Copied from transformers.models.dpt.modeling_dpt.DPTPreActResidualLayer DPT->DepthPro
|
| 710 |
+
class DepthProPreActResidualLayer(nn.Module):
|
| 711 |
+
"""
|
| 712 |
+
ResidualConvUnit, pre-activate residual unit.
|
| 713 |
+
|
| 714 |
+
Args:
|
| 715 |
+
config (`[DepthProConfig]`):
|
| 716 |
+
Model configuration class defining the model architecture.
|
| 717 |
+
"""
|
| 718 |
+
|
| 719 |
+
def __init__(self, config: DepthProConfig):
|
| 720 |
+
super().__init__()
|
| 721 |
+
|
| 722 |
+
self.use_batch_norm = config.use_batch_norm_in_fusion_residual
|
| 723 |
+
use_bias_in_fusion_residual = (
|
| 724 |
+
config.use_bias_in_fusion_residual
|
| 725 |
+
if config.use_bias_in_fusion_residual is not None
|
| 726 |
+
else not self.use_batch_norm
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
self.activation1 = nn.ReLU()
|
| 730 |
+
self.convolution1 = nn.Conv2d(
|
| 731 |
+
config.fusion_hidden_size,
|
| 732 |
+
config.fusion_hidden_size,
|
| 733 |
+
kernel_size=3,
|
| 734 |
+
stride=1,
|
| 735 |
+
padding=1,
|
| 736 |
+
bias=use_bias_in_fusion_residual,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
self.activation2 = nn.ReLU()
|
| 740 |
+
self.convolution2 = nn.Conv2d(
|
| 741 |
+
config.fusion_hidden_size,
|
| 742 |
+
config.fusion_hidden_size,
|
| 743 |
+
kernel_size=3,
|
| 744 |
+
stride=1,
|
| 745 |
+
padding=1,
|
| 746 |
+
bias=use_bias_in_fusion_residual,
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
if self.use_batch_norm:
|
| 750 |
+
self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
|
| 751 |
+
self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
|
| 752 |
+
|
| 753 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 754 |
+
residual = hidden_state
|
| 755 |
+
hidden_state = self.activation1(hidden_state)
|
| 756 |
+
|
| 757 |
+
hidden_state = self.convolution1(hidden_state)
|
| 758 |
+
|
| 759 |
+
if self.use_batch_norm:
|
| 760 |
+
hidden_state = self.batch_norm1(hidden_state)
|
| 761 |
+
|
| 762 |
+
hidden_state = self.activation2(hidden_state)
|
| 763 |
+
hidden_state = self.convolution2(hidden_state)
|
| 764 |
+
|
| 765 |
+
if self.use_batch_norm:
|
| 766 |
+
hidden_state = self.batch_norm2(hidden_state)
|
| 767 |
+
|
| 768 |
+
return hidden_state + residual
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
# Modified from transformers.models.dpt.modeling_dpt.DPTFeatureFusionLayer
|
| 772 |
+
# except it uses deconv and skip_add and needs no interpolation
|
| 773 |
+
class DepthProFeatureFusionLayer(nn.Module):
|
| 774 |
+
def __init__(self, config: DepthProConfig, use_deconv: bool = True):
|
| 775 |
+
super().__init__()
|
| 776 |
+
self.config = config
|
| 777 |
+
self.use_deconv = use_deconv
|
| 778 |
+
|
| 779 |
+
self.residual_layer1 = DepthProPreActResidualLayer(config)
|
| 780 |
+
self.residual_layer2 = DepthProPreActResidualLayer(config)
|
| 781 |
+
|
| 782 |
+
if self.use_deconv:
|
| 783 |
+
self.deconv = nn.ConvTranspose2d(
|
| 784 |
+
in_channels=config.fusion_hidden_size,
|
| 785 |
+
out_channels=config.fusion_hidden_size,
|
| 786 |
+
kernel_size=2,
|
| 787 |
+
stride=2,
|
| 788 |
+
padding=0,
|
| 789 |
+
bias=False,
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
|
| 793 |
+
|
| 794 |
+
def forward(self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 795 |
+
if residual is not None:
|
| 796 |
+
residual = self.residual_layer1(residual)
|
| 797 |
+
hidden_state = hidden_state + residual
|
| 798 |
+
|
| 799 |
+
hidden_state = self.residual_layer2(hidden_state)
|
| 800 |
+
if self.use_deconv:
|
| 801 |
+
hidden_state = self.deconv(hidden_state)
|
| 802 |
+
hidden_state = self.projection(hidden_state)
|
| 803 |
+
|
| 804 |
+
return hidden_state
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
# Modified from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage with DPT->DepthPro
|
| 808 |
+
# with deconv and reversed layers
|
| 809 |
+
class DepthProFeatureFusionStage(nn.Module):
|
| 810 |
+
def __init__(self, config):
|
| 811 |
+
super().__init__()
|
| 812 |
+
self.config = config
|
| 813 |
+
|
| 814 |
+
self.num_layers = len(config.intermediate_hook_ids) + len(config.scaled_images_ratios)
|
| 815 |
+
self.intermediate = nn.ModuleList()
|
| 816 |
+
for _ in range(self.num_layers - 1):
|
| 817 |
+
self.intermediate.append(DepthProFeatureFusionLayer(config))
|
| 818 |
+
|
| 819 |
+
# final layer does not require deconvolution
|
| 820 |
+
self.final = DepthProFeatureFusionLayer(config, use_deconv=False)
|
| 821 |
+
|
| 822 |
+
def forward(self, hidden_states: list[torch.Tensor]) -> list[torch.Tensor]:
|
| 823 |
+
if self.num_layers != len(hidden_states):
|
| 824 |
+
raise ValueError(
|
| 825 |
+
f"num_layers={self.num_layers} in DepthProFeatureFusionStage"
|
| 826 |
+
f"does not match len(hidden_states)={len(hidden_states)}"
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
fused_hidden_states = []
|
| 830 |
+
fused_hidden_state = None
|
| 831 |
+
for hidden_state, layer in zip(hidden_states[:-1], self.intermediate):
|
| 832 |
+
if fused_hidden_state is None:
|
| 833 |
+
# first layer only uses the last hidden_state
|
| 834 |
+
fused_hidden_state = layer(hidden_state)
|
| 835 |
+
else:
|
| 836 |
+
fused_hidden_state = layer(fused_hidden_state, hidden_state)
|
| 837 |
+
fused_hidden_states.append(fused_hidden_state)
|
| 838 |
+
|
| 839 |
+
hidden_state = hidden_states[-1]
|
| 840 |
+
fused_hidden_state = self.final(fused_hidden_state, hidden_state)
|
| 841 |
+
fused_hidden_states.append(fused_hidden_state)
|
| 842 |
+
|
| 843 |
+
return fused_hidden_states
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
class DepthProFovEncoder(nn.Module):
|
| 847 |
+
def __init__(self, config: DepthProConfig):
|
| 848 |
+
super().__init__()
|
| 849 |
+
self.config = config
|
| 850 |
+
self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
|
| 851 |
+
|
| 852 |
+
self.model = AutoModel.from_config(config.fov_model_config)
|
| 853 |
+
self.neck = nn.Linear(config.fov_model_config.hidden_size, config.fusion_hidden_size // 2)
|
| 854 |
+
|
| 855 |
+
def forward(
|
| 856 |
+
self,
|
| 857 |
+
pixel_values: torch.Tensor,
|
| 858 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 859 |
+
) -> torch.Tensor:
|
| 860 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 861 |
+
|
| 862 |
+
# scale the image for fov_encoder
|
| 863 |
+
size = self.config.fov_model_config.image_size
|
| 864 |
+
pixel_values = F.interpolate(
|
| 865 |
+
pixel_values,
|
| 866 |
+
size=(size, size),
|
| 867 |
+
mode="bilinear",
|
| 868 |
+
align_corners=False,
|
| 869 |
+
)
|
| 870 |
+
encodings = self.model(
|
| 871 |
+
pixel_values=pixel_values,
|
| 872 |
+
head_mask=head_mask,
|
| 873 |
+
)
|
| 874 |
+
hidden_state = encodings[0]
|
| 875 |
+
hidden_state = self.neck(hidden_state)
|
| 876 |
+
|
| 877 |
+
# calculate base height and width
|
| 878 |
+
# base height and width are the dimensions of the lowest resolution features
|
| 879 |
+
exponent_value = torch_int(math.log2(width / self.out_size))
|
| 880 |
+
base_height = height // 2**exponent_value
|
| 881 |
+
base_width = width // 2**exponent_value
|
| 882 |
+
|
| 883 |
+
features = reconstruct_feature_maps(
|
| 884 |
+
hidden_state,
|
| 885 |
+
batch_size=batch_size,
|
| 886 |
+
padding=0,
|
| 887 |
+
output_size=(base_height, base_width),
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
return features
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
class DepthProFovHead(nn.Module):
|
| 894 |
+
def __init__(self, config: DepthProConfig):
|
| 895 |
+
super().__init__()
|
| 896 |
+
self.config = config
|
| 897 |
+
self.fusion_hidden_size = config.fusion_hidden_size
|
| 898 |
+
self.out_size = config.image_model_config.image_size // config.image_model_config.patch_size
|
| 899 |
+
|
| 900 |
+
# create initial head layers
|
| 901 |
+
self.layers = nn.ModuleList()
|
| 902 |
+
for i in range(config.num_fov_head_layers):
|
| 903 |
+
self.layers.append(
|
| 904 |
+
nn.Conv2d(
|
| 905 |
+
math.ceil(self.fusion_hidden_size / 2 ** (i + 1)),
|
| 906 |
+
math.ceil(self.fusion_hidden_size / 2 ** (i + 2)),
|
| 907 |
+
kernel_size=3,
|
| 908 |
+
stride=2,
|
| 909 |
+
padding=1,
|
| 910 |
+
)
|
| 911 |
+
)
|
| 912 |
+
self.layers.append(nn.ReLU(True))
|
| 913 |
+
# calculate expected shapes to finally generate a scalar output from final head layer
|
| 914 |
+
final_in_channels = math.ceil(self.fusion_hidden_size / 2 ** (config.num_fov_head_layers + 1))
|
| 915 |
+
final_kernel_size = torch_int((self.out_size - 1) / 2**config.num_fov_head_layers + 1)
|
| 916 |
+
self.layers.append(
|
| 917 |
+
nn.Conv2d(
|
| 918 |
+
in_channels=final_in_channels, out_channels=1, kernel_size=final_kernel_size, stride=1, padding=0
|
| 919 |
+
)
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 923 |
+
features = F.interpolate(
|
| 924 |
+
features,
|
| 925 |
+
size=(self.out_size, self.out_size),
|
| 926 |
+
mode="bilinear",
|
| 927 |
+
align_corners=False,
|
| 928 |
+
)
|
| 929 |
+
for layer in self.layers:
|
| 930 |
+
features = layer(features)
|
| 931 |
+
return features
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
class DepthProFovModel(nn.Module):
|
| 935 |
+
def __init__(self, config: DepthProConfig):
|
| 936 |
+
super().__init__()
|
| 937 |
+
self.config = config
|
| 938 |
+
self.fusion_hidden_size = config.fusion_hidden_size
|
| 939 |
+
|
| 940 |
+
self.fov_encoder = DepthProFovEncoder(config)
|
| 941 |
+
self.conv = nn.Conv2d(
|
| 942 |
+
self.fusion_hidden_size, self.fusion_hidden_size // 2, kernel_size=3, stride=2, padding=1
|
| 943 |
+
)
|
| 944 |
+
self.activation = nn.ReLU(inplace=True)
|
| 945 |
+
self.head = DepthProFovHead(config)
|
| 946 |
+
|
| 947 |
+
def forward(
|
| 948 |
+
self,
|
| 949 |
+
pixel_values: torch.Tensor,
|
| 950 |
+
global_features: torch.Tensor,
|
| 951 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 952 |
+
) -> torch.Tensor:
|
| 953 |
+
fov_features = self.fov_encoder(pixel_values, head_mask)
|
| 954 |
+
|
| 955 |
+
global_features = self.conv(global_features)
|
| 956 |
+
global_features = self.activation(global_features)
|
| 957 |
+
|
| 958 |
+
fov_features = fov_features + global_features
|
| 959 |
+
fov_output = self.head(fov_features)
|
| 960 |
+
fov_output = fov_output.flatten()
|
| 961 |
+
|
| 962 |
+
return fov_output
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
class DepthProDepthEstimationHead(nn.Module):
|
| 966 |
+
"""
|
| 967 |
+
The DepthProDepthEstimationHead module serves as the output head for depth estimation tasks.
|
| 968 |
+
This module comprises a sequence of convolutional and transposed convolutional layers
|
| 969 |
+
that process the feature map from the fusion to produce a single-channel depth map.
|
| 970 |
+
Key operations include dimensionality reduction and upsampling to match the input resolution.
|
| 971 |
+
"""
|
| 972 |
+
|
| 973 |
+
def __init__(self, config):
|
| 974 |
+
super().__init__()
|
| 975 |
+
self.config = config
|
| 976 |
+
|
| 977 |
+
features = config.fusion_hidden_size
|
| 978 |
+
self.layers = nn.ModuleList(
|
| 979 |
+
[
|
| 980 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
| 981 |
+
nn.ConvTranspose2d(
|
| 982 |
+
in_channels=features // 2,
|
| 983 |
+
out_channels=features // 2,
|
| 984 |
+
kernel_size=2,
|
| 985 |
+
stride=2,
|
| 986 |
+
padding=0,
|
| 987 |
+
bias=True,
|
| 988 |
+
),
|
| 989 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
| 990 |
+
nn.ReLU(True),
|
| 991 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
| 992 |
+
nn.ReLU(),
|
| 993 |
+
]
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 997 |
+
for layer in self.layers:
|
| 998 |
+
hidden_states = layer(hidden_states)
|
| 999 |
+
|
| 1000 |
+
predicted_depth = hidden_states.squeeze(dim=1)
|
| 1001 |
+
return predicted_depth
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
@auto_docstring(
|
| 1005 |
+
custom_intro="""
|
| 1006 |
+
DepthPro Model with a depth estimation head on top (consisting of 3 convolutional layers).
|
| 1007 |
+
"""
|
| 1008 |
+
)
|
| 1009 |
+
class DepthProForDepthEstimation(DepthProPreTrainedModel):
|
| 1010 |
+
def __init__(self, config, use_fov_model=None):
|
| 1011 |
+
r"""
|
| 1012 |
+
use_fov_model (bool, *optional*):
|
| 1013 |
+
Whether to use the field of view model.
|
| 1014 |
+
"""
|
| 1015 |
+
super().__init__(config)
|
| 1016 |
+
self.config = config
|
| 1017 |
+
self.use_fov_model = use_fov_model if use_fov_model is not None else self.config.use_fov_model
|
| 1018 |
+
|
| 1019 |
+
# dinov2 (vit) like encoders
|
| 1020 |
+
self.depth_pro = DepthProModel(config)
|
| 1021 |
+
|
| 1022 |
+
# dpt (vit) like fusion stage
|
| 1023 |
+
self.fusion_stage = DepthProFeatureFusionStage(config)
|
| 1024 |
+
|
| 1025 |
+
# depth estimation head
|
| 1026 |
+
self.head = DepthProDepthEstimationHead(config)
|
| 1027 |
+
|
| 1028 |
+
# dinov2 (vit) like encoder
|
| 1029 |
+
self.fov_model = DepthProFovModel(config) if self.use_fov_model else None
|
| 1030 |
+
|
| 1031 |
+
# Initialize weights and apply final processing
|
| 1032 |
+
self.post_init()
|
| 1033 |
+
|
| 1034 |
+
@auto_docstring
|
| 1035 |
+
def forward(
|
| 1036 |
+
self,
|
| 1037 |
+
pixel_values: torch.FloatTensor,
|
| 1038 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1039 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1040 |
+
output_attentions: Optional[bool] = None,
|
| 1041 |
+
output_hidden_states: Optional[bool] = None,
|
| 1042 |
+
return_dict: Optional[bool] = None,
|
| 1043 |
+
) -> Union[tuple[torch.Tensor], DepthProDepthEstimatorOutput]:
|
| 1044 |
+
r"""
|
| 1045 |
+
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
| 1046 |
+
Ground truth depth estimation maps for computing the loss.
|
| 1047 |
+
|
| 1048 |
+
Examples:
|
| 1049 |
+
|
| 1050 |
+
```python
|
| 1051 |
+
>>> from transformers import AutoImageProcessor, DepthProForDepthEstimation
|
| 1052 |
+
>>> import torch
|
| 1053 |
+
>>> from PIL import Image
|
| 1054 |
+
>>> import requests
|
| 1055 |
+
|
| 1056 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1057 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1058 |
+
|
| 1059 |
+
>>> checkpoint = "apple/DepthPro-hf"
|
| 1060 |
+
>>> processor = AutoImageProcessor.from_pretrained(checkpoint)
|
| 1061 |
+
>>> model = DepthProForDepthEstimation.from_pretrained(checkpoint)
|
| 1062 |
+
|
| 1063 |
+
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1064 |
+
>>> model.to(device)
|
| 1065 |
+
|
| 1066 |
+
>>> # prepare image for the model
|
| 1067 |
+
>>> inputs = processor(images=image, return_tensors="pt").to(device)
|
| 1068 |
+
|
| 1069 |
+
>>> with torch.no_grad():
|
| 1070 |
+
... outputs = model(**inputs)
|
| 1071 |
+
|
| 1072 |
+
>>> # interpolate to original size
|
| 1073 |
+
>>> post_processed_output = processor.post_process_depth_estimation(
|
| 1074 |
+
... outputs, target_sizes=[(image.height, image.width)],
|
| 1075 |
+
... )
|
| 1076 |
+
|
| 1077 |
+
>>> # get the field of view (fov) predictions
|
| 1078 |
+
>>> field_of_view = post_processed_output[0]["field_of_view"]
|
| 1079 |
+
>>> focal_length = post_processed_output[0]["focal_length"]
|
| 1080 |
+
|
| 1081 |
+
>>> # visualize the prediction
|
| 1082 |
+
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
| 1083 |
+
>>> depth = predicted_depth * 255 / predicted_depth.max()
|
| 1084 |
+
>>> depth = depth.detach().cpu().numpy()
|
| 1085 |
+
>>> depth = Image.fromarray(depth.astype("uint8"))
|
| 1086 |
+
```"""
|
| 1087 |
+
loss = None
|
| 1088 |
+
if labels is not None:
|
| 1089 |
+
raise NotImplementedError("Training is not implemented yet")
|
| 1090 |
+
|
| 1091 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1092 |
+
output_hidden_states = (
|
| 1093 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1094 |
+
)
|
| 1095 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1096 |
+
|
| 1097 |
+
depth_pro_outputs = self.depth_pro(
|
| 1098 |
+
pixel_values=pixel_values,
|
| 1099 |
+
head_mask=head_mask,
|
| 1100 |
+
output_attentions=output_attentions,
|
| 1101 |
+
output_hidden_states=output_hidden_states,
|
| 1102 |
+
return_dict=True,
|
| 1103 |
+
)
|
| 1104 |
+
features = depth_pro_outputs.features
|
| 1105 |
+
fused_hidden_states = self.fusion_stage(features)
|
| 1106 |
+
predicted_depth = self.head(fused_hidden_states[-1])
|
| 1107 |
+
|
| 1108 |
+
if self.use_fov_model:
|
| 1109 |
+
# frozen features from encoder are used
|
| 1110 |
+
features_for_fov = features[0].detach()
|
| 1111 |
+
fov = self.fov_model(
|
| 1112 |
+
pixel_values=pixel_values,
|
| 1113 |
+
global_features=features_for_fov,
|
| 1114 |
+
head_mask=head_mask,
|
| 1115 |
+
)
|
| 1116 |
+
else:
|
| 1117 |
+
fov = None
|
| 1118 |
+
|
| 1119 |
+
if not return_dict:
|
| 1120 |
+
outputs = [loss, predicted_depth, fov, depth_pro_outputs.hidden_states, depth_pro_outputs.attentions]
|
| 1121 |
+
return tuple(v for v in outputs if v is not None)
|
| 1122 |
+
|
| 1123 |
+
return DepthProDepthEstimatorOutput(
|
| 1124 |
+
loss=loss,
|
| 1125 |
+
predicted_depth=predicted_depth,
|
| 1126 |
+
field_of_view=fov,
|
| 1127 |
+
hidden_states=depth_pro_outputs.hidden_states,
|
| 1128 |
+
attentions=depth_pro_outputs.attentions,
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
|
| 1132 |
+
__all__ = ["DepthProPreTrainedModel", "DepthProModel", "DepthProForDepthEstimation"]
|
phivenv/Lib/site-packages/transformers/models/detr/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import _LazyModule
|
| 18 |
+
from ...utils.import_utils import define_import_structure
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from .configuration_detr import *
|
| 23 |
+
from .feature_extraction_detr import *
|
| 24 |
+
from .image_processing_detr import *
|
| 25 |
+
from .image_processing_detr_fast import *
|
| 26 |
+
from .modeling_detr import *
|
| 27 |
+
else:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
_file = globals()["__file__"]
|
| 31 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (628 Bytes). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/configuration_detr.cpython-39.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/feature_extraction_detr.cpython-39.pyc
ADDED
|
Binary file (1.4 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/image_processing_detr.cpython-39.pyc
ADDED
|
Binary file (71.3 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/image_processing_detr_fast.cpython-39.pyc
ADDED
|
Binary file (43.1 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/detr/__pycache__/modeling_detr.cpython-39.pyc
ADDED
|
Binary file (53.3 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/detr/configuration_detr.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Facebook AI Research and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""DETR model configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from collections.abc import Mapping
|
| 19 |
+
|
| 20 |
+
from packaging import version
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import PretrainedConfig
|
| 23 |
+
from ...onnx import OnnxConfig
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 26 |
+
from ..auto import CONFIG_MAPPING
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DetrConfig(PretrainedConfig):
|
| 33 |
+
r"""
|
| 34 |
+
This is the configuration class to store the configuration of a [`DetrModel`]. It is used to instantiate a DETR
|
| 35 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 36 |
+
defaults will yield a similar configuration to that of the DETR
|
| 37 |
+
[facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50) architecture.
|
| 38 |
+
|
| 39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 40 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
| 44 |
+
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
|
| 45 |
+
API.
|
| 46 |
+
backbone_config (`PretrainedConfig` or `dict`, *optional*):
|
| 47 |
+
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
|
| 48 |
+
case it will default to `ResNetConfig()`.
|
| 49 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 50 |
+
The number of input channels.
|
| 51 |
+
num_queries (`int`, *optional*, defaults to 100):
|
| 52 |
+
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
|
| 53 |
+
detect in a single image. For COCO, we recommend 100 queries.
|
| 54 |
+
d_model (`int`, *optional*, defaults to 256):
|
| 55 |
+
This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
|
| 56 |
+
encoder_layers (`int`, *optional*, defaults to 6):
|
| 57 |
+
Number of encoder layers.
|
| 58 |
+
decoder_layers (`int`, *optional*, defaults to 6):
|
| 59 |
+
Number of decoder layers.
|
| 60 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 61 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 62 |
+
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 63 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 64 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 2048):
|
| 65 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 66 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 2048):
|
| 67 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 68 |
+
activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
|
| 69 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 70 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 71 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 72 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 73 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 74 |
+
The dropout ratio for the attention probabilities.
|
| 75 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 76 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 77 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
| 78 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 79 |
+
init_xavier_std (`float`, *optional*, defaults to 1):
|
| 80 |
+
The scaling factor used for the Xavier initialization gain in the HM Attention map module.
|
| 81 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
| 82 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
|
| 83 |
+
for more details.
|
| 84 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
| 85 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
|
| 86 |
+
for more details.
|
| 87 |
+
auxiliary_loss (`bool`, *optional*, defaults to `False`):
|
| 88 |
+
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 89 |
+
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
|
| 90 |
+
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
|
| 91 |
+
backbone (`str`, *optional*, defaults to `"resnet50"`):
|
| 92 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 93 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 94 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 95 |
+
use_pretrained_backbone (`bool`, *optional*, `True`):
|
| 96 |
+
Whether to use pretrained weights for the backbone.
|
| 97 |
+
backbone_kwargs (`dict`, *optional*):
|
| 98 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 99 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 100 |
+
dilation (`bool`, *optional*, defaults to `False`):
|
| 101 |
+
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
| 102 |
+
`use_timm_backbone` = `True`.
|
| 103 |
+
class_cost (`float`, *optional*, defaults to 1):
|
| 104 |
+
Relative weight of the classification error in the Hungarian matching cost.
|
| 105 |
+
bbox_cost (`float`, *optional*, defaults to 5):
|
| 106 |
+
Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
|
| 107 |
+
giou_cost (`float`, *optional*, defaults to 2):
|
| 108 |
+
Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
|
| 109 |
+
mask_loss_coefficient (`float`, *optional*, defaults to 1):
|
| 110 |
+
Relative weight of the Focal loss in the panoptic segmentation loss.
|
| 111 |
+
dice_loss_coefficient (`float`, *optional*, defaults to 1):
|
| 112 |
+
Relative weight of the DICE/F-1 loss in the panoptic segmentation loss.
|
| 113 |
+
bbox_loss_coefficient (`float`, *optional*, defaults to 5):
|
| 114 |
+
Relative weight of the L1 bounding box loss in the object detection loss.
|
| 115 |
+
giou_loss_coefficient (`float`, *optional*, defaults to 2):
|
| 116 |
+
Relative weight of the generalized IoU loss in the object detection loss.
|
| 117 |
+
eos_coefficient (`float`, *optional*, defaults to 0.1):
|
| 118 |
+
Relative classification weight of the 'no-object' class in the object detection loss.
|
| 119 |
+
|
| 120 |
+
Examples:
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
>>> from transformers import DetrConfig, DetrModel
|
| 124 |
+
|
| 125 |
+
>>> # Initializing a DETR facebook/detr-resnet-50 style configuration
|
| 126 |
+
>>> configuration = DetrConfig()
|
| 127 |
+
|
| 128 |
+
>>> # Initializing a model (with random weights) from the facebook/detr-resnet-50 style configuration
|
| 129 |
+
>>> model = DetrModel(configuration)
|
| 130 |
+
|
| 131 |
+
>>> # Accessing the model configuration
|
| 132 |
+
>>> configuration = model.config
|
| 133 |
+
```"""
|
| 134 |
+
|
| 135 |
+
model_type = "detr"
|
| 136 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 137 |
+
attribute_map = {
|
| 138 |
+
"hidden_size": "d_model",
|
| 139 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
use_timm_backbone=True,
|
| 145 |
+
backbone_config=None,
|
| 146 |
+
num_channels=3,
|
| 147 |
+
num_queries=100,
|
| 148 |
+
encoder_layers=6,
|
| 149 |
+
encoder_ffn_dim=2048,
|
| 150 |
+
encoder_attention_heads=8,
|
| 151 |
+
decoder_layers=6,
|
| 152 |
+
decoder_ffn_dim=2048,
|
| 153 |
+
decoder_attention_heads=8,
|
| 154 |
+
encoder_layerdrop=0.0,
|
| 155 |
+
decoder_layerdrop=0.0,
|
| 156 |
+
is_encoder_decoder=True,
|
| 157 |
+
activation_function="relu",
|
| 158 |
+
d_model=256,
|
| 159 |
+
dropout=0.1,
|
| 160 |
+
attention_dropout=0.0,
|
| 161 |
+
activation_dropout=0.0,
|
| 162 |
+
init_std=0.02,
|
| 163 |
+
init_xavier_std=1.0,
|
| 164 |
+
auxiliary_loss=False,
|
| 165 |
+
position_embedding_type="sine",
|
| 166 |
+
backbone="resnet50",
|
| 167 |
+
use_pretrained_backbone=True,
|
| 168 |
+
backbone_kwargs=None,
|
| 169 |
+
dilation=False,
|
| 170 |
+
class_cost=1,
|
| 171 |
+
bbox_cost=5,
|
| 172 |
+
giou_cost=2,
|
| 173 |
+
mask_loss_coefficient=1,
|
| 174 |
+
dice_loss_coefficient=1,
|
| 175 |
+
bbox_loss_coefficient=5,
|
| 176 |
+
giou_loss_coefficient=2,
|
| 177 |
+
eos_coefficient=0.1,
|
| 178 |
+
**kwargs,
|
| 179 |
+
):
|
| 180 |
+
# We default to values which were previously hard-coded in the model. This enables configurability of the config
|
| 181 |
+
# while keeping the default behavior the same.
|
| 182 |
+
if use_timm_backbone and backbone_kwargs is None:
|
| 183 |
+
backbone_kwargs = {}
|
| 184 |
+
if dilation:
|
| 185 |
+
backbone_kwargs["output_stride"] = 16
|
| 186 |
+
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
|
| 187 |
+
backbone_kwargs["in_chans"] = num_channels
|
| 188 |
+
# Backwards compatibility
|
| 189 |
+
elif not use_timm_backbone and backbone in (None, "resnet50"):
|
| 190 |
+
if backbone_config is None:
|
| 191 |
+
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
| 192 |
+
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
|
| 193 |
+
elif isinstance(backbone_config, dict):
|
| 194 |
+
backbone_model_type = backbone_config.get("model_type")
|
| 195 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 196 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 197 |
+
backbone = None
|
| 198 |
+
# set timm attributes to None
|
| 199 |
+
dilation = None
|
| 200 |
+
|
| 201 |
+
verify_backbone_config_arguments(
|
| 202 |
+
use_timm_backbone=use_timm_backbone,
|
| 203 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 204 |
+
backbone=backbone,
|
| 205 |
+
backbone_config=backbone_config,
|
| 206 |
+
backbone_kwargs=backbone_kwargs,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
self.use_timm_backbone = use_timm_backbone
|
| 210 |
+
self.backbone_config = backbone_config
|
| 211 |
+
self.num_channels = num_channels
|
| 212 |
+
self.num_queries = num_queries
|
| 213 |
+
self.d_model = d_model
|
| 214 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 215 |
+
self.encoder_layers = encoder_layers
|
| 216 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 217 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 218 |
+
self.decoder_layers = decoder_layers
|
| 219 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 220 |
+
self.dropout = dropout
|
| 221 |
+
self.attention_dropout = attention_dropout
|
| 222 |
+
self.activation_dropout = activation_dropout
|
| 223 |
+
self.activation_function = activation_function
|
| 224 |
+
self.init_std = init_std
|
| 225 |
+
self.init_xavier_std = init_xavier_std
|
| 226 |
+
self.encoder_layerdrop = encoder_layerdrop
|
| 227 |
+
self.decoder_layerdrop = decoder_layerdrop
|
| 228 |
+
self.num_hidden_layers = encoder_layers
|
| 229 |
+
self.auxiliary_loss = auxiliary_loss
|
| 230 |
+
self.position_embedding_type = position_embedding_type
|
| 231 |
+
self.backbone = backbone
|
| 232 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 233 |
+
self.backbone_kwargs = backbone_kwargs
|
| 234 |
+
self.dilation = dilation
|
| 235 |
+
# Hungarian matcher
|
| 236 |
+
self.class_cost = class_cost
|
| 237 |
+
self.bbox_cost = bbox_cost
|
| 238 |
+
self.giou_cost = giou_cost
|
| 239 |
+
# Loss coefficients
|
| 240 |
+
self.mask_loss_coefficient = mask_loss_coefficient
|
| 241 |
+
self.dice_loss_coefficient = dice_loss_coefficient
|
| 242 |
+
self.bbox_loss_coefficient = bbox_loss_coefficient
|
| 243 |
+
self.giou_loss_coefficient = giou_loss_coefficient
|
| 244 |
+
self.eos_coefficient = eos_coefficient
|
| 245 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 246 |
+
|
| 247 |
+
@property
|
| 248 |
+
def num_attention_heads(self) -> int:
|
| 249 |
+
return self.encoder_attention_heads
|
| 250 |
+
|
| 251 |
+
@property
|
| 252 |
+
def hidden_size(self) -> int:
|
| 253 |
+
return self.d_model
|
| 254 |
+
|
| 255 |
+
@property
|
| 256 |
+
def sub_configs(self):
|
| 257 |
+
return (
|
| 258 |
+
{"backbone_config": type(self.backbone_config)}
|
| 259 |
+
if getattr(self, "backbone_config", None) is not None
|
| 260 |
+
else {}
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
@classmethod
|
| 264 |
+
def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
|
| 265 |
+
"""Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
backbone_config ([`PretrainedConfig`]):
|
| 269 |
+
The backbone configuration.
|
| 270 |
+
Returns:
|
| 271 |
+
[`DetrConfig`]: An instance of a configuration object
|
| 272 |
+
"""
|
| 273 |
+
return cls(backbone_config=backbone_config, **kwargs)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class DetrOnnxConfig(OnnxConfig):
|
| 277 |
+
torch_onnx_minimum_version = version.parse("1.11")
|
| 278 |
+
|
| 279 |
+
@property
|
| 280 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 281 |
+
return OrderedDict(
|
| 282 |
+
[
|
| 283 |
+
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
| 284 |
+
("pixel_mask", {0: "batch"}),
|
| 285 |
+
]
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
@property
|
| 289 |
+
def atol_for_validation(self) -> float:
|
| 290 |
+
return 1e-5
|
| 291 |
+
|
| 292 |
+
@property
|
| 293 |
+
def default_onnx_opset(self) -> int:
|
| 294 |
+
return 12
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
__all__ = ["DetrConfig", "DetrOnnxConfig"]
|
phivenv/Lib/site-packages/transformers/models/detr/feature_extraction_detr.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Feature extractor class for DETR."""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from ...image_transforms import rgb_to_id as _rgb_to_id
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from ...utils.import_utils import requires
|
| 22 |
+
from .image_processing_detr import DetrImageProcessor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def rgb_to_id(x):
|
| 29 |
+
warnings.warn(
|
| 30 |
+
"rgb_to_id has moved and will not be importable from this module from v5. "
|
| 31 |
+
"Please import from transformers.image_transforms instead.",
|
| 32 |
+
FutureWarning,
|
| 33 |
+
)
|
| 34 |
+
return _rgb_to_id(x)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@requires(backends=("vision",))
|
| 38 |
+
class DetrFeatureExtractor(DetrImageProcessor):
|
| 39 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 40 |
+
warnings.warn(
|
| 41 |
+
"The class DetrFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
|
| 42 |
+
" Please use DetrImageProcessor instead.",
|
| 43 |
+
FutureWarning,
|
| 44 |
+
)
|
| 45 |
+
super().__init__(*args, **kwargs)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
__all__ = ["DetrFeatureExtractor"]
|
phivenv/Lib/site-packages/transformers/models/detr/image_processing_detr.py
ADDED
|
@@ -0,0 +1,2049 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for DETR."""
|
| 16 |
+
|
| 17 |
+
import io
|
| 18 |
+
import pathlib
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from collections.abc import Iterable
|
| 21 |
+
from typing import Any, Callable, Optional, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 26 |
+
from ...image_transforms import (
|
| 27 |
+
PaddingMode,
|
| 28 |
+
center_to_corners_format,
|
| 29 |
+
corners_to_center_format,
|
| 30 |
+
id_to_rgb,
|
| 31 |
+
pad,
|
| 32 |
+
rescale,
|
| 33 |
+
resize,
|
| 34 |
+
rgb_to_id,
|
| 35 |
+
to_channel_dimension_format,
|
| 36 |
+
)
|
| 37 |
+
from ...image_utils import (
|
| 38 |
+
IMAGENET_DEFAULT_MEAN,
|
| 39 |
+
IMAGENET_DEFAULT_STD,
|
| 40 |
+
AnnotationFormat,
|
| 41 |
+
AnnotationType,
|
| 42 |
+
ChannelDimension,
|
| 43 |
+
ImageInput,
|
| 44 |
+
PILImageResampling,
|
| 45 |
+
get_image_size,
|
| 46 |
+
infer_channel_dimension_format,
|
| 47 |
+
is_scaled_image,
|
| 48 |
+
make_list_of_images,
|
| 49 |
+
to_numpy_array,
|
| 50 |
+
valid_images,
|
| 51 |
+
validate_annotations,
|
| 52 |
+
validate_kwargs,
|
| 53 |
+
validate_preprocess_arguments,
|
| 54 |
+
)
|
| 55 |
+
from ...utils import (
|
| 56 |
+
TensorType,
|
| 57 |
+
is_flax_available,
|
| 58 |
+
is_jax_tensor,
|
| 59 |
+
is_scipy_available,
|
| 60 |
+
is_tf_available,
|
| 61 |
+
is_tf_tensor,
|
| 62 |
+
is_torch_available,
|
| 63 |
+
is_torch_tensor,
|
| 64 |
+
is_vision_available,
|
| 65 |
+
logging,
|
| 66 |
+
)
|
| 67 |
+
from ...utils.import_utils import requires
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if is_torch_available():
|
| 71 |
+
import torch
|
| 72 |
+
from torch import nn
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if is_vision_available():
|
| 76 |
+
import PIL
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if is_scipy_available():
|
| 80 |
+
import scipy.special
|
| 81 |
+
import scipy.stats
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 85 |
+
|
| 86 |
+
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# From the original repo: https://github.com/facebookresearch/detr/blob/3af9fa878e73b6894ce3596450a8d9b89d918ca9/datasets/transforms.py#L76
|
| 90 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
|
| 91 |
+
"""
|
| 92 |
+
Computes the output image size given the input image size and the desired output size.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
image_size (`tuple[int, int]`):
|
| 96 |
+
The input image size.
|
| 97 |
+
size (`int`):
|
| 98 |
+
The desired output size.
|
| 99 |
+
max_size (`int`, *optional*):
|
| 100 |
+
The maximum allowed output size.
|
| 101 |
+
"""
|
| 102 |
+
height, width = image_size
|
| 103 |
+
raw_size = None
|
| 104 |
+
if max_size is not None:
|
| 105 |
+
min_original_size = float(min((height, width)))
|
| 106 |
+
max_original_size = float(max((height, width)))
|
| 107 |
+
if max_original_size / min_original_size * size > max_size:
|
| 108 |
+
raw_size = max_size * min_original_size / max_original_size
|
| 109 |
+
size = int(round(raw_size))
|
| 110 |
+
|
| 111 |
+
if (height <= width and height == size) or (width <= height and width == size):
|
| 112 |
+
oh, ow = height, width
|
| 113 |
+
elif width < height:
|
| 114 |
+
ow = size
|
| 115 |
+
if max_size is not None and raw_size is not None:
|
| 116 |
+
oh = int(raw_size * height / width)
|
| 117 |
+
else:
|
| 118 |
+
oh = int(size * height / width)
|
| 119 |
+
else:
|
| 120 |
+
oh = size
|
| 121 |
+
if max_size is not None and raw_size is not None:
|
| 122 |
+
ow = int(raw_size * width / height)
|
| 123 |
+
else:
|
| 124 |
+
ow = int(size * width / height)
|
| 125 |
+
|
| 126 |
+
return (oh, ow)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_image_size_for_max_height_width(
|
| 130 |
+
input_image: np.ndarray,
|
| 131 |
+
max_height: int,
|
| 132 |
+
max_width: int,
|
| 133 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 134 |
+
) -> tuple[int, int]:
|
| 135 |
+
"""
|
| 136 |
+
Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
|
| 137 |
+
Important, even if image_height < max_height and image_width < max_width, the image will be resized
|
| 138 |
+
to at least one of the edges be equal to max_height or max_width.
|
| 139 |
+
|
| 140 |
+
For example:
|
| 141 |
+
- input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
|
| 142 |
+
- input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
input_image (`np.ndarray`):
|
| 146 |
+
The image to resize.
|
| 147 |
+
max_height (`int`):
|
| 148 |
+
The maximum allowed height.
|
| 149 |
+
max_width (`int`):
|
| 150 |
+
The maximum allowed width.
|
| 151 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 152 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
| 153 |
+
"""
|
| 154 |
+
image_size = get_image_size(input_image, input_data_format)
|
| 155 |
+
height, width = image_size
|
| 156 |
+
height_scale = max_height / height
|
| 157 |
+
width_scale = max_width / width
|
| 158 |
+
min_scale = min(height_scale, width_scale)
|
| 159 |
+
new_height = int(height * min_scale)
|
| 160 |
+
new_width = int(width * min_scale)
|
| 161 |
+
return new_height, new_width
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def get_resize_output_image_size(
|
| 165 |
+
input_image: np.ndarray,
|
| 166 |
+
size: Union[int, tuple[int, int], list[int]],
|
| 167 |
+
max_size: Optional[int] = None,
|
| 168 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 169 |
+
) -> tuple[int, int]:
|
| 170 |
+
"""
|
| 171 |
+
Computes the output image size given the input image size and the desired output size. If the desired output size
|
| 172 |
+
is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
|
| 173 |
+
image size is computed by keeping the aspect ratio of the input image size.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
input_image (`np.ndarray`):
|
| 177 |
+
The image to resize.
|
| 178 |
+
size (`int` or `tuple[int, int]` or `list[int]`):
|
| 179 |
+
The desired output size.
|
| 180 |
+
max_size (`int`, *optional*):
|
| 181 |
+
The maximum allowed output size.
|
| 182 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 183 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
| 184 |
+
"""
|
| 185 |
+
image_size = get_image_size(input_image, input_data_format)
|
| 186 |
+
if isinstance(size, (list, tuple)):
|
| 187 |
+
return size
|
| 188 |
+
|
| 189 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_numpy_to_framework_fn(arr) -> Callable:
|
| 193 |
+
"""
|
| 194 |
+
Returns a function that converts a numpy array to the framework of the input array.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
arr (`np.ndarray`): The array to convert.
|
| 198 |
+
"""
|
| 199 |
+
if isinstance(arr, np.ndarray):
|
| 200 |
+
return np.array
|
| 201 |
+
if is_tf_available() and is_tf_tensor(arr):
|
| 202 |
+
import tensorflow as tf
|
| 203 |
+
|
| 204 |
+
return tf.convert_to_tensor
|
| 205 |
+
if is_torch_available() and is_torch_tensor(arr):
|
| 206 |
+
import torch
|
| 207 |
+
|
| 208 |
+
return torch.tensor
|
| 209 |
+
if is_flax_available() and is_jax_tensor(arr):
|
| 210 |
+
import jax.numpy as jnp
|
| 211 |
+
|
| 212 |
+
return jnp.array
|
| 213 |
+
raise ValueError(f"Cannot convert arrays of type {type(arr)}")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
|
| 217 |
+
"""
|
| 218 |
+
Squeezes an array, but only if the axis specified has dim 1.
|
| 219 |
+
"""
|
| 220 |
+
if axis is None:
|
| 221 |
+
return arr.squeeze()
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
return arr.squeeze(axis=axis)
|
| 225 |
+
except ValueError:
|
| 226 |
+
return arr
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def normalize_annotation(annotation: dict, image_size: tuple[int, int]) -> dict:
|
| 230 |
+
image_height, image_width = image_size
|
| 231 |
+
norm_annotation = {}
|
| 232 |
+
for key, value in annotation.items():
|
| 233 |
+
if key == "boxes":
|
| 234 |
+
boxes = value
|
| 235 |
+
boxes = corners_to_center_format(boxes)
|
| 236 |
+
boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
|
| 237 |
+
norm_annotation[key] = boxes
|
| 238 |
+
else:
|
| 239 |
+
norm_annotation[key] = value
|
| 240 |
+
return norm_annotation
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# Copied from transformers.models.vilt.image_processing_vilt.max_across_indices
|
| 244 |
+
def max_across_indices(values: Iterable[Any]) -> list[Any]:
|
| 245 |
+
"""
|
| 246 |
+
Return the maximum value across all indices of an iterable of values.
|
| 247 |
+
"""
|
| 248 |
+
return [max(values_i) for values_i in zip(*values)]
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
|
| 252 |
+
def get_max_height_width(
|
| 253 |
+
images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
| 254 |
+
) -> list[int]:
|
| 255 |
+
"""
|
| 256 |
+
Get the maximum height and width across all images in a batch.
|
| 257 |
+
"""
|
| 258 |
+
if input_data_format is None:
|
| 259 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 260 |
+
|
| 261 |
+
if input_data_format == ChannelDimension.FIRST:
|
| 262 |
+
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
| 263 |
+
elif input_data_format == ChannelDimension.LAST:
|
| 264 |
+
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
| 265 |
+
else:
|
| 266 |
+
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
| 267 |
+
return (max_height, max_width)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
|
| 271 |
+
def make_pixel_mask(
|
| 272 |
+
image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
| 273 |
+
) -> np.ndarray:
|
| 274 |
+
"""
|
| 275 |
+
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
image (`np.ndarray`):
|
| 279 |
+
Image to make the pixel mask for.
|
| 280 |
+
output_size (`tuple[int, int]`):
|
| 281 |
+
Output size of the mask.
|
| 282 |
+
"""
|
| 283 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 284 |
+
mask = np.zeros(output_size, dtype=np.int64)
|
| 285 |
+
mask[:input_height, :input_width] = 1
|
| 286 |
+
return mask
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33
|
| 290 |
+
def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray:
|
| 291 |
+
"""
|
| 292 |
+
Convert a COCO polygon annotation to a mask.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
segmentations (`list[list[float]]`):
|
| 296 |
+
List of polygons, each polygon represented by a list of x-y coordinates.
|
| 297 |
+
height (`int`):
|
| 298 |
+
Height of the mask.
|
| 299 |
+
width (`int`):
|
| 300 |
+
Width of the mask.
|
| 301 |
+
"""
|
| 302 |
+
try:
|
| 303 |
+
from pycocotools import mask as coco_mask
|
| 304 |
+
except ImportError:
|
| 305 |
+
raise ImportError("Pycocotools is not installed in your environment.")
|
| 306 |
+
|
| 307 |
+
masks = []
|
| 308 |
+
for polygons in segmentations:
|
| 309 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
| 310 |
+
mask = coco_mask.decode(rles)
|
| 311 |
+
if len(mask.shape) < 3:
|
| 312 |
+
mask = mask[..., None]
|
| 313 |
+
mask = np.asarray(mask, dtype=np.uint8)
|
| 314 |
+
mask = np.any(mask, axis=2)
|
| 315 |
+
masks.append(mask)
|
| 316 |
+
if masks:
|
| 317 |
+
masks = np.stack(masks, axis=0)
|
| 318 |
+
else:
|
| 319 |
+
masks = np.zeros((0, height, width), dtype=np.uint8)
|
| 320 |
+
|
| 321 |
+
return masks
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50
|
| 325 |
+
def prepare_coco_detection_annotation(
|
| 326 |
+
image,
|
| 327 |
+
target,
|
| 328 |
+
return_segmentation_masks: bool = False,
|
| 329 |
+
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
| 330 |
+
):
|
| 331 |
+
"""
|
| 332 |
+
Convert the target in COCO format into the format expected by DETR.
|
| 333 |
+
"""
|
| 334 |
+
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
| 335 |
+
|
| 336 |
+
image_id = target["image_id"]
|
| 337 |
+
image_id = np.asarray([image_id], dtype=np.int64)
|
| 338 |
+
|
| 339 |
+
# Get all COCO annotations for the given image.
|
| 340 |
+
annotations = target["annotations"]
|
| 341 |
+
annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
|
| 342 |
+
|
| 343 |
+
classes = [obj["category_id"] for obj in annotations]
|
| 344 |
+
classes = np.asarray(classes, dtype=np.int64)
|
| 345 |
+
|
| 346 |
+
# for conversion to coco api
|
| 347 |
+
area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
|
| 348 |
+
iscrowd = np.asarray([obj.get("iscrowd", 0) for obj in annotations], dtype=np.int64)
|
| 349 |
+
|
| 350 |
+
boxes = [obj["bbox"] for obj in annotations]
|
| 351 |
+
# guard against no boxes via resizing
|
| 352 |
+
boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
|
| 353 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 354 |
+
boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
|
| 355 |
+
boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
|
| 356 |
+
|
| 357 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 358 |
+
|
| 359 |
+
new_target = {}
|
| 360 |
+
new_target["image_id"] = image_id
|
| 361 |
+
new_target["class_labels"] = classes[keep]
|
| 362 |
+
new_target["boxes"] = boxes[keep]
|
| 363 |
+
new_target["area"] = area[keep]
|
| 364 |
+
new_target["iscrowd"] = iscrowd[keep]
|
| 365 |
+
new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
|
| 366 |
+
|
| 367 |
+
if annotations and "keypoints" in annotations[0]:
|
| 368 |
+
keypoints = [obj["keypoints"] for obj in annotations]
|
| 369 |
+
# Converting the filtered keypoints list to a numpy array
|
| 370 |
+
keypoints = np.asarray(keypoints, dtype=np.float32)
|
| 371 |
+
# Apply the keep mask here to filter the relevant annotations
|
| 372 |
+
keypoints = keypoints[keep]
|
| 373 |
+
num_keypoints = keypoints.shape[0]
|
| 374 |
+
keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
|
| 375 |
+
new_target["keypoints"] = keypoints
|
| 376 |
+
|
| 377 |
+
if return_segmentation_masks:
|
| 378 |
+
segmentation_masks = [obj["segmentation"] for obj in annotations]
|
| 379 |
+
masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width)
|
| 380 |
+
new_target["masks"] = masks[keep]
|
| 381 |
+
|
| 382 |
+
return new_target
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
|
| 386 |
+
"""
|
| 387 |
+
Compute the bounding boxes around the provided panoptic segmentation masks.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
masks: masks in format `[number_masks, height, width]` where N is the number of masks
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
|
| 394 |
+
"""
|
| 395 |
+
if masks.size == 0:
|
| 396 |
+
return np.zeros((0, 4))
|
| 397 |
+
|
| 398 |
+
h, w = masks.shape[-2:]
|
| 399 |
+
y = np.arange(0, h, dtype=np.float32)
|
| 400 |
+
x = np.arange(0, w, dtype=np.float32)
|
| 401 |
+
# see https://github.com/pytorch/pytorch/issues/50276
|
| 402 |
+
y, x = np.meshgrid(y, x, indexing="ij")
|
| 403 |
+
|
| 404 |
+
x_mask = masks * np.expand_dims(x, axis=0)
|
| 405 |
+
x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1)
|
| 406 |
+
x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool)))
|
| 407 |
+
x_min = x.filled(fill_value=1e8)
|
| 408 |
+
x_min = x_min.reshape(x_min.shape[0], -1).min(-1)
|
| 409 |
+
|
| 410 |
+
y_mask = masks * np.expand_dims(y, axis=0)
|
| 411 |
+
y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1)
|
| 412 |
+
y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool)))
|
| 413 |
+
y_min = y.filled(fill_value=1e8)
|
| 414 |
+
y_min = y_min.reshape(y_min.shape[0], -1).min(-1)
|
| 415 |
+
|
| 416 |
+
return np.stack([x_min, y_min, x_max, y_max], 1)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def prepare_coco_panoptic_annotation(
|
| 420 |
+
image: np.ndarray,
|
| 421 |
+
target: dict,
|
| 422 |
+
masks_path: Union[str, pathlib.Path],
|
| 423 |
+
return_masks: bool = True,
|
| 424 |
+
input_data_format: Union[ChannelDimension, str] = None,
|
| 425 |
+
) -> dict:
|
| 426 |
+
"""
|
| 427 |
+
Prepare a coco panoptic annotation for DETR.
|
| 428 |
+
"""
|
| 429 |
+
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
| 430 |
+
annotation_path = pathlib.Path(masks_path) / target["file_name"]
|
| 431 |
+
|
| 432 |
+
new_target = {}
|
| 433 |
+
new_target["image_id"] = np.asarray([target["image_id"] if "image_id" in target else target["id"]], dtype=np.int64)
|
| 434 |
+
new_target["size"] = np.asarray([image_height, image_width], dtype=np.int64)
|
| 435 |
+
new_target["orig_size"] = np.asarray([image_height, image_width], dtype=np.int64)
|
| 436 |
+
|
| 437 |
+
if "segments_info" in target:
|
| 438 |
+
masks = np.asarray(PIL.Image.open(annotation_path), dtype=np.uint32)
|
| 439 |
+
masks = rgb_to_id(masks)
|
| 440 |
+
|
| 441 |
+
ids = np.array([segment_info["id"] for segment_info in target["segments_info"]])
|
| 442 |
+
masks = masks == ids[:, None, None]
|
| 443 |
+
masks = masks.astype(np.uint8)
|
| 444 |
+
if return_masks:
|
| 445 |
+
new_target["masks"] = masks
|
| 446 |
+
new_target["boxes"] = masks_to_boxes(masks)
|
| 447 |
+
new_target["class_labels"] = np.array(
|
| 448 |
+
[segment_info["category_id"] for segment_info in target["segments_info"]], dtype=np.int64
|
| 449 |
+
)
|
| 450 |
+
new_target["iscrowd"] = np.asarray(
|
| 451 |
+
[segment_info["iscrowd"] for segment_info in target["segments_info"]], dtype=np.int64
|
| 452 |
+
)
|
| 453 |
+
new_target["area"] = np.asarray(
|
| 454 |
+
[segment_info["area"] for segment_info in target["segments_info"]], dtype=np.float32
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
return new_target
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def get_segmentation_image(
|
| 461 |
+
masks: np.ndarray, input_size: tuple, target_size: tuple, stuff_equiv_classes, deduplicate=False
|
| 462 |
+
):
|
| 463 |
+
h, w = input_size
|
| 464 |
+
final_h, final_w = target_size
|
| 465 |
+
|
| 466 |
+
m_id = scipy.special.softmax(masks.transpose(0, 1), -1)
|
| 467 |
+
|
| 468 |
+
if m_id.shape[-1] == 0:
|
| 469 |
+
# We didn't detect any mask :(
|
| 470 |
+
m_id = np.zeros((h, w), dtype=np.int64)
|
| 471 |
+
else:
|
| 472 |
+
m_id = m_id.argmax(-1).reshape(h, w)
|
| 473 |
+
|
| 474 |
+
if deduplicate:
|
| 475 |
+
# Merge the masks corresponding to the same stuff class
|
| 476 |
+
for equiv in stuff_equiv_classes.values():
|
| 477 |
+
for eq_id in equiv:
|
| 478 |
+
m_id[m_id == eq_id] = equiv[0]
|
| 479 |
+
|
| 480 |
+
seg_img = id_to_rgb(m_id)
|
| 481 |
+
seg_img = resize(seg_img, (final_w, final_h), resample=PILImageResampling.NEAREST)
|
| 482 |
+
return seg_img
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def get_mask_area(seg_img: np.ndarray, target_size: tuple[int, int], n_classes: int) -> np.ndarray:
|
| 486 |
+
final_h, final_w = target_size
|
| 487 |
+
np_seg_img = seg_img.astype(np.uint8)
|
| 488 |
+
np_seg_img = np_seg_img.reshape(final_h, final_w, 3)
|
| 489 |
+
m_id = rgb_to_id(np_seg_img)
|
| 490 |
+
area = [(m_id == i).sum() for i in range(n_classes)]
|
| 491 |
+
return area
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def score_labels_from_class_probabilities(logits: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 495 |
+
probs = scipy.special.softmax(logits, axis=-1)
|
| 496 |
+
labels = probs.argmax(-1, keepdims=True)
|
| 497 |
+
scores = np.take_along_axis(probs, labels, axis=-1)
|
| 498 |
+
scores, labels = scores.squeeze(-1), labels.squeeze(-1)
|
| 499 |
+
return scores, labels
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def post_process_panoptic_sample(
|
| 503 |
+
out_logits: np.ndarray,
|
| 504 |
+
masks: np.ndarray,
|
| 505 |
+
boxes: np.ndarray,
|
| 506 |
+
processed_size: tuple[int, int],
|
| 507 |
+
target_size: tuple[int, int],
|
| 508 |
+
is_thing_map: dict,
|
| 509 |
+
threshold=0.85,
|
| 510 |
+
) -> dict:
|
| 511 |
+
"""
|
| 512 |
+
Converts the output of [`DetrForSegmentation`] into panoptic segmentation predictions for a single sample.
|
| 513 |
+
|
| 514 |
+
Args:
|
| 515 |
+
out_logits (`torch.Tensor`):
|
| 516 |
+
The logits for this sample.
|
| 517 |
+
masks (`torch.Tensor`):
|
| 518 |
+
The predicted segmentation masks for this sample.
|
| 519 |
+
boxes (`torch.Tensor`):
|
| 520 |
+
The predicted bounding boxes for this sample. The boxes are in the normalized format `(center_x, center_y,
|
| 521 |
+
width, height)` and values between `[0, 1]`, relative to the size the image (disregarding padding).
|
| 522 |
+
processed_size (`tuple[int, int]`):
|
| 523 |
+
The processed size of the image `(height, width)`, as returned by the preprocessing step i.e. the size
|
| 524 |
+
after data augmentation but before batching.
|
| 525 |
+
target_size (`tuple[int, int]`):
|
| 526 |
+
The target size of the image, `(height, width)` corresponding to the requested final size of the
|
| 527 |
+
prediction.
|
| 528 |
+
is_thing_map (`Dict`):
|
| 529 |
+
A dictionary mapping class indices to a boolean value indicating whether the class is a thing or not.
|
| 530 |
+
threshold (`float`, *optional*, defaults to 0.85):
|
| 531 |
+
The threshold used to binarize the segmentation masks.
|
| 532 |
+
"""
|
| 533 |
+
# we filter empty queries and detection below threshold
|
| 534 |
+
scores, labels = score_labels_from_class_probabilities(out_logits)
|
| 535 |
+
keep = (labels != out_logits.shape[-1] - 1) & (scores > threshold)
|
| 536 |
+
|
| 537 |
+
cur_scores = scores[keep]
|
| 538 |
+
cur_classes = labels[keep]
|
| 539 |
+
cur_boxes = center_to_corners_format(boxes[keep])
|
| 540 |
+
|
| 541 |
+
if len(cur_boxes) != len(cur_classes):
|
| 542 |
+
raise ValueError("Not as many boxes as there are classes")
|
| 543 |
+
|
| 544 |
+
cur_masks = masks[keep]
|
| 545 |
+
cur_masks = resize(cur_masks[:, None], processed_size, resample=PILImageResampling.BILINEAR)
|
| 546 |
+
cur_masks = safe_squeeze(cur_masks, 1)
|
| 547 |
+
b, h, w = cur_masks.shape
|
| 548 |
+
|
| 549 |
+
# It may be that we have several predicted masks for the same stuff class.
|
| 550 |
+
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
|
| 551 |
+
cur_masks = cur_masks.reshape(b, -1)
|
| 552 |
+
stuff_equiv_classes = defaultdict(list)
|
| 553 |
+
for k, label in enumerate(cur_classes):
|
| 554 |
+
if not is_thing_map[label]:
|
| 555 |
+
stuff_equiv_classes[label].append(k)
|
| 556 |
+
|
| 557 |
+
seg_img = get_segmentation_image(cur_masks, processed_size, target_size, stuff_equiv_classes, deduplicate=True)
|
| 558 |
+
area = get_mask_area(cur_masks, processed_size, n_classes=len(cur_scores))
|
| 559 |
+
|
| 560 |
+
# We filter out any mask that is too small
|
| 561 |
+
if cur_classes.size() > 0:
|
| 562 |
+
# We know filter empty masks as long as we find some
|
| 563 |
+
filtered_small = np.array([a <= 4 for a in area], dtype=bool)
|
| 564 |
+
while filtered_small.any():
|
| 565 |
+
cur_masks = cur_masks[~filtered_small]
|
| 566 |
+
cur_scores = cur_scores[~filtered_small]
|
| 567 |
+
cur_classes = cur_classes[~filtered_small]
|
| 568 |
+
seg_img = get_segmentation_image(cur_masks, (h, w), target_size, stuff_equiv_classes, deduplicate=True)
|
| 569 |
+
area = get_mask_area(seg_img, target_size, n_classes=len(cur_scores))
|
| 570 |
+
filtered_small = np.array([a <= 4 for a in area], dtype=bool)
|
| 571 |
+
else:
|
| 572 |
+
cur_classes = np.ones((1, 1), dtype=np.int64)
|
| 573 |
+
|
| 574 |
+
segments_info = [
|
| 575 |
+
{"id": i, "isthing": is_thing_map[cat], "category_id": int(cat), "area": a}
|
| 576 |
+
for i, (cat, a) in enumerate(zip(cur_classes, area))
|
| 577 |
+
]
|
| 578 |
+
del cur_classes
|
| 579 |
+
|
| 580 |
+
with io.BytesIO() as out:
|
| 581 |
+
PIL.Image.fromarray(seg_img).save(out, format="PNG")
|
| 582 |
+
predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
|
| 583 |
+
|
| 584 |
+
return predictions
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def resize_annotation(
|
| 588 |
+
annotation: dict[str, Any],
|
| 589 |
+
orig_size: tuple[int, int],
|
| 590 |
+
target_size: tuple[int, int],
|
| 591 |
+
threshold: float = 0.5,
|
| 592 |
+
resample: PILImageResampling = PILImageResampling.NEAREST,
|
| 593 |
+
):
|
| 594 |
+
"""
|
| 595 |
+
Resizes an annotation to a target size.
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
annotation (`dict[str, Any]`):
|
| 599 |
+
The annotation dictionary.
|
| 600 |
+
orig_size (`tuple[int, int]`):
|
| 601 |
+
The original size of the input image.
|
| 602 |
+
target_size (`tuple[int, int]`):
|
| 603 |
+
The target size of the image, as returned by the preprocessing `resize` step.
|
| 604 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 605 |
+
The threshold used to binarize the segmentation masks.
|
| 606 |
+
resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
|
| 607 |
+
The resampling filter to use when resizing the masks.
|
| 608 |
+
"""
|
| 609 |
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
|
| 610 |
+
ratio_height, ratio_width = ratios
|
| 611 |
+
|
| 612 |
+
new_annotation = {}
|
| 613 |
+
new_annotation["size"] = target_size
|
| 614 |
+
|
| 615 |
+
for key, value in annotation.items():
|
| 616 |
+
if key == "boxes":
|
| 617 |
+
boxes = value
|
| 618 |
+
scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
|
| 619 |
+
new_annotation["boxes"] = scaled_boxes
|
| 620 |
+
elif key == "area":
|
| 621 |
+
area = value
|
| 622 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 623 |
+
new_annotation["area"] = scaled_area
|
| 624 |
+
elif key == "masks":
|
| 625 |
+
masks = value[:, None]
|
| 626 |
+
masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
|
| 627 |
+
masks = masks.astype(np.float32)
|
| 628 |
+
masks = masks[:, 0] > threshold
|
| 629 |
+
new_annotation["masks"] = masks
|
| 630 |
+
elif key == "size":
|
| 631 |
+
new_annotation["size"] = target_size
|
| 632 |
+
else:
|
| 633 |
+
new_annotation[key] = value
|
| 634 |
+
|
| 635 |
+
return new_annotation
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# TODO - (Amy) make compatible with other frameworks
|
| 639 |
+
def binary_mask_to_rle(mask):
|
| 640 |
+
"""
|
| 641 |
+
Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
mask (`torch.Tensor` or `numpy.array`):
|
| 645 |
+
A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
|
| 646 |
+
segment_id or class_id.
|
| 647 |
+
Returns:
|
| 648 |
+
`List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
|
| 649 |
+
format.
|
| 650 |
+
"""
|
| 651 |
+
if is_torch_tensor(mask):
|
| 652 |
+
mask = mask.numpy()
|
| 653 |
+
|
| 654 |
+
pixels = mask.flatten()
|
| 655 |
+
pixels = np.concatenate([[0], pixels, [0]])
|
| 656 |
+
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
|
| 657 |
+
runs[1::2] -= runs[::2]
|
| 658 |
+
return list(runs)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
# TODO - (Amy) make compatible with other frameworks
|
| 662 |
+
def convert_segmentation_to_rle(segmentation):
|
| 663 |
+
"""
|
| 664 |
+
Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
|
| 665 |
+
|
| 666 |
+
Args:
|
| 667 |
+
segmentation (`torch.Tensor` or `numpy.array`):
|
| 668 |
+
A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
|
| 669 |
+
Returns:
|
| 670 |
+
`list[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
|
| 671 |
+
"""
|
| 672 |
+
segment_ids = torch.unique(segmentation)
|
| 673 |
+
|
| 674 |
+
run_length_encodings = []
|
| 675 |
+
for idx in segment_ids:
|
| 676 |
+
mask = torch.where(segmentation == idx, 1, 0)
|
| 677 |
+
rle = binary_mask_to_rle(mask)
|
| 678 |
+
run_length_encodings.append(rle)
|
| 679 |
+
|
| 680 |
+
return run_length_encodings
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
|
| 684 |
+
"""
|
| 685 |
+
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
|
| 686 |
+
`labels`.
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
masks (`torch.Tensor`):
|
| 690 |
+
A tensor of shape `(num_queries, height, width)`.
|
| 691 |
+
scores (`torch.Tensor`):
|
| 692 |
+
A tensor of shape `(num_queries)`.
|
| 693 |
+
labels (`torch.Tensor`):
|
| 694 |
+
A tensor of shape `(num_queries)`.
|
| 695 |
+
object_mask_threshold (`float`):
|
| 696 |
+
A number between 0 and 1 used to binarize the masks.
|
| 697 |
+
Raises:
|
| 698 |
+
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
|
| 699 |
+
Returns:
|
| 700 |
+
`tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
|
| 701 |
+
< `object_mask_threshold`.
|
| 702 |
+
"""
|
| 703 |
+
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
|
| 704 |
+
raise ValueError("mask, scores and labels must have the same shape!")
|
| 705 |
+
|
| 706 |
+
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
|
| 707 |
+
|
| 708 |
+
return masks[to_keep], scores[to_keep], labels[to_keep]
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
|
| 712 |
+
# Get the mask associated with the k class
|
| 713 |
+
mask_k = mask_labels == k
|
| 714 |
+
mask_k_area = mask_k.sum()
|
| 715 |
+
|
| 716 |
+
# Compute the area of all the stuff in query k
|
| 717 |
+
original_area = (mask_probs[k] >= mask_threshold).sum()
|
| 718 |
+
mask_exists = mask_k_area > 0 and original_area > 0
|
| 719 |
+
|
| 720 |
+
# Eliminate disconnected tiny segments
|
| 721 |
+
if mask_exists:
|
| 722 |
+
area_ratio = mask_k_area / original_area
|
| 723 |
+
if not area_ratio.item() > overlap_mask_area_threshold:
|
| 724 |
+
mask_exists = False
|
| 725 |
+
|
| 726 |
+
return mask_exists, mask_k
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def compute_segments(
|
| 730 |
+
mask_probs,
|
| 731 |
+
pred_scores,
|
| 732 |
+
pred_labels,
|
| 733 |
+
mask_threshold: float = 0.5,
|
| 734 |
+
overlap_mask_area_threshold: float = 0.8,
|
| 735 |
+
label_ids_to_fuse: Optional[set[int]] = None,
|
| 736 |
+
target_size: Optional[tuple[int, int]] = None,
|
| 737 |
+
):
|
| 738 |
+
height = mask_probs.shape[1] if target_size is None else target_size[0]
|
| 739 |
+
width = mask_probs.shape[2] if target_size is None else target_size[1]
|
| 740 |
+
|
| 741 |
+
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
|
| 742 |
+
segments: list[dict] = []
|
| 743 |
+
|
| 744 |
+
if target_size is not None:
|
| 745 |
+
mask_probs = nn.functional.interpolate(
|
| 746 |
+
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
|
| 747 |
+
)[0]
|
| 748 |
+
|
| 749 |
+
current_segment_id = 0
|
| 750 |
+
|
| 751 |
+
# Weigh each mask by its prediction score
|
| 752 |
+
mask_probs *= pred_scores.view(-1, 1, 1)
|
| 753 |
+
mask_labels = mask_probs.argmax(0) # [height, width]
|
| 754 |
+
|
| 755 |
+
# Keep track of instances of each class
|
| 756 |
+
stuff_memory_list: dict[str, int] = {}
|
| 757 |
+
for k in range(pred_labels.shape[0]):
|
| 758 |
+
pred_class = pred_labels[k].item()
|
| 759 |
+
should_fuse = pred_class in label_ids_to_fuse
|
| 760 |
+
|
| 761 |
+
# Check if mask exists and large enough to be a segment
|
| 762 |
+
mask_exists, mask_k = check_segment_validity(
|
| 763 |
+
mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
if mask_exists:
|
| 767 |
+
if pred_class in stuff_memory_list:
|
| 768 |
+
current_segment_id = stuff_memory_list[pred_class]
|
| 769 |
+
else:
|
| 770 |
+
current_segment_id += 1
|
| 771 |
+
|
| 772 |
+
# Add current object segment to final segmentation map
|
| 773 |
+
segmentation[mask_k] = current_segment_id
|
| 774 |
+
segment_score = round(pred_scores[k].item(), 6)
|
| 775 |
+
segments.append(
|
| 776 |
+
{
|
| 777 |
+
"id": current_segment_id,
|
| 778 |
+
"label_id": pred_class,
|
| 779 |
+
"was_fused": should_fuse,
|
| 780 |
+
"score": segment_score,
|
| 781 |
+
}
|
| 782 |
+
)
|
| 783 |
+
if should_fuse:
|
| 784 |
+
stuff_memory_list[pred_class] = current_segment_id
|
| 785 |
+
|
| 786 |
+
return segmentation, segments
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
@requires(backends=("vision",))
|
| 790 |
+
class DetrImageProcessor(BaseImageProcessor):
|
| 791 |
+
r"""
|
| 792 |
+
Constructs a Detr image processor.
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
format (`str`, *optional*, defaults to `"coco_detection"`):
|
| 796 |
+
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
| 797 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 798 |
+
Controls whether to resize the image's `(height, width)` dimensions to the specified `size`. Can be
|
| 799 |
+
overridden by the `do_resize` parameter in the `preprocess` method.
|
| 800 |
+
size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`):
|
| 801 |
+
Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
|
| 802 |
+
in the `preprocess` method. Available options are:
|
| 803 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 804 |
+
Do NOT keep the aspect ratio.
|
| 805 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 806 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 807 |
+
less or equal to `longest_edge`.
|
| 808 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 809 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 810 |
+
`max_width`.
|
| 811 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 812 |
+
Resampling filter to use if resizing the image.
|
| 813 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 814 |
+
Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
| 815 |
+
`do_rescale` parameter in the `preprocess` method.
|
| 816 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 817 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 818 |
+
`preprocess` method.
|
| 819 |
+
do_normalize (`bool`, *optional*, defaults to True):
|
| 820 |
+
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
|
| 821 |
+
`preprocess` method.
|
| 822 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
|
| 823 |
+
Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
|
| 824 |
+
channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 825 |
+
image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
|
| 826 |
+
Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
|
| 827 |
+
for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 828 |
+
do_convert_annotations (`bool`, *optional*, defaults to `True`):
|
| 829 |
+
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
|
| 830 |
+
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
|
| 831 |
+
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
|
| 832 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 833 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 834 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 835 |
+
If `pad_size` is provided, the image will be padded to the specified dimensions.
|
| 836 |
+
Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 837 |
+
pad_size (`dict[str, int]`, *optional*):
|
| 838 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 839 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 840 |
+
height and width in the batch.
|
| 841 |
+
"""
|
| 842 |
+
|
| 843 |
+
model_input_names = ["pixel_values", "pixel_mask"]
|
| 844 |
+
|
| 845 |
+
def __init__(
|
| 846 |
+
self,
|
| 847 |
+
format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
|
| 848 |
+
do_resize: bool = True,
|
| 849 |
+
size: Optional[dict[str, int]] = None,
|
| 850 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 851 |
+
do_rescale: bool = True,
|
| 852 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 853 |
+
do_normalize: bool = True,
|
| 854 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 855 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 856 |
+
do_convert_annotations: Optional[bool] = None,
|
| 857 |
+
do_pad: bool = True,
|
| 858 |
+
pad_size: Optional[dict[str, int]] = None,
|
| 859 |
+
**kwargs,
|
| 860 |
+
) -> None:
|
| 861 |
+
if "pad_and_return_pixel_mask" in kwargs:
|
| 862 |
+
do_pad = kwargs.pop("pad_and_return_pixel_mask")
|
| 863 |
+
|
| 864 |
+
if "max_size" in kwargs:
|
| 865 |
+
logger.warning_once(
|
| 866 |
+
"The `max_size` parameter is deprecated and will be removed in v4.26. "
|
| 867 |
+
"Please specify in `size['longest_edge'] instead`.",
|
| 868 |
+
)
|
| 869 |
+
max_size = kwargs.pop("max_size")
|
| 870 |
+
else:
|
| 871 |
+
max_size = None if size is None else 1333
|
| 872 |
+
|
| 873 |
+
size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
|
| 874 |
+
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
| 875 |
+
|
| 876 |
+
# Backwards compatibility
|
| 877 |
+
if do_convert_annotations is None:
|
| 878 |
+
do_convert_annotations = do_normalize
|
| 879 |
+
|
| 880 |
+
super().__init__(**kwargs)
|
| 881 |
+
self.format = format
|
| 882 |
+
self.do_resize = do_resize
|
| 883 |
+
self.size = size
|
| 884 |
+
self.resample = resample
|
| 885 |
+
self.do_rescale = do_rescale
|
| 886 |
+
self.rescale_factor = rescale_factor
|
| 887 |
+
self.do_normalize = do_normalize
|
| 888 |
+
self.do_convert_annotations = do_convert_annotations
|
| 889 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 890 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 891 |
+
self.do_pad = do_pad
|
| 892 |
+
self.pad_size = pad_size
|
| 893 |
+
self._valid_processor_keys = [
|
| 894 |
+
"images",
|
| 895 |
+
"annotations",
|
| 896 |
+
"return_segmentation_masks",
|
| 897 |
+
"masks_path",
|
| 898 |
+
"do_resize",
|
| 899 |
+
"size",
|
| 900 |
+
"resample",
|
| 901 |
+
"do_rescale",
|
| 902 |
+
"rescale_factor",
|
| 903 |
+
"do_normalize",
|
| 904 |
+
"do_convert_annotations",
|
| 905 |
+
"image_mean",
|
| 906 |
+
"image_std",
|
| 907 |
+
"do_pad",
|
| 908 |
+
"pad_size",
|
| 909 |
+
"format",
|
| 910 |
+
"return_tensors",
|
| 911 |
+
"data_format",
|
| 912 |
+
"input_data_format",
|
| 913 |
+
]
|
| 914 |
+
|
| 915 |
+
@classmethod
|
| 916 |
+
def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
|
| 917 |
+
"""
|
| 918 |
+
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
|
| 919 |
+
created using from_dict and kwargs e.g. `DetrImageProcessor.from_pretrained(checkpoint, size=600,
|
| 920 |
+
max_size=800)`
|
| 921 |
+
"""
|
| 922 |
+
image_processor_dict = image_processor_dict.copy()
|
| 923 |
+
if "max_size" in kwargs:
|
| 924 |
+
image_processor_dict["max_size"] = kwargs.pop("max_size")
|
| 925 |
+
if "pad_and_return_pixel_mask" in kwargs:
|
| 926 |
+
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
|
| 927 |
+
return super().from_dict(image_processor_dict, **kwargs)
|
| 928 |
+
|
| 929 |
+
def prepare_annotation(
|
| 930 |
+
self,
|
| 931 |
+
image: np.ndarray,
|
| 932 |
+
target: dict,
|
| 933 |
+
format: Optional[AnnotationFormat] = None,
|
| 934 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 935 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 936 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 937 |
+
) -> dict:
|
| 938 |
+
"""
|
| 939 |
+
Prepare an annotation for feeding into DETR model.
|
| 940 |
+
"""
|
| 941 |
+
format = format if format is not None else self.format
|
| 942 |
+
|
| 943 |
+
if format == AnnotationFormat.COCO_DETECTION:
|
| 944 |
+
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
| 945 |
+
target = prepare_coco_detection_annotation(
|
| 946 |
+
image, target, return_segmentation_masks, input_data_format=input_data_format
|
| 947 |
+
)
|
| 948 |
+
elif format == AnnotationFormat.COCO_PANOPTIC:
|
| 949 |
+
return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
|
| 950 |
+
target = prepare_coco_panoptic_annotation(
|
| 951 |
+
image,
|
| 952 |
+
target,
|
| 953 |
+
masks_path=masks_path,
|
| 954 |
+
return_masks=return_segmentation_masks,
|
| 955 |
+
input_data_format=input_data_format,
|
| 956 |
+
)
|
| 957 |
+
else:
|
| 958 |
+
raise ValueError(f"Format {format} is not supported.")
|
| 959 |
+
return target
|
| 960 |
+
|
| 961 |
+
def resize(
|
| 962 |
+
self,
|
| 963 |
+
image: np.ndarray,
|
| 964 |
+
size: dict[str, int],
|
| 965 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 966 |
+
data_format: Optional[ChannelDimension] = None,
|
| 967 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 968 |
+
**kwargs,
|
| 969 |
+
) -> np.ndarray:
|
| 970 |
+
"""
|
| 971 |
+
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
| 972 |
+
int, smaller edge of the image will be matched to this number.
|
| 973 |
+
|
| 974 |
+
Args:
|
| 975 |
+
image (`np.ndarray`):
|
| 976 |
+
Image to resize.
|
| 977 |
+
size (`dict[str, int]`):
|
| 978 |
+
Size of the image's `(height, width)` dimensions after resizing. Available options are:
|
| 979 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 980 |
+
Do NOT keep the aspect ratio.
|
| 981 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 982 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 983 |
+
less or equal to `longest_edge`.
|
| 984 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 985 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 986 |
+
`max_width`.
|
| 987 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 988 |
+
Resampling filter to use if resizing the image.
|
| 989 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 990 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 991 |
+
image is used.
|
| 992 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 993 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 994 |
+
"""
|
| 995 |
+
if "max_size" in kwargs:
|
| 996 |
+
logger.warning_once(
|
| 997 |
+
"The `max_size` parameter is deprecated and will be removed in v4.26. "
|
| 998 |
+
"Please specify in `size['longest_edge'] instead`.",
|
| 999 |
+
)
|
| 1000 |
+
max_size = kwargs.pop("max_size")
|
| 1001 |
+
else:
|
| 1002 |
+
max_size = None
|
| 1003 |
+
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
| 1004 |
+
if "shortest_edge" in size and "longest_edge" in size:
|
| 1005 |
+
new_size = get_resize_output_image_size(
|
| 1006 |
+
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
|
| 1007 |
+
)
|
| 1008 |
+
elif "max_height" in size and "max_width" in size:
|
| 1009 |
+
new_size = get_image_size_for_max_height_width(
|
| 1010 |
+
image, size["max_height"], size["max_width"], input_data_format=input_data_format
|
| 1011 |
+
)
|
| 1012 |
+
elif "height" in size and "width" in size:
|
| 1013 |
+
new_size = (size["height"], size["width"])
|
| 1014 |
+
else:
|
| 1015 |
+
raise ValueError(
|
| 1016 |
+
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
| 1017 |
+
f" {size.keys()}."
|
| 1018 |
+
)
|
| 1019 |
+
image = resize(
|
| 1020 |
+
image,
|
| 1021 |
+
size=new_size,
|
| 1022 |
+
resample=resample,
|
| 1023 |
+
data_format=data_format,
|
| 1024 |
+
input_data_format=input_data_format,
|
| 1025 |
+
**kwargs,
|
| 1026 |
+
)
|
| 1027 |
+
return image
|
| 1028 |
+
|
| 1029 |
+
def resize_annotation(
|
| 1030 |
+
self,
|
| 1031 |
+
annotation,
|
| 1032 |
+
orig_size,
|
| 1033 |
+
size,
|
| 1034 |
+
resample: PILImageResampling = PILImageResampling.NEAREST,
|
| 1035 |
+
) -> dict:
|
| 1036 |
+
"""
|
| 1037 |
+
Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
|
| 1038 |
+
to this number.
|
| 1039 |
+
"""
|
| 1040 |
+
return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
|
| 1041 |
+
|
| 1042 |
+
# TODO (Amy) - update to use `rescale_factor` instead of `scale`
|
| 1043 |
+
def rescale(
|
| 1044 |
+
self,
|
| 1045 |
+
image: np.ndarray,
|
| 1046 |
+
rescale_factor: float,
|
| 1047 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 1048 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 1049 |
+
) -> np.ndarray:
|
| 1050 |
+
"""
|
| 1051 |
+
Rescale the image by the given factor. image = image * rescale_factor.
|
| 1052 |
+
|
| 1053 |
+
Args:
|
| 1054 |
+
image (`np.ndarray`):
|
| 1055 |
+
Image to rescale.
|
| 1056 |
+
rescale_factor (`float`):
|
| 1057 |
+
The value to use for rescaling.
|
| 1058 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 1059 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 1060 |
+
image is used. Can be one of:
|
| 1061 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 1062 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 1063 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 1064 |
+
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
| 1065 |
+
one of:
|
| 1066 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 1067 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 1068 |
+
"""
|
| 1069 |
+
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
| 1070 |
+
|
| 1071 |
+
def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
|
| 1072 |
+
"""
|
| 1073 |
+
Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
|
| 1074 |
+
`[center_x, center_y, width, height]` format and from absolute to relative pixel values.
|
| 1075 |
+
"""
|
| 1076 |
+
return normalize_annotation(annotation, image_size=image_size)
|
| 1077 |
+
|
| 1078 |
+
def _update_annotation_for_padded_image(
|
| 1079 |
+
self,
|
| 1080 |
+
annotation: dict,
|
| 1081 |
+
input_image_size: tuple[int, int],
|
| 1082 |
+
output_image_size: tuple[int, int],
|
| 1083 |
+
padding,
|
| 1084 |
+
update_bboxes,
|
| 1085 |
+
) -> dict:
|
| 1086 |
+
"""
|
| 1087 |
+
Update the annotation for a padded image.
|
| 1088 |
+
"""
|
| 1089 |
+
new_annotation = {}
|
| 1090 |
+
new_annotation["size"] = output_image_size
|
| 1091 |
+
|
| 1092 |
+
for key, value in annotation.items():
|
| 1093 |
+
if key == "masks":
|
| 1094 |
+
masks = value
|
| 1095 |
+
masks = pad(
|
| 1096 |
+
masks,
|
| 1097 |
+
padding,
|
| 1098 |
+
mode=PaddingMode.CONSTANT,
|
| 1099 |
+
constant_values=0,
|
| 1100 |
+
input_data_format=ChannelDimension.FIRST,
|
| 1101 |
+
)
|
| 1102 |
+
masks = safe_squeeze(masks, 1)
|
| 1103 |
+
new_annotation["masks"] = masks
|
| 1104 |
+
elif key == "boxes" and update_bboxes:
|
| 1105 |
+
boxes = value
|
| 1106 |
+
boxes *= np.asarray(
|
| 1107 |
+
[
|
| 1108 |
+
input_image_size[1] / output_image_size[1],
|
| 1109 |
+
input_image_size[0] / output_image_size[0],
|
| 1110 |
+
input_image_size[1] / output_image_size[1],
|
| 1111 |
+
input_image_size[0] / output_image_size[0],
|
| 1112 |
+
]
|
| 1113 |
+
)
|
| 1114 |
+
new_annotation["boxes"] = boxes
|
| 1115 |
+
elif key == "size":
|
| 1116 |
+
new_annotation["size"] = output_image_size
|
| 1117 |
+
else:
|
| 1118 |
+
new_annotation[key] = value
|
| 1119 |
+
return new_annotation
|
| 1120 |
+
|
| 1121 |
+
def _pad_image(
|
| 1122 |
+
self,
|
| 1123 |
+
image: np.ndarray,
|
| 1124 |
+
output_size: tuple[int, int],
|
| 1125 |
+
annotation: Optional[dict[str, Any]] = None,
|
| 1126 |
+
constant_values: Union[float, Iterable[float]] = 0,
|
| 1127 |
+
data_format: Optional[ChannelDimension] = None,
|
| 1128 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 1129 |
+
update_bboxes: bool = True,
|
| 1130 |
+
) -> np.ndarray:
|
| 1131 |
+
"""
|
| 1132 |
+
Pad an image with zeros to the given size.
|
| 1133 |
+
"""
|
| 1134 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 1135 |
+
output_height, output_width = output_size
|
| 1136 |
+
|
| 1137 |
+
pad_bottom = output_height - input_height
|
| 1138 |
+
pad_right = output_width - input_width
|
| 1139 |
+
padding = ((0, pad_bottom), (0, pad_right))
|
| 1140 |
+
padded_image = pad(
|
| 1141 |
+
image,
|
| 1142 |
+
padding,
|
| 1143 |
+
mode=PaddingMode.CONSTANT,
|
| 1144 |
+
constant_values=constant_values,
|
| 1145 |
+
data_format=data_format,
|
| 1146 |
+
input_data_format=input_data_format,
|
| 1147 |
+
)
|
| 1148 |
+
if annotation is not None:
|
| 1149 |
+
annotation = self._update_annotation_for_padded_image(
|
| 1150 |
+
annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
|
| 1151 |
+
)
|
| 1152 |
+
return padded_image, annotation
|
| 1153 |
+
|
| 1154 |
+
def pad(
|
| 1155 |
+
self,
|
| 1156 |
+
images: list[np.ndarray],
|
| 1157 |
+
annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
|
| 1158 |
+
constant_values: Union[float, Iterable[float]] = 0,
|
| 1159 |
+
return_pixel_mask: bool = True,
|
| 1160 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 1161 |
+
data_format: Optional[ChannelDimension] = None,
|
| 1162 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 1163 |
+
update_bboxes: bool = True,
|
| 1164 |
+
pad_size: Optional[dict[str, int]] = None,
|
| 1165 |
+
) -> BatchFeature:
|
| 1166 |
+
"""
|
| 1167 |
+
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
| 1168 |
+
in the batch and optionally returns their corresponding pixel mask.
|
| 1169 |
+
|
| 1170 |
+
Args:
|
| 1171 |
+
images (list[`np.ndarray`]):
|
| 1172 |
+
Images to pad.
|
| 1173 |
+
annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
|
| 1174 |
+
Annotations to transform according to the padding that is applied to the images.
|
| 1175 |
+
constant_values (`float` or `Iterable[float]`, *optional*):
|
| 1176 |
+
The value to use for the padding if `mode` is `"constant"`.
|
| 1177 |
+
return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
| 1178 |
+
Whether to return a pixel mask.
|
| 1179 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 1180 |
+
The type of tensors to return. Can be one of:
|
| 1181 |
+
- Unset: Return a list of `np.ndarray`.
|
| 1182 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 1183 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 1184 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 1185 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 1186 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 1187 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 1188 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 1189 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 1190 |
+
update_bboxes (`bool`, *optional*, defaults to `True`):
|
| 1191 |
+
Whether to update the bounding boxes in the annotations to match the padded images. If the
|
| 1192 |
+
bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
|
| 1193 |
+
format, the bounding boxes will not be updated.
|
| 1194 |
+
pad_size (`dict[str, int]`, *optional*):
|
| 1195 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 1196 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 1197 |
+
height and width in the batch.
|
| 1198 |
+
"""
|
| 1199 |
+
pad_size = pad_size if pad_size is not None else self.pad_size
|
| 1200 |
+
if pad_size is not None:
|
| 1201 |
+
padded_size = (pad_size["height"], pad_size["width"])
|
| 1202 |
+
else:
|
| 1203 |
+
padded_size = get_max_height_width(images, input_data_format=input_data_format)
|
| 1204 |
+
|
| 1205 |
+
annotation_list = annotations if annotations is not None else [None] * len(images)
|
| 1206 |
+
padded_images = []
|
| 1207 |
+
padded_annotations = []
|
| 1208 |
+
for image, annotation in zip(images, annotation_list):
|
| 1209 |
+
padded_image, padded_annotation = self._pad_image(
|
| 1210 |
+
image,
|
| 1211 |
+
padded_size,
|
| 1212 |
+
annotation,
|
| 1213 |
+
constant_values=constant_values,
|
| 1214 |
+
data_format=data_format,
|
| 1215 |
+
input_data_format=input_data_format,
|
| 1216 |
+
update_bboxes=update_bboxes,
|
| 1217 |
+
)
|
| 1218 |
+
padded_images.append(padded_image)
|
| 1219 |
+
padded_annotations.append(padded_annotation)
|
| 1220 |
+
|
| 1221 |
+
data = {"pixel_values": padded_images}
|
| 1222 |
+
|
| 1223 |
+
if return_pixel_mask:
|
| 1224 |
+
masks = [
|
| 1225 |
+
make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
|
| 1226 |
+
for image in images
|
| 1227 |
+
]
|
| 1228 |
+
data["pixel_mask"] = masks
|
| 1229 |
+
|
| 1230 |
+
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
| 1231 |
+
|
| 1232 |
+
if annotations is not None:
|
| 1233 |
+
encoded_inputs["labels"] = [
|
| 1234 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
|
| 1235 |
+
]
|
| 1236 |
+
|
| 1237 |
+
return encoded_inputs
|
| 1238 |
+
|
| 1239 |
+
def preprocess(
|
| 1240 |
+
self,
|
| 1241 |
+
images: ImageInput,
|
| 1242 |
+
annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
|
| 1243 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 1244 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 1245 |
+
do_resize: Optional[bool] = None,
|
| 1246 |
+
size: Optional[dict[str, int]] = None,
|
| 1247 |
+
resample=None, # PILImageResampling
|
| 1248 |
+
do_rescale: Optional[bool] = None,
|
| 1249 |
+
rescale_factor: Optional[Union[int, float]] = None,
|
| 1250 |
+
do_normalize: Optional[bool] = None,
|
| 1251 |
+
do_convert_annotations: Optional[bool] = None,
|
| 1252 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 1253 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 1254 |
+
do_pad: Optional[bool] = None,
|
| 1255 |
+
format: Optional[Union[str, AnnotationFormat]] = None,
|
| 1256 |
+
return_tensors: Optional[Union[TensorType, str]] = None,
|
| 1257 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 1258 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 1259 |
+
pad_size: Optional[dict[str, int]] = None,
|
| 1260 |
+
**kwargs,
|
| 1261 |
+
) -> BatchFeature:
|
| 1262 |
+
"""
|
| 1263 |
+
Preprocess an image or a batch of images so that it can be used by the model.
|
| 1264 |
+
|
| 1265 |
+
Args:
|
| 1266 |
+
images (`ImageInput`):
|
| 1267 |
+
Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
|
| 1268 |
+
from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 1269 |
+
annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
|
| 1270 |
+
List of annotations associated with the image or batch of images. If annotation is for object
|
| 1271 |
+
detection, the annotations should be a dictionary with the following keys:
|
| 1272 |
+
- "image_id" (`int`): The image id.
|
| 1273 |
+
- "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
|
| 1274 |
+
dictionary. An image can have no annotations, in which case the list should be empty.
|
| 1275 |
+
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
| 1276 |
+
- "image_id" (`int`): The image id.
|
| 1277 |
+
- "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
| 1278 |
+
An image can have no segments, in which case the list should be empty.
|
| 1279 |
+
- "file_name" (`str`): The file name of the image.
|
| 1280 |
+
return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
|
| 1281 |
+
Whether to return segmentation masks.
|
| 1282 |
+
masks_path (`str` or `pathlib.Path`, *optional*):
|
| 1283 |
+
Path to the directory containing the segmentation masks.
|
| 1284 |
+
do_resize (`bool`, *optional*, defaults to self.do_resize):
|
| 1285 |
+
Whether to resize the image.
|
| 1286 |
+
size (`dict[str, int]`, *optional*, defaults to self.size):
|
| 1287 |
+
Size of the image's `(height, width)` dimensions after resizing. Available options are:
|
| 1288 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 1289 |
+
Do NOT keep the aspect ratio.
|
| 1290 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 1291 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 1292 |
+
less or equal to `longest_edge`.
|
| 1293 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 1294 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 1295 |
+
`max_width`.
|
| 1296 |
+
resample (`PILImageResampling`, *optional*, defaults to self.resample):
|
| 1297 |
+
Resampling filter to use when resizing the image.
|
| 1298 |
+
do_rescale (`bool`, *optional*, defaults to self.do_rescale):
|
| 1299 |
+
Whether to rescale the image.
|
| 1300 |
+
rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
|
| 1301 |
+
Rescale factor to use when rescaling the image.
|
| 1302 |
+
do_normalize (`bool`, *optional*, defaults to self.do_normalize):
|
| 1303 |
+
Whether to normalize the image.
|
| 1304 |
+
do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
|
| 1305 |
+
Whether to convert the annotations to the format expected by the model. Converts the bounding
|
| 1306 |
+
boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
|
| 1307 |
+
and in relative coordinates.
|
| 1308 |
+
image_mean (`float` or `list[float]`, *optional*, defaults to self.image_mean):
|
| 1309 |
+
Mean to use when normalizing the image.
|
| 1310 |
+
image_std (`float` or `list[float]`, *optional*, defaults to self.image_std):
|
| 1311 |
+
Standard deviation to use when normalizing the image.
|
| 1312 |
+
do_pad (`bool`, *optional*, defaults to self.do_pad):
|
| 1313 |
+
Whether to pad the image. If `True`, padding will be applied to the bottom and right of
|
| 1314 |
+
the image with zeros. If `pad_size` is provided, the image will be padded to the specified
|
| 1315 |
+
dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 1316 |
+
format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
|
| 1317 |
+
Format of the annotations.
|
| 1318 |
+
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
|
| 1319 |
+
Type of tensors to return. If `None`, will return the list of images.
|
| 1320 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 1321 |
+
The channel dimension format for the output image. Can be one of:
|
| 1322 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 1323 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 1324 |
+
- Unset: Use the channel dimension format of the input image.
|
| 1325 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 1326 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 1327 |
+
from the input image. Can be one of:
|
| 1328 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 1329 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 1330 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 1331 |
+
pad_size (`dict[str, int]`, *optional*):
|
| 1332 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 1333 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 1334 |
+
height and width in the batch.
|
| 1335 |
+
"""
|
| 1336 |
+
if "pad_and_return_pixel_mask" in kwargs:
|
| 1337 |
+
logger.warning_once(
|
| 1338 |
+
"The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
|
| 1339 |
+
"use `do_pad` instead."
|
| 1340 |
+
)
|
| 1341 |
+
do_pad = kwargs.pop("pad_and_return_pixel_mask")
|
| 1342 |
+
|
| 1343 |
+
if "max_size" in kwargs:
|
| 1344 |
+
logger.warning_once(
|
| 1345 |
+
"The `max_size` argument is deprecated and will be removed in a future version, use"
|
| 1346 |
+
" `size['longest_edge']` instead."
|
| 1347 |
+
)
|
| 1348 |
+
size = kwargs.pop("max_size")
|
| 1349 |
+
|
| 1350 |
+
do_resize = self.do_resize if do_resize is None else do_resize
|
| 1351 |
+
size = self.size if size is None else size
|
| 1352 |
+
size = get_size_dict(size=size, default_to_square=False)
|
| 1353 |
+
resample = self.resample if resample is None else resample
|
| 1354 |
+
do_rescale = self.do_rescale if do_rescale is None else do_rescale
|
| 1355 |
+
rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
|
| 1356 |
+
do_normalize = self.do_normalize if do_normalize is None else do_normalize
|
| 1357 |
+
image_mean = self.image_mean if image_mean is None else image_mean
|
| 1358 |
+
image_std = self.image_std if image_std is None else image_std
|
| 1359 |
+
do_convert_annotations = (
|
| 1360 |
+
self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
|
| 1361 |
+
)
|
| 1362 |
+
do_pad = self.do_pad if do_pad is None else do_pad
|
| 1363 |
+
pad_size = self.pad_size if pad_size is None else pad_size
|
| 1364 |
+
format = self.format if format is None else format
|
| 1365 |
+
|
| 1366 |
+
images = make_list_of_images(images)
|
| 1367 |
+
|
| 1368 |
+
if not valid_images(images):
|
| 1369 |
+
raise ValueError(
|
| 1370 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 1371 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 1372 |
+
)
|
| 1373 |
+
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
| 1374 |
+
|
| 1375 |
+
# Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
|
| 1376 |
+
validate_preprocess_arguments(
|
| 1377 |
+
do_rescale=do_rescale,
|
| 1378 |
+
rescale_factor=rescale_factor,
|
| 1379 |
+
do_normalize=do_normalize,
|
| 1380 |
+
image_mean=image_mean,
|
| 1381 |
+
image_std=image_std,
|
| 1382 |
+
do_resize=do_resize,
|
| 1383 |
+
size=size,
|
| 1384 |
+
resample=resample,
|
| 1385 |
+
)
|
| 1386 |
+
|
| 1387 |
+
if annotations is not None and isinstance(annotations, dict):
|
| 1388 |
+
annotations = [annotations]
|
| 1389 |
+
|
| 1390 |
+
if annotations is not None and len(images) != len(annotations):
|
| 1391 |
+
raise ValueError(
|
| 1392 |
+
f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
|
| 1393 |
+
)
|
| 1394 |
+
|
| 1395 |
+
format = AnnotationFormat(format)
|
| 1396 |
+
if annotations is not None:
|
| 1397 |
+
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
|
| 1398 |
+
|
| 1399 |
+
if (
|
| 1400 |
+
masks_path is not None
|
| 1401 |
+
and format == AnnotationFormat.COCO_PANOPTIC
|
| 1402 |
+
and not isinstance(masks_path, (pathlib.Path, str))
|
| 1403 |
+
):
|
| 1404 |
+
raise ValueError(
|
| 1405 |
+
"The path to the directory containing the mask PNG files should be provided as a"
|
| 1406 |
+
f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
|
| 1407 |
+
)
|
| 1408 |
+
|
| 1409 |
+
# All transformations expect numpy arrays
|
| 1410 |
+
images = [to_numpy_array(image) for image in images]
|
| 1411 |
+
|
| 1412 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 1413 |
+
logger.warning_once(
|
| 1414 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 1415 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 1416 |
+
)
|
| 1417 |
+
|
| 1418 |
+
if input_data_format is None:
|
| 1419 |
+
# We assume that all images have the same channel dimension format.
|
| 1420 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 1421 |
+
|
| 1422 |
+
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
| 1423 |
+
if annotations is not None:
|
| 1424 |
+
prepared_images = []
|
| 1425 |
+
prepared_annotations = []
|
| 1426 |
+
for image, target in zip(images, annotations):
|
| 1427 |
+
target = self.prepare_annotation(
|
| 1428 |
+
image,
|
| 1429 |
+
target,
|
| 1430 |
+
format,
|
| 1431 |
+
return_segmentation_masks=return_segmentation_masks,
|
| 1432 |
+
masks_path=masks_path,
|
| 1433 |
+
input_data_format=input_data_format,
|
| 1434 |
+
)
|
| 1435 |
+
prepared_images.append(image)
|
| 1436 |
+
prepared_annotations.append(target)
|
| 1437 |
+
images = prepared_images
|
| 1438 |
+
annotations = prepared_annotations
|
| 1439 |
+
del prepared_images, prepared_annotations
|
| 1440 |
+
|
| 1441 |
+
# transformations
|
| 1442 |
+
if do_resize:
|
| 1443 |
+
if annotations is not None:
|
| 1444 |
+
resized_images, resized_annotations = [], []
|
| 1445 |
+
for image, target in zip(images, annotations):
|
| 1446 |
+
orig_size = get_image_size(image, input_data_format)
|
| 1447 |
+
resized_image = self.resize(
|
| 1448 |
+
image, size=size, resample=resample, input_data_format=input_data_format
|
| 1449 |
+
)
|
| 1450 |
+
resized_annotation = self.resize_annotation(
|
| 1451 |
+
target, orig_size, get_image_size(resized_image, input_data_format)
|
| 1452 |
+
)
|
| 1453 |
+
resized_images.append(resized_image)
|
| 1454 |
+
resized_annotations.append(resized_annotation)
|
| 1455 |
+
images = resized_images
|
| 1456 |
+
annotations = resized_annotations
|
| 1457 |
+
del resized_images, resized_annotations
|
| 1458 |
+
else:
|
| 1459 |
+
images = [
|
| 1460 |
+
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
| 1461 |
+
for image in images
|
| 1462 |
+
]
|
| 1463 |
+
|
| 1464 |
+
if do_rescale:
|
| 1465 |
+
images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
|
| 1466 |
+
|
| 1467 |
+
if do_normalize:
|
| 1468 |
+
images = [
|
| 1469 |
+
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
|
| 1470 |
+
]
|
| 1471 |
+
|
| 1472 |
+
if do_convert_annotations and annotations is not None:
|
| 1473 |
+
annotations = [
|
| 1474 |
+
self.normalize_annotation(annotation, get_image_size(image, input_data_format))
|
| 1475 |
+
for annotation, image in zip(annotations, images)
|
| 1476 |
+
]
|
| 1477 |
+
|
| 1478 |
+
if do_pad:
|
| 1479 |
+
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
| 1480 |
+
encoded_inputs = self.pad(
|
| 1481 |
+
images,
|
| 1482 |
+
annotations=annotations,
|
| 1483 |
+
return_pixel_mask=True,
|
| 1484 |
+
data_format=data_format,
|
| 1485 |
+
input_data_format=input_data_format,
|
| 1486 |
+
update_bboxes=do_convert_annotations,
|
| 1487 |
+
return_tensors=return_tensors,
|
| 1488 |
+
pad_size=pad_size,
|
| 1489 |
+
)
|
| 1490 |
+
else:
|
| 1491 |
+
images = [
|
| 1492 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 1493 |
+
for image in images
|
| 1494 |
+
]
|
| 1495 |
+
encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
| 1496 |
+
if annotations is not None:
|
| 1497 |
+
encoded_inputs["labels"] = [
|
| 1498 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
|
| 1499 |
+
]
|
| 1500 |
+
|
| 1501 |
+
return encoded_inputs
|
| 1502 |
+
|
| 1503 |
+
# POSTPROCESSING METHODS - TODO: add support for other frameworks
|
| 1504 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258
|
| 1505 |
+
def post_process(self, outputs, target_sizes):
|
| 1506 |
+
"""
|
| 1507 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 1508 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 1509 |
+
|
| 1510 |
+
Args:
|
| 1511 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 1512 |
+
Raw outputs of the model.
|
| 1513 |
+
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 1514 |
+
Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
|
| 1515 |
+
original image size (before any data augmentation). For visualization, this should be the image size
|
| 1516 |
+
after data augment, but before padding.
|
| 1517 |
+
Returns:
|
| 1518 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 1519 |
+
in the batch as predicted by the model.
|
| 1520 |
+
"""
|
| 1521 |
+
logger.warning_once(
|
| 1522 |
+
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
|
| 1523 |
+
" `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
|
| 1524 |
+
)
|
| 1525 |
+
|
| 1526 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 1527 |
+
|
| 1528 |
+
if len(out_logits) != len(target_sizes):
|
| 1529 |
+
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
| 1530 |
+
if target_sizes.shape[1] != 2:
|
| 1531 |
+
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
| 1532 |
+
|
| 1533 |
+
prob = nn.functional.softmax(out_logits, -1)
|
| 1534 |
+
scores, labels = prob[..., :-1].max(-1)
|
| 1535 |
+
|
| 1536 |
+
# convert to [x0, y0, x1, y1] format
|
| 1537 |
+
boxes = center_to_corners_format(out_bbox)
|
| 1538 |
+
# and from relative [0, 1] to absolute [0, height] coordinates
|
| 1539 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 1540 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 1541 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 1542 |
+
|
| 1543 |
+
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
|
| 1544 |
+
return results
|
| 1545 |
+
|
| 1546 |
+
def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
|
| 1547 |
+
"""
|
| 1548 |
+
Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch.
|
| 1549 |
+
|
| 1550 |
+
Args:
|
| 1551 |
+
outputs ([`DetrSegmentationOutput`]):
|
| 1552 |
+
Raw outputs of the model.
|
| 1553 |
+
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`):
|
| 1554 |
+
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
|
| 1555 |
+
threshold (`float`, *optional*, defaults to 0.9):
|
| 1556 |
+
Threshold to use to filter out queries.
|
| 1557 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 1558 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 1559 |
+
Returns:
|
| 1560 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
|
| 1561 |
+
in the batch as predicted by the model.
|
| 1562 |
+
"""
|
| 1563 |
+
logger.warning_once(
|
| 1564 |
+
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
|
| 1565 |
+
" `post_process_semantic_segmentation`.",
|
| 1566 |
+
)
|
| 1567 |
+
out_logits, raw_masks = outputs.logits, outputs.pred_masks
|
| 1568 |
+
empty_label = out_logits.shape[-1] - 1
|
| 1569 |
+
preds = []
|
| 1570 |
+
|
| 1571 |
+
def to_tuple(tup):
|
| 1572 |
+
if isinstance(tup, tuple):
|
| 1573 |
+
return tup
|
| 1574 |
+
return tuple(tup.tolist())
|
| 1575 |
+
|
| 1576 |
+
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
| 1577 |
+
# we filter empty queries and detection below threshold
|
| 1578 |
+
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
|
| 1579 |
+
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
|
| 1580 |
+
cur_scores = cur_scores[keep]
|
| 1581 |
+
cur_labels = cur_labels[keep]
|
| 1582 |
+
cur_masks = cur_masks[keep]
|
| 1583 |
+
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
| 1584 |
+
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
|
| 1585 |
+
|
| 1586 |
+
predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
|
| 1587 |
+
preds.append(predictions)
|
| 1588 |
+
return preds
|
| 1589 |
+
|
| 1590 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218
|
| 1591 |
+
def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
|
| 1592 |
+
"""
|
| 1593 |
+
Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports
|
| 1594 |
+
PyTorch.
|
| 1595 |
+
|
| 1596 |
+
Args:
|
| 1597 |
+
results (`list[Dict]`):
|
| 1598 |
+
Results list obtained by [`~DetrImageProcessor.post_process`], to which "masks" results will be added.
|
| 1599 |
+
outputs ([`DetrSegmentationOutput`]):
|
| 1600 |
+
Raw outputs of the model.
|
| 1601 |
+
orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 1602 |
+
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
|
| 1603 |
+
image size (before any data augmentation).
|
| 1604 |
+
max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 1605 |
+
Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the
|
| 1606 |
+
original image size (before any data augmentation).
|
| 1607 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1608 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 1609 |
+
Returns:
|
| 1610 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
|
| 1611 |
+
image in the batch as predicted by the model.
|
| 1612 |
+
"""
|
| 1613 |
+
logger.warning_once(
|
| 1614 |
+
"`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
|
| 1615 |
+
" `post_process_instance_segmentation`.",
|
| 1616 |
+
)
|
| 1617 |
+
|
| 1618 |
+
if len(orig_target_sizes) != len(max_target_sizes):
|
| 1619 |
+
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
|
| 1620 |
+
max_h, max_w = max_target_sizes.max(0)[0].tolist()
|
| 1621 |
+
outputs_masks = outputs.pred_masks.squeeze(2)
|
| 1622 |
+
outputs_masks = nn.functional.interpolate(
|
| 1623 |
+
outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
|
| 1624 |
+
)
|
| 1625 |
+
outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
|
| 1626 |
+
|
| 1627 |
+
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
|
| 1628 |
+
img_h, img_w = t[0], t[1]
|
| 1629 |
+
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
|
| 1630 |
+
results[i]["masks"] = nn.functional.interpolate(
|
| 1631 |
+
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
|
| 1632 |
+
).byte()
|
| 1633 |
+
|
| 1634 |
+
return results
|
| 1635 |
+
|
| 1636 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241
|
| 1637 |
+
def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85):
|
| 1638 |
+
"""
|
| 1639 |
+
Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch.
|
| 1640 |
+
|
| 1641 |
+
Args:
|
| 1642 |
+
outputs ([`DetrSegmentationOutput`]):
|
| 1643 |
+
Raw outputs of the model.
|
| 1644 |
+
processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`):
|
| 1645 |
+
Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data
|
| 1646 |
+
augmentation but before batching.
|
| 1647 |
+
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`, *optional*):
|
| 1648 |
+
Torch Tensor (or list) corresponding to the requested final size `(height, width)` of each prediction.
|
| 1649 |
+
If left to None, it will default to the `processed_sizes`.
|
| 1650 |
+
is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*):
|
| 1651 |
+
Dictionary mapping class indices to either True or False, depending on whether or not they are a thing.
|
| 1652 |
+
If not set, defaults to the `is_thing_map` of COCO panoptic.
|
| 1653 |
+
threshold (`float`, *optional*, defaults to 0.85):
|
| 1654 |
+
Threshold to use to filter out queries.
|
| 1655 |
+
Returns:
|
| 1656 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
|
| 1657 |
+
an image in the batch as predicted by the model.
|
| 1658 |
+
"""
|
| 1659 |
+
logger.warning_once(
|
| 1660 |
+
"`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
|
| 1661 |
+
" `post_process_panoptic_segmentation`.",
|
| 1662 |
+
)
|
| 1663 |
+
if target_sizes is None:
|
| 1664 |
+
target_sizes = processed_sizes
|
| 1665 |
+
if len(processed_sizes) != len(target_sizes):
|
| 1666 |
+
raise ValueError("Make sure to pass in as many processed_sizes as target_sizes")
|
| 1667 |
+
|
| 1668 |
+
if is_thing_map is None:
|
| 1669 |
+
# default to is_thing_map of COCO panoptic
|
| 1670 |
+
is_thing_map = {i: i <= 90 for i in range(201)}
|
| 1671 |
+
|
| 1672 |
+
out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes
|
| 1673 |
+
if not len(out_logits) == len(raw_masks) == len(target_sizes):
|
| 1674 |
+
raise ValueError(
|
| 1675 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
|
| 1676 |
+
)
|
| 1677 |
+
empty_label = out_logits.shape[-1] - 1
|
| 1678 |
+
preds = []
|
| 1679 |
+
|
| 1680 |
+
def to_tuple(tup):
|
| 1681 |
+
if isinstance(tup, tuple):
|
| 1682 |
+
return tup
|
| 1683 |
+
return tuple(tup.tolist())
|
| 1684 |
+
|
| 1685 |
+
for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
|
| 1686 |
+
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
| 1687 |
+
):
|
| 1688 |
+
# we filter empty queries and detection below threshold
|
| 1689 |
+
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
|
| 1690 |
+
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
|
| 1691 |
+
cur_scores = cur_scores[keep]
|
| 1692 |
+
cur_labels = cur_labels[keep]
|
| 1693 |
+
cur_masks = cur_masks[keep]
|
| 1694 |
+
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
| 1695 |
+
cur_boxes = center_to_corners_format(cur_boxes[keep])
|
| 1696 |
+
|
| 1697 |
+
h, w = cur_masks.shape[-2:]
|
| 1698 |
+
if len(cur_boxes) != len(cur_labels):
|
| 1699 |
+
raise ValueError("Not as many boxes as there are classes")
|
| 1700 |
+
|
| 1701 |
+
# It may be that we have several predicted masks for the same stuff class.
|
| 1702 |
+
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
|
| 1703 |
+
cur_masks = cur_masks.flatten(1)
|
| 1704 |
+
stuff_equiv_classes = defaultdict(lambda: [])
|
| 1705 |
+
for k, label in enumerate(cur_labels):
|
| 1706 |
+
if not is_thing_map[label.item()]:
|
| 1707 |
+
stuff_equiv_classes[label.item()].append(k)
|
| 1708 |
+
|
| 1709 |
+
def get_ids_area(masks, scores, dedup=False):
|
| 1710 |
+
# This helper function creates the final panoptic segmentation image
|
| 1711 |
+
# It also returns the area of the masks that appears on the image
|
| 1712 |
+
|
| 1713 |
+
m_id = masks.transpose(0, 1).softmax(-1)
|
| 1714 |
+
|
| 1715 |
+
if m_id.shape[-1] == 0:
|
| 1716 |
+
# We didn't detect any mask :(
|
| 1717 |
+
m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
|
| 1718 |
+
else:
|
| 1719 |
+
m_id = m_id.argmax(-1).view(h, w)
|
| 1720 |
+
|
| 1721 |
+
if dedup:
|
| 1722 |
+
# Merge the masks corresponding to the same stuff class
|
| 1723 |
+
for equiv in stuff_equiv_classes.values():
|
| 1724 |
+
if len(equiv) > 1:
|
| 1725 |
+
for eq_id in equiv:
|
| 1726 |
+
m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
|
| 1727 |
+
|
| 1728 |
+
final_h, final_w = to_tuple(target_size)
|
| 1729 |
+
|
| 1730 |
+
seg_img = PIL.Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy()))
|
| 1731 |
+
seg_img = seg_img.resize(size=(final_w, final_h), resample=PILImageResampling.NEAREST)
|
| 1732 |
+
|
| 1733 |
+
np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
|
| 1734 |
+
np_seg_img = np_seg_img.view(final_h, final_w, 3)
|
| 1735 |
+
np_seg_img = np_seg_img.numpy()
|
| 1736 |
+
|
| 1737 |
+
m_id = torch.from_numpy(rgb_to_id(np_seg_img))
|
| 1738 |
+
|
| 1739 |
+
area = []
|
| 1740 |
+
for i in range(len(scores)):
|
| 1741 |
+
area.append(m_id.eq(i).sum().item())
|
| 1742 |
+
return area, seg_img
|
| 1743 |
+
|
| 1744 |
+
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
|
| 1745 |
+
if cur_labels.numel() > 0:
|
| 1746 |
+
# We know filter empty masks as long as we find some
|
| 1747 |
+
while True:
|
| 1748 |
+
filtered_small = torch.as_tensor(
|
| 1749 |
+
[area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
|
| 1750 |
+
)
|
| 1751 |
+
if filtered_small.any().item():
|
| 1752 |
+
cur_scores = cur_scores[~filtered_small]
|
| 1753 |
+
cur_labels = cur_labels[~filtered_small]
|
| 1754 |
+
cur_masks = cur_masks[~filtered_small]
|
| 1755 |
+
area, seg_img = get_ids_area(cur_masks, cur_scores)
|
| 1756 |
+
else:
|
| 1757 |
+
break
|
| 1758 |
+
|
| 1759 |
+
else:
|
| 1760 |
+
cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
|
| 1761 |
+
|
| 1762 |
+
segments_info = []
|
| 1763 |
+
for i, a in enumerate(area):
|
| 1764 |
+
cat = cur_labels[i].item()
|
| 1765 |
+
segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
|
| 1766 |
+
del cur_labels
|
| 1767 |
+
|
| 1768 |
+
with io.BytesIO() as out:
|
| 1769 |
+
seg_img.save(out, format="PNG")
|
| 1770 |
+
predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
|
| 1771 |
+
preds.append(predictions)
|
| 1772 |
+
return preds
|
| 1773 |
+
|
| 1774 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258
|
| 1775 |
+
def post_process_object_detection(
|
| 1776 |
+
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None
|
| 1777 |
+
):
|
| 1778 |
+
"""
|
| 1779 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 1780 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 1781 |
+
|
| 1782 |
+
Args:
|
| 1783 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 1784 |
+
Raw outputs of the model.
|
| 1785 |
+
threshold (`float`, *optional*):
|
| 1786 |
+
Score threshold to keep object detection predictions.
|
| 1787 |
+
target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
|
| 1788 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
|
| 1789 |
+
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
| 1790 |
+
Returns:
|
| 1791 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 1792 |
+
in the batch as predicted by the model.
|
| 1793 |
+
"""
|
| 1794 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 1795 |
+
|
| 1796 |
+
if target_sizes is not None:
|
| 1797 |
+
if len(out_logits) != len(target_sizes):
|
| 1798 |
+
raise ValueError(
|
| 1799 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 1800 |
+
)
|
| 1801 |
+
|
| 1802 |
+
prob = nn.functional.softmax(out_logits, -1)
|
| 1803 |
+
scores, labels = prob[..., :-1].max(-1)
|
| 1804 |
+
|
| 1805 |
+
# Convert to [x0, y0, x1, y1] format
|
| 1806 |
+
boxes = center_to_corners_format(out_bbox)
|
| 1807 |
+
|
| 1808 |
+
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
| 1809 |
+
if target_sizes is not None:
|
| 1810 |
+
if isinstance(target_sizes, list):
|
| 1811 |
+
img_h = torch.Tensor([i[0] for i in target_sizes])
|
| 1812 |
+
img_w = torch.Tensor([i[1] for i in target_sizes])
|
| 1813 |
+
else:
|
| 1814 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 1815 |
+
|
| 1816 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 1817 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 1818 |
+
|
| 1819 |
+
results = []
|
| 1820 |
+
for s, l, b in zip(scores, labels, boxes):
|
| 1821 |
+
score = s[s > threshold]
|
| 1822 |
+
label = l[s > threshold]
|
| 1823 |
+
box = b[s > threshold]
|
| 1824 |
+
results.append({"scores": score, "labels": label, "boxes": box})
|
| 1825 |
+
|
| 1826 |
+
return results
|
| 1827 |
+
|
| 1828 |
+
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None):
|
| 1829 |
+
"""
|
| 1830 |
+
Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
| 1831 |
+
|
| 1832 |
+
Args:
|
| 1833 |
+
outputs ([`DetrForSegmentation`]):
|
| 1834 |
+
Raw outputs of the model.
|
| 1835 |
+
target_sizes (`list[tuple[int, int]]`, *optional*):
|
| 1836 |
+
A list of tuples (`tuple[int, int]`) containing the target size (height, width) of each image in the
|
| 1837 |
+
batch. If unset, predictions will not be resized.
|
| 1838 |
+
Returns:
|
| 1839 |
+
`list[torch.Tensor]`:
|
| 1840 |
+
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
|
| 1841 |
+
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
|
| 1842 |
+
`torch.Tensor` correspond to a semantic class id.
|
| 1843 |
+
"""
|
| 1844 |
+
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
| 1845 |
+
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
| 1846 |
+
|
| 1847 |
+
# Remove the null class `[..., :-1]`
|
| 1848 |
+
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
|
| 1849 |
+
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 1850 |
+
|
| 1851 |
+
# Semantic segmentation logits of shape (batch_size, num_classes, height, width)
|
| 1852 |
+
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
| 1853 |
+
batch_size = class_queries_logits.shape[0]
|
| 1854 |
+
|
| 1855 |
+
# Resize logits and compute semantic segmentation maps
|
| 1856 |
+
if target_sizes is not None:
|
| 1857 |
+
if batch_size != len(target_sizes):
|
| 1858 |
+
raise ValueError(
|
| 1859 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 1860 |
+
)
|
| 1861 |
+
|
| 1862 |
+
semantic_segmentation = []
|
| 1863 |
+
for idx in range(batch_size):
|
| 1864 |
+
resized_logits = nn.functional.interpolate(
|
| 1865 |
+
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
| 1866 |
+
)
|
| 1867 |
+
semantic_map = resized_logits[0].argmax(dim=0)
|
| 1868 |
+
semantic_segmentation.append(semantic_map)
|
| 1869 |
+
else:
|
| 1870 |
+
semantic_segmentation = segmentation.argmax(dim=1)
|
| 1871 |
+
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
| 1872 |
+
|
| 1873 |
+
return semantic_segmentation
|
| 1874 |
+
|
| 1875 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218
|
| 1876 |
+
def post_process_instance_segmentation(
|
| 1877 |
+
self,
|
| 1878 |
+
outputs,
|
| 1879 |
+
threshold: float = 0.5,
|
| 1880 |
+
mask_threshold: float = 0.5,
|
| 1881 |
+
overlap_mask_area_threshold: float = 0.8,
|
| 1882 |
+
target_sizes: Optional[list[tuple[int, int]]] = None,
|
| 1883 |
+
return_coco_annotation: Optional[bool] = False,
|
| 1884 |
+
) -> list[dict]:
|
| 1885 |
+
"""
|
| 1886 |
+
Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
|
| 1887 |
+
|
| 1888 |
+
Args:
|
| 1889 |
+
outputs ([`DetrForSegmentation`]):
|
| 1890 |
+
Raw outputs of the model.
|
| 1891 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1892 |
+
The probability score threshold to keep predicted instance masks.
|
| 1893 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 1894 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 1895 |
+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
| 1896 |
+
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
| 1897 |
+
instance mask.
|
| 1898 |
+
target_sizes (`list[Tuple]`, *optional*):
|
| 1899 |
+
List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
|
| 1900 |
+
final size (height, width) of each prediction. If unset, predictions will not be resized.
|
| 1901 |
+
return_coco_annotation (`bool`, *optional*):
|
| 1902 |
+
Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
|
| 1903 |
+
format.
|
| 1904 |
+
Returns:
|
| 1905 |
+
`list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
| 1906 |
+
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
| 1907 |
+
`list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
|
| 1908 |
+
`True`. Set to `None` if no mask if found above `threshold`.
|
| 1909 |
+
- **segments_info** -- A dictionary that contains additional information on each segment.
|
| 1910 |
+
- **id** -- An integer representing the `segment_id`.
|
| 1911 |
+
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
| 1912 |
+
- **score** -- Prediction score of segment with `segment_id`.
|
| 1913 |
+
"""
|
| 1914 |
+
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
| 1915 |
+
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
| 1916 |
+
|
| 1917 |
+
batch_size = class_queries_logits.shape[0]
|
| 1918 |
+
num_labels = class_queries_logits.shape[-1] - 1
|
| 1919 |
+
|
| 1920 |
+
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 1921 |
+
|
| 1922 |
+
# Predicted label and score of each query (batch_size, num_queries)
|
| 1923 |
+
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
|
| 1924 |
+
|
| 1925 |
+
# Loop over items in batch size
|
| 1926 |
+
results: list[dict[str, TensorType]] = []
|
| 1927 |
+
|
| 1928 |
+
for i in range(batch_size):
|
| 1929 |
+
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
|
| 1930 |
+
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
| 1931 |
+
)
|
| 1932 |
+
|
| 1933 |
+
# No mask found
|
| 1934 |
+
if mask_probs_item.shape[0] <= 0:
|
| 1935 |
+
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
|
| 1936 |
+
segmentation = torch.zeros((height, width)) - 1
|
| 1937 |
+
results.append({"segmentation": segmentation, "segments_info": []})
|
| 1938 |
+
continue
|
| 1939 |
+
|
| 1940 |
+
# Get segmentation map and segment information of batch item
|
| 1941 |
+
target_size = target_sizes[i] if target_sizes is not None else None
|
| 1942 |
+
segmentation, segments = compute_segments(
|
| 1943 |
+
mask_probs=mask_probs_item,
|
| 1944 |
+
pred_scores=pred_scores_item,
|
| 1945 |
+
pred_labels=pred_labels_item,
|
| 1946 |
+
mask_threshold=mask_threshold,
|
| 1947 |
+
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
| 1948 |
+
label_ids_to_fuse=[],
|
| 1949 |
+
target_size=target_size,
|
| 1950 |
+
)
|
| 1951 |
+
|
| 1952 |
+
# Return segmentation map in run-length encoding (RLE) format
|
| 1953 |
+
if return_coco_annotation:
|
| 1954 |
+
segmentation = convert_segmentation_to_rle(segmentation)
|
| 1955 |
+
|
| 1956 |
+
results.append({"segmentation": segmentation, "segments_info": segments})
|
| 1957 |
+
return results
|
| 1958 |
+
|
| 1959 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241
|
| 1960 |
+
def post_process_panoptic_segmentation(
|
| 1961 |
+
self,
|
| 1962 |
+
outputs,
|
| 1963 |
+
threshold: float = 0.5,
|
| 1964 |
+
mask_threshold: float = 0.5,
|
| 1965 |
+
overlap_mask_area_threshold: float = 0.8,
|
| 1966 |
+
label_ids_to_fuse: Optional[set[int]] = None,
|
| 1967 |
+
target_sizes: Optional[list[tuple[int, int]]] = None,
|
| 1968 |
+
) -> list[dict]:
|
| 1969 |
+
"""
|
| 1970 |
+
Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports
|
| 1971 |
+
PyTorch.
|
| 1972 |
+
|
| 1973 |
+
Args:
|
| 1974 |
+
outputs ([`DetrForSegmentation`]):
|
| 1975 |
+
The outputs from [`DetrForSegmentation`].
|
| 1976 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1977 |
+
The probability score threshold to keep predicted instance masks.
|
| 1978 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 1979 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 1980 |
+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
| 1981 |
+
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
| 1982 |
+
instance mask.
|
| 1983 |
+
label_ids_to_fuse (`Set[int]`, *optional*):
|
| 1984 |
+
The labels in this state will have all their instances be fused together. For instance we could say
|
| 1985 |
+
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
|
| 1986 |
+
set, but not the one for person.
|
| 1987 |
+
target_sizes (`list[Tuple]`, *optional*):
|
| 1988 |
+
List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
|
| 1989 |
+
final size (height, width) of each prediction in batch. If unset, predictions will not be resized.
|
| 1990 |
+
Returns:
|
| 1991 |
+
`list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
| 1992 |
+
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
| 1993 |
+
`None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to
|
| 1994 |
+
the corresponding `target_sizes` entry.
|
| 1995 |
+
- **segments_info** -- A dictionary that contains additional information on each segment.
|
| 1996 |
+
- **id** -- an integer representing the `segment_id`.
|
| 1997 |
+
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
| 1998 |
+
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
|
| 1999 |
+
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
|
| 2000 |
+
- **score** -- Prediction score of segment with `segment_id`.
|
| 2001 |
+
"""
|
| 2002 |
+
|
| 2003 |
+
if label_ids_to_fuse is None:
|
| 2004 |
+
logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.")
|
| 2005 |
+
label_ids_to_fuse = set()
|
| 2006 |
+
|
| 2007 |
+
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
| 2008 |
+
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
| 2009 |
+
|
| 2010 |
+
batch_size = class_queries_logits.shape[0]
|
| 2011 |
+
num_labels = class_queries_logits.shape[-1] - 1
|
| 2012 |
+
|
| 2013 |
+
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 2014 |
+
|
| 2015 |
+
# Predicted label and score of each query (batch_size, num_queries)
|
| 2016 |
+
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
|
| 2017 |
+
|
| 2018 |
+
# Loop over items in batch size
|
| 2019 |
+
results: list[dict[str, TensorType]] = []
|
| 2020 |
+
|
| 2021 |
+
for i in range(batch_size):
|
| 2022 |
+
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
|
| 2023 |
+
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
| 2024 |
+
)
|
| 2025 |
+
|
| 2026 |
+
# No mask found
|
| 2027 |
+
if mask_probs_item.shape[0] <= 0:
|
| 2028 |
+
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
|
| 2029 |
+
segmentation = torch.zeros((height, width)) - 1
|
| 2030 |
+
results.append({"segmentation": segmentation, "segments_info": []})
|
| 2031 |
+
continue
|
| 2032 |
+
|
| 2033 |
+
# Get segmentation map and segment information of batch item
|
| 2034 |
+
target_size = target_sizes[i] if target_sizes is not None else None
|
| 2035 |
+
segmentation, segments = compute_segments(
|
| 2036 |
+
mask_probs=mask_probs_item,
|
| 2037 |
+
pred_scores=pred_scores_item,
|
| 2038 |
+
pred_labels=pred_labels_item,
|
| 2039 |
+
mask_threshold=mask_threshold,
|
| 2040 |
+
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
| 2041 |
+
label_ids_to_fuse=label_ids_to_fuse,
|
| 2042 |
+
target_size=target_size,
|
| 2043 |
+
)
|
| 2044 |
+
|
| 2045 |
+
results.append({"segmentation": segmentation, "segments_info": segments})
|
| 2046 |
+
return results
|
| 2047 |
+
|
| 2048 |
+
|
| 2049 |
+
__all__ = ["DetrImageProcessor"]
|
phivenv/Lib/site-packages/transformers/models/detr/image_processing_detr_fast.py
ADDED
|
@@ -0,0 +1,1291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Image processor class for DETR."""
|
| 16 |
+
|
| 17 |
+
import io
|
| 18 |
+
import pathlib
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from typing import Any, Optional, Union
|
| 21 |
+
|
| 22 |
+
from ...image_processing_utils import BatchFeature, get_size_dict
|
| 23 |
+
from ...image_processing_utils_fast import (
|
| 24 |
+
BaseImageProcessorFast,
|
| 25 |
+
DefaultFastImageProcessorKwargs,
|
| 26 |
+
SizeDict,
|
| 27 |
+
get_image_size_for_max_height_width,
|
| 28 |
+
get_max_height_width,
|
| 29 |
+
safe_squeeze,
|
| 30 |
+
)
|
| 31 |
+
from ...image_transforms import center_to_corners_format, corners_to_center_format, id_to_rgb
|
| 32 |
+
from ...image_utils import (
|
| 33 |
+
IMAGENET_DEFAULT_MEAN,
|
| 34 |
+
IMAGENET_DEFAULT_STD,
|
| 35 |
+
AnnotationFormat,
|
| 36 |
+
AnnotationType,
|
| 37 |
+
ChannelDimension,
|
| 38 |
+
ImageInput,
|
| 39 |
+
PILImageResampling,
|
| 40 |
+
get_image_size,
|
| 41 |
+
validate_annotations,
|
| 42 |
+
)
|
| 43 |
+
from ...processing_utils import Unpack
|
| 44 |
+
from ...utils import (
|
| 45 |
+
TensorType,
|
| 46 |
+
auto_docstring,
|
| 47 |
+
is_torch_available,
|
| 48 |
+
is_torchvision_available,
|
| 49 |
+
is_torchvision_v2_available,
|
| 50 |
+
is_vision_available,
|
| 51 |
+
logging,
|
| 52 |
+
)
|
| 53 |
+
from ...utils.import_utils import requires
|
| 54 |
+
from .image_processing_detr import (
|
| 55 |
+
compute_segments,
|
| 56 |
+
convert_segmentation_to_rle,
|
| 57 |
+
get_size_with_aspect_ratio,
|
| 58 |
+
remove_low_and_no_objects,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if is_torch_available():
|
| 63 |
+
import torch
|
| 64 |
+
from torch import nn
|
| 65 |
+
|
| 66 |
+
if is_vision_available():
|
| 67 |
+
import PIL
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if is_torchvision_v2_available():
|
| 71 |
+
from torchvision.io import read_image
|
| 72 |
+
from torchvision.transforms.v2 import functional as F
|
| 73 |
+
|
| 74 |
+
elif is_torchvision_available():
|
| 75 |
+
from torchvision.io import read_image
|
| 76 |
+
from torchvision.transforms import functional as F
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
logger = logging.get_logger(__name__)
|
| 80 |
+
|
| 81 |
+
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33
|
| 85 |
+
def convert_coco_poly_to_mask(segmentations, height: int, width: int, device: torch.device) -> torch.Tensor:
|
| 86 |
+
"""
|
| 87 |
+
Convert a COCO polygon annotation to a mask.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
segmentations (`list[list[float]]`):
|
| 91 |
+
List of polygons, each polygon represented by a list of x-y coordinates.
|
| 92 |
+
height (`int`):
|
| 93 |
+
Height of the mask.
|
| 94 |
+
width (`int`):
|
| 95 |
+
Width of the mask.
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
from pycocotools import mask as coco_mask
|
| 99 |
+
except ImportError:
|
| 100 |
+
raise ImportError("Pycocotools is not installed in your environment.")
|
| 101 |
+
|
| 102 |
+
masks = []
|
| 103 |
+
for polygons in segmentations:
|
| 104 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
| 105 |
+
mask = coco_mask.decode(rles)
|
| 106 |
+
if len(mask.shape) < 3:
|
| 107 |
+
mask = mask[..., None]
|
| 108 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8, device=device)
|
| 109 |
+
mask = torch.any(mask, axis=2)
|
| 110 |
+
masks.append(mask)
|
| 111 |
+
if masks:
|
| 112 |
+
masks = torch.stack(masks, axis=0)
|
| 113 |
+
else:
|
| 114 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8, device=device)
|
| 115 |
+
|
| 116 |
+
return masks
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50
|
| 120 |
+
def prepare_coco_detection_annotation(
|
| 121 |
+
image,
|
| 122 |
+
target,
|
| 123 |
+
return_segmentation_masks: bool = False,
|
| 124 |
+
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
| 125 |
+
):
|
| 126 |
+
"""
|
| 127 |
+
Convert the target in COCO format into the format expected by DETR.
|
| 128 |
+
"""
|
| 129 |
+
image_height, image_width = image.size()[-2:]
|
| 130 |
+
|
| 131 |
+
image_id = target["image_id"]
|
| 132 |
+
image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
|
| 133 |
+
|
| 134 |
+
# Get all COCO annotations for the given image.
|
| 135 |
+
annotations = target["annotations"]
|
| 136 |
+
classes = []
|
| 137 |
+
area = []
|
| 138 |
+
boxes = []
|
| 139 |
+
keypoints = []
|
| 140 |
+
for obj in annotations:
|
| 141 |
+
if "iscrowd" not in obj or obj["iscrowd"] == 0:
|
| 142 |
+
classes.append(obj["category_id"])
|
| 143 |
+
area.append(obj["area"])
|
| 144 |
+
boxes.append(obj["bbox"])
|
| 145 |
+
if "keypoints" in obj:
|
| 146 |
+
keypoints.append(obj["keypoints"])
|
| 147 |
+
|
| 148 |
+
classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
|
| 149 |
+
area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
|
| 150 |
+
iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
|
| 151 |
+
# guard against no boxes via resizing
|
| 152 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
|
| 153 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 154 |
+
boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
|
| 155 |
+
boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
|
| 156 |
+
|
| 157 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 158 |
+
|
| 159 |
+
new_target = {
|
| 160 |
+
"image_id": image_id,
|
| 161 |
+
"class_labels": classes[keep],
|
| 162 |
+
"boxes": boxes[keep],
|
| 163 |
+
"area": area[keep],
|
| 164 |
+
"iscrowd": iscrowd[keep],
|
| 165 |
+
"orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
if keypoints:
|
| 169 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
|
| 170 |
+
# Apply the keep mask here to filter the relevant annotations
|
| 171 |
+
keypoints = keypoints[keep]
|
| 172 |
+
num_keypoints = keypoints.shape[0]
|
| 173 |
+
keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
|
| 174 |
+
new_target["keypoints"] = keypoints
|
| 175 |
+
|
| 176 |
+
if return_segmentation_masks:
|
| 177 |
+
segmentation_masks = [obj["segmentation"] for obj in annotations]
|
| 178 |
+
masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width, device=image.device)
|
| 179 |
+
new_target["masks"] = masks[keep]
|
| 180 |
+
|
| 181 |
+
return new_target
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
|
| 185 |
+
"""
|
| 186 |
+
Compute the bounding boxes around the provided panoptic segmentation masks.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
masks: masks in format `[number_masks, height, width]` where N is the number of masks
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
|
| 193 |
+
"""
|
| 194 |
+
if masks.numel() == 0:
|
| 195 |
+
return torch.zeros((0, 4), device=masks.device)
|
| 196 |
+
|
| 197 |
+
h, w = masks.shape[-2:]
|
| 198 |
+
y = torch.arange(0, h, dtype=torch.float32, device=masks.device)
|
| 199 |
+
x = torch.arange(0, w, dtype=torch.float32, device=masks.device)
|
| 200 |
+
# see https://github.com/pytorch/pytorch/issues/50276
|
| 201 |
+
y, x = torch.meshgrid(y, x, indexing="ij")
|
| 202 |
+
|
| 203 |
+
x_mask = masks * torch.unsqueeze(x, 0)
|
| 204 |
+
x_max = x_mask.view(x_mask.shape[0], -1).max(-1)[0]
|
| 205 |
+
x_min = (
|
| 206 |
+
torch.where(masks, x.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0]
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
y_mask = masks * torch.unsqueeze(y, 0)
|
| 210 |
+
y_max = y_mask.view(y_mask.shape[0], -1).max(-1)[0]
|
| 211 |
+
y_min = (
|
| 212 |
+
torch.where(masks, y.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0]
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# 2 functions below adapted from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
|
| 219 |
+
# Copyright (c) 2018, Alexander Kirillov
|
| 220 |
+
# All rights reserved.
|
| 221 |
+
def rgb_to_id(color):
|
| 222 |
+
"""
|
| 223 |
+
Converts RGB color to unique ID.
|
| 224 |
+
"""
|
| 225 |
+
if isinstance(color, torch.Tensor) and len(color.shape) == 3:
|
| 226 |
+
if color.dtype == torch.uint8:
|
| 227 |
+
color = color.to(torch.int32)
|
| 228 |
+
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
|
| 229 |
+
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def prepare_coco_panoptic_annotation(
|
| 233 |
+
image: torch.Tensor,
|
| 234 |
+
target: dict,
|
| 235 |
+
masks_path: Union[str, pathlib.Path],
|
| 236 |
+
return_masks: bool = True,
|
| 237 |
+
input_data_format: Union[ChannelDimension, str] = None,
|
| 238 |
+
) -> dict:
|
| 239 |
+
"""
|
| 240 |
+
Prepare a coco panoptic annotation for DETR.
|
| 241 |
+
"""
|
| 242 |
+
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
| 243 |
+
annotation_path = pathlib.Path(masks_path) / target["file_name"]
|
| 244 |
+
|
| 245 |
+
new_target = {}
|
| 246 |
+
new_target["image_id"] = torch.as_tensor(
|
| 247 |
+
[target["image_id"] if "image_id" in target else target["id"]], dtype=torch.int64, device=image.device
|
| 248 |
+
)
|
| 249 |
+
new_target["size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
|
| 250 |
+
new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
|
| 251 |
+
|
| 252 |
+
if "segments_info" in target:
|
| 253 |
+
masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device)
|
| 254 |
+
masks = rgb_to_id(masks)
|
| 255 |
+
|
| 256 |
+
ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device)
|
| 257 |
+
masks = masks == ids[:, None, None]
|
| 258 |
+
masks = masks.to(torch.bool)
|
| 259 |
+
if return_masks:
|
| 260 |
+
new_target["masks"] = masks
|
| 261 |
+
new_target["boxes"] = masks_to_boxes(masks)
|
| 262 |
+
new_target["class_labels"] = torch.as_tensor(
|
| 263 |
+
[segment_info["category_id"] for segment_info in target["segments_info"]],
|
| 264 |
+
dtype=torch.int64,
|
| 265 |
+
device=image.device,
|
| 266 |
+
)
|
| 267 |
+
new_target["iscrowd"] = torch.as_tensor(
|
| 268 |
+
[segment_info["iscrowd"] for segment_info in target["segments_info"]],
|
| 269 |
+
dtype=torch.int64,
|
| 270 |
+
device=image.device,
|
| 271 |
+
)
|
| 272 |
+
new_target["area"] = torch.as_tensor(
|
| 273 |
+
[segment_info["area"] for segment_info in target["segments_info"]],
|
| 274 |
+
dtype=torch.float32,
|
| 275 |
+
device=image.device,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
return new_target
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class DetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
| 282 |
+
r"""
|
| 283 |
+
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
|
| 284 |
+
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
| 285 |
+
do_convert_annotations (`bool`, *optional*, defaults to `True`):
|
| 286 |
+
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
|
| 287 |
+
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
|
| 288 |
+
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
|
| 289 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 290 |
+
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
| 291 |
+
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
| 292 |
+
If `pad_size` is provided, the image will be padded to the specified dimensions.
|
| 293 |
+
Otherwise, the image will be padded to the maximum height and width of the batch.
|
| 294 |
+
pad_size (`dict[str, int]`, *optional*):
|
| 295 |
+
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
| 296 |
+
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
|
| 297 |
+
height and width in the batch.
|
| 298 |
+
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
|
| 299 |
+
Whether to return segmentation masks.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
format: Optional[Union[str, AnnotationFormat]]
|
| 303 |
+
do_convert_annotations: Optional[bool]
|
| 304 |
+
do_pad: Optional[bool]
|
| 305 |
+
pad_size: Optional[dict[str, int]]
|
| 306 |
+
return_segmentation_masks: Optional[bool]
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
@auto_docstring
|
| 310 |
+
@requires(backends=("torchvision", "torch"))
|
| 311 |
+
class DetrImageProcessorFast(BaseImageProcessorFast):
|
| 312 |
+
resample = PILImageResampling.BILINEAR
|
| 313 |
+
image_mean = IMAGENET_DEFAULT_MEAN
|
| 314 |
+
image_std = IMAGENET_DEFAULT_STD
|
| 315 |
+
format = AnnotationFormat.COCO_DETECTION
|
| 316 |
+
do_resize = True
|
| 317 |
+
do_rescale = True
|
| 318 |
+
do_normalize = True
|
| 319 |
+
do_pad = True
|
| 320 |
+
size = {"shortest_edge": 800, "longest_edge": 1333}
|
| 321 |
+
default_to_square = False
|
| 322 |
+
model_input_names = ["pixel_values", "pixel_mask"]
|
| 323 |
+
valid_kwargs = DetrFastImageProcessorKwargs
|
| 324 |
+
|
| 325 |
+
def __init__(self, **kwargs: Unpack[DetrFastImageProcessorKwargs]) -> None:
|
| 326 |
+
if "pad_and_return_pixel_mask" in kwargs:
|
| 327 |
+
kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
|
| 328 |
+
|
| 329 |
+
size = kwargs.pop("size", None)
|
| 330 |
+
if "max_size" in kwargs:
|
| 331 |
+
logger.warning_once(
|
| 332 |
+
"The `max_size` parameter is deprecated and will be removed in v4.26. "
|
| 333 |
+
"Please specify in `size['longest_edge'] instead`.",
|
| 334 |
+
)
|
| 335 |
+
max_size = kwargs.pop("max_size")
|
| 336 |
+
else:
|
| 337 |
+
max_size = None if size is None else 1333
|
| 338 |
+
|
| 339 |
+
size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
|
| 340 |
+
self.size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
| 341 |
+
|
| 342 |
+
# Backwards compatibility
|
| 343 |
+
do_convert_annotations = kwargs.get("do_convert_annotations")
|
| 344 |
+
do_normalize = kwargs.get("do_normalize")
|
| 345 |
+
if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
|
| 346 |
+
self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
|
| 347 |
+
|
| 348 |
+
super().__init__(**kwargs)
|
| 349 |
+
|
| 350 |
+
@classmethod
|
| 351 |
+
def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
|
| 352 |
+
"""
|
| 353 |
+
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
|
| 354 |
+
created using from_dict and kwargs e.g. `DetrImageProcessorFast.from_pretrained(checkpoint, size=600,
|
| 355 |
+
max_size=800)`
|
| 356 |
+
"""
|
| 357 |
+
image_processor_dict = image_processor_dict.copy()
|
| 358 |
+
if "max_size" in kwargs:
|
| 359 |
+
image_processor_dict["max_size"] = kwargs.pop("max_size")
|
| 360 |
+
if "pad_and_return_pixel_mask" in kwargs:
|
| 361 |
+
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
|
| 362 |
+
return super().from_dict(image_processor_dict, **kwargs)
|
| 363 |
+
|
| 364 |
+
def prepare_annotation(
|
| 365 |
+
self,
|
| 366 |
+
image: torch.Tensor,
|
| 367 |
+
target: dict,
|
| 368 |
+
format: Optional[AnnotationFormat] = None,
|
| 369 |
+
return_segmentation_masks: Optional[bool] = None,
|
| 370 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 371 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 372 |
+
) -> dict:
|
| 373 |
+
"""
|
| 374 |
+
Prepare an annotation for feeding into DETR model.
|
| 375 |
+
"""
|
| 376 |
+
format = format if format is not None else self.format
|
| 377 |
+
|
| 378 |
+
if format == AnnotationFormat.COCO_DETECTION:
|
| 379 |
+
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
| 380 |
+
target = prepare_coco_detection_annotation(
|
| 381 |
+
image, target, return_segmentation_masks, input_data_format=input_data_format
|
| 382 |
+
)
|
| 383 |
+
elif format == AnnotationFormat.COCO_PANOPTIC:
|
| 384 |
+
return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
|
| 385 |
+
target = prepare_coco_panoptic_annotation(
|
| 386 |
+
image,
|
| 387 |
+
target,
|
| 388 |
+
masks_path=masks_path,
|
| 389 |
+
return_masks=return_segmentation_masks,
|
| 390 |
+
input_data_format=input_data_format,
|
| 391 |
+
)
|
| 392 |
+
else:
|
| 393 |
+
raise ValueError(f"Format {format} is not supported.")
|
| 394 |
+
return target
|
| 395 |
+
|
| 396 |
+
def resize(
|
| 397 |
+
self,
|
| 398 |
+
image: torch.Tensor,
|
| 399 |
+
size: SizeDict,
|
| 400 |
+
interpolation: "F.InterpolationMode" = None,
|
| 401 |
+
**kwargs,
|
| 402 |
+
) -> torch.Tensor:
|
| 403 |
+
"""
|
| 404 |
+
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
| 405 |
+
int, smaller edge of the image will be matched to this number.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
image (`torch.Tensor`):
|
| 409 |
+
Image to resize.
|
| 410 |
+
size (`SizeDict`):
|
| 411 |
+
Size of the image's `(height, width)` dimensions after resizing. Available options are:
|
| 412 |
+
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
| 413 |
+
Do NOT keep the aspect ratio.
|
| 414 |
+
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
| 415 |
+
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
| 416 |
+
less or equal to `longest_edge`.
|
| 417 |
+
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
| 418 |
+
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
| 419 |
+
`max_width`.
|
| 420 |
+
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
| 421 |
+
Resampling filter to use if resizing the image.
|
| 422 |
+
"""
|
| 423 |
+
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
| 424 |
+
if size.shortest_edge and size.longest_edge:
|
| 425 |
+
# Resize the image so that the shortest edge or the longest edge is of the given size
|
| 426 |
+
# while maintaining the aspect ratio of the original image.
|
| 427 |
+
new_size = get_size_with_aspect_ratio(
|
| 428 |
+
image.size()[-2:],
|
| 429 |
+
size["shortest_edge"],
|
| 430 |
+
size["longest_edge"],
|
| 431 |
+
)
|
| 432 |
+
elif size.max_height and size.max_width:
|
| 433 |
+
new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"])
|
| 434 |
+
elif size.height and size.width:
|
| 435 |
+
new_size = (size["height"], size["width"])
|
| 436 |
+
else:
|
| 437 |
+
raise ValueError(
|
| 438 |
+
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
| 439 |
+
f" {size.keys()}."
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
image = F.resize(
|
| 443 |
+
image,
|
| 444 |
+
size=new_size,
|
| 445 |
+
interpolation=interpolation,
|
| 446 |
+
**kwargs,
|
| 447 |
+
)
|
| 448 |
+
return image
|
| 449 |
+
|
| 450 |
+
def resize_annotation(
|
| 451 |
+
self,
|
| 452 |
+
annotation: dict[str, Any],
|
| 453 |
+
orig_size: tuple[int, int],
|
| 454 |
+
target_size: tuple[int, int],
|
| 455 |
+
threshold: float = 0.5,
|
| 456 |
+
interpolation: "F.InterpolationMode" = None,
|
| 457 |
+
):
|
| 458 |
+
"""
|
| 459 |
+
Resizes an annotation to a target size.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
annotation (`dict[str, Any]`):
|
| 463 |
+
The annotation dictionary.
|
| 464 |
+
orig_size (`tuple[int, int]`):
|
| 465 |
+
The original size of the input image.
|
| 466 |
+
target_size (`tuple[int, int]`):
|
| 467 |
+
The target size of the image, as returned by the preprocessing `resize` step.
|
| 468 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 469 |
+
The threshold used to binarize the segmentation masks.
|
| 470 |
+
resample (`InterpolationMode`, defaults to `F.InterpolationMode.NEAREST_EXACT`):
|
| 471 |
+
The resampling filter to use when resizing the masks.
|
| 472 |
+
"""
|
| 473 |
+
interpolation = (
|
| 474 |
+
interpolation
|
| 475 |
+
if interpolation is not None
|
| 476 |
+
else F.InterpolationMode.NEAREST_EXACT
|
| 477 |
+
if is_torchvision_v2_available()
|
| 478 |
+
else F.InterpolationMode.NEAREST
|
| 479 |
+
)
|
| 480 |
+
ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)]
|
| 481 |
+
|
| 482 |
+
new_annotation = {}
|
| 483 |
+
new_annotation["size"] = target_size
|
| 484 |
+
|
| 485 |
+
for key, value in annotation.items():
|
| 486 |
+
if key == "boxes":
|
| 487 |
+
boxes = value
|
| 488 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 489 |
+
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device
|
| 490 |
+
)
|
| 491 |
+
new_annotation["boxes"] = scaled_boxes
|
| 492 |
+
elif key == "area":
|
| 493 |
+
area = value
|
| 494 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 495 |
+
new_annotation["area"] = scaled_area
|
| 496 |
+
elif key == "masks":
|
| 497 |
+
masks = value[:, None]
|
| 498 |
+
masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks]
|
| 499 |
+
masks = torch.stack(masks).to(torch.float32)
|
| 500 |
+
masks = masks[:, 0] > threshold
|
| 501 |
+
new_annotation["masks"] = masks
|
| 502 |
+
elif key == "size":
|
| 503 |
+
new_annotation["size"] = target_size
|
| 504 |
+
else:
|
| 505 |
+
new_annotation[key] = value
|
| 506 |
+
|
| 507 |
+
return new_annotation
|
| 508 |
+
|
| 509 |
+
def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
|
| 510 |
+
image_height, image_width = image_size
|
| 511 |
+
norm_annotation = {}
|
| 512 |
+
for key, value in annotation.items():
|
| 513 |
+
if key == "boxes":
|
| 514 |
+
boxes = value
|
| 515 |
+
boxes = corners_to_center_format(boxes)
|
| 516 |
+
boxes /= torch.as_tensor(
|
| 517 |
+
[image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device
|
| 518 |
+
)
|
| 519 |
+
norm_annotation[key] = boxes
|
| 520 |
+
else:
|
| 521 |
+
norm_annotation[key] = value
|
| 522 |
+
return norm_annotation
|
| 523 |
+
|
| 524 |
+
def _update_annotation_for_padded_image(
|
| 525 |
+
self,
|
| 526 |
+
annotation: dict,
|
| 527 |
+
input_image_size: tuple[int, int],
|
| 528 |
+
output_image_size: tuple[int, int],
|
| 529 |
+
padding,
|
| 530 |
+
update_bboxes,
|
| 531 |
+
) -> dict:
|
| 532 |
+
"""
|
| 533 |
+
Update the annotation for a padded image.
|
| 534 |
+
"""
|
| 535 |
+
new_annotation = {}
|
| 536 |
+
new_annotation["size"] = output_image_size
|
| 537 |
+
ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size))
|
| 538 |
+
|
| 539 |
+
for key, value in annotation.items():
|
| 540 |
+
if key == "masks":
|
| 541 |
+
masks = value
|
| 542 |
+
masks = F.pad(
|
| 543 |
+
masks,
|
| 544 |
+
padding,
|
| 545 |
+
fill=0,
|
| 546 |
+
)
|
| 547 |
+
masks = safe_squeeze(masks, 1)
|
| 548 |
+
new_annotation["masks"] = masks
|
| 549 |
+
elif key == "boxes" and update_bboxes:
|
| 550 |
+
boxes = value
|
| 551 |
+
boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device)
|
| 552 |
+
new_annotation["boxes"] = boxes
|
| 553 |
+
elif key == "size":
|
| 554 |
+
new_annotation["size"] = output_image_size
|
| 555 |
+
else:
|
| 556 |
+
new_annotation[key] = value
|
| 557 |
+
return new_annotation
|
| 558 |
+
|
| 559 |
+
def pad(
|
| 560 |
+
self,
|
| 561 |
+
image: torch.Tensor,
|
| 562 |
+
padded_size: tuple[int, int],
|
| 563 |
+
annotation: Optional[dict[str, Any]] = None,
|
| 564 |
+
update_bboxes: bool = True,
|
| 565 |
+
fill: int = 0,
|
| 566 |
+
):
|
| 567 |
+
original_size = image.size()[-2:]
|
| 568 |
+
padding_bottom = padded_size[0] - original_size[0]
|
| 569 |
+
padding_right = padded_size[1] - original_size[1]
|
| 570 |
+
if padding_bottom < 0 or padding_right < 0:
|
| 571 |
+
raise ValueError(
|
| 572 |
+
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
|
| 573 |
+
f"original size. Got padded size: {padded_size}, original size: {original_size}."
|
| 574 |
+
)
|
| 575 |
+
if original_size != padded_size:
|
| 576 |
+
padding = [0, 0, padding_right, padding_bottom]
|
| 577 |
+
image = F.pad(image, padding, fill=fill)
|
| 578 |
+
if annotation is not None:
|
| 579 |
+
annotation = self._update_annotation_for_padded_image(
|
| 580 |
+
annotation, original_size, padded_size, padding, update_bboxes
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
# Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
| 584 |
+
pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
|
| 585 |
+
pixel_mask[: original_size[0], : original_size[1]] = 1
|
| 586 |
+
|
| 587 |
+
return image, pixel_mask, annotation
|
| 588 |
+
|
| 589 |
+
@auto_docstring
|
| 590 |
+
def preprocess(
|
| 591 |
+
self,
|
| 592 |
+
images: ImageInput,
|
| 593 |
+
annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
|
| 594 |
+
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
| 595 |
+
**kwargs: Unpack[DetrFastImageProcessorKwargs],
|
| 596 |
+
) -> BatchFeature:
|
| 597 |
+
r"""
|
| 598 |
+
annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
|
| 599 |
+
List of annotations associated with the image or batch of images. If annotation is for object
|
| 600 |
+
detection, the annotations should be a dictionary with the following keys:
|
| 601 |
+
- "image_id" (`int`): The image id.
|
| 602 |
+
- "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
|
| 603 |
+
dictionary. An image can have no annotations, in which case the list should be empty.
|
| 604 |
+
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
| 605 |
+
- "image_id" (`int`): The image id.
|
| 606 |
+
- "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
| 607 |
+
An image can have no segments, in which case the list should be empty.
|
| 608 |
+
- "file_name" (`str`): The file name of the image.
|
| 609 |
+
masks_path (`str` or `pathlib.Path`, *optional*):
|
| 610 |
+
Path to the directory containing the segmentation masks.
|
| 611 |
+
"""
|
| 612 |
+
if "pad_and_return_pixel_mask" in kwargs:
|
| 613 |
+
kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
|
| 614 |
+
logger.warning_once(
|
| 615 |
+
"The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
|
| 616 |
+
"use `do_pad` instead."
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
if "max_size" in kwargs:
|
| 620 |
+
logger.warning_once(
|
| 621 |
+
"The `max_size` argument is deprecated and will be removed in a future version, use"
|
| 622 |
+
" `size['longest_edge']` instead."
|
| 623 |
+
)
|
| 624 |
+
kwargs["size"] = kwargs.pop("max_size")
|
| 625 |
+
|
| 626 |
+
return super().preprocess(images, annotations, masks_path, **kwargs)
|
| 627 |
+
|
| 628 |
+
def _preprocess(
|
| 629 |
+
self,
|
| 630 |
+
images: list["torch.Tensor"],
|
| 631 |
+
annotations: Optional[Union[AnnotationType, list[AnnotationType]]],
|
| 632 |
+
masks_path: Optional[Union[str, pathlib.Path]],
|
| 633 |
+
return_segmentation_masks: bool,
|
| 634 |
+
do_resize: bool,
|
| 635 |
+
size: SizeDict,
|
| 636 |
+
interpolation: Optional["F.InterpolationMode"],
|
| 637 |
+
do_rescale: bool,
|
| 638 |
+
rescale_factor: float,
|
| 639 |
+
do_normalize: bool,
|
| 640 |
+
do_convert_annotations: bool,
|
| 641 |
+
image_mean: Optional[Union[float, list[float]]],
|
| 642 |
+
image_std: Optional[Union[float, list[float]]],
|
| 643 |
+
do_pad: bool,
|
| 644 |
+
pad_size: Optional[dict[str, int]],
|
| 645 |
+
format: Optional[Union[str, AnnotationFormat]],
|
| 646 |
+
return_tensors: Optional[Union[str, TensorType]],
|
| 647 |
+
**kwargs,
|
| 648 |
+
) -> BatchFeature:
|
| 649 |
+
"""
|
| 650 |
+
Preprocess an image or a batch of images so that it can be used by the model.
|
| 651 |
+
"""
|
| 652 |
+
if annotations is not None and isinstance(annotations, dict):
|
| 653 |
+
annotations = [annotations]
|
| 654 |
+
|
| 655 |
+
if annotations is not None and len(images) != len(annotations):
|
| 656 |
+
raise ValueError(
|
| 657 |
+
f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
format = AnnotationFormat(format)
|
| 661 |
+
if annotations is not None:
|
| 662 |
+
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
|
| 663 |
+
|
| 664 |
+
if (
|
| 665 |
+
masks_path is not None
|
| 666 |
+
and format == AnnotationFormat.COCO_PANOPTIC
|
| 667 |
+
and not isinstance(masks_path, (pathlib.Path, str))
|
| 668 |
+
):
|
| 669 |
+
raise ValueError(
|
| 670 |
+
"The path to the directory containing the mask PNG files should be provided as a"
|
| 671 |
+
f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
data = {}
|
| 675 |
+
|
| 676 |
+
processed_images = []
|
| 677 |
+
processed_annotations = []
|
| 678 |
+
pixel_masks = [] # Initialize pixel_masks here
|
| 679 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 680 |
+
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
| 681 |
+
if annotations is not None:
|
| 682 |
+
annotation = self.prepare_annotation(
|
| 683 |
+
image,
|
| 684 |
+
annotation,
|
| 685 |
+
format,
|
| 686 |
+
return_segmentation_masks=return_segmentation_masks,
|
| 687 |
+
masks_path=masks_path,
|
| 688 |
+
input_data_format=ChannelDimension.FIRST,
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
if do_resize:
|
| 692 |
+
resized_image = self.resize(image, size=size, interpolation=interpolation)
|
| 693 |
+
if annotations is not None:
|
| 694 |
+
annotation = self.resize_annotation(
|
| 695 |
+
annotation,
|
| 696 |
+
orig_size=image.size()[-2:],
|
| 697 |
+
target_size=resized_image.size()[-2:],
|
| 698 |
+
)
|
| 699 |
+
image = resized_image
|
| 700 |
+
# Fused rescale and normalize
|
| 701 |
+
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
| 702 |
+
if do_convert_annotations and annotations is not None:
|
| 703 |
+
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
| 704 |
+
|
| 705 |
+
processed_images.append(image)
|
| 706 |
+
processed_annotations.append(annotation)
|
| 707 |
+
images = processed_images
|
| 708 |
+
annotations = processed_annotations if annotations is not None else None
|
| 709 |
+
|
| 710 |
+
if do_pad:
|
| 711 |
+
# depends on all resized image shapes so we need another loop
|
| 712 |
+
if pad_size is not None:
|
| 713 |
+
padded_size = (pad_size["height"], pad_size["width"])
|
| 714 |
+
else:
|
| 715 |
+
padded_size = get_max_height_width(images)
|
| 716 |
+
|
| 717 |
+
padded_images = []
|
| 718 |
+
padded_annotations = []
|
| 719 |
+
for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
|
| 720 |
+
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
| 721 |
+
if padded_size == image.size()[-2:]:
|
| 722 |
+
padded_images.append(image)
|
| 723 |
+
pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
|
| 724 |
+
padded_annotations.append(annotation)
|
| 725 |
+
continue
|
| 726 |
+
image, pixel_mask, annotation = self.pad(
|
| 727 |
+
image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
|
| 728 |
+
)
|
| 729 |
+
padded_images.append(image)
|
| 730 |
+
padded_annotations.append(annotation)
|
| 731 |
+
pixel_masks.append(pixel_mask)
|
| 732 |
+
images = padded_images
|
| 733 |
+
annotations = padded_annotations if annotations is not None else None
|
| 734 |
+
data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
|
| 735 |
+
|
| 736 |
+
data.update({"pixel_values": torch.stack(images, dim=0)})
|
| 737 |
+
encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
|
| 738 |
+
if annotations is not None:
|
| 739 |
+
encoded_inputs["labels"] = [
|
| 740 |
+
BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
|
| 741 |
+
]
|
| 742 |
+
return encoded_inputs
|
| 743 |
+
|
| 744 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process
|
| 745 |
+
def post_process(self, outputs, target_sizes):
|
| 746 |
+
"""
|
| 747 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 748 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 749 |
+
|
| 750 |
+
Args:
|
| 751 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 752 |
+
Raw outputs of the model.
|
| 753 |
+
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 754 |
+
Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
|
| 755 |
+
original image size (before any data augmentation). For visualization, this should be the image size
|
| 756 |
+
after data augment, but before padding.
|
| 757 |
+
Returns:
|
| 758 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 759 |
+
in the batch as predicted by the model.
|
| 760 |
+
"""
|
| 761 |
+
logger.warning_once(
|
| 762 |
+
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
|
| 763 |
+
" `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 767 |
+
|
| 768 |
+
if len(out_logits) != len(target_sizes):
|
| 769 |
+
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
| 770 |
+
if target_sizes.shape[1] != 2:
|
| 771 |
+
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
| 772 |
+
|
| 773 |
+
prob = nn.functional.softmax(out_logits, -1)
|
| 774 |
+
scores, labels = prob[..., :-1].max(-1)
|
| 775 |
+
|
| 776 |
+
# convert to [x0, y0, x1, y1] format
|
| 777 |
+
boxes = center_to_corners_format(out_bbox)
|
| 778 |
+
# and from relative [0, 1] to absolute [0, height] coordinates
|
| 779 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 780 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 781 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 782 |
+
|
| 783 |
+
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
|
| 784 |
+
return results
|
| 785 |
+
|
| 786 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_segmentation
|
| 787 |
+
def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
|
| 788 |
+
"""
|
| 789 |
+
Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch.
|
| 790 |
+
|
| 791 |
+
Args:
|
| 792 |
+
outputs ([`DetrSegmentationOutput`]):
|
| 793 |
+
Raw outputs of the model.
|
| 794 |
+
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`):
|
| 795 |
+
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
|
| 796 |
+
threshold (`float`, *optional*, defaults to 0.9):
|
| 797 |
+
Threshold to use to filter out queries.
|
| 798 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 799 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 800 |
+
Returns:
|
| 801 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
|
| 802 |
+
in the batch as predicted by the model.
|
| 803 |
+
"""
|
| 804 |
+
logger.warning_once(
|
| 805 |
+
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
|
| 806 |
+
" `post_process_semantic_segmentation`.",
|
| 807 |
+
)
|
| 808 |
+
out_logits, raw_masks = outputs.logits, outputs.pred_masks
|
| 809 |
+
empty_label = out_logits.shape[-1] - 1
|
| 810 |
+
preds = []
|
| 811 |
+
|
| 812 |
+
def to_tuple(tup):
|
| 813 |
+
if isinstance(tup, tuple):
|
| 814 |
+
return tup
|
| 815 |
+
return tuple(tup.tolist())
|
| 816 |
+
|
| 817 |
+
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
| 818 |
+
# we filter empty queries and detection below threshold
|
| 819 |
+
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
|
| 820 |
+
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
|
| 821 |
+
cur_scores = cur_scores[keep]
|
| 822 |
+
cur_labels = cur_labels[keep]
|
| 823 |
+
cur_masks = cur_masks[keep]
|
| 824 |
+
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
| 825 |
+
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
|
| 826 |
+
|
| 827 |
+
predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
|
| 828 |
+
preds.append(predictions)
|
| 829 |
+
return preds
|
| 830 |
+
|
| 831 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance
|
| 832 |
+
def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
|
| 833 |
+
"""
|
| 834 |
+
Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports
|
| 835 |
+
PyTorch.
|
| 836 |
+
|
| 837 |
+
Args:
|
| 838 |
+
results (`list[Dict]`):
|
| 839 |
+
Results list obtained by [`~DetrImageProcessor.post_process`], to which "masks" results will be added.
|
| 840 |
+
outputs ([`DetrSegmentationOutput`]):
|
| 841 |
+
Raw outputs of the model.
|
| 842 |
+
orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 843 |
+
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
|
| 844 |
+
image size (before any data augmentation).
|
| 845 |
+
max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
| 846 |
+
Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the
|
| 847 |
+
original image size (before any data augmentation).
|
| 848 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 849 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 850 |
+
Returns:
|
| 851 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
|
| 852 |
+
image in the batch as predicted by the model.
|
| 853 |
+
"""
|
| 854 |
+
logger.warning_once(
|
| 855 |
+
"`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
|
| 856 |
+
" `post_process_instance_segmentation`.",
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
if len(orig_target_sizes) != len(max_target_sizes):
|
| 860 |
+
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
|
| 861 |
+
max_h, max_w = max_target_sizes.max(0)[0].tolist()
|
| 862 |
+
outputs_masks = outputs.pred_masks.squeeze(2)
|
| 863 |
+
outputs_masks = nn.functional.interpolate(
|
| 864 |
+
outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
|
| 865 |
+
)
|
| 866 |
+
outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
|
| 867 |
+
|
| 868 |
+
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
|
| 869 |
+
img_h, img_w = t[0], t[1]
|
| 870 |
+
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
|
| 871 |
+
results[i]["masks"] = nn.functional.interpolate(
|
| 872 |
+
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
|
| 873 |
+
).byte()
|
| 874 |
+
|
| 875 |
+
return results
|
| 876 |
+
|
| 877 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic
|
| 878 |
+
def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85):
|
| 879 |
+
"""
|
| 880 |
+
Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch.
|
| 881 |
+
|
| 882 |
+
Args:
|
| 883 |
+
outputs ([`DetrSegmentationOutput`]):
|
| 884 |
+
Raw outputs of the model.
|
| 885 |
+
processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`):
|
| 886 |
+
Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data
|
| 887 |
+
augmentation but before batching.
|
| 888 |
+
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `list[Tuple]` of length `batch_size`, *optional*):
|
| 889 |
+
Torch Tensor (or list) corresponding to the requested final size `(height, width)` of each prediction.
|
| 890 |
+
If left to None, it will default to the `processed_sizes`.
|
| 891 |
+
is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*):
|
| 892 |
+
Dictionary mapping class indices to either True or False, depending on whether or not they are a thing.
|
| 893 |
+
If not set, defaults to the `is_thing_map` of COCO panoptic.
|
| 894 |
+
threshold (`float`, *optional*, defaults to 0.85):
|
| 895 |
+
Threshold to use to filter out queries.
|
| 896 |
+
Returns:
|
| 897 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
|
| 898 |
+
an image in the batch as predicted by the model.
|
| 899 |
+
"""
|
| 900 |
+
logger.warning_once(
|
| 901 |
+
"`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
|
| 902 |
+
" `post_process_panoptic_segmentation`.",
|
| 903 |
+
)
|
| 904 |
+
if target_sizes is None:
|
| 905 |
+
target_sizes = processed_sizes
|
| 906 |
+
if len(processed_sizes) != len(target_sizes):
|
| 907 |
+
raise ValueError("Make sure to pass in as many processed_sizes as target_sizes")
|
| 908 |
+
|
| 909 |
+
if is_thing_map is None:
|
| 910 |
+
# default to is_thing_map of COCO panoptic
|
| 911 |
+
is_thing_map = {i: i <= 90 for i in range(201)}
|
| 912 |
+
|
| 913 |
+
out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes
|
| 914 |
+
if not len(out_logits) == len(raw_masks) == len(target_sizes):
|
| 915 |
+
raise ValueError(
|
| 916 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
|
| 917 |
+
)
|
| 918 |
+
empty_label = out_logits.shape[-1] - 1
|
| 919 |
+
preds = []
|
| 920 |
+
|
| 921 |
+
def to_tuple(tup):
|
| 922 |
+
if isinstance(tup, tuple):
|
| 923 |
+
return tup
|
| 924 |
+
return tuple(tup.tolist())
|
| 925 |
+
|
| 926 |
+
for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
|
| 927 |
+
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
| 928 |
+
):
|
| 929 |
+
# we filter empty queries and detection below threshold
|
| 930 |
+
cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
|
| 931 |
+
keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
|
| 932 |
+
cur_scores = cur_scores[keep]
|
| 933 |
+
cur_labels = cur_labels[keep]
|
| 934 |
+
cur_masks = cur_masks[keep]
|
| 935 |
+
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
| 936 |
+
cur_boxes = center_to_corners_format(cur_boxes[keep])
|
| 937 |
+
|
| 938 |
+
h, w = cur_masks.shape[-2:]
|
| 939 |
+
if len(cur_boxes) != len(cur_labels):
|
| 940 |
+
raise ValueError("Not as many boxes as there are classes")
|
| 941 |
+
|
| 942 |
+
# It may be that we have several predicted masks for the same stuff class.
|
| 943 |
+
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
|
| 944 |
+
cur_masks = cur_masks.flatten(1)
|
| 945 |
+
stuff_equiv_classes = defaultdict(lambda: [])
|
| 946 |
+
for k, label in enumerate(cur_labels):
|
| 947 |
+
if not is_thing_map[label.item()]:
|
| 948 |
+
stuff_equiv_classes[label.item()].append(k)
|
| 949 |
+
|
| 950 |
+
def get_ids_area(masks, scores, dedup=False):
|
| 951 |
+
# This helper function creates the final panoptic segmentation image
|
| 952 |
+
# It also returns the area of the masks that appears on the image
|
| 953 |
+
|
| 954 |
+
m_id = masks.transpose(0, 1).softmax(-1)
|
| 955 |
+
|
| 956 |
+
if m_id.shape[-1] == 0:
|
| 957 |
+
# We didn't detect any mask :(
|
| 958 |
+
m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
|
| 959 |
+
else:
|
| 960 |
+
m_id = m_id.argmax(-1).view(h, w)
|
| 961 |
+
|
| 962 |
+
if dedup:
|
| 963 |
+
# Merge the masks corresponding to the same stuff class
|
| 964 |
+
for equiv in stuff_equiv_classes.values():
|
| 965 |
+
if len(equiv) > 1:
|
| 966 |
+
for eq_id in equiv:
|
| 967 |
+
m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
|
| 968 |
+
|
| 969 |
+
final_h, final_w = to_tuple(target_size)
|
| 970 |
+
|
| 971 |
+
seg_img = PIL.Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy()))
|
| 972 |
+
seg_img = seg_img.resize(size=(final_w, final_h), resample=PILImageResampling.NEAREST)
|
| 973 |
+
|
| 974 |
+
np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
|
| 975 |
+
np_seg_img = np_seg_img.view(final_h, final_w, 3)
|
| 976 |
+
np_seg_img = np_seg_img.numpy()
|
| 977 |
+
|
| 978 |
+
m_id = torch.from_numpy(rgb_to_id(np_seg_img))
|
| 979 |
+
|
| 980 |
+
area = []
|
| 981 |
+
for i in range(len(scores)):
|
| 982 |
+
area.append(m_id.eq(i).sum().item())
|
| 983 |
+
return area, seg_img
|
| 984 |
+
|
| 985 |
+
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
|
| 986 |
+
if cur_labels.numel() > 0:
|
| 987 |
+
# We know filter empty masks as long as we find some
|
| 988 |
+
while True:
|
| 989 |
+
filtered_small = torch.as_tensor(
|
| 990 |
+
[area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
|
| 991 |
+
)
|
| 992 |
+
if filtered_small.any().item():
|
| 993 |
+
cur_scores = cur_scores[~filtered_small]
|
| 994 |
+
cur_labels = cur_labels[~filtered_small]
|
| 995 |
+
cur_masks = cur_masks[~filtered_small]
|
| 996 |
+
area, seg_img = get_ids_area(cur_masks, cur_scores)
|
| 997 |
+
else:
|
| 998 |
+
break
|
| 999 |
+
|
| 1000 |
+
else:
|
| 1001 |
+
cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
|
| 1002 |
+
|
| 1003 |
+
segments_info = []
|
| 1004 |
+
for i, a in enumerate(area):
|
| 1005 |
+
cat = cur_labels[i].item()
|
| 1006 |
+
segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
|
| 1007 |
+
del cur_labels
|
| 1008 |
+
|
| 1009 |
+
with io.BytesIO() as out:
|
| 1010 |
+
seg_img.save(out, format="PNG")
|
| 1011 |
+
predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
|
| 1012 |
+
preds.append(predictions)
|
| 1013 |
+
return preds
|
| 1014 |
+
|
| 1015 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_object_detection
|
| 1016 |
+
def post_process_object_detection(
|
| 1017 |
+
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None
|
| 1018 |
+
):
|
| 1019 |
+
"""
|
| 1020 |
+
Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 1021 |
+
bottom_right_x, bottom_right_y) format. Only supports PyTorch.
|
| 1022 |
+
|
| 1023 |
+
Args:
|
| 1024 |
+
outputs ([`DetrObjectDetectionOutput`]):
|
| 1025 |
+
Raw outputs of the model.
|
| 1026 |
+
threshold (`float`, *optional*):
|
| 1027 |
+
Score threshold to keep object detection predictions.
|
| 1028 |
+
target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
|
| 1029 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
|
| 1030 |
+
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
| 1031 |
+
Returns:
|
| 1032 |
+
`list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
| 1033 |
+
in the batch as predicted by the model.
|
| 1034 |
+
"""
|
| 1035 |
+
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
| 1036 |
+
|
| 1037 |
+
if target_sizes is not None:
|
| 1038 |
+
if len(out_logits) != len(target_sizes):
|
| 1039 |
+
raise ValueError(
|
| 1040 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 1041 |
+
)
|
| 1042 |
+
|
| 1043 |
+
prob = nn.functional.softmax(out_logits, -1)
|
| 1044 |
+
scores, labels = prob[..., :-1].max(-1)
|
| 1045 |
+
|
| 1046 |
+
# Convert to [x0, y0, x1, y1] format
|
| 1047 |
+
boxes = center_to_corners_format(out_bbox)
|
| 1048 |
+
|
| 1049 |
+
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
| 1050 |
+
if target_sizes is not None:
|
| 1051 |
+
if isinstance(target_sizes, list):
|
| 1052 |
+
img_h = torch.Tensor([i[0] for i in target_sizes])
|
| 1053 |
+
img_w = torch.Tensor([i[1] for i in target_sizes])
|
| 1054 |
+
else:
|
| 1055 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 1056 |
+
|
| 1057 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
|
| 1058 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 1059 |
+
|
| 1060 |
+
results = []
|
| 1061 |
+
for s, l, b in zip(scores, labels, boxes):
|
| 1062 |
+
score = s[s > threshold]
|
| 1063 |
+
label = l[s > threshold]
|
| 1064 |
+
box = b[s > threshold]
|
| 1065 |
+
results.append({"scores": score, "labels": label, "boxes": box})
|
| 1066 |
+
|
| 1067 |
+
return results
|
| 1068 |
+
|
| 1069 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_semantic_segmentation
|
| 1070 |
+
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None):
|
| 1071 |
+
"""
|
| 1072 |
+
Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
| 1073 |
+
|
| 1074 |
+
Args:
|
| 1075 |
+
outputs ([`DetrForSegmentation`]):
|
| 1076 |
+
Raw outputs of the model.
|
| 1077 |
+
target_sizes (`list[tuple[int, int]]`, *optional*):
|
| 1078 |
+
A list of tuples (`tuple[int, int]`) containing the target size (height, width) of each image in the
|
| 1079 |
+
batch. If unset, predictions will not be resized.
|
| 1080 |
+
Returns:
|
| 1081 |
+
`list[torch.Tensor]`:
|
| 1082 |
+
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
|
| 1083 |
+
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
|
| 1084 |
+
`torch.Tensor` correspond to a semantic class id.
|
| 1085 |
+
"""
|
| 1086 |
+
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
| 1087 |
+
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
| 1088 |
+
|
| 1089 |
+
# Remove the null class `[..., :-1]`
|
| 1090 |
+
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
|
| 1091 |
+
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 1092 |
+
|
| 1093 |
+
# Semantic segmentation logits of shape (batch_size, num_classes, height, width)
|
| 1094 |
+
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
| 1095 |
+
batch_size = class_queries_logits.shape[0]
|
| 1096 |
+
|
| 1097 |
+
# Resize logits and compute semantic segmentation maps
|
| 1098 |
+
if target_sizes is not None:
|
| 1099 |
+
if batch_size != len(target_sizes):
|
| 1100 |
+
raise ValueError(
|
| 1101 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
semantic_segmentation = []
|
| 1105 |
+
for idx in range(batch_size):
|
| 1106 |
+
resized_logits = nn.functional.interpolate(
|
| 1107 |
+
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
| 1108 |
+
)
|
| 1109 |
+
semantic_map = resized_logits[0].argmax(dim=0)
|
| 1110 |
+
semantic_segmentation.append(semantic_map)
|
| 1111 |
+
else:
|
| 1112 |
+
semantic_segmentation = segmentation.argmax(dim=1)
|
| 1113 |
+
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
| 1114 |
+
|
| 1115 |
+
return semantic_segmentation
|
| 1116 |
+
|
| 1117 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_instance_segmentation
|
| 1118 |
+
def post_process_instance_segmentation(
|
| 1119 |
+
self,
|
| 1120 |
+
outputs,
|
| 1121 |
+
threshold: float = 0.5,
|
| 1122 |
+
mask_threshold: float = 0.5,
|
| 1123 |
+
overlap_mask_area_threshold: float = 0.8,
|
| 1124 |
+
target_sizes: Optional[list[tuple[int, int]]] = None,
|
| 1125 |
+
return_coco_annotation: Optional[bool] = False,
|
| 1126 |
+
) -> list[dict]:
|
| 1127 |
+
"""
|
| 1128 |
+
Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
|
| 1129 |
+
|
| 1130 |
+
Args:
|
| 1131 |
+
outputs ([`DetrForSegmentation`]):
|
| 1132 |
+
Raw outputs of the model.
|
| 1133 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1134 |
+
The probability score threshold to keep predicted instance masks.
|
| 1135 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 1136 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 1137 |
+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
| 1138 |
+
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
| 1139 |
+
instance mask.
|
| 1140 |
+
target_sizes (`list[Tuple]`, *optional*):
|
| 1141 |
+
List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
|
| 1142 |
+
final size (height, width) of each prediction. If unset, predictions will not be resized.
|
| 1143 |
+
return_coco_annotation (`bool`, *optional*):
|
| 1144 |
+
Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
|
| 1145 |
+
format.
|
| 1146 |
+
Returns:
|
| 1147 |
+
`list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
| 1148 |
+
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
| 1149 |
+
`list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
|
| 1150 |
+
`True`. Set to `None` if no mask if found above `threshold`.
|
| 1151 |
+
- **segments_info** -- A dictionary that contains additional information on each segment.
|
| 1152 |
+
- **id** -- An integer representing the `segment_id`.
|
| 1153 |
+
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
| 1154 |
+
- **score** -- Prediction score of segment with `segment_id`.
|
| 1155 |
+
"""
|
| 1156 |
+
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
| 1157 |
+
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
| 1158 |
+
|
| 1159 |
+
batch_size = class_queries_logits.shape[0]
|
| 1160 |
+
num_labels = class_queries_logits.shape[-1] - 1
|
| 1161 |
+
|
| 1162 |
+
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 1163 |
+
|
| 1164 |
+
# Predicted label and score of each query (batch_size, num_queries)
|
| 1165 |
+
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
|
| 1166 |
+
|
| 1167 |
+
# Loop over items in batch size
|
| 1168 |
+
results: list[dict[str, TensorType]] = []
|
| 1169 |
+
|
| 1170 |
+
for i in range(batch_size):
|
| 1171 |
+
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
|
| 1172 |
+
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
# No mask found
|
| 1176 |
+
if mask_probs_item.shape[0] <= 0:
|
| 1177 |
+
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
|
| 1178 |
+
segmentation = torch.zeros((height, width)) - 1
|
| 1179 |
+
results.append({"segmentation": segmentation, "segments_info": []})
|
| 1180 |
+
continue
|
| 1181 |
+
|
| 1182 |
+
# Get segmentation map and segment information of batch item
|
| 1183 |
+
target_size = target_sizes[i] if target_sizes is not None else None
|
| 1184 |
+
segmentation, segments = compute_segments(
|
| 1185 |
+
mask_probs=mask_probs_item,
|
| 1186 |
+
pred_scores=pred_scores_item,
|
| 1187 |
+
pred_labels=pred_labels_item,
|
| 1188 |
+
mask_threshold=mask_threshold,
|
| 1189 |
+
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
| 1190 |
+
label_ids_to_fuse=[],
|
| 1191 |
+
target_size=target_size,
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
# Return segmentation map in run-length encoding (RLE) format
|
| 1195 |
+
if return_coco_annotation:
|
| 1196 |
+
segmentation = convert_segmentation_to_rle(segmentation)
|
| 1197 |
+
|
| 1198 |
+
results.append({"segmentation": segmentation, "segments_info": segments})
|
| 1199 |
+
return results
|
| 1200 |
+
|
| 1201 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.post_process_panoptic_segmentation
|
| 1202 |
+
def post_process_panoptic_segmentation(
|
| 1203 |
+
self,
|
| 1204 |
+
outputs,
|
| 1205 |
+
threshold: float = 0.5,
|
| 1206 |
+
mask_threshold: float = 0.5,
|
| 1207 |
+
overlap_mask_area_threshold: float = 0.8,
|
| 1208 |
+
label_ids_to_fuse: Optional[set[int]] = None,
|
| 1209 |
+
target_sizes: Optional[list[tuple[int, int]]] = None,
|
| 1210 |
+
) -> list[dict]:
|
| 1211 |
+
"""
|
| 1212 |
+
Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports
|
| 1213 |
+
PyTorch.
|
| 1214 |
+
|
| 1215 |
+
Args:
|
| 1216 |
+
outputs ([`DetrForSegmentation`]):
|
| 1217 |
+
The outputs from [`DetrForSegmentation`].
|
| 1218 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1219 |
+
The probability score threshold to keep predicted instance masks.
|
| 1220 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 1221 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 1222 |
+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
| 1223 |
+
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
| 1224 |
+
instance mask.
|
| 1225 |
+
label_ids_to_fuse (`Set[int]`, *optional*):
|
| 1226 |
+
The labels in this state will have all their instances be fused together. For instance we could say
|
| 1227 |
+
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
|
| 1228 |
+
set, but not the one for person.
|
| 1229 |
+
target_sizes (`list[Tuple]`, *optional*):
|
| 1230 |
+
List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested
|
| 1231 |
+
final size (height, width) of each prediction in batch. If unset, predictions will not be resized.
|
| 1232 |
+
Returns:
|
| 1233 |
+
`list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
| 1234 |
+
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or
|
| 1235 |
+
`None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to
|
| 1236 |
+
the corresponding `target_sizes` entry.
|
| 1237 |
+
- **segments_info** -- A dictionary that contains additional information on each segment.
|
| 1238 |
+
- **id** -- an integer representing the `segment_id`.
|
| 1239 |
+
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
| 1240 |
+
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
|
| 1241 |
+
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
|
| 1242 |
+
- **score** -- Prediction score of segment with `segment_id`.
|
| 1243 |
+
"""
|
| 1244 |
+
|
| 1245 |
+
if label_ids_to_fuse is None:
|
| 1246 |
+
logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.")
|
| 1247 |
+
label_ids_to_fuse = set()
|
| 1248 |
+
|
| 1249 |
+
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
|
| 1250 |
+
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
|
| 1251 |
+
|
| 1252 |
+
batch_size = class_queries_logits.shape[0]
|
| 1253 |
+
num_labels = class_queries_logits.shape[-1] - 1
|
| 1254 |
+
|
| 1255 |
+
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 1256 |
+
|
| 1257 |
+
# Predicted label and score of each query (batch_size, num_queries)
|
| 1258 |
+
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
|
| 1259 |
+
|
| 1260 |
+
# Loop over items in batch size
|
| 1261 |
+
results: list[dict[str, TensorType]] = []
|
| 1262 |
+
|
| 1263 |
+
for i in range(batch_size):
|
| 1264 |
+
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
|
| 1265 |
+
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
| 1266 |
+
)
|
| 1267 |
+
|
| 1268 |
+
# No mask found
|
| 1269 |
+
if mask_probs_item.shape[0] <= 0:
|
| 1270 |
+
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
|
| 1271 |
+
segmentation = torch.zeros((height, width)) - 1
|
| 1272 |
+
results.append({"segmentation": segmentation, "segments_info": []})
|
| 1273 |
+
continue
|
| 1274 |
+
|
| 1275 |
+
# Get segmentation map and segment information of batch item
|
| 1276 |
+
target_size = target_sizes[i] if target_sizes is not None else None
|
| 1277 |
+
segmentation, segments = compute_segments(
|
| 1278 |
+
mask_probs=mask_probs_item,
|
| 1279 |
+
pred_scores=pred_scores_item,
|
| 1280 |
+
pred_labels=pred_labels_item,
|
| 1281 |
+
mask_threshold=mask_threshold,
|
| 1282 |
+
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
| 1283 |
+
label_ids_to_fuse=label_ids_to_fuse,
|
| 1284 |
+
target_size=target_size,
|
| 1285 |
+
)
|
| 1286 |
+
|
| 1287 |
+
results.append({"segmentation": segmentation, "segments_info": segments})
|
| 1288 |
+
return results
|
| 1289 |
+
|
| 1290 |
+
|
| 1291 |
+
__all__ = ["DetrImageProcessorFast"]
|
phivenv/Lib/site-packages/transformers/models/detr/modeling_detr.py
ADDED
|
@@ -0,0 +1,1693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch DETR model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import Tensor, nn
|
| 23 |
+
|
| 24 |
+
from ...activations import ACT2FN
|
| 25 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
| 26 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 27 |
+
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
|
| 28 |
+
from ...modeling_utils import PreTrainedModel
|
| 29 |
+
from ...utils import (
|
| 30 |
+
ModelOutput,
|
| 31 |
+
auto_docstring,
|
| 32 |
+
is_timm_available,
|
| 33 |
+
logging,
|
| 34 |
+
requires_backends,
|
| 35 |
+
)
|
| 36 |
+
from ...utils.backbone_utils import load_backbone
|
| 37 |
+
from .configuration_detr import DetrConfig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if is_timm_available():
|
| 41 |
+
from timm import create_model
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
@auto_docstring(
|
| 49 |
+
custom_intro="""
|
| 50 |
+
Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
|
| 51 |
+
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
| 52 |
+
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
| 53 |
+
"""
|
| 54 |
+
)
|
| 55 |
+
class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
| 56 |
+
r"""
|
| 57 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
| 58 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 59 |
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
| 60 |
+
used to compute the weighted average in the cross-attention heads.
|
| 61 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
| 62 |
+
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
| 63 |
+
layernorm.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclass
|
| 70 |
+
@auto_docstring(
|
| 71 |
+
custom_intro="""
|
| 72 |
+
Base class for outputs of the DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
|
| 73 |
+
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
|
| 74 |
+
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
|
| 75 |
+
"""
|
| 76 |
+
)
|
| 77 |
+
class DetrModelOutput(Seq2SeqModelOutput):
|
| 78 |
+
r"""
|
| 79 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 80 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 81 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
| 82 |
+
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
| 83 |
+
layernorm.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@dataclass
|
| 90 |
+
@auto_docstring(
|
| 91 |
+
custom_intro="""
|
| 92 |
+
Output type of [`DetrForObjectDetection`].
|
| 93 |
+
"""
|
| 94 |
+
)
|
| 95 |
+
class DetrObjectDetectionOutput(ModelOutput):
|
| 96 |
+
r"""
|
| 97 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
| 98 |
+
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
| 99 |
+
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
| 100 |
+
scale-invariant IoU loss.
|
| 101 |
+
loss_dict (`Dict`, *optional*):
|
| 102 |
+
A dictionary containing the individual losses. Useful for logging.
|
| 103 |
+
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
|
| 104 |
+
Classification logits (including no-object) for all queries.
|
| 105 |
+
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 106 |
+
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
| 107 |
+
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
| 108 |
+
possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
|
| 109 |
+
unnormalized bounding boxes.
|
| 110 |
+
auxiliary_outputs (`list[Dict]`, *optional*):
|
| 111 |
+
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
| 112 |
+
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
| 113 |
+
`pred_boxes`) for each decoder layer.
|
| 114 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 115 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
loss: Optional[torch.FloatTensor] = None
|
| 119 |
+
loss_dict: Optional[dict] = None
|
| 120 |
+
logits: Optional[torch.FloatTensor] = None
|
| 121 |
+
pred_boxes: Optional[torch.FloatTensor] = None
|
| 122 |
+
auxiliary_outputs: Optional[list[dict]] = None
|
| 123 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 124 |
+
decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 125 |
+
decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 126 |
+
cross_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 127 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 128 |
+
encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 129 |
+
encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@dataclass
|
| 133 |
+
@auto_docstring(
|
| 134 |
+
custom_intro="""
|
| 135 |
+
Output type of [`DetrForSegmentation`].
|
| 136 |
+
"""
|
| 137 |
+
)
|
| 138 |
+
class DetrSegmentationOutput(ModelOutput):
|
| 139 |
+
r"""
|
| 140 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
| 141 |
+
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
| 142 |
+
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
| 143 |
+
scale-invariant IoU loss.
|
| 144 |
+
loss_dict (`Dict`, *optional*):
|
| 145 |
+
A dictionary containing the individual losses. Useful for logging.
|
| 146 |
+
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
|
| 147 |
+
Classification logits (including no-object) for all queries.
|
| 148 |
+
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 149 |
+
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
| 150 |
+
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
| 151 |
+
possible padding). You can use [`~DetrImageProcessor.post_process_object_detection`] to retrieve the
|
| 152 |
+
unnormalized bounding boxes.
|
| 153 |
+
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
|
| 154 |
+
Segmentation masks logits for all queries. See also
|
| 155 |
+
[`~DetrImageProcessor.post_process_semantic_segmentation`] or
|
| 156 |
+
[`~DetrImageProcessor.post_process_instance_segmentation`]
|
| 157 |
+
[`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
|
| 158 |
+
segmentation masks respectively.
|
| 159 |
+
auxiliary_outputs (`list[Dict]`, *optional*):
|
| 160 |
+
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
| 161 |
+
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
| 162 |
+
`pred_boxes`) for each decoder layer.
|
| 163 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 164 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
loss: Optional[torch.FloatTensor] = None
|
| 168 |
+
loss_dict: Optional[dict] = None
|
| 169 |
+
logits: Optional[torch.FloatTensor] = None
|
| 170 |
+
pred_boxes: Optional[torch.FloatTensor] = None
|
| 171 |
+
pred_masks: Optional[torch.FloatTensor] = None
|
| 172 |
+
auxiliary_outputs: Optional[list[dict]] = None
|
| 173 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 174 |
+
decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 175 |
+
decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 176 |
+
cross_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 177 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 178 |
+
encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 179 |
+
encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# BELOW: utilities copied from
|
| 183 |
+
# https://github.com/facebookresearch/detr/blob/master/backbone.py
|
| 184 |
+
class DetrFrozenBatchNorm2d(nn.Module):
|
| 185 |
+
"""
|
| 186 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 187 |
+
|
| 188 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
| 189 |
+
torchvision.models.resnet[18,34,50,101] produce nans.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, n):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.register_buffer("weight", torch.ones(n))
|
| 195 |
+
self.register_buffer("bias", torch.zeros(n))
|
| 196 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
| 197 |
+
self.register_buffer("running_var", torch.ones(n))
|
| 198 |
+
|
| 199 |
+
def _load_from_state_dict(
|
| 200 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 201 |
+
):
|
| 202 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
| 203 |
+
if num_batches_tracked_key in state_dict:
|
| 204 |
+
del state_dict[num_batches_tracked_key]
|
| 205 |
+
|
| 206 |
+
super()._load_from_state_dict(
|
| 207 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def forward(self, x):
|
| 211 |
+
# move reshapes to the beginning
|
| 212 |
+
# to make it user-friendly
|
| 213 |
+
weight = self.weight.reshape(1, -1, 1, 1)
|
| 214 |
+
bias = self.bias.reshape(1, -1, 1, 1)
|
| 215 |
+
running_var = self.running_var.reshape(1, -1, 1, 1)
|
| 216 |
+
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
| 217 |
+
epsilon = 1e-5
|
| 218 |
+
scale = weight * (running_var + epsilon).rsqrt()
|
| 219 |
+
bias = bias - running_mean * scale
|
| 220 |
+
return x * scale + bias
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def replace_batch_norm(model):
|
| 224 |
+
r"""
|
| 225 |
+
Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
model (torch.nn.Module):
|
| 229 |
+
input model
|
| 230 |
+
"""
|
| 231 |
+
for name, module in model.named_children():
|
| 232 |
+
if isinstance(module, nn.BatchNorm2d):
|
| 233 |
+
new_module = DetrFrozenBatchNorm2d(module.num_features)
|
| 234 |
+
|
| 235 |
+
if module.weight.device != torch.device("meta"):
|
| 236 |
+
new_module.weight.data.copy_(module.weight)
|
| 237 |
+
new_module.bias.data.copy_(module.bias)
|
| 238 |
+
new_module.running_mean.data.copy_(module.running_mean)
|
| 239 |
+
new_module.running_var.data.copy_(module.running_var)
|
| 240 |
+
|
| 241 |
+
model._modules[name] = new_module
|
| 242 |
+
|
| 243 |
+
if len(list(module.children())) > 0:
|
| 244 |
+
replace_batch_norm(module)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class DetrConvEncoder(nn.Module):
|
| 248 |
+
"""
|
| 249 |
+
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
| 250 |
+
|
| 251 |
+
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
|
| 252 |
+
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def __init__(self, config):
|
| 256 |
+
super().__init__()
|
| 257 |
+
|
| 258 |
+
self.config = config
|
| 259 |
+
|
| 260 |
+
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
|
| 261 |
+
if config.use_timm_backbone:
|
| 262 |
+
# We default to values which were previously hard-coded. This enables configurability from the config
|
| 263 |
+
# using backbone arguments, while keeping the default behavior the same.
|
| 264 |
+
requires_backends(self, ["timm"])
|
| 265 |
+
kwargs = getattr(config, "backbone_kwargs", {})
|
| 266 |
+
kwargs = {} if kwargs is None else kwargs.copy()
|
| 267 |
+
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
|
| 268 |
+
num_channels = kwargs.pop("in_chans", config.num_channels)
|
| 269 |
+
if config.dilation:
|
| 270 |
+
kwargs["output_stride"] = kwargs.get("output_stride", 16)
|
| 271 |
+
backbone = create_model(
|
| 272 |
+
config.backbone,
|
| 273 |
+
pretrained=config.use_pretrained_backbone,
|
| 274 |
+
features_only=True,
|
| 275 |
+
out_indices=out_indices,
|
| 276 |
+
in_chans=num_channels,
|
| 277 |
+
**kwargs,
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
backbone = load_backbone(config)
|
| 281 |
+
|
| 282 |
+
# replace batch norm by frozen batch norm
|
| 283 |
+
with torch.no_grad():
|
| 284 |
+
replace_batch_norm(backbone)
|
| 285 |
+
self.model = backbone
|
| 286 |
+
self.intermediate_channel_sizes = (
|
| 287 |
+
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
backbone_model_type = None
|
| 291 |
+
if config.backbone is not None:
|
| 292 |
+
backbone_model_type = config.backbone
|
| 293 |
+
elif config.backbone_config is not None:
|
| 294 |
+
backbone_model_type = config.backbone_config.model_type
|
| 295 |
+
else:
|
| 296 |
+
raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
|
| 297 |
+
|
| 298 |
+
if "resnet" in backbone_model_type:
|
| 299 |
+
for name, parameter in self.model.named_parameters():
|
| 300 |
+
if config.use_timm_backbone:
|
| 301 |
+
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
| 302 |
+
parameter.requires_grad_(False)
|
| 303 |
+
else:
|
| 304 |
+
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
|
| 305 |
+
parameter.requires_grad_(False)
|
| 306 |
+
|
| 307 |
+
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
| 308 |
+
# send pixel_values through the model to get list of feature maps
|
| 309 |
+
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
|
| 310 |
+
|
| 311 |
+
out = []
|
| 312 |
+
for feature_map in features:
|
| 313 |
+
# downsample pixel_mask to match shape of corresponding feature_map
|
| 314 |
+
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
|
| 315 |
+
out.append((feature_map, mask))
|
| 316 |
+
return out
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class DetrConvModel(nn.Module):
|
| 320 |
+
"""
|
| 321 |
+
This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
def __init__(self, conv_encoder, position_embedding):
|
| 325 |
+
super().__init__()
|
| 326 |
+
self.conv_encoder = conv_encoder
|
| 327 |
+
self.position_embedding = position_embedding
|
| 328 |
+
|
| 329 |
+
def forward(self, pixel_values, pixel_mask):
|
| 330 |
+
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
|
| 331 |
+
out = self.conv_encoder(pixel_values, pixel_mask)
|
| 332 |
+
pos = []
|
| 333 |
+
for feature_map, mask in out:
|
| 334 |
+
# position encoding
|
| 335 |
+
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
|
| 336 |
+
|
| 337 |
+
return out, pos
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class DetrSinePositionEmbedding(nn.Module):
|
| 341 |
+
"""
|
| 342 |
+
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
| 343 |
+
need paper, generalized to work on images.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
|
| 347 |
+
super().__init__()
|
| 348 |
+
self.embedding_dim = embedding_dim
|
| 349 |
+
self.temperature = temperature
|
| 350 |
+
self.normalize = normalize
|
| 351 |
+
if scale is not None and normalize is False:
|
| 352 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 353 |
+
if scale is None:
|
| 354 |
+
scale = 2 * math.pi
|
| 355 |
+
self.scale = scale
|
| 356 |
+
|
| 357 |
+
def forward(self, pixel_values, pixel_mask):
|
| 358 |
+
if pixel_mask is None:
|
| 359 |
+
raise ValueError("No pixel mask provided")
|
| 360 |
+
y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
|
| 361 |
+
x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
|
| 362 |
+
if self.normalize:
|
| 363 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
|
| 364 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
|
| 365 |
+
|
| 366 |
+
dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
|
| 367 |
+
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
|
| 368 |
+
|
| 369 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 370 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 371 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 372 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 373 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 374 |
+
return pos
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class DetrLearnedPositionEmbedding(nn.Module):
|
| 378 |
+
"""
|
| 379 |
+
This module learns positional embeddings up to a fixed maximum size.
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
def __init__(self, embedding_dim=256):
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.row_embeddings = nn.Embedding(50, embedding_dim)
|
| 385 |
+
self.column_embeddings = nn.Embedding(50, embedding_dim)
|
| 386 |
+
|
| 387 |
+
def forward(self, pixel_values, pixel_mask=None):
|
| 388 |
+
height, width = pixel_values.shape[-2:]
|
| 389 |
+
width_values = torch.arange(width, device=pixel_values.device)
|
| 390 |
+
height_values = torch.arange(height, device=pixel_values.device)
|
| 391 |
+
x_emb = self.column_embeddings(width_values)
|
| 392 |
+
y_emb = self.row_embeddings(height_values)
|
| 393 |
+
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
|
| 394 |
+
pos = pos.permute(2, 0, 1)
|
| 395 |
+
pos = pos.unsqueeze(0)
|
| 396 |
+
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
|
| 397 |
+
return pos
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def build_position_encoding(config):
|
| 401 |
+
n_steps = config.d_model // 2
|
| 402 |
+
if config.position_embedding_type == "sine":
|
| 403 |
+
# TODO find a better way of exposing other arguments
|
| 404 |
+
position_embedding = DetrSinePositionEmbedding(n_steps, normalize=True)
|
| 405 |
+
elif config.position_embedding_type == "learned":
|
| 406 |
+
position_embedding = DetrLearnedPositionEmbedding(n_steps)
|
| 407 |
+
else:
|
| 408 |
+
raise ValueError(f"Not supported {config.position_embedding_type}")
|
| 409 |
+
|
| 410 |
+
return position_embedding
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class DetrAttention(nn.Module):
|
| 414 |
+
"""
|
| 415 |
+
Multi-headed attention from 'Attention Is All You Need' paper.
|
| 416 |
+
|
| 417 |
+
Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
|
| 418 |
+
"""
|
| 419 |
+
|
| 420 |
+
def __init__(
|
| 421 |
+
self,
|
| 422 |
+
embed_dim: int,
|
| 423 |
+
num_heads: int,
|
| 424 |
+
dropout: float = 0.0,
|
| 425 |
+
bias: bool = True,
|
| 426 |
+
):
|
| 427 |
+
super().__init__()
|
| 428 |
+
self.embed_dim = embed_dim
|
| 429 |
+
self.num_heads = num_heads
|
| 430 |
+
self.dropout = dropout
|
| 431 |
+
self.head_dim = embed_dim // num_heads
|
| 432 |
+
if self.head_dim * num_heads != self.embed_dim:
|
| 433 |
+
raise ValueError(
|
| 434 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 435 |
+
f" {num_heads})."
|
| 436 |
+
)
|
| 437 |
+
self.scaling = self.head_dim**-0.5
|
| 438 |
+
|
| 439 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 440 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 441 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 442 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 443 |
+
|
| 444 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
| 445 |
+
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 446 |
+
|
| 447 |
+
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
|
| 448 |
+
return tensor if object_queries is None else tensor + object_queries
|
| 449 |
+
|
| 450 |
+
def forward(
|
| 451 |
+
self,
|
| 452 |
+
hidden_states: torch.Tensor,
|
| 453 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 454 |
+
object_queries: Optional[torch.Tensor] = None,
|
| 455 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 456 |
+
spatial_position_embeddings: Optional[torch.Tensor] = None,
|
| 457 |
+
output_attentions: bool = False,
|
| 458 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 459 |
+
"""Input shape: Batch x Time x Channel"""
|
| 460 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 461 |
+
# for the decoder
|
| 462 |
+
is_cross_attention = key_value_states is not None
|
| 463 |
+
batch_size, target_len, embed_dim = hidden_states.size()
|
| 464 |
+
|
| 465 |
+
# add position embeddings to the hidden states before projecting to queries and keys
|
| 466 |
+
if object_queries is not None:
|
| 467 |
+
hidden_states_original = hidden_states
|
| 468 |
+
hidden_states = self.with_pos_embed(hidden_states, object_queries)
|
| 469 |
+
|
| 470 |
+
# add key-value position embeddings to the key value states
|
| 471 |
+
if spatial_position_embeddings is not None:
|
| 472 |
+
key_value_states_original = key_value_states
|
| 473 |
+
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
|
| 474 |
+
|
| 475 |
+
# get query proj
|
| 476 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 477 |
+
# get key, value proj
|
| 478 |
+
if is_cross_attention:
|
| 479 |
+
# cross_attentions
|
| 480 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
|
| 481 |
+
value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
|
| 482 |
+
else:
|
| 483 |
+
# self_attention
|
| 484 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
|
| 485 |
+
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
|
| 486 |
+
|
| 487 |
+
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
|
| 488 |
+
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
|
| 489 |
+
key_states = key_states.view(*proj_shape)
|
| 490 |
+
value_states = value_states.view(*proj_shape)
|
| 491 |
+
|
| 492 |
+
source_len = key_states.size(1)
|
| 493 |
+
|
| 494 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 495 |
+
|
| 496 |
+
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
|
| 497 |
+
raise ValueError(
|
| 498 |
+
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
|
| 499 |
+
f" {attn_weights.size()}"
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
if attention_mask is not None:
|
| 503 |
+
if attention_mask.size() != (batch_size, 1, target_len, source_len):
|
| 504 |
+
raise ValueError(
|
| 505 |
+
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
| 506 |
+
f" {attention_mask.size()}"
|
| 507 |
+
)
|
| 508 |
+
if attention_mask.dtype == torch.bool:
|
| 509 |
+
attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
|
| 510 |
+
attention_mask, -torch.inf
|
| 511 |
+
)
|
| 512 |
+
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
| 513 |
+
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
| 514 |
+
|
| 515 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 516 |
+
|
| 517 |
+
if output_attentions:
|
| 518 |
+
# this operation is a bit awkward, but it's required to
|
| 519 |
+
# make sure that attn_weights keeps its gradient.
|
| 520 |
+
# In order to do so, attn_weights have to reshaped
|
| 521 |
+
# twice and have to be reused in the following
|
| 522 |
+
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
| 523 |
+
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
| 524 |
+
else:
|
| 525 |
+
attn_weights_reshaped = None
|
| 526 |
+
|
| 527 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 528 |
+
|
| 529 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 530 |
+
|
| 531 |
+
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
|
| 532 |
+
raise ValueError(
|
| 533 |
+
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
|
| 534 |
+
f" {attn_output.size()}"
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
|
| 538 |
+
attn_output = attn_output.transpose(1, 2)
|
| 539 |
+
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
|
| 540 |
+
|
| 541 |
+
attn_output = self.out_proj(attn_output)
|
| 542 |
+
|
| 543 |
+
return attn_output, attn_weights_reshaped
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class DetrEncoderLayer(nn.Module):
|
| 547 |
+
def __init__(self, config: DetrConfig):
|
| 548 |
+
super().__init__()
|
| 549 |
+
self.embed_dim = config.d_model
|
| 550 |
+
self.self_attn = DetrAttention(
|
| 551 |
+
embed_dim=self.embed_dim,
|
| 552 |
+
num_heads=config.encoder_attention_heads,
|
| 553 |
+
dropout=config.attention_dropout,
|
| 554 |
+
)
|
| 555 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 556 |
+
self.dropout = config.dropout
|
| 557 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
| 558 |
+
self.activation_dropout = config.activation_dropout
|
| 559 |
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
| 560 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
| 561 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 562 |
+
|
| 563 |
+
def forward(
|
| 564 |
+
self,
|
| 565 |
+
hidden_states: torch.Tensor,
|
| 566 |
+
attention_mask: torch.Tensor,
|
| 567 |
+
object_queries: Optional[torch.Tensor] = None,
|
| 568 |
+
output_attentions: bool = False,
|
| 569 |
+
):
|
| 570 |
+
"""
|
| 571 |
+
Args:
|
| 572 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 573 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 574 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 575 |
+
values.
|
| 576 |
+
object_queries (`torch.FloatTensor`, *optional*):
|
| 577 |
+
Object queries (also called content embeddings), to be added to the hidden states.
|
| 578 |
+
output_attentions (`bool`, *optional*):
|
| 579 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 580 |
+
returned tensors for more detail.
|
| 581 |
+
"""
|
| 582 |
+
residual = hidden_states
|
| 583 |
+
hidden_states, attn_weights = self.self_attn(
|
| 584 |
+
hidden_states=hidden_states,
|
| 585 |
+
attention_mask=attention_mask,
|
| 586 |
+
object_queries=object_queries,
|
| 587 |
+
output_attentions=output_attentions,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 591 |
+
hidden_states = residual + hidden_states
|
| 592 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 593 |
+
|
| 594 |
+
residual = hidden_states
|
| 595 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 596 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 597 |
+
|
| 598 |
+
hidden_states = self.fc2(hidden_states)
|
| 599 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 600 |
+
|
| 601 |
+
hidden_states = residual + hidden_states
|
| 602 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 603 |
+
|
| 604 |
+
if self.training:
|
| 605 |
+
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
| 606 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 607 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
| 608 |
+
|
| 609 |
+
outputs = (hidden_states,)
|
| 610 |
+
|
| 611 |
+
if output_attentions:
|
| 612 |
+
outputs += (attn_weights,)
|
| 613 |
+
|
| 614 |
+
return outputs
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class DetrDecoderLayer(GradientCheckpointingLayer):
|
| 618 |
+
def __init__(self, config: DetrConfig):
|
| 619 |
+
super().__init__()
|
| 620 |
+
self.embed_dim = config.d_model
|
| 621 |
+
|
| 622 |
+
self.self_attn = DetrAttention(
|
| 623 |
+
embed_dim=self.embed_dim,
|
| 624 |
+
num_heads=config.decoder_attention_heads,
|
| 625 |
+
dropout=config.attention_dropout,
|
| 626 |
+
)
|
| 627 |
+
self.dropout = config.dropout
|
| 628 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
| 629 |
+
self.activation_dropout = config.activation_dropout
|
| 630 |
+
|
| 631 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 632 |
+
self.encoder_attn = DetrAttention(
|
| 633 |
+
self.embed_dim,
|
| 634 |
+
config.decoder_attention_heads,
|
| 635 |
+
dropout=config.attention_dropout,
|
| 636 |
+
)
|
| 637 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 638 |
+
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
| 639 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
| 640 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 641 |
+
|
| 642 |
+
def forward(
|
| 643 |
+
self,
|
| 644 |
+
hidden_states: torch.Tensor,
|
| 645 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 646 |
+
object_queries: Optional[torch.Tensor] = None,
|
| 647 |
+
query_position_embeddings: Optional[torch.Tensor] = None,
|
| 648 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 649 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 650 |
+
output_attentions: Optional[bool] = False,
|
| 651 |
+
):
|
| 652 |
+
"""
|
| 653 |
+
Args:
|
| 654 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 655 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 656 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 657 |
+
values.
|
| 658 |
+
object_queries (`torch.FloatTensor`, *optional*):
|
| 659 |
+
object_queries that are added to the hidden states
|
| 660 |
+
in the cross-attention layer.
|
| 661 |
+
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
| 662 |
+
position embeddings that are added to the queries and keys
|
| 663 |
+
in the self-attention layer.
|
| 664 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 665 |
+
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 666 |
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
| 667 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 668 |
+
values.
|
| 669 |
+
output_attentions (`bool`, *optional*):
|
| 670 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 671 |
+
returned tensors for more detail.
|
| 672 |
+
"""
|
| 673 |
+
residual = hidden_states
|
| 674 |
+
|
| 675 |
+
# Self Attention
|
| 676 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 677 |
+
hidden_states=hidden_states,
|
| 678 |
+
object_queries=query_position_embeddings,
|
| 679 |
+
attention_mask=attention_mask,
|
| 680 |
+
output_attentions=output_attentions,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 684 |
+
hidden_states = residual + hidden_states
|
| 685 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 686 |
+
|
| 687 |
+
# Cross-Attention Block
|
| 688 |
+
cross_attn_weights = None
|
| 689 |
+
if encoder_hidden_states is not None:
|
| 690 |
+
residual = hidden_states
|
| 691 |
+
|
| 692 |
+
hidden_states, cross_attn_weights = self.encoder_attn(
|
| 693 |
+
hidden_states=hidden_states,
|
| 694 |
+
object_queries=query_position_embeddings,
|
| 695 |
+
key_value_states=encoder_hidden_states,
|
| 696 |
+
attention_mask=encoder_attention_mask,
|
| 697 |
+
spatial_position_embeddings=object_queries,
|
| 698 |
+
output_attentions=output_attentions,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 702 |
+
hidden_states = residual + hidden_states
|
| 703 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
| 704 |
+
|
| 705 |
+
# Fully Connected
|
| 706 |
+
residual = hidden_states
|
| 707 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 708 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 709 |
+
hidden_states = self.fc2(hidden_states)
|
| 710 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 711 |
+
hidden_states = residual + hidden_states
|
| 712 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 713 |
+
|
| 714 |
+
outputs = (hidden_states,)
|
| 715 |
+
|
| 716 |
+
if output_attentions:
|
| 717 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
| 718 |
+
|
| 719 |
+
return outputs
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
@auto_docstring
|
| 723 |
+
class DetrPreTrainedModel(PreTrainedModel):
|
| 724 |
+
config: DetrConfig
|
| 725 |
+
base_model_prefix = "model"
|
| 726 |
+
main_input_name = "pixel_values"
|
| 727 |
+
_no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
|
| 728 |
+
|
| 729 |
+
def _init_weights(self, module):
|
| 730 |
+
std = self.config.init_std
|
| 731 |
+
xavier_std = self.config.init_xavier_std
|
| 732 |
+
|
| 733 |
+
if isinstance(module, DetrMHAttentionMap):
|
| 734 |
+
nn.init.zeros_(module.k_linear.bias)
|
| 735 |
+
nn.init.zeros_(module.q_linear.bias)
|
| 736 |
+
nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
|
| 737 |
+
nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
|
| 738 |
+
elif isinstance(module, DetrLearnedPositionEmbedding):
|
| 739 |
+
nn.init.uniform_(module.row_embeddings.weight)
|
| 740 |
+
nn.init.uniform_(module.column_embeddings.weight)
|
| 741 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
| 742 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 743 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 744 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 745 |
+
if module.bias is not None:
|
| 746 |
+
module.bias.data.zero_()
|
| 747 |
+
elif isinstance(module, nn.Embedding):
|
| 748 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 749 |
+
if module.padding_idx is not None:
|
| 750 |
+
module.weight.data[module.padding_idx].zero_()
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
class DetrEncoder(DetrPreTrainedModel):
|
| 754 |
+
"""
|
| 755 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
| 756 |
+
[`DetrEncoderLayer`].
|
| 757 |
+
|
| 758 |
+
The encoder updates the flattened feature map through multiple self-attention layers.
|
| 759 |
+
|
| 760 |
+
Small tweak for DETR:
|
| 761 |
+
|
| 762 |
+
- object_queries are added to the forward pass.
|
| 763 |
+
|
| 764 |
+
Args:
|
| 765 |
+
config: DetrConfig
|
| 766 |
+
"""
|
| 767 |
+
|
| 768 |
+
def __init__(self, config: DetrConfig):
|
| 769 |
+
super().__init__(config)
|
| 770 |
+
|
| 771 |
+
self.dropout = config.dropout
|
| 772 |
+
self.layerdrop = config.encoder_layerdrop
|
| 773 |
+
|
| 774 |
+
self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
| 775 |
+
|
| 776 |
+
# in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
|
| 777 |
+
|
| 778 |
+
# Initialize weights and apply final processing
|
| 779 |
+
self.post_init()
|
| 780 |
+
|
| 781 |
+
def forward(
|
| 782 |
+
self,
|
| 783 |
+
inputs_embeds=None,
|
| 784 |
+
attention_mask=None,
|
| 785 |
+
object_queries=None,
|
| 786 |
+
output_attentions=None,
|
| 787 |
+
output_hidden_states=None,
|
| 788 |
+
return_dict=None,
|
| 789 |
+
):
|
| 790 |
+
r"""
|
| 791 |
+
Args:
|
| 792 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 793 |
+
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
| 794 |
+
|
| 795 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 796 |
+
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
| 797 |
+
|
| 798 |
+
- 1 for pixel features that are real (i.e. **not masked**),
|
| 799 |
+
- 0 for pixel features that are padding (i.e. **masked**).
|
| 800 |
+
|
| 801 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 802 |
+
|
| 803 |
+
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 804 |
+
Object queries that are added to the queries in each self-attention layer.
|
| 805 |
+
|
| 806 |
+
output_attentions (`bool`, *optional*):
|
| 807 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 808 |
+
returned tensors for more detail.
|
| 809 |
+
output_hidden_states (`bool`, *optional*):
|
| 810 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 811 |
+
for more detail.
|
| 812 |
+
return_dict (`bool`, *optional*):
|
| 813 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 814 |
+
"""
|
| 815 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 816 |
+
output_hidden_states = (
|
| 817 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 818 |
+
)
|
| 819 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 820 |
+
|
| 821 |
+
hidden_states = inputs_embeds
|
| 822 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 823 |
+
|
| 824 |
+
# expand attention_mask
|
| 825 |
+
if attention_mask is not None:
|
| 826 |
+
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
| 827 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
| 828 |
+
|
| 829 |
+
encoder_states = () if output_hidden_states else None
|
| 830 |
+
all_attentions = () if output_attentions else None
|
| 831 |
+
for i, encoder_layer in enumerate(self.layers):
|
| 832 |
+
if output_hidden_states:
|
| 833 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 834 |
+
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
| 835 |
+
to_drop = False
|
| 836 |
+
if self.training:
|
| 837 |
+
dropout_probability = torch.rand([])
|
| 838 |
+
if dropout_probability < self.layerdrop: # skip the layer
|
| 839 |
+
to_drop = True
|
| 840 |
+
|
| 841 |
+
if to_drop:
|
| 842 |
+
layer_outputs = (None, None)
|
| 843 |
+
else:
|
| 844 |
+
# we add object_queries as extra input to the encoder_layer
|
| 845 |
+
layer_outputs = encoder_layer(
|
| 846 |
+
hidden_states,
|
| 847 |
+
attention_mask,
|
| 848 |
+
object_queries=object_queries,
|
| 849 |
+
output_attentions=output_attentions,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
hidden_states = layer_outputs[0]
|
| 853 |
+
|
| 854 |
+
if output_attentions:
|
| 855 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 856 |
+
|
| 857 |
+
if output_hidden_states:
|
| 858 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 859 |
+
|
| 860 |
+
if not return_dict:
|
| 861 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 862 |
+
return BaseModelOutput(
|
| 863 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
class DetrDecoder(DetrPreTrainedModel):
|
| 868 |
+
"""
|
| 869 |
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
|
| 870 |
+
|
| 871 |
+
The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
|
| 872 |
+
|
| 873 |
+
Some small tweaks for DETR:
|
| 874 |
+
|
| 875 |
+
- object_queries and query_position_embeddings are added to the forward pass.
|
| 876 |
+
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
| 877 |
+
|
| 878 |
+
Args:
|
| 879 |
+
config: DetrConfig
|
| 880 |
+
"""
|
| 881 |
+
|
| 882 |
+
def __init__(self, config: DetrConfig):
|
| 883 |
+
super().__init__(config)
|
| 884 |
+
self.dropout = config.dropout
|
| 885 |
+
self.layerdrop = config.decoder_layerdrop
|
| 886 |
+
|
| 887 |
+
self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
| 888 |
+
# in DETR, the decoder uses layernorm after the last decoder layer output
|
| 889 |
+
self.layernorm = nn.LayerNorm(config.d_model)
|
| 890 |
+
|
| 891 |
+
self.gradient_checkpointing = False
|
| 892 |
+
# Initialize weights and apply final processing
|
| 893 |
+
self.post_init()
|
| 894 |
+
|
| 895 |
+
def forward(
|
| 896 |
+
self,
|
| 897 |
+
inputs_embeds=None,
|
| 898 |
+
attention_mask=None,
|
| 899 |
+
encoder_hidden_states=None,
|
| 900 |
+
encoder_attention_mask=None,
|
| 901 |
+
object_queries=None,
|
| 902 |
+
query_position_embeddings=None,
|
| 903 |
+
output_attentions=None,
|
| 904 |
+
output_hidden_states=None,
|
| 905 |
+
return_dict=None,
|
| 906 |
+
):
|
| 907 |
+
r"""
|
| 908 |
+
Args:
|
| 909 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 910 |
+
The query embeddings that are passed into the decoder.
|
| 911 |
+
|
| 912 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 913 |
+
Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
|
| 914 |
+
|
| 915 |
+
- 1 for queries that are **not masked**,
|
| 916 |
+
- 0 for queries that are **masked**.
|
| 917 |
+
|
| 918 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 919 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
| 920 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
| 921 |
+
of the decoder.
|
| 922 |
+
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
|
| 923 |
+
Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
|
| 924 |
+
in `[0, 1]`:
|
| 925 |
+
|
| 926 |
+
- 1 for pixels that are real (i.e. **not masked**),
|
| 927 |
+
- 0 for pixels that are padding (i.e. **masked**).
|
| 928 |
+
|
| 929 |
+
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 930 |
+
Object queries that are added to the queries and keys in each cross-attention layer.
|
| 931 |
+
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
| 932 |
+
, *optional*): Position embeddings that are added to the values and keys in each self-attention layer.
|
| 933 |
+
|
| 934 |
+
output_attentions (`bool`, *optional*):
|
| 935 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 936 |
+
returned tensors for more detail.
|
| 937 |
+
output_hidden_states (`bool`, *optional*):
|
| 938 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 939 |
+
for more detail.
|
| 940 |
+
return_dict (`bool`, *optional*):
|
| 941 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 942 |
+
"""
|
| 943 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 944 |
+
output_hidden_states = (
|
| 945 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 946 |
+
)
|
| 947 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 948 |
+
|
| 949 |
+
if inputs_embeds is not None:
|
| 950 |
+
hidden_states = inputs_embeds
|
| 951 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 952 |
+
|
| 953 |
+
combined_attention_mask = None
|
| 954 |
+
|
| 955 |
+
if attention_mask is not None and combined_attention_mask is not None:
|
| 956 |
+
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
| 957 |
+
combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask(
|
| 958 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
# expand encoder attention mask
|
| 962 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
| 963 |
+
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
| 964 |
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
| 965 |
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
# optional intermediate hidden states
|
| 969 |
+
intermediate = () if self.config.auxiliary_loss else None
|
| 970 |
+
|
| 971 |
+
# decoder layers
|
| 972 |
+
all_hidden_states = () if output_hidden_states else None
|
| 973 |
+
all_self_attns = () if output_attentions else None
|
| 974 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
| 975 |
+
|
| 976 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 977 |
+
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
| 978 |
+
if output_hidden_states:
|
| 979 |
+
all_hidden_states += (hidden_states,)
|
| 980 |
+
if self.training:
|
| 981 |
+
dropout_probability = torch.rand([])
|
| 982 |
+
if dropout_probability < self.layerdrop:
|
| 983 |
+
continue
|
| 984 |
+
|
| 985 |
+
layer_outputs = decoder_layer(
|
| 986 |
+
hidden_states,
|
| 987 |
+
combined_attention_mask,
|
| 988 |
+
object_queries,
|
| 989 |
+
query_position_embeddings,
|
| 990 |
+
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
| 991 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 992 |
+
output_attentions=output_attentions,
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
hidden_states = layer_outputs[0]
|
| 996 |
+
|
| 997 |
+
if self.config.auxiliary_loss:
|
| 998 |
+
hidden_states = self.layernorm(hidden_states)
|
| 999 |
+
intermediate += (hidden_states,)
|
| 1000 |
+
|
| 1001 |
+
if output_attentions:
|
| 1002 |
+
all_self_attns += (layer_outputs[1],)
|
| 1003 |
+
|
| 1004 |
+
if encoder_hidden_states is not None:
|
| 1005 |
+
all_cross_attentions += (layer_outputs[2],)
|
| 1006 |
+
|
| 1007 |
+
# finally, apply layernorm
|
| 1008 |
+
hidden_states = self.layernorm(hidden_states)
|
| 1009 |
+
|
| 1010 |
+
# add hidden states from the last decoder layer
|
| 1011 |
+
if output_hidden_states:
|
| 1012 |
+
all_hidden_states += (hidden_states,)
|
| 1013 |
+
|
| 1014 |
+
# stack intermediate decoder activations
|
| 1015 |
+
if self.config.auxiliary_loss:
|
| 1016 |
+
intermediate = torch.stack(intermediate)
|
| 1017 |
+
|
| 1018 |
+
if not return_dict:
|
| 1019 |
+
return tuple(
|
| 1020 |
+
v
|
| 1021 |
+
for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
|
| 1022 |
+
if v is not None
|
| 1023 |
+
)
|
| 1024 |
+
return DetrDecoderOutput(
|
| 1025 |
+
last_hidden_state=hidden_states,
|
| 1026 |
+
hidden_states=all_hidden_states,
|
| 1027 |
+
attentions=all_self_attns,
|
| 1028 |
+
cross_attentions=all_cross_attentions,
|
| 1029 |
+
intermediate_hidden_states=intermediate,
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
@auto_docstring(
|
| 1034 |
+
custom_intro="""
|
| 1035 |
+
The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
|
| 1036 |
+
any specific head on top.
|
| 1037 |
+
"""
|
| 1038 |
+
)
|
| 1039 |
+
class DetrModel(DetrPreTrainedModel):
|
| 1040 |
+
def __init__(self, config: DetrConfig):
|
| 1041 |
+
super().__init__(config)
|
| 1042 |
+
|
| 1043 |
+
# Create backbone + positional encoding
|
| 1044 |
+
backbone = DetrConvEncoder(config)
|
| 1045 |
+
object_queries = build_position_encoding(config)
|
| 1046 |
+
self.backbone = DetrConvModel(backbone, object_queries)
|
| 1047 |
+
|
| 1048 |
+
# Create projection layer
|
| 1049 |
+
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
| 1050 |
+
|
| 1051 |
+
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
|
| 1052 |
+
|
| 1053 |
+
self.encoder = DetrEncoder(config)
|
| 1054 |
+
self.decoder = DetrDecoder(config)
|
| 1055 |
+
|
| 1056 |
+
# Initialize weights and apply final processing
|
| 1057 |
+
self.post_init()
|
| 1058 |
+
|
| 1059 |
+
def get_encoder(self):
|
| 1060 |
+
return self.encoder
|
| 1061 |
+
|
| 1062 |
+
def freeze_backbone(self):
|
| 1063 |
+
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
| 1064 |
+
param.requires_grad_(False)
|
| 1065 |
+
|
| 1066 |
+
def unfreeze_backbone(self):
|
| 1067 |
+
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
| 1068 |
+
param.requires_grad_(True)
|
| 1069 |
+
|
| 1070 |
+
@auto_docstring
|
| 1071 |
+
def forward(
|
| 1072 |
+
self,
|
| 1073 |
+
pixel_values: torch.FloatTensor,
|
| 1074 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1075 |
+
decoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 1076 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1077 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1078 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1079 |
+
output_attentions: Optional[bool] = None,
|
| 1080 |
+
output_hidden_states: Optional[bool] = None,
|
| 1081 |
+
return_dict: Optional[bool] = None,
|
| 1082 |
+
) -> Union[tuple[torch.FloatTensor], DetrModelOutput]:
|
| 1083 |
+
r"""
|
| 1084 |
+
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
| 1085 |
+
Not used by default. Can be used to mask object queries.
|
| 1086 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1087 |
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
| 1088 |
+
can choose to directly pass a flattened representation of an image.
|
| 1089 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 1090 |
+
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
| 1091 |
+
embedded representation.
|
| 1092 |
+
|
| 1093 |
+
Examples:
|
| 1094 |
+
|
| 1095 |
+
```python
|
| 1096 |
+
>>> from transformers import AutoImageProcessor, DetrModel
|
| 1097 |
+
>>> from PIL import Image
|
| 1098 |
+
>>> import requests
|
| 1099 |
+
|
| 1100 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1101 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1102 |
+
|
| 1103 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
| 1104 |
+
>>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
|
| 1105 |
+
|
| 1106 |
+
>>> # prepare image for the model
|
| 1107 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1108 |
+
|
| 1109 |
+
>>> # forward pass
|
| 1110 |
+
>>> outputs = model(**inputs)
|
| 1111 |
+
|
| 1112 |
+
>>> # the last hidden states are the final query embeddings of the Transformer decoder
|
| 1113 |
+
>>> # these are of shape (batch_size, num_queries, hidden_size)
|
| 1114 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 1115 |
+
>>> list(last_hidden_states.shape)
|
| 1116 |
+
[1, 100, 256]
|
| 1117 |
+
```"""
|
| 1118 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1119 |
+
output_hidden_states = (
|
| 1120 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1121 |
+
)
|
| 1122 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1123 |
+
|
| 1124 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 1125 |
+
device = pixel_values.device
|
| 1126 |
+
|
| 1127 |
+
if pixel_mask is None:
|
| 1128 |
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
| 1129 |
+
|
| 1130 |
+
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
| 1131 |
+
# pixel_values should be of shape (batch_size, num_channels, height, width)
|
| 1132 |
+
# pixel_mask should be of shape (batch_size, height, width)
|
| 1133 |
+
features, object_queries_list = self.backbone(pixel_values, pixel_mask)
|
| 1134 |
+
|
| 1135 |
+
# get final feature map and downsampled mask
|
| 1136 |
+
feature_map, mask = features[-1]
|
| 1137 |
+
|
| 1138 |
+
if mask is None:
|
| 1139 |
+
raise ValueError("Backbone does not return downsampled pixel mask")
|
| 1140 |
+
|
| 1141 |
+
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
| 1142 |
+
projected_feature_map = self.input_projection(feature_map)
|
| 1143 |
+
|
| 1144 |
+
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
| 1145 |
+
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
| 1146 |
+
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
| 1147 |
+
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
| 1148 |
+
|
| 1149 |
+
flattened_mask = mask.flatten(1)
|
| 1150 |
+
|
| 1151 |
+
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
| 1152 |
+
# flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
|
| 1153 |
+
# flattened_mask is a Tensor of shape (batch_size, height*width)
|
| 1154 |
+
if encoder_outputs is None:
|
| 1155 |
+
encoder_outputs = self.encoder(
|
| 1156 |
+
inputs_embeds=flattened_features,
|
| 1157 |
+
attention_mask=flattened_mask,
|
| 1158 |
+
object_queries=object_queries,
|
| 1159 |
+
output_attentions=output_attentions,
|
| 1160 |
+
output_hidden_states=output_hidden_states,
|
| 1161 |
+
return_dict=return_dict,
|
| 1162 |
+
)
|
| 1163 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
| 1164 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 1165 |
+
encoder_outputs = BaseModelOutput(
|
| 1166 |
+
last_hidden_state=encoder_outputs[0],
|
| 1167 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
| 1168 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
# Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
|
| 1172 |
+
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
| 1173 |
+
queries = torch.zeros_like(query_position_embeddings)
|
| 1174 |
+
|
| 1175 |
+
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
| 1176 |
+
decoder_outputs = self.decoder(
|
| 1177 |
+
inputs_embeds=queries,
|
| 1178 |
+
attention_mask=None,
|
| 1179 |
+
object_queries=object_queries,
|
| 1180 |
+
query_position_embeddings=query_position_embeddings,
|
| 1181 |
+
encoder_hidden_states=encoder_outputs[0],
|
| 1182 |
+
encoder_attention_mask=flattened_mask,
|
| 1183 |
+
output_attentions=output_attentions,
|
| 1184 |
+
output_hidden_states=output_hidden_states,
|
| 1185 |
+
return_dict=return_dict,
|
| 1186 |
+
)
|
| 1187 |
+
|
| 1188 |
+
if not return_dict:
|
| 1189 |
+
return decoder_outputs + encoder_outputs
|
| 1190 |
+
|
| 1191 |
+
return DetrModelOutput(
|
| 1192 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 1193 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 1194 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 1195 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 1196 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1197 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 1198 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 1199 |
+
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
| 1204 |
+
class DetrMLPPredictionHead(nn.Module):
|
| 1205 |
+
"""
|
| 1206 |
+
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
| 1207 |
+
height and width of a bounding box w.r.t. an image.
|
| 1208 |
+
|
| 1209 |
+
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
| 1210 |
+
|
| 1211 |
+
"""
|
| 1212 |
+
|
| 1213 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 1214 |
+
super().__init__()
|
| 1215 |
+
self.num_layers = num_layers
|
| 1216 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 1217 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 1218 |
+
|
| 1219 |
+
def forward(self, x):
|
| 1220 |
+
for i, layer in enumerate(self.layers):
|
| 1221 |
+
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 1222 |
+
return x
|
| 1223 |
+
|
| 1224 |
+
|
| 1225 |
+
@auto_docstring(
|
| 1226 |
+
custom_intro="""
|
| 1227 |
+
DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
|
| 1228 |
+
such as COCO detection.
|
| 1229 |
+
"""
|
| 1230 |
+
)
|
| 1231 |
+
class DetrForObjectDetection(DetrPreTrainedModel):
|
| 1232 |
+
def __init__(self, config: DetrConfig):
|
| 1233 |
+
super().__init__(config)
|
| 1234 |
+
|
| 1235 |
+
# DETR encoder-decoder model
|
| 1236 |
+
self.model = DetrModel(config)
|
| 1237 |
+
|
| 1238 |
+
# Object detection heads
|
| 1239 |
+
self.class_labels_classifier = nn.Linear(
|
| 1240 |
+
config.d_model, config.num_labels + 1
|
| 1241 |
+
) # We add one for the "no object" class
|
| 1242 |
+
self.bbox_predictor = DetrMLPPredictionHead(
|
| 1243 |
+
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
# Initialize weights and apply final processing
|
| 1247 |
+
self.post_init()
|
| 1248 |
+
|
| 1249 |
+
@auto_docstring
|
| 1250 |
+
def forward(
|
| 1251 |
+
self,
|
| 1252 |
+
pixel_values: torch.FloatTensor,
|
| 1253 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1254 |
+
decoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 1255 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1256 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1257 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1258 |
+
labels: Optional[list[dict]] = None,
|
| 1259 |
+
output_attentions: Optional[bool] = None,
|
| 1260 |
+
output_hidden_states: Optional[bool] = None,
|
| 1261 |
+
return_dict: Optional[bool] = None,
|
| 1262 |
+
) -> Union[tuple[torch.FloatTensor], DetrObjectDetectionOutput]:
|
| 1263 |
+
r"""
|
| 1264 |
+
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
| 1265 |
+
Not used by default. Can be used to mask object queries.
|
| 1266 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1267 |
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
| 1268 |
+
can choose to directly pass a flattened representation of an image.
|
| 1269 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 1270 |
+
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
| 1271 |
+
embedded representation.
|
| 1272 |
+
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
| 1273 |
+
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
| 1274 |
+
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
| 1275 |
+
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
| 1276 |
+
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
| 1277 |
+
|
| 1278 |
+
Examples:
|
| 1279 |
+
|
| 1280 |
+
```python
|
| 1281 |
+
>>> from transformers import AutoImageProcessor, DetrForObjectDetection
|
| 1282 |
+
>>> import torch
|
| 1283 |
+
>>> from PIL import Image
|
| 1284 |
+
>>> import requests
|
| 1285 |
+
|
| 1286 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1287 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1288 |
+
|
| 1289 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
| 1290 |
+
>>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
| 1291 |
+
|
| 1292 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1293 |
+
>>> outputs = model(**inputs)
|
| 1294 |
+
|
| 1295 |
+
>>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
|
| 1296 |
+
>>> target_sizes = torch.tensor([image.size[::-1]])
|
| 1297 |
+
>>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
|
| 1298 |
+
... 0
|
| 1299 |
+
... ]
|
| 1300 |
+
|
| 1301 |
+
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
| 1302 |
+
... box = [round(i, 2) for i in box.tolist()]
|
| 1303 |
+
... print(
|
| 1304 |
+
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
| 1305 |
+
... f"{round(score.item(), 3)} at location {box}"
|
| 1306 |
+
... )
|
| 1307 |
+
Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
|
| 1308 |
+
Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
|
| 1309 |
+
Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
|
| 1310 |
+
Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
|
| 1311 |
+
Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
|
| 1312 |
+
```"""
|
| 1313 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1314 |
+
|
| 1315 |
+
# First, sent images through DETR base model to obtain encoder + decoder outputs
|
| 1316 |
+
outputs = self.model(
|
| 1317 |
+
pixel_values,
|
| 1318 |
+
pixel_mask=pixel_mask,
|
| 1319 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 1320 |
+
encoder_outputs=encoder_outputs,
|
| 1321 |
+
inputs_embeds=inputs_embeds,
|
| 1322 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1323 |
+
output_attentions=output_attentions,
|
| 1324 |
+
output_hidden_states=output_hidden_states,
|
| 1325 |
+
return_dict=return_dict,
|
| 1326 |
+
)
|
| 1327 |
+
|
| 1328 |
+
sequence_output = outputs[0]
|
| 1329 |
+
|
| 1330 |
+
# class logits + predicted bounding boxes
|
| 1331 |
+
logits = self.class_labels_classifier(sequence_output)
|
| 1332 |
+
pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
|
| 1333 |
+
|
| 1334 |
+
loss, loss_dict, auxiliary_outputs = None, None, None
|
| 1335 |
+
if labels is not None:
|
| 1336 |
+
outputs_class, outputs_coord = None, None
|
| 1337 |
+
if self.config.auxiliary_loss:
|
| 1338 |
+
intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
|
| 1339 |
+
outputs_class = self.class_labels_classifier(intermediate)
|
| 1340 |
+
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
|
| 1341 |
+
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
| 1342 |
+
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
| 1343 |
+
)
|
| 1344 |
+
|
| 1345 |
+
if not return_dict:
|
| 1346 |
+
if auxiliary_outputs is not None:
|
| 1347 |
+
output = (logits, pred_boxes) + auxiliary_outputs + outputs
|
| 1348 |
+
else:
|
| 1349 |
+
output = (logits, pred_boxes) + outputs
|
| 1350 |
+
return ((loss, loss_dict) + output) if loss is not None else output
|
| 1351 |
+
|
| 1352 |
+
return DetrObjectDetectionOutput(
|
| 1353 |
+
loss=loss,
|
| 1354 |
+
loss_dict=loss_dict,
|
| 1355 |
+
logits=logits,
|
| 1356 |
+
pred_boxes=pred_boxes,
|
| 1357 |
+
auxiliary_outputs=auxiliary_outputs,
|
| 1358 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 1359 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 1360 |
+
decoder_attentions=outputs.decoder_attentions,
|
| 1361 |
+
cross_attentions=outputs.cross_attentions,
|
| 1362 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 1363 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 1364 |
+
encoder_attentions=outputs.encoder_attentions,
|
| 1365 |
+
)
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
@auto_docstring(
|
| 1369 |
+
custom_intro="""
|
| 1370 |
+
DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
|
| 1371 |
+
such as COCO panoptic.
|
| 1372 |
+
"""
|
| 1373 |
+
)
|
| 1374 |
+
class DetrForSegmentation(DetrPreTrainedModel):
|
| 1375 |
+
def __init__(self, config: DetrConfig):
|
| 1376 |
+
super().__init__(config)
|
| 1377 |
+
|
| 1378 |
+
# object detection model
|
| 1379 |
+
self.detr = DetrForObjectDetection(config)
|
| 1380 |
+
|
| 1381 |
+
# segmentation head
|
| 1382 |
+
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
|
| 1383 |
+
intermediate_channel_sizes = self.detr.model.backbone.conv_encoder.intermediate_channel_sizes
|
| 1384 |
+
|
| 1385 |
+
self.mask_head = DetrMaskHeadSmallConv(
|
| 1386 |
+
hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
self.bbox_attention = DetrMHAttentionMap(
|
| 1390 |
+
hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
|
| 1391 |
+
)
|
| 1392 |
+
# Initialize weights and apply final processing
|
| 1393 |
+
self.post_init()
|
| 1394 |
+
|
| 1395 |
+
@auto_docstring
|
| 1396 |
+
def forward(
|
| 1397 |
+
self,
|
| 1398 |
+
pixel_values: torch.FloatTensor,
|
| 1399 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1400 |
+
decoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 1401 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1402 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1403 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1404 |
+
labels: Optional[list[dict]] = None,
|
| 1405 |
+
output_attentions: Optional[bool] = None,
|
| 1406 |
+
output_hidden_states: Optional[bool] = None,
|
| 1407 |
+
return_dict: Optional[bool] = None,
|
| 1408 |
+
) -> Union[tuple[torch.FloatTensor], DetrSegmentationOutput]:
|
| 1409 |
+
r"""
|
| 1410 |
+
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
| 1411 |
+
Not used by default. Can be used to mask object queries.
|
| 1412 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1413 |
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
| 1414 |
+
can choose to directly pass a flattened representation of an image.
|
| 1415 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 1416 |
+
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
| 1417 |
+
embedded representation.
|
| 1418 |
+
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
| 1419 |
+
Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
|
| 1420 |
+
dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
|
| 1421 |
+
bounding boxes and segmentation masks of an image in the batch respectively). The class labels themselves
|
| 1422 |
+
should be a `torch.LongTensor` of len `(number of bounding boxes in the image,)`, the boxes a
|
| 1423 |
+
`torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)` and the masks a
|
| 1424 |
+
`torch.FloatTensor` of shape `(number of bounding boxes in the image, height, width)`.
|
| 1425 |
+
|
| 1426 |
+
Examples:
|
| 1427 |
+
|
| 1428 |
+
```python
|
| 1429 |
+
>>> import io
|
| 1430 |
+
>>> import requests
|
| 1431 |
+
>>> from PIL import Image
|
| 1432 |
+
>>> import torch
|
| 1433 |
+
>>> import numpy
|
| 1434 |
+
|
| 1435 |
+
>>> from transformers import AutoImageProcessor, DetrForSegmentation
|
| 1436 |
+
>>> from transformers.image_transforms import rgb_to_id
|
| 1437 |
+
|
| 1438 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1439 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1440 |
+
|
| 1441 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
|
| 1442 |
+
>>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
|
| 1443 |
+
|
| 1444 |
+
>>> # prepare image for the model
|
| 1445 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1446 |
+
|
| 1447 |
+
>>> # forward pass
|
| 1448 |
+
>>> outputs = model(**inputs)
|
| 1449 |
+
|
| 1450 |
+
>>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
|
| 1451 |
+
>>> # Segmentation results are returned as a list of dictionaries
|
| 1452 |
+
>>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
|
| 1453 |
+
|
| 1454 |
+
>>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
|
| 1455 |
+
>>> panoptic_seg = result[0]["segmentation"]
|
| 1456 |
+
>>> # Get prediction score and segment_id to class_id mapping of each segment
|
| 1457 |
+
>>> panoptic_segments_info = result[0]["segments_info"]
|
| 1458 |
+
```"""
|
| 1459 |
+
|
| 1460 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1461 |
+
|
| 1462 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 1463 |
+
device = pixel_values.device
|
| 1464 |
+
|
| 1465 |
+
if pixel_mask is None:
|
| 1466 |
+
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
| 1467 |
+
|
| 1468 |
+
# First, get list of feature maps and position embeddings
|
| 1469 |
+
features, object_queries_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
|
| 1470 |
+
|
| 1471 |
+
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
| 1472 |
+
feature_map, mask = features[-1]
|
| 1473 |
+
batch_size, num_channels, height, width = feature_map.shape
|
| 1474 |
+
projected_feature_map = self.detr.model.input_projection(feature_map)
|
| 1475 |
+
|
| 1476 |
+
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
| 1477 |
+
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
| 1478 |
+
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
| 1479 |
+
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
| 1480 |
+
|
| 1481 |
+
flattened_mask = mask.flatten(1)
|
| 1482 |
+
|
| 1483 |
+
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
| 1484 |
+
# flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
|
| 1485 |
+
# flattened_mask is a Tensor of shape (batch_size, height*width)
|
| 1486 |
+
if encoder_outputs is None:
|
| 1487 |
+
encoder_outputs = self.detr.model.encoder(
|
| 1488 |
+
inputs_embeds=flattened_features,
|
| 1489 |
+
attention_mask=flattened_mask,
|
| 1490 |
+
object_queries=object_queries,
|
| 1491 |
+
output_attentions=output_attentions,
|
| 1492 |
+
output_hidden_states=output_hidden_states,
|
| 1493 |
+
return_dict=return_dict,
|
| 1494 |
+
)
|
| 1495 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
| 1496 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 1497 |
+
encoder_outputs = BaseModelOutput(
|
| 1498 |
+
last_hidden_state=encoder_outputs[0],
|
| 1499 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
| 1500 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
| 1501 |
+
)
|
| 1502 |
+
|
| 1503 |
+
# Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
|
| 1504 |
+
query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
|
| 1505 |
+
batch_size, 1, 1
|
| 1506 |
+
)
|
| 1507 |
+
queries = torch.zeros_like(query_position_embeddings)
|
| 1508 |
+
|
| 1509 |
+
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
| 1510 |
+
decoder_outputs = self.detr.model.decoder(
|
| 1511 |
+
inputs_embeds=queries,
|
| 1512 |
+
attention_mask=None,
|
| 1513 |
+
object_queries=object_queries,
|
| 1514 |
+
query_position_embeddings=query_position_embeddings,
|
| 1515 |
+
encoder_hidden_states=encoder_outputs[0],
|
| 1516 |
+
encoder_attention_mask=flattened_mask,
|
| 1517 |
+
output_attentions=output_attentions,
|
| 1518 |
+
output_hidden_states=output_hidden_states,
|
| 1519 |
+
return_dict=return_dict,
|
| 1520 |
+
)
|
| 1521 |
+
|
| 1522 |
+
sequence_output = decoder_outputs[0]
|
| 1523 |
+
|
| 1524 |
+
# Sixth, compute logits, pred_boxes and pred_masks
|
| 1525 |
+
logits = self.detr.class_labels_classifier(sequence_output)
|
| 1526 |
+
pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
|
| 1527 |
+
|
| 1528 |
+
memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width)
|
| 1529 |
+
mask = flattened_mask.view(batch_size, height, width)
|
| 1530 |
+
|
| 1531 |
+
# FIXME h_boxes takes the last one computed, keep this in mind
|
| 1532 |
+
# important: we need to reverse the mask, since in the original implementation the mask works reversed
|
| 1533 |
+
# bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
|
| 1534 |
+
bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)
|
| 1535 |
+
|
| 1536 |
+
seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]])
|
| 1537 |
+
|
| 1538 |
+
pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
|
| 1539 |
+
|
| 1540 |
+
loss, loss_dict, auxiliary_outputs = None, None, None
|
| 1541 |
+
if labels is not None:
|
| 1542 |
+
outputs_class, outputs_coord = None, None
|
| 1543 |
+
if self.config.auxiliary_loss:
|
| 1544 |
+
intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
|
| 1545 |
+
outputs_class = self.detr.class_labels_classifier(intermediate)
|
| 1546 |
+
outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
|
| 1547 |
+
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
| 1548 |
+
logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
|
| 1549 |
+
)
|
| 1550 |
+
|
| 1551 |
+
if not return_dict:
|
| 1552 |
+
if auxiliary_outputs is not None:
|
| 1553 |
+
output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
|
| 1554 |
+
else:
|
| 1555 |
+
output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
|
| 1556 |
+
return ((loss, loss_dict) + output) if loss is not None else output
|
| 1557 |
+
|
| 1558 |
+
return DetrSegmentationOutput(
|
| 1559 |
+
loss=loss,
|
| 1560 |
+
loss_dict=loss_dict,
|
| 1561 |
+
logits=logits,
|
| 1562 |
+
pred_boxes=pred_boxes,
|
| 1563 |
+
pred_masks=pred_masks,
|
| 1564 |
+
auxiliary_outputs=auxiliary_outputs,
|
| 1565 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 1566 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 1567 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 1568 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 1569 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1570 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 1571 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 1572 |
+
)
|
| 1573 |
+
|
| 1574 |
+
|
| 1575 |
+
def _expand(tensor, length: int):
|
| 1576 |
+
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
| 1577 |
+
|
| 1578 |
+
|
| 1579 |
+
# taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
|
| 1580 |
+
class DetrMaskHeadSmallConv(nn.Module):
|
| 1581 |
+
"""
|
| 1582 |
+
Simple convolutional head, using group norm. Upsampling is done using a FPN approach
|
| 1583 |
+
"""
|
| 1584 |
+
|
| 1585 |
+
def __init__(self, dim, fpn_dims, context_dim):
|
| 1586 |
+
super().__init__()
|
| 1587 |
+
|
| 1588 |
+
if dim % 8 != 0:
|
| 1589 |
+
raise ValueError(
|
| 1590 |
+
"The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
|
| 1591 |
+
" GroupNorm is set to 8"
|
| 1592 |
+
)
|
| 1593 |
+
|
| 1594 |
+
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
|
| 1595 |
+
|
| 1596 |
+
self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
|
| 1597 |
+
self.gn1 = nn.GroupNorm(8, dim)
|
| 1598 |
+
self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
|
| 1599 |
+
self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
|
| 1600 |
+
self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
|
| 1601 |
+
self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
|
| 1602 |
+
self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
|
| 1603 |
+
self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
|
| 1604 |
+
self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
|
| 1605 |
+
self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
|
| 1606 |
+
self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
|
| 1607 |
+
|
| 1608 |
+
self.dim = dim
|
| 1609 |
+
|
| 1610 |
+
self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
|
| 1611 |
+
self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
|
| 1612 |
+
self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
|
| 1613 |
+
|
| 1614 |
+
for m in self.modules():
|
| 1615 |
+
if isinstance(m, nn.Conv2d):
|
| 1616 |
+
nn.init.kaiming_uniform_(m.weight, a=1)
|
| 1617 |
+
nn.init.constant_(m.bias, 0)
|
| 1618 |
+
|
| 1619 |
+
def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]):
|
| 1620 |
+
# here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with
|
| 1621 |
+
# the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
|
| 1622 |
+
# We expand the projected feature map to match the number of heads.
|
| 1623 |
+
x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
|
| 1624 |
+
|
| 1625 |
+
x = self.lay1(x)
|
| 1626 |
+
x = self.gn1(x)
|
| 1627 |
+
x = nn.functional.relu(x)
|
| 1628 |
+
x = self.lay2(x)
|
| 1629 |
+
x = self.gn2(x)
|
| 1630 |
+
x = nn.functional.relu(x)
|
| 1631 |
+
|
| 1632 |
+
cur_fpn = self.adapter1(fpns[0])
|
| 1633 |
+
if cur_fpn.size(0) != x.size(0):
|
| 1634 |
+
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
| 1635 |
+
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
| 1636 |
+
x = self.lay3(x)
|
| 1637 |
+
x = self.gn3(x)
|
| 1638 |
+
x = nn.functional.relu(x)
|
| 1639 |
+
|
| 1640 |
+
cur_fpn = self.adapter2(fpns[1])
|
| 1641 |
+
if cur_fpn.size(0) != x.size(0):
|
| 1642 |
+
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
| 1643 |
+
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
| 1644 |
+
x = self.lay4(x)
|
| 1645 |
+
x = self.gn4(x)
|
| 1646 |
+
x = nn.functional.relu(x)
|
| 1647 |
+
|
| 1648 |
+
cur_fpn = self.adapter3(fpns[2])
|
| 1649 |
+
if cur_fpn.size(0) != x.size(0):
|
| 1650 |
+
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
|
| 1651 |
+
x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
| 1652 |
+
x = self.lay5(x)
|
| 1653 |
+
x = self.gn5(x)
|
| 1654 |
+
x = nn.functional.relu(x)
|
| 1655 |
+
|
| 1656 |
+
x = self.out_lay(x)
|
| 1657 |
+
return x
|
| 1658 |
+
|
| 1659 |
+
|
| 1660 |
+
class DetrMHAttentionMap(nn.Module):
|
| 1661 |
+
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
| 1662 |
+
|
| 1663 |
+
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
|
| 1664 |
+
super().__init__()
|
| 1665 |
+
self.num_heads = num_heads
|
| 1666 |
+
self.hidden_dim = hidden_dim
|
| 1667 |
+
self.dropout = nn.Dropout(dropout)
|
| 1668 |
+
|
| 1669 |
+
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
| 1670 |
+
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
| 1671 |
+
|
| 1672 |
+
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
|
| 1673 |
+
|
| 1674 |
+
def forward(self, q, k, mask: Optional[Tensor] = None):
|
| 1675 |
+
q = self.q_linear(q)
|
| 1676 |
+
k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
|
| 1677 |
+
queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
|
| 1678 |
+
keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
|
| 1679 |
+
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
|
| 1680 |
+
|
| 1681 |
+
if mask is not None:
|
| 1682 |
+
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
| 1683 |
+
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
| 1684 |
+
weights = self.dropout(weights)
|
| 1685 |
+
return weights
|
| 1686 |
+
|
| 1687 |
+
|
| 1688 |
+
__all__ = [
|
| 1689 |
+
"DetrForObjectDetection",
|
| 1690 |
+
"DetrForSegmentation",
|
| 1691 |
+
"DetrModel",
|
| 1692 |
+
"DetrPreTrainedModel",
|
| 1693 |
+
]
|
phivenv/Lib/site-packages/transformers/models/dia/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_dia import *
|
| 22 |
+
from .feature_extraction_dia import *
|
| 23 |
+
from .generation_dia import *
|
| 24 |
+
from .modeling_dia import *
|
| 25 |
+
from .processing_dia import *
|
| 26 |
+
from .tokenization_dia import *
|
| 27 |
+
else:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
_file = globals()["__file__"]
|
| 31 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (633 Bytes). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/configuration_dia.cpython-39.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/feature_extraction_dia.cpython-39.pyc
ADDED
|
Binary file (6.62 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/generation_dia.cpython-39.pyc
ADDED
|
Binary file (9.73 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/modeling_dia.cpython-39.pyc
ADDED
|
Binary file (28 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/modular_dia.cpython-39.pyc
ADDED
|
Binary file (21.4 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/processing_dia.cpython-39.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/dia/__pycache__/tokenization_dia.cpython-39.pyc
ADDED
|
Binary file (4.4 kB). View file
|
|
|
phivenv/Lib/site-packages/transformers/models/dia/configuration_dia.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Dia model configuration"""
|
| 16 |
+
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...modeling_rope_utils import rope_config_validation
|
| 21 |
+
from ...utils import logging
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DiaEncoderConfig(PretrainedConfig):
|
| 28 |
+
r"""
|
| 29 |
+
This is the configuration class to store the configuration of a [`DiaEncoder`]. It is used to instantiate a Dia
|
| 30 |
+
encoder according to the specified arguments, defining the encoder architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
| 37 |
+
The maximum sequence length that this model might ever be used with.
|
| 38 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 39 |
+
Number of hidden layers in the Transformer encoder.
|
| 40 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
| 41 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 42 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 43 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 44 |
+
num_key_value_heads (`int`, *optional*, defaults to 16):
|
| 45 |
+
Number of key and value heads for each attention layer in the Transformer encoder.
|
| 46 |
+
head_dim (`int`, *optional*, defaults to 128):
|
| 47 |
+
Dimensionality of the attention head.
|
| 48 |
+
intermediate_size (`int`, *optional*, defaults to 4096):
|
| 49 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 50 |
+
norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 51 |
+
The epsilon used by the normalization layers.
|
| 52 |
+
vocab_size (`int`, *optional*, defaults to 256):
|
| 53 |
+
Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
|
| 54 |
+
`inputs_ids` passed when calling [`DiaModel`].
|
| 55 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 56 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 57 |
+
`"relu"`, `"swish"` and `"gelu_new"` are supported.
|
| 58 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 59 |
+
The base period of the RoPE embeddings.
|
| 60 |
+
rope_scaling (`dict`, *optional*):
|
| 61 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 62 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 63 |
+
accordingly.
|
| 64 |
+
Expected contents:
|
| 65 |
+
`rope_type` (`str`):
|
| 66 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 67 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 68 |
+
`factor` (`float`, *optional*):
|
| 69 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 70 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 71 |
+
original maximum pre-trained length.
|
| 72 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 73 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 74 |
+
pretraining.
|
| 75 |
+
`attention_factor` (`float`, *optional*):
|
| 76 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 77 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 78 |
+
`factor` field to infer the suggested value.
|
| 79 |
+
`beta_fast` (`float`, *optional*):
|
| 80 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 81 |
+
ramp function. If unspecified, it defaults to 32.
|
| 82 |
+
`beta_slow` (`float`, *optional*):
|
| 83 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 84 |
+
ramp function. If unspecified, it defaults to 1.
|
| 85 |
+
`short_factor` (`List[float]`, *optional*):
|
| 86 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 87 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 88 |
+
size divided by the number of attention heads divided by 2
|
| 89 |
+
`long_factor` (`List[float]`, *optional*):
|
| 90 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 91 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 92 |
+
size divided by the number of attention heads divided by 2
|
| 93 |
+
`low_freq_factor` (`float`, *optional*):
|
| 94 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 95 |
+
`high_freq_factor` (`float`, *optional*):
|
| 96 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 97 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 98 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
model_type = "dia_encoder"
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
max_position_embeddings: int = 1024,
|
| 106 |
+
num_hidden_layers: int = 12,
|
| 107 |
+
hidden_size: int = 1024,
|
| 108 |
+
num_attention_heads: int = 16,
|
| 109 |
+
num_key_value_heads: int = 16,
|
| 110 |
+
head_dim: int = 128,
|
| 111 |
+
intermediate_size: int = 4096,
|
| 112 |
+
norm_eps: float = 1e-5,
|
| 113 |
+
vocab_size: int = 256,
|
| 114 |
+
hidden_act: str = "silu",
|
| 115 |
+
rope_theta: float = 10000.0,
|
| 116 |
+
rope_scaling: Optional[dict] = None,
|
| 117 |
+
initializer_range: float = 0.02,
|
| 118 |
+
**kwargs,
|
| 119 |
+
):
|
| 120 |
+
self.max_position_embeddings = max_position_embeddings
|
| 121 |
+
self.num_hidden_layers = num_hidden_layers
|
| 122 |
+
self.hidden_size = hidden_size
|
| 123 |
+
self.intermediate_size = intermediate_size
|
| 124 |
+
self.num_attention_heads = num_attention_heads
|
| 125 |
+
self.head_dim = head_dim
|
| 126 |
+
self.norm_eps = norm_eps
|
| 127 |
+
self.vocab_size = vocab_size
|
| 128 |
+
self.num_key_value_heads = num_key_value_heads
|
| 129 |
+
self.hidden_act = hidden_act
|
| 130 |
+
self.rope_theta = rope_theta
|
| 131 |
+
self.rope_scaling = rope_scaling
|
| 132 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 133 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
| 134 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 135 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 136 |
+
rope_config_validation(self)
|
| 137 |
+
self.initializer_range = initializer_range
|
| 138 |
+
super().__init__(**kwargs)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class DiaDecoderConfig(PretrainedConfig):
|
| 142 |
+
r"""
|
| 143 |
+
This is the configuration class to store the configuration of a [`DiaDecoder`]. It is used to instantiate a Dia
|
| 144 |
+
decoder according to the specified arguments, defining the decoder architecture.
|
| 145 |
+
|
| 146 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 147 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
max_position_embeddings (`int`, *optional*, defaults to 3072):
|
| 151 |
+
The maximum sequence length that this model might ever be used with.
|
| 152 |
+
num_hidden_layers (`int`, *optional*, defaults to 18):
|
| 153 |
+
Number of hidden layers in the Transformer decoder.
|
| 154 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
| 155 |
+
Dimensionality of the decoder layers and the pooler layer.
|
| 156 |
+
intermediate_size (`int`, *optional*, defaults to 8192):
|
| 157 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder.
|
| 158 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 159 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 160 |
+
num_key_value_heads (`int`, *optional*, defaults to 4):
|
| 161 |
+
Number of key and value heads for each attention layer in the Transformer decoder.
|
| 162 |
+
head_dim (`int`, *optional*, defaults to 128):
|
| 163 |
+
Dimensionality of the attention head.
|
| 164 |
+
cross_num_attention_heads (`int`, *optional*, defaults to 16):
|
| 165 |
+
Number of attention heads for each cross-attention layer in the Transformer decoder.
|
| 166 |
+
cross_head_dim (`int`, *optional*, defaults to 128):
|
| 167 |
+
Dimensionality of the cross-attention head.
|
| 168 |
+
cross_num_key_value_heads (`int`, *optional*, defaults to 16):
|
| 169 |
+
Number of key and value heads for each cross-attention layer in the Transformer decoder.
|
| 170 |
+
cross_hidden_size (`int`, *optional*, defaults to 1024):
|
| 171 |
+
Dimensionality of the cross-attention layers.
|
| 172 |
+
norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 173 |
+
The epsilon used by the normalization layers.
|
| 174 |
+
vocab_size (`int`, *optional*, defaults to 1028):
|
| 175 |
+
Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
|
| 176 |
+
`inputs_ids` passed when calling [`DiaModel`].
|
| 177 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 178 |
+
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, `"relu"`,
|
| 179 |
+
`"swish"` and `"gelu_new"` are supported.
|
| 180 |
+
num_channels (`int`, *optional*, defaults to 9):
|
| 181 |
+
Number of channels for the Dia decoder.
|
| 182 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 183 |
+
The base period of the RoPE embeddings.
|
| 184 |
+
rope_scaling (`dict`, *optional*):
|
| 185 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 186 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 187 |
+
accordingly.
|
| 188 |
+
Expected contents:
|
| 189 |
+
`rope_type` (`str`):
|
| 190 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 191 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 192 |
+
`factor` (`float`, *optional*):
|
| 193 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 194 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 195 |
+
original maximum pre-trained length.
|
| 196 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 197 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 198 |
+
pretraining.
|
| 199 |
+
`attention_factor` (`float`, *optional*):
|
| 200 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 201 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 202 |
+
`factor` field to infer the suggested value.
|
| 203 |
+
`beta_fast` (`float`, *optional*):
|
| 204 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 205 |
+
ramp function. If unspecified, it defaults to 32.
|
| 206 |
+
`beta_slow` (`float`, *optional*):
|
| 207 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 208 |
+
ramp function. If unspecified, it defaults to 1.
|
| 209 |
+
`short_factor` (`List[float]`, *optional*):
|
| 210 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 211 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 212 |
+
size divided by the number of attention heads divided by 2
|
| 213 |
+
`long_factor` (`List[float]`, *optional*):
|
| 214 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 215 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 216 |
+
size divided by the number of attention heads divided by 2
|
| 217 |
+
`low_freq_factor` (`float`, *optional*):
|
| 218 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 219 |
+
`high_freq_factor` (`float`, *optional*):
|
| 220 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 221 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 222 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 223 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 224 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 225 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 226 |
+
Indicating that this model is part of an encoder-decoder architecture.
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
model_type = "dia_decoder"
|
| 230 |
+
|
| 231 |
+
def __init__(
|
| 232 |
+
self,
|
| 233 |
+
max_position_embeddings: int = 3072,
|
| 234 |
+
num_hidden_layers: int = 18,
|
| 235 |
+
hidden_size: int = 2048,
|
| 236 |
+
intermediate_size: int = 8192,
|
| 237 |
+
num_attention_heads: int = 16,
|
| 238 |
+
num_key_value_heads: int = 4,
|
| 239 |
+
head_dim: int = 128,
|
| 240 |
+
cross_num_attention_heads: int = 16,
|
| 241 |
+
cross_head_dim: int = 128,
|
| 242 |
+
cross_num_key_value_heads: int = 16,
|
| 243 |
+
cross_hidden_size: int = 1024,
|
| 244 |
+
norm_eps: float = 1e-5,
|
| 245 |
+
vocab_size: int = 1028,
|
| 246 |
+
hidden_act: str = "silu",
|
| 247 |
+
num_channels: int = 9,
|
| 248 |
+
rope_theta: float = 10000.0,
|
| 249 |
+
rope_scaling: Optional[dict] = None,
|
| 250 |
+
initializer_range: float = 0.02,
|
| 251 |
+
use_cache: bool = True,
|
| 252 |
+
is_encoder_decoder: bool = True,
|
| 253 |
+
**kwargs,
|
| 254 |
+
):
|
| 255 |
+
self.max_position_embeddings = max_position_embeddings
|
| 256 |
+
self.num_hidden_layers = num_hidden_layers
|
| 257 |
+
self.hidden_size = hidden_size
|
| 258 |
+
self.intermediate_size = intermediate_size
|
| 259 |
+
self.num_attention_heads = num_attention_heads
|
| 260 |
+
self.num_key_value_heads = num_key_value_heads
|
| 261 |
+
self.head_dim = head_dim
|
| 262 |
+
self.cross_num_key_value_heads = cross_num_key_value_heads
|
| 263 |
+
self.cross_num_attention_heads = cross_num_attention_heads
|
| 264 |
+
self.cross_head_dim = cross_head_dim
|
| 265 |
+
self.cross_hidden_size = cross_hidden_size
|
| 266 |
+
self.norm_eps = norm_eps
|
| 267 |
+
self.vocab_size = vocab_size
|
| 268 |
+
self.hidden_act = hidden_act
|
| 269 |
+
self.num_channels = num_channels
|
| 270 |
+
self.rope_theta = rope_theta
|
| 271 |
+
self.rope_scaling = rope_scaling
|
| 272 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 273 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
| 274 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 275 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 276 |
+
rope_config_validation(self)
|
| 277 |
+
self.initializer_range = initializer_range
|
| 278 |
+
self.use_cache = use_cache
|
| 279 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class DiaConfig(PretrainedConfig):
|
| 283 |
+
r"""
|
| 284 |
+
This is the configuration class to store the configuration of a [`DiaModel`]. It is used to instantiate a
|
| 285 |
+
Dia model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 286 |
+
with the defaults will yield a similar configuration to that of the
|
| 287 |
+
[nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) architecture.
|
| 288 |
+
|
| 289 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 290 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
encoder_config (`DiaEncoderConfig`, *optional*):
|
| 294 |
+
Configuration for the encoder part of the model. If not provided, a default `DiaEncoderConfig` will be used.
|
| 295 |
+
decoder_config (`DiaDecoderConfig`, *optional*):
|
| 296 |
+
Configuration for the decoder part of the model. If not provided, a default `DiaDecoderConfig` will be used.
|
| 297 |
+
norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 298 |
+
The epsilon used by the normalization layers.
|
| 299 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 300 |
+
Indicating that this model uses an encoder-decoder architecture.
|
| 301 |
+
pad_token_id (`int`, *optional*, defaults to 1025):
|
| 302 |
+
Padding token id.
|
| 303 |
+
eos_token_id (`int`, *optional*, defaults to 1024):
|
| 304 |
+
End of stream token id.
|
| 305 |
+
bos_token_id (`int`, *optional*, defaults to 1026):
|
| 306 |
+
Beginning of stream token id.
|
| 307 |
+
delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`):
|
| 308 |
+
The delay pattern for the decoder. The length of this list must match `decoder_config.num_channels`.
|
| 309 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 310 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 311 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 312 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 313 |
+
|
| 314 |
+
Example:
|
| 315 |
+
|
| 316 |
+
```python
|
| 317 |
+
>>> from transformers import DiaConfig, DiaModel
|
| 318 |
+
|
| 319 |
+
>>> # Initializing a DiaConfig with default values
|
| 320 |
+
>>> configuration = DiaConfig()
|
| 321 |
+
|
| 322 |
+
>>> # Initializing a DiaModel (with random weights) from the configuration
|
| 323 |
+
>>> model = DiaModel(configuration)
|
| 324 |
+
|
| 325 |
+
>>> # Accessing the model configuration
|
| 326 |
+
>>> configuration = model.config
|
| 327 |
+
```
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
model_type = "dia"
|
| 331 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 332 |
+
sub_configs = {"encoder_config": DiaEncoderConfig, "decoder_config": DiaDecoderConfig}
|
| 333 |
+
|
| 334 |
+
def __init__(
|
| 335 |
+
self,
|
| 336 |
+
encoder_config: Optional[DiaEncoderConfig] = None,
|
| 337 |
+
decoder_config: Optional[DiaDecoderConfig] = None,
|
| 338 |
+
norm_eps: float = 1e-5,
|
| 339 |
+
is_encoder_decoder: bool = True,
|
| 340 |
+
pad_token_id: int = 1025,
|
| 341 |
+
eos_token_id: int = 1024,
|
| 342 |
+
bos_token_id: int = 1026,
|
| 343 |
+
delay_pattern: Optional[list[int]] = None,
|
| 344 |
+
initializer_range: float = 0.02,
|
| 345 |
+
use_cache: bool = True,
|
| 346 |
+
**kwargs,
|
| 347 |
+
):
|
| 348 |
+
if isinstance(encoder_config, dict):
|
| 349 |
+
encoder_config = DiaEncoderConfig(**encoder_config)
|
| 350 |
+
if isinstance(decoder_config, dict):
|
| 351 |
+
decoder_config = DiaDecoderConfig(**decoder_config)
|
| 352 |
+
self.encoder_config = encoder_config if encoder_config is not None else DiaEncoderConfig()
|
| 353 |
+
self.decoder_config = decoder_config if decoder_config is not None else DiaDecoderConfig()
|
| 354 |
+
self.norm_eps = norm_eps
|
| 355 |
+
self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 8, 9, 10, 11, 12, 13, 14, 15]
|
| 356 |
+
self.initializer_range = initializer_range
|
| 357 |
+
self.use_cache = use_cache
|
| 358 |
+
|
| 359 |
+
assert self.decoder_config.num_channels == len(self.delay_pattern), (
|
| 360 |
+
"Number of channels must match delay pattern length."
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
super().__init__(
|
| 364 |
+
pad_token_id=pad_token_id,
|
| 365 |
+
eos_token_id=eos_token_id,
|
| 366 |
+
bos_token_id=bos_token_id,
|
| 367 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 368 |
+
**kwargs,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
def get_text_config(self, *args, **kwargs):
|
| 372 |
+
"""Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
|
| 373 |
+
return self.decoder_config
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
__all__ = ["DiaConfig", "DiaEncoderConfig", "DiaDecoderConfig"]
|
phivenv/Lib/site-packages/transformers/models/dia/feature_extraction_dia.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Feature extractor class for Dia"""
|
| 16 |
+
|
| 17 |
+
from typing import Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
| 22 |
+
from ...feature_extraction_utils import BatchFeature
|
| 23 |
+
from ...utils import PaddingStrategy, TensorType, logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DiaFeatureExtractor(SequenceFeatureExtractor):
|
| 30 |
+
r"""
|
| 31 |
+
Constructs an Dia feature extractor.
|
| 32 |
+
|
| 33 |
+
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
| 34 |
+
most of the main methods. Users should refer to this superclass for more information regarding those methods.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
feature_size (`int`, *optional*, defaults to 1):
|
| 38 |
+
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
|
| 39 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 40 |
+
The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
|
| 41 |
+
padding_value (`float`, *optional*, defaults to 0.0):
|
| 42 |
+
The value that is used for padding.
|
| 43 |
+
hop_length (`int`, *optional*, defaults to 512):
|
| 44 |
+
Overlap length between successive windows.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
model_input_names = ["input_values", "n_quantizers"]
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
feature_size: int = 1,
|
| 52 |
+
sampling_rate: int = 16000,
|
| 53 |
+
padding_value: float = 0.0,
|
| 54 |
+
hop_length: int = 512,
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
| 58 |
+
self.hop_length = hop_length
|
| 59 |
+
|
| 60 |
+
def __call__(
|
| 61 |
+
self,
|
| 62 |
+
raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
|
| 63 |
+
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
|
| 64 |
+
truncation: Optional[bool] = False,
|
| 65 |
+
max_length: Optional[int] = None,
|
| 66 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 67 |
+
sampling_rate: Optional[int] = None,
|
| 68 |
+
) -> BatchFeature:
|
| 69 |
+
"""
|
| 70 |
+
Main method to featurize and prepare for the model one or several sequence(s).
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
|
| 74 |
+
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
|
| 75 |
+
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
|
| 76 |
+
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
|
| 77 |
+
(`feature_size = 2`).
|
| 78 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 79 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
| 80 |
+
index) among:
|
| 81 |
+
|
| 82 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
| 83 |
+
sequence if provided).
|
| 84 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 85 |
+
acceptable input length for the model if that argument is not provided.
|
| 86 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
| 87 |
+
lengths).
|
| 88 |
+
truncation (`bool`, *optional*, defaults to `False`):
|
| 89 |
+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
| 90 |
+
max_length (`int`, *optional*):
|
| 91 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 92 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
|
| 93 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
| 94 |
+
|
| 95 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 96 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 97 |
+
- `'np'`: Return Numpy `np.ndarray` objects.
|
| 98 |
+
sampling_rate (`int`, *optional*):
|
| 99 |
+
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
|
| 100 |
+
`sampling_rate` at the forward call to prevent silent errors.
|
| 101 |
+
"""
|
| 102 |
+
if sampling_rate is not None:
|
| 103 |
+
if sampling_rate != self.sampling_rate:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
| 106 |
+
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
|
| 107 |
+
f" {self.sampling_rate} and not {sampling_rate}."
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
logger.warning(
|
| 111 |
+
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
| 112 |
+
"Failing to do so can result in silent errors that might be hard to debug."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if padding and truncation:
|
| 116 |
+
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
|
| 117 |
+
elif padding is None:
|
| 118 |
+
# by default let's pad the inputs
|
| 119 |
+
padding = True
|
| 120 |
+
|
| 121 |
+
is_batched = bool(
|
| 122 |
+
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if is_batched:
|
| 126 |
+
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
|
| 127 |
+
elif not is_batched and not isinstance(raw_audio, np.ndarray):
|
| 128 |
+
raw_audio = np.asarray(raw_audio, dtype=np.float32)
|
| 129 |
+
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
|
| 130 |
+
raw_audio = raw_audio.astype(np.float32)
|
| 131 |
+
|
| 132 |
+
# always return batch
|
| 133 |
+
if not is_batched:
|
| 134 |
+
raw_audio = [np.asarray(raw_audio).T]
|
| 135 |
+
|
| 136 |
+
# convert stereo to mono if necessary, unique to Dia
|
| 137 |
+
for idx, example in enumerate(raw_audio):
|
| 138 |
+
if self.feature_size == 2 and example.ndim == 2:
|
| 139 |
+
raw_audio[idx] = np.mean(example, -1)
|
| 140 |
+
|
| 141 |
+
# verify inputs are valid
|
| 142 |
+
for idx, example in enumerate(raw_audio):
|
| 143 |
+
if example.ndim > 2:
|
| 144 |
+
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
|
| 145 |
+
if self.feature_size == 1 and example.ndim != 1:
|
| 146 |
+
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
|
| 147 |
+
if self.feature_size == 2 and example.ndim != 1: # note the conversion before
|
| 148 |
+
raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
|
| 149 |
+
|
| 150 |
+
input_values = BatchFeature({"input_values": raw_audio})
|
| 151 |
+
|
| 152 |
+
# temporarily treat it as if we were mono as we also convert stereo to mono
|
| 153 |
+
origingal_feature_size = self.feature_size
|
| 154 |
+
self.feature_size = 1
|
| 155 |
+
|
| 156 |
+
# normal padding on batch
|
| 157 |
+
padded_inputs = self.pad(
|
| 158 |
+
input_values,
|
| 159 |
+
max_length=max_length,
|
| 160 |
+
truncation=truncation,
|
| 161 |
+
padding=padding,
|
| 162 |
+
return_attention_mask=True,
|
| 163 |
+
pad_to_multiple_of=self.hop_length,
|
| 164 |
+
)
|
| 165 |
+
padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
|
| 166 |
+
|
| 167 |
+
input_values = []
|
| 168 |
+
for example in padded_inputs.pop("input_values"):
|
| 169 |
+
if self.feature_size == 1:
|
| 170 |
+
example = example[..., None]
|
| 171 |
+
input_values.append(example.T)
|
| 172 |
+
|
| 173 |
+
padded_inputs["input_values"] = input_values
|
| 174 |
+
if return_tensors is not None:
|
| 175 |
+
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
| 176 |
+
|
| 177 |
+
# rewrite back to original feature size
|
| 178 |
+
self.feature_size = origingal_feature_size
|
| 179 |
+
|
| 180 |
+
return padded_inputs
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
__all__ = ["DiaFeatureExtractor"]
|
phivenv/Lib/site-packages/transformers/models/dia/generation_dia.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any, Callable, Optional, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
|
| 21 |
+
from ...generation.logits_process import (
|
| 22 |
+
DiaClassifierFreeGuidanceLogitsProcessor,
|
| 23 |
+
DiaEOSChannelFilterLogitsProcessor,
|
| 24 |
+
DiaEOSDelayPatternLogitsProcessor,
|
| 25 |
+
LogitsProcessorList,
|
| 26 |
+
TemperatureLogitsWarper,
|
| 27 |
+
)
|
| 28 |
+
from ...generation.stopping_criteria import StoppingCriteriaList
|
| 29 |
+
from ...generation.streamers import BaseStreamer
|
| 30 |
+
from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode
|
| 31 |
+
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 32 |
+
from ...integrations.fsdp import is_fsdp_managed_module
|
| 33 |
+
from ...modeling_utils import PreTrainedModel
|
| 34 |
+
from ...utils import logging
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class DiaGenerationMixin(GenerationMixin):
|
| 41 |
+
# Indicates CFG which needs preparation to be properly handled by repeats
|
| 42 |
+
_uses_cfg = None
|
| 43 |
+
|
| 44 |
+
def _get_logits_processor(
|
| 45 |
+
self,
|
| 46 |
+
generation_config: GenerationConfig,
|
| 47 |
+
input_ids_seq_length: Optional[int] = None,
|
| 48 |
+
encoder_input_ids: torch.LongTensor = None,
|
| 49 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
| 50 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 51 |
+
device: Optional[str] = None,
|
| 52 |
+
model_kwargs: Optional[dict[str, Any]] = None,
|
| 53 |
+
negative_prompt_ids: Optional[torch.Tensor] = None,
|
| 54 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 55 |
+
) -> LogitsProcessorList:
|
| 56 |
+
# Need either custom order or custom processor instead
|
| 57 |
+
# (Temporarily disabling those for the super function)
|
| 58 |
+
original_guidance_scale = generation_config.guidance_scale
|
| 59 |
+
original_temperature = generation_config.temperature
|
| 60 |
+
generation_config.guidance_scale = None
|
| 61 |
+
generation_config.temperature = None
|
| 62 |
+
|
| 63 |
+
# Get base processors and those we can integrate easily
|
| 64 |
+
custom_processors = LogitsProcessorList()
|
| 65 |
+
|
| 66 |
+
if original_temperature is not None and original_temperature != 1.0:
|
| 67 |
+
custom_processors.append(TemperatureLogitsWarper(original_temperature))
|
| 68 |
+
|
| 69 |
+
custom_processors.append(
|
| 70 |
+
DiaEOSChannelFilterLogitsProcessor(
|
| 71 |
+
num_channels=len(self.config.delay_pattern),
|
| 72 |
+
eos_token_id=self.config.eos_token_id,
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
merged_processors = super()._get_logits_processor(
|
| 77 |
+
generation_config=generation_config,
|
| 78 |
+
input_ids_seq_length=input_ids_seq_length,
|
| 79 |
+
encoder_input_ids=encoder_input_ids,
|
| 80 |
+
prefix_allowed_tokens_fn=None,
|
| 81 |
+
logits_processor=custom_processors,
|
| 82 |
+
device=device,
|
| 83 |
+
model_kwargs=model_kwargs,
|
| 84 |
+
negative_prompt_ids=negative_prompt_ids,
|
| 85 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Custom processors we need at specific positions
|
| 89 |
+
if original_guidance_scale is not None and original_guidance_scale != 1:
|
| 90 |
+
cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor(
|
| 91 |
+
guidance_scale=original_guidance_scale,
|
| 92 |
+
guidance_top_k=generation_config.top_k,
|
| 93 |
+
)
|
| 94 |
+
merged_processors.insert(0, cfg_processor)
|
| 95 |
+
|
| 96 |
+
merged_processors.append(
|
| 97 |
+
DiaEOSDelayPatternLogitsProcessor(
|
| 98 |
+
delay_pattern=self.config.delay_pattern,
|
| 99 |
+
eos_token_id=self.config.eos_token_id,
|
| 100 |
+
max_generation_len=generation_config.max_length,
|
| 101 |
+
device=device,
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Enable temporarily disabled values back
|
| 106 |
+
generation_config.guidance_scale = original_guidance_scale
|
| 107 |
+
generation_config.temperature = original_temperature
|
| 108 |
+
|
| 109 |
+
return merged_processors
|
| 110 |
+
|
| 111 |
+
def _prepare_generation_config(
|
| 112 |
+
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict
|
| 113 |
+
) -> tuple[GenerationConfig, dict]:
|
| 114 |
+
generation_config, model_kwargs = super()._prepare_generation_config(
|
| 115 |
+
generation_config, use_model_defaults, **kwargs
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# We allow generation up to max length + max delay pattern
|
| 119 |
+
# (will revert back to max length after generation)
|
| 120 |
+
generation_config.max_length += max(self.config.delay_pattern)
|
| 121 |
+
|
| 122 |
+
# Internal flag to indicate CFG that needs to prepare unconditioned input
|
| 123 |
+
self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1
|
| 124 |
+
|
| 125 |
+
return generation_config, model_kwargs
|
| 126 |
+
|
| 127 |
+
def _prepare_model_inputs(
|
| 128 |
+
self,
|
| 129 |
+
inputs: Optional[torch.Tensor] = None,
|
| 130 |
+
bos_token_id: Optional[torch.Tensor] = None,
|
| 131 |
+
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
| 132 |
+
) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
|
| 133 |
+
inputs, input_name, model_kwargs = super()._prepare_model_inputs(
|
| 134 |
+
inputs=inputs,
|
| 135 |
+
bos_token_id=bos_token_id,
|
| 136 |
+
model_kwargs=model_kwargs,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# If CFG is requested we fill in the unconditioned parts
|
| 140 |
+
if self._uses_cfg:
|
| 141 |
+
unconditioned_inputs = torch.zeros_like(inputs)
|
| 142 |
+
inputs = torch.cat([inputs, unconditioned_inputs], dim=0)
|
| 143 |
+
|
| 144 |
+
if model_kwargs.get("attention_mask", None) is not None:
|
| 145 |
+
model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1)
|
| 146 |
+
|
| 147 |
+
return inputs, input_name, model_kwargs
|
| 148 |
+
|
| 149 |
+
def _prepare_decoder_input_ids_for_generation(
|
| 150 |
+
self,
|
| 151 |
+
batch_size: int,
|
| 152 |
+
model_input_name: str,
|
| 153 |
+
model_kwargs: dict[str, torch.Tensor],
|
| 154 |
+
decoder_start_token_id: torch.Tensor,
|
| 155 |
+
device: Optional[torch.device] = None,
|
| 156 |
+
) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
|
| 157 |
+
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
|
| 158 |
+
# 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out
|
| 159 |
+
decoder_input_ids = decoder_attention_mask = None
|
| 160 |
+
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
| 161 |
+
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
| 162 |
+
if model_kwargs is not None and "decoder_attention_mask" in model_kwargs:
|
| 163 |
+
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask")
|
| 164 |
+
|
| 165 |
+
# We allow generating without preparation (no proper delay) but discourage it
|
| 166 |
+
if decoder_input_ids is None or decoder_attention_mask is None:
|
| 167 |
+
logger.warning_once(
|
| 168 |
+
"In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:"
|
| 169 |
+
f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}."
|
| 170 |
+
f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation."
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
num_channels = self.config.decoder_config.num_channels
|
| 174 |
+
real_batch_size = batch_size // 2 if self._uses_cfg else batch_size
|
| 175 |
+
|
| 176 |
+
if decoder_input_ids is None:
|
| 177 |
+
decoder_input_ids = torch.full(
|
| 178 |
+
(real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
decoder_attention_mask = torch.ones(
|
| 182 |
+
size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# 2. Determine the valid input and what works as mask within the input
|
| 186 |
+
delay_mask = decoder_input_ids.long()
|
| 187 |
+
valid_input_size = (
|
| 188 |
+
decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max()
|
| 189 |
+
)
|
| 190 |
+
decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long()
|
| 191 |
+
decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long()
|
| 192 |
+
|
| 193 |
+
# 3. Overwrite into model kwargs
|
| 194 |
+
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
| 195 |
+
model_kwargs["decoder_delay_mask"] = delay_mask
|
| 196 |
+
|
| 197 |
+
return decoder_input_ids, model_kwargs
|
| 198 |
+
|
| 199 |
+
def prepare_inputs_for_generation(
|
| 200 |
+
self,
|
| 201 |
+
input_ids,
|
| 202 |
+
encoder_outputs=None, # Using this to easily get the batch size
|
| 203 |
+
decoder_delay_mask=None,
|
| 204 |
+
**kwargs,
|
| 205 |
+
):
|
| 206 |
+
# Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape
|
| 207 |
+
batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0]
|
| 208 |
+
input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2)
|
| 209 |
+
|
| 210 |
+
# Base method handles most things except CFG and the delay pattern mask
|
| 211 |
+
model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs)
|
| 212 |
+
|
| 213 |
+
# Post processing for CFG and overwriting via delay pattern mask
|
| 214 |
+
# 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask)
|
| 215 |
+
model_inputs["decoder_input_ids"] = self.apply_delay_mask(
|
| 216 |
+
input_ids, self.config.pad_token_id, decoder_delay_mask
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Depending on cache usage we need to pass all or just one
|
| 220 |
+
if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0:
|
| 221 |
+
model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :]
|
| 222 |
+
|
| 223 |
+
# Be compile friendly
|
| 224 |
+
model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous()
|
| 225 |
+
|
| 226 |
+
# 2. Apply CFG duplication if needed
|
| 227 |
+
if self._uses_cfg:
|
| 228 |
+
for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]:
|
| 229 |
+
if model_inputs.get(key, None) is not None:
|
| 230 |
+
# double first dimension and keep everything else the same
|
| 231 |
+
repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1))
|
| 232 |
+
model_inputs[key] = model_inputs[key].repeat(*repeat_pattern)
|
| 233 |
+
|
| 234 |
+
return model_inputs
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
| 238 |
+
if delay_mask is None:
|
| 239 |
+
return input_ids
|
| 240 |
+
|
| 241 |
+
mask_len = min(input_ids.shape[1], delay_mask.shape[1])
|
| 242 |
+
valid_mask = delay_mask[:, :mask_len, :]
|
| 243 |
+
valid_input = input_ids[:, :mask_len, :]
|
| 244 |
+
|
| 245 |
+
# Overwrite the respective parts of the input
|
| 246 |
+
input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask)
|
| 247 |
+
|
| 248 |
+
return input_ids
|
| 249 |
+
|
| 250 |
+
def _main_generate_loop(
|
| 251 |
+
self,
|
| 252 |
+
inputs: Optional[torch.Tensor] = None,
|
| 253 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 254 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 255 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 256 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
| 257 |
+
synced_gpus: Optional[bool] = None,
|
| 258 |
+
assistant_model: Optional["PreTrainedModel"] = None,
|
| 259 |
+
streamer: Optional["BaseStreamer"] = None,
|
| 260 |
+
negative_prompt_ids: Optional[torch.Tensor] = None,
|
| 261 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 262 |
+
use_model_defaults: Optional[bool] = None,
|
| 263 |
+
custom_generate: Optional[str] = None,
|
| 264 |
+
**kwargs,
|
| 265 |
+
):
|
| 266 |
+
# ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
|
| 267 |
+
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
| 268 |
+
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
| 269 |
+
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
| 270 |
+
|
| 271 |
+
generation_config, model_kwargs = self._prepare_generation_config(
|
| 272 |
+
generation_config, use_model_defaults, **kwargs
|
| 273 |
+
)
|
| 274 |
+
self._validate_model_kwargs(model_kwargs.copy())
|
| 275 |
+
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
|
| 276 |
+
|
| 277 |
+
# 2. Set generation parameters if not already defined
|
| 278 |
+
if synced_gpus is None:
|
| 279 |
+
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
|
| 280 |
+
|
| 281 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
| 282 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
| 283 |
+
|
| 284 |
+
# 3. Define model inputs
|
| 285 |
+
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
| 286 |
+
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
| 287 |
+
inputs, generation_config.bos_token_id, model_kwargs
|
| 288 |
+
)
|
| 289 |
+
batch_size = inputs_tensor.shape[0]
|
| 290 |
+
|
| 291 |
+
device = inputs_tensor.device
|
| 292 |
+
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
| 293 |
+
|
| 294 |
+
# 4. Define other model kwargs
|
| 295 |
+
if "encoder_outputs" not in model_kwargs:
|
| 296 |
+
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
|
| 297 |
+
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
| 298 |
+
inputs_tensor, model_kwargs, model_input_name, generation_config
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
| 302 |
+
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
| 303 |
+
batch_size=batch_size,
|
| 304 |
+
model_input_name=model_input_name,
|
| 305 |
+
model_kwargs=model_kwargs,
|
| 306 |
+
decoder_start_token_id=generation_config._decoder_start_token_tensor,
|
| 307 |
+
device=inputs_tensor.device,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
if generation_config.token_healing:
|
| 311 |
+
input_ids = self.heal_tokens(input_ids, tokenizer)
|
| 312 |
+
|
| 313 |
+
if streamer is not None:
|
| 314 |
+
streamer.put(input_ids.cpu())
|
| 315 |
+
|
| 316 |
+
# 6. Prepare `max_length` depending on other stopping criteria.
|
| 317 |
+
# NOTE: incorrect `input_ids.shape[1]` previously
|
| 318 |
+
input_ids_length = input_ids.shape[-1]
|
| 319 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
| 320 |
+
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
| 321 |
+
generation_config = self._prepare_generated_length(
|
| 322 |
+
generation_config=generation_config,
|
| 323 |
+
has_default_max_length=has_default_max_length,
|
| 324 |
+
has_default_min_length=has_default_min_length,
|
| 325 |
+
model_input_name=model_input_name,
|
| 326 |
+
inputs_tensor=inputs_tensor,
|
| 327 |
+
input_ids_length=input_ids_length,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
|
| 331 |
+
# logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
|
| 332 |
+
# dynamically overrides this value as it can need more than the last token logits
|
| 333 |
+
if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
|
| 334 |
+
model_kwargs["logits_to_keep"] = 1
|
| 335 |
+
|
| 336 |
+
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
| 337 |
+
|
| 338 |
+
# 7. Prepare the cache.
|
| 339 |
+
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
| 340 |
+
# - different models have a different cache name expected by the model (default = "past_key_values")
|
| 341 |
+
# - `max_length`, prepared above, is used to determine the maximum cache length
|
| 342 |
+
max_cache_length = generation_config.max_length - 1
|
| 343 |
+
if (
|
| 344 |
+
inputs_tensor.shape[1] != input_ids_length
|
| 345 |
+
and model_input_name == "inputs_embeds"
|
| 346 |
+
and not self.config.is_encoder_decoder
|
| 347 |
+
):
|
| 348 |
+
max_cache_length += inputs_tensor.shape[1]
|
| 349 |
+
self._prepare_cache_for_generation(
|
| 350 |
+
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# 8. determine generation mode
|
| 354 |
+
generation_mode = generation_config.get_generation_mode(assistant_model)
|
| 355 |
+
|
| 356 |
+
if streamer is not None and (generation_config.num_beams > 1):
|
| 357 |
+
raise ValueError(
|
| 358 |
+
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# 9. prepare logits processors and stopping criteria
|
| 362 |
+
prepared_logits_processor = self._get_logits_processor(
|
| 363 |
+
generation_config=generation_config,
|
| 364 |
+
input_ids_seq_length=input_ids_length,
|
| 365 |
+
encoder_input_ids=inputs_tensor,
|
| 366 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 367 |
+
logits_processor=logits_processor,
|
| 368 |
+
device=inputs_tensor.device,
|
| 369 |
+
model_kwargs=model_kwargs,
|
| 370 |
+
negative_prompt_ids=negative_prompt_ids,
|
| 371 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 372 |
+
)
|
| 373 |
+
prepared_stopping_criteria = self._get_stopping_criteria(
|
| 374 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Set model_kwargs `use_cache` so we can use it later in forward runs
|
| 378 |
+
model_kwargs["use_cache"] = generation_config.use_cache
|
| 379 |
+
# ******************* taken from main generate function up to calling the different methods *******************
|
| 380 |
+
|
| 381 |
+
# Prepare inner 2D logic in generation loop
|
| 382 |
+
input_ids = input_ids.reshape(-1, input_ids.shape[-1])
|
| 383 |
+
|
| 384 |
+
# 10. go into different generation modes
|
| 385 |
+
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
| 386 |
+
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
| 387 |
+
if generation_config.num_return_sequences > 1:
|
| 388 |
+
raise ValueError("`num_return_sequences>1` is incompatible with Dia.")
|
| 389 |
+
|
| 390 |
+
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
| 391 |
+
return self._sample(
|
| 392 |
+
input_ids,
|
| 393 |
+
logits_processor=prepared_logits_processor,
|
| 394 |
+
stopping_criteria=prepared_stopping_criteria,
|
| 395 |
+
generation_config=generation_config,
|
| 396 |
+
synced_gpus=synced_gpus,
|
| 397 |
+
streamer=streamer,
|
| 398 |
+
**model_kwargs,
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
raise ValueError(
|
| 402 |
+
"Got incompatible mode for generation, should be one of greedy or sampling. "
|
| 403 |
+
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
@torch.no_grad()
|
| 407 |
+
def generate(
|
| 408 |
+
self,
|
| 409 |
+
inputs: Optional[torch.Tensor] = None,
|
| 410 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 411 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 412 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 413 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
| 414 |
+
synced_gpus: Optional[bool] = None,
|
| 415 |
+
assistant_model: Optional["PreTrainedModel"] = None,
|
| 416 |
+
streamer: Optional["BaseStreamer"] = None,
|
| 417 |
+
negative_prompt_ids: Optional[torch.Tensor] = None,
|
| 418 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 419 |
+
use_model_defaults: Optional[bool] = None,
|
| 420 |
+
custom_generate: Optional[str] = None,
|
| 421 |
+
**kwargs,
|
| 422 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 423 |
+
# We expect the initial input ids to be the complete mask (delayed input)
|
| 424 |
+
delay_mask = kwargs.get("decoder_input_ids")
|
| 425 |
+
if delay_mask is not None:
|
| 426 |
+
delay_mask = delay_mask.clone()
|
| 427 |
+
|
| 428 |
+
output = self._main_generate_loop(
|
| 429 |
+
inputs=inputs,
|
| 430 |
+
generation_config=generation_config,
|
| 431 |
+
logits_processor=logits_processor,
|
| 432 |
+
stopping_criteria=stopping_criteria,
|
| 433 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 434 |
+
synced_gpus=synced_gpus,
|
| 435 |
+
assistant_model=assistant_model,
|
| 436 |
+
streamer=streamer,
|
| 437 |
+
negative_prompt_ids=negative_prompt_ids,
|
| 438 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 439 |
+
use_model_defaults=use_model_defaults,
|
| 440 |
+
custom_generate=custom_generate,
|
| 441 |
+
**kwargs,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
return_dict_in_generate = not isinstance(output, torch.Tensor)
|
| 445 |
+
|
| 446 |
+
if return_dict_in_generate:
|
| 447 |
+
output_sequences = output.sequences
|
| 448 |
+
else:
|
| 449 |
+
output_sequences = output
|
| 450 |
+
|
| 451 |
+
# Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels)
|
| 452 |
+
num_channels = self.config.decoder_config.num_channels
|
| 453 |
+
bsz = output_sequences.shape[0] // num_channels
|
| 454 |
+
output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2)
|
| 455 |
+
|
| 456 |
+
# Apply delay mask
|
| 457 |
+
output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask)
|
| 458 |
+
|
| 459 |
+
if return_dict_in_generate:
|
| 460 |
+
output.sequences = output_sequences
|
| 461 |
+
else:
|
| 462 |
+
output = output_sequences
|
| 463 |
+
|
| 464 |
+
return output
|
phivenv/Lib/site-packages/transformers/models/dia/modeling_dia.py
ADDED
|
@@ -0,0 +1,958 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/dia/modular_dia.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_dia.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
from typing import Callable, Optional, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
from ...activations import ACT2FN
|
| 28 |
+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
| 29 |
+
from ...integrations import use_kernel_forward_from_hub
|
| 30 |
+
from ...masking_utils import create_causal_mask
|
| 31 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
| 32 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 33 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 34 |
+
from ...modeling_outputs import (
|
| 35 |
+
BaseModelOutput,
|
| 36 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 37 |
+
Seq2SeqLMOutput,
|
| 38 |
+
Seq2SeqModelOutput,
|
| 39 |
+
)
|
| 40 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 41 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 42 |
+
from ...processing_utils import Unpack
|
| 43 |
+
from ...utils import (
|
| 44 |
+
TransformersKwargs,
|
| 45 |
+
auto_docstring,
|
| 46 |
+
can_return_tuple,
|
| 47 |
+
is_torch_flex_attn_available,
|
| 48 |
+
is_torchdynamo_compiling,
|
| 49 |
+
logging,
|
| 50 |
+
)
|
| 51 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 52 |
+
from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
|
| 53 |
+
from .generation_dia import DiaGenerationMixin
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if is_torch_flex_attn_available():
|
| 57 |
+
from ...integrations.flex_attention import make_flex_block_causal_mask
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
logger = logging.get_logger(__name__)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@auto_docstring
|
| 64 |
+
class DiaPreTrainedModel(PreTrainedModel):
|
| 65 |
+
config: DiaConfig
|
| 66 |
+
base_model_prefix = "model"
|
| 67 |
+
supports_gradient_checkpointing = True
|
| 68 |
+
_supports_flash_attn = True
|
| 69 |
+
_supports_sdpa = True
|
| 70 |
+
_supports_flex_attn = True
|
| 71 |
+
_can_compile_fullgraph = True
|
| 72 |
+
main_input_name = "input_ids"
|
| 73 |
+
_no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class DiaMultiChannelEmbedding(nn.Module):
|
| 77 |
+
"""In order to efficiently compute the audio embedding from the 9 different channels,
|
| 78 |
+
we vectorize the embedding process by using a single embedding layer and an offset.
|
| 79 |
+
Example:
|
| 80 |
+
- num_embeds = 4
|
| 81 |
+
- vocab_size = 8
|
| 82 |
+
- num_channels = 3
|
| 83 |
+
We would have offsets = [0, 8, 16]
|
| 84 |
+
If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
|
| 85 |
+
then tokens = audio_codes + offsets
|
| 86 |
+
= [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
|
| 87 |
+
This allows us to use a single embedding layer for all channels.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, config: DiaDecoderConfig):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
|
| 93 |
+
self.hidden_size = config.hidden_size
|
| 94 |
+
self.num_channels = config.num_channels
|
| 95 |
+
offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
|
| 96 |
+
self.register_buffer("offsets", offsets, persistent=False)
|
| 97 |
+
|
| 98 |
+
def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
|
| 99 |
+
tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
|
| 100 |
+
embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
|
| 101 |
+
return embeds.sum(dim=2)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class DiaMLP(nn.Module):
|
| 105 |
+
def __init__(self, config):
|
| 106 |
+
super().__init__()
|
| 107 |
+
|
| 108 |
+
self.config = config
|
| 109 |
+
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
|
| 110 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
| 111 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 112 |
+
|
| 113 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 114 |
+
up_states = self.gate_up_proj(hidden_states)
|
| 115 |
+
|
| 116 |
+
gate, up_states = up_states.chunk(2, dim=-1)
|
| 117 |
+
up_states = up_states * self.activation_fn(gate)
|
| 118 |
+
|
| 119 |
+
return self.down_proj(up_states)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 123 |
+
class DiaRMSNorm(nn.Module):
|
| 124 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 125 |
+
"""
|
| 126 |
+
DiaRMSNorm is equivalent to T5LayerNorm
|
| 127 |
+
"""
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 130 |
+
self.variance_epsilon = eps
|
| 131 |
+
|
| 132 |
+
def forward(self, hidden_states):
|
| 133 |
+
input_dtype = hidden_states.dtype
|
| 134 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 135 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 136 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 137 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 138 |
+
|
| 139 |
+
def extra_repr(self):
|
| 140 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class DiaRotaryEmbedding(nn.Module):
|
| 144 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 145 |
+
|
| 146 |
+
def __init__(self, config: DiaConfig, device=None):
|
| 147 |
+
super().__init__()
|
| 148 |
+
# BC: "rope_type" was originally "type"
|
| 149 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 150 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 151 |
+
else:
|
| 152 |
+
self.rope_type = "default"
|
| 153 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 154 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 155 |
+
|
| 156 |
+
self.config = config
|
| 157 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 158 |
+
|
| 159 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 160 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 161 |
+
self.original_inv_freq = self.inv_freq
|
| 162 |
+
|
| 163 |
+
@torch.no_grad()
|
| 164 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 165 |
+
def forward(self, x, position_ids):
|
| 166 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 167 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 168 |
+
|
| 169 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 170 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 171 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 172 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 173 |
+
cos = emb.cos() * self.attention_scaling
|
| 174 |
+
sin = emb.sin() * self.attention_scaling
|
| 175 |
+
|
| 176 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def rotate_half(x):
|
| 180 |
+
"""Rotates half the hidden dims of the input."""
|
| 181 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 182 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 183 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 187 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
q (`torch.Tensor`): The query tensor.
|
| 191 |
+
k (`torch.Tensor`): The key tensor.
|
| 192 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 193 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 194 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 195 |
+
Deprecated and unused.
|
| 196 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 197 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 198 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 199 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 200 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 201 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 202 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 203 |
+
Returns:
|
| 204 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 205 |
+
"""
|
| 206 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 207 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 208 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 209 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 210 |
+
return q_embed, k_embed
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 214 |
+
"""
|
| 215 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 216 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 217 |
+
"""
|
| 218 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 219 |
+
if n_rep == 1:
|
| 220 |
+
return hidden_states
|
| 221 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 222 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def eager_attention_forward(
|
| 226 |
+
module: nn.Module,
|
| 227 |
+
query: torch.Tensor,
|
| 228 |
+
key: torch.Tensor,
|
| 229 |
+
value: torch.Tensor,
|
| 230 |
+
attention_mask: Optional[torch.Tensor],
|
| 231 |
+
scaling: float,
|
| 232 |
+
dropout: float = 0.0,
|
| 233 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 234 |
+
):
|
| 235 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 236 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 237 |
+
|
| 238 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 239 |
+
if attention_mask is not None:
|
| 240 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 241 |
+
attn_weights = attn_weights + causal_mask
|
| 242 |
+
|
| 243 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 244 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 245 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 246 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 247 |
+
|
| 248 |
+
return attn_output, attn_weights
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class DiaSelfAttention(nn.Module):
|
| 252 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 253 |
+
|
| 254 |
+
def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.config = config
|
| 257 |
+
self.layer_idx = layer_idx
|
| 258 |
+
self.hidden_size = config.hidden_size
|
| 259 |
+
self.num_heads = self.config.num_attention_heads
|
| 260 |
+
self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
|
| 261 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 262 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
|
| 263 |
+
self.scaling = 1
|
| 264 |
+
self.attention_dropout = 0.0
|
| 265 |
+
self.is_causal = is_causal
|
| 266 |
+
|
| 267 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 268 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 269 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 270 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 271 |
+
|
| 272 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 273 |
+
def forward(
|
| 274 |
+
self,
|
| 275 |
+
hidden_states: torch.Tensor,
|
| 276 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 277 |
+
attention_mask: Optional[torch.Tensor],
|
| 278 |
+
past_key_values: Optional[Cache] = None,
|
| 279 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 280 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 281 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 282 |
+
input_shape = hidden_states.shape[:-1]
|
| 283 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 284 |
+
|
| 285 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 286 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 287 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 288 |
+
|
| 289 |
+
cos, sin = position_embeddings
|
| 290 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 291 |
+
|
| 292 |
+
if past_key_values is not None:
|
| 293 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 294 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 295 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 296 |
+
|
| 297 |
+
attention_interface: Callable = eager_attention_forward
|
| 298 |
+
if self.config._attn_implementation != "eager":
|
| 299 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 300 |
+
|
| 301 |
+
attn_output, attn_weights = attention_interface(
|
| 302 |
+
self,
|
| 303 |
+
query_states,
|
| 304 |
+
key_states,
|
| 305 |
+
value_states,
|
| 306 |
+
attention_mask,
|
| 307 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 308 |
+
scaling=self.scaling,
|
| 309 |
+
**kwargs,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 313 |
+
attn_output = self.o_proj(attn_output)
|
| 314 |
+
return attn_output, attn_weights
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class DiaCrossAttention(nn.Module):
|
| 318 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 319 |
+
|
| 320 |
+
def __init__(self, config: DiaDecoderConfig, layer_idx: int):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.config = config
|
| 323 |
+
self.layer_idx = layer_idx
|
| 324 |
+
self.hidden_size = config.hidden_size
|
| 325 |
+
self.cross_hidden_size = config.cross_hidden_size
|
| 326 |
+
self.num_heads = self.config.cross_num_attention_heads
|
| 327 |
+
self.num_key_value_heads = self.config.cross_num_key_value_heads
|
| 328 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 329 |
+
self.head_dim = config.cross_head_dim
|
| 330 |
+
self.scaling = 1
|
| 331 |
+
self.attention_dropout = 0.0
|
| 332 |
+
self.is_causal = False
|
| 333 |
+
|
| 334 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 335 |
+
self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 336 |
+
self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 337 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 338 |
+
|
| 339 |
+
def forward(
|
| 340 |
+
self,
|
| 341 |
+
hidden_states: torch.Tensor,
|
| 342 |
+
cross_attention_states: torch.Tensor,
|
| 343 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 344 |
+
past_key_values: Optional[EncoderDecoderCache] = None,
|
| 345 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 346 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 347 |
+
input_shape = hidden_states.shape[:-1]
|
| 348 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 349 |
+
cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
|
| 350 |
+
|
| 351 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 352 |
+
|
| 353 |
+
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
|
| 354 |
+
if past_key_values is not None and is_updated:
|
| 355 |
+
# reuse k,v, cross_attentions
|
| 356 |
+
key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
|
| 357 |
+
value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
|
| 358 |
+
else:
|
| 359 |
+
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
|
| 360 |
+
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
|
| 361 |
+
|
| 362 |
+
if past_key_values is not None:
|
| 363 |
+
# save all states to the cache
|
| 364 |
+
key_states, value_states = past_key_values.cross_attention_cache.update(
|
| 365 |
+
key_states,
|
| 366 |
+
value_states,
|
| 367 |
+
self.layer_idx,
|
| 368 |
+
)
|
| 369 |
+
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
| 370 |
+
past_key_values.is_updated[self.layer_idx] = True
|
| 371 |
+
|
| 372 |
+
attention_interface: Callable = eager_attention_forward
|
| 373 |
+
if self.config._attn_implementation != "eager":
|
| 374 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 375 |
+
|
| 376 |
+
attn_output, attn_weights = attention_interface(
|
| 377 |
+
self,
|
| 378 |
+
query_states,
|
| 379 |
+
key_states,
|
| 380 |
+
value_states,
|
| 381 |
+
attention_mask,
|
| 382 |
+
scaling=self.scaling,
|
| 383 |
+
**kwargs,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
|
| 387 |
+
attn_output = self.o_proj(attn_output)
|
| 388 |
+
return attn_output, attn_weights
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class DiaEncoderLayer(GradientCheckpointingLayer):
|
| 392 |
+
def __init__(self, config: DiaEncoderConfig, layer_idx: int):
|
| 393 |
+
super().__init__()
|
| 394 |
+
self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
| 395 |
+
self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
|
| 396 |
+
self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
| 397 |
+
self.mlp = DiaMLP(config)
|
| 398 |
+
|
| 399 |
+
def forward(
|
| 400 |
+
self,
|
| 401 |
+
hidden_states: torch.Tensor,
|
| 402 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 403 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 404 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 405 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 406 |
+
residual = hidden_states
|
| 407 |
+
normed_states = self.pre_sa_norm(hidden_states)
|
| 408 |
+
self_attn_output, self_attn_weights = self.self_attention(
|
| 409 |
+
normed_states,
|
| 410 |
+
position_embeddings=position_embeddings,
|
| 411 |
+
attention_mask=attention_mask,
|
| 412 |
+
**kwargs,
|
| 413 |
+
)
|
| 414 |
+
hidden_states = residual + self_attn_output
|
| 415 |
+
|
| 416 |
+
residual = hidden_states
|
| 417 |
+
normed_states = self.post_sa_norm(hidden_states)
|
| 418 |
+
mlp_out = self.mlp(normed_states)
|
| 419 |
+
hidden_states = residual + mlp_out
|
| 420 |
+
|
| 421 |
+
return hidden_states, self_attn_weights
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class DiaEncoder(DiaPreTrainedModel):
|
| 425 |
+
def __init__(self, config: DiaEncoderConfig):
|
| 426 |
+
super().__init__(config)
|
| 427 |
+
self.config = config
|
| 428 |
+
|
| 429 |
+
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 430 |
+
self.layers = nn.ModuleList(
|
| 431 |
+
[DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 432 |
+
)
|
| 433 |
+
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
| 434 |
+
self.rotary_embeddings = DiaRotaryEmbedding(config)
|
| 435 |
+
|
| 436 |
+
@auto_docstring
|
| 437 |
+
@can_return_tuple
|
| 438 |
+
def forward(
|
| 439 |
+
self,
|
| 440 |
+
input_ids: torch.Tensor,
|
| 441 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 442 |
+
output_attentions: Optional[bool] = False,
|
| 443 |
+
output_hidden_states: Optional[bool] = False,
|
| 444 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 445 |
+
) -> Union[BaseModelOutput, tuple]:
|
| 446 |
+
hidden_states = self.embedding(input_ids)
|
| 447 |
+
|
| 448 |
+
# RoPE
|
| 449 |
+
# Note: We expect right padding and hence always generate
|
| 450 |
+
# the position ids on the fly to reduce preparation overhead
|
| 451 |
+
position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
|
| 452 |
+
position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
|
| 453 |
+
|
| 454 |
+
attention_mask = self._update_full_mask(
|
| 455 |
+
attention_mask,
|
| 456 |
+
hidden_states,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
encoder_states = () if output_hidden_states else None
|
| 460 |
+
all_attentions = () if output_attentions else None
|
| 461 |
+
|
| 462 |
+
for encoder_layer in self.layers:
|
| 463 |
+
if output_hidden_states:
|
| 464 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 465 |
+
|
| 466 |
+
layer_outputs = encoder_layer(
|
| 467 |
+
hidden_states,
|
| 468 |
+
position_embeddings=position_embeddings,
|
| 469 |
+
attention_mask=attention_mask,
|
| 470 |
+
**kwargs,
|
| 471 |
+
)
|
| 472 |
+
hidden_states = layer_outputs[0]
|
| 473 |
+
|
| 474 |
+
if output_attentions:
|
| 475 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 476 |
+
|
| 477 |
+
hidden_states = self.norm(hidden_states)
|
| 478 |
+
|
| 479 |
+
if output_hidden_states:
|
| 480 |
+
encoder_states += (hidden_states,)
|
| 481 |
+
|
| 482 |
+
return BaseModelOutput(
|
| 483 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
|
| 487 |
+
def _update_full_mask(
|
| 488 |
+
self,
|
| 489 |
+
attention_mask: Union[torch.Tensor, None],
|
| 490 |
+
inputs_embeds: torch.Tensor,
|
| 491 |
+
):
|
| 492 |
+
if attention_mask is not None:
|
| 493 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 494 |
+
attention_mask = attention_mask if 0 in attention_mask else None
|
| 495 |
+
elif self.config._attn_implementation == "sdpa":
|
| 496 |
+
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
| 497 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
| 498 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 499 |
+
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
| 500 |
+
elif self.config._attn_implementation == "flex_attention":
|
| 501 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 502 |
+
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
|
| 503 |
+
else:
|
| 504 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 505 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
| 506 |
+
|
| 507 |
+
return attention_mask
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class DiaDecoderLayer(GradientCheckpointingLayer):
|
| 511 |
+
def __init__(self, config: DiaDecoderConfig, layer_idx: int):
|
| 512 |
+
super().__init__()
|
| 513 |
+
self.embed_dim = config.hidden_size
|
| 514 |
+
self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
|
| 515 |
+
self.cross_attention = DiaCrossAttention(config, layer_idx)
|
| 516 |
+
self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
| 517 |
+
self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
| 518 |
+
self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
| 519 |
+
self.mlp = DiaMLP(config)
|
| 520 |
+
|
| 521 |
+
def forward(
|
| 522 |
+
self,
|
| 523 |
+
hidden_states: torch.Tensor,
|
| 524 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 525 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 526 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 527 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 528 |
+
past_key_values: Optional[EncoderDecoderCache] = None,
|
| 529 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 530 |
+
**kwargs,
|
| 531 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 532 |
+
self_attn_cache = past_key_values
|
| 533 |
+
if isinstance(self_attn_cache, EncoderDecoderCache):
|
| 534 |
+
self_attn_cache = self_attn_cache.self_attention_cache
|
| 535 |
+
|
| 536 |
+
residual = hidden_states
|
| 537 |
+
normed_states = self.pre_sa_norm(hidden_states)
|
| 538 |
+
self_attn_output, self_attn_weights = self.self_attention(
|
| 539 |
+
normed_states,
|
| 540 |
+
position_embeddings,
|
| 541 |
+
attention_mask,
|
| 542 |
+
# Needs to be an arg in order to function properly
|
| 543 |
+
# on inplace operations to be carried (e.g. compile)
|
| 544 |
+
self_attn_cache,
|
| 545 |
+
cache_position=cache_position,
|
| 546 |
+
**kwargs,
|
| 547 |
+
)
|
| 548 |
+
hidden_states = residual + self_attn_output
|
| 549 |
+
|
| 550 |
+
residual = hidden_states
|
| 551 |
+
normed_states = self.pre_ca_norm(hidden_states)
|
| 552 |
+
cross_states, cross_attn_weights = self.cross_attention(
|
| 553 |
+
normed_states,
|
| 554 |
+
encoder_hidden_states,
|
| 555 |
+
attention_mask=encoder_attention_mask,
|
| 556 |
+
past_key_values=past_key_values,
|
| 557 |
+
**kwargs,
|
| 558 |
+
)
|
| 559 |
+
hidden_states = residual + cross_states
|
| 560 |
+
|
| 561 |
+
residual = hidden_states
|
| 562 |
+
normed_states = self.pre_mlp_norm(hidden_states)
|
| 563 |
+
mlp_out = self.mlp(normed_states)
|
| 564 |
+
hidden_states = residual + mlp_out
|
| 565 |
+
|
| 566 |
+
return hidden_states, self_attn_weights, cross_attn_weights
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class DiaDecoder(DiaPreTrainedModel):
|
| 570 |
+
"""Transformer Decoder Stack using DenseGeneral."""
|
| 571 |
+
|
| 572 |
+
def __init__(self, config: DiaDecoderConfig):
|
| 573 |
+
super().__init__(config)
|
| 574 |
+
self.num_channels = config.num_channels
|
| 575 |
+
self.vocab_size = config.vocab_size
|
| 576 |
+
self.embeddings = DiaMultiChannelEmbedding(config)
|
| 577 |
+
self.rotary_embeddings = DiaRotaryEmbedding(config)
|
| 578 |
+
self.layers = nn.ModuleList(
|
| 579 |
+
[DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 580 |
+
)
|
| 581 |
+
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
| 582 |
+
|
| 583 |
+
@auto_docstring
|
| 584 |
+
@can_return_tuple
|
| 585 |
+
def forward(
|
| 586 |
+
self,
|
| 587 |
+
input_ids: torch.Tensor,
|
| 588 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 589 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 590 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 591 |
+
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
| 592 |
+
past_key_values: Optional[EncoderDecoderCache] = None,
|
| 593 |
+
output_attentions: Optional[bool] = False,
|
| 594 |
+
output_hidden_states: Optional[bool] = False,
|
| 595 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 596 |
+
**kwargs,
|
| 597 |
+
) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
|
| 598 |
+
r"""
|
| 599 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
|
| 600 |
+
The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
|
| 601 |
+
|
| 602 |
+
[What are input IDs?](../glossary#input-ids)
|
| 603 |
+
"""
|
| 604 |
+
|
| 605 |
+
batch_size, seq_length = input_ids.size()[:-1]
|
| 606 |
+
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 607 |
+
if cache_position is None:
|
| 608 |
+
cache_position = torch.arange(
|
| 609 |
+
past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
|
| 610 |
+
)
|
| 611 |
+
if position_ids is None:
|
| 612 |
+
position_ids = cache_position[None, :]
|
| 613 |
+
|
| 614 |
+
# RoPE
|
| 615 |
+
hidden_states = self.embeddings(input_ids)
|
| 616 |
+
position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
|
| 617 |
+
|
| 618 |
+
if attention_mask is None and not is_torchdynamo_compiling():
|
| 619 |
+
# required mask seq length can be calculated via length of past cache
|
| 620 |
+
mask_seq_length = past_key_values_length + seq_length
|
| 621 |
+
attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
|
| 622 |
+
|
| 623 |
+
attention_mask = create_causal_mask(
|
| 624 |
+
config=self.config,
|
| 625 |
+
input_embeds=hidden_states,
|
| 626 |
+
attention_mask=attention_mask,
|
| 627 |
+
cache_position=cache_position,
|
| 628 |
+
past_key_values=past_key_values,
|
| 629 |
+
position_ids=position_ids,
|
| 630 |
+
)
|
| 631 |
+
encoder_attention_mask = self._update_cross_attn_mask(
|
| 632 |
+
encoder_hidden_states,
|
| 633 |
+
encoder_attention_mask,
|
| 634 |
+
hidden_states.shape[:2],
|
| 635 |
+
hidden_states,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
all_hidden_states = () if output_hidden_states else None
|
| 639 |
+
all_self_attns = () if output_attentions else None
|
| 640 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
| 641 |
+
|
| 642 |
+
for layer in self.layers:
|
| 643 |
+
if output_hidden_states:
|
| 644 |
+
all_hidden_states += (hidden_states,)
|
| 645 |
+
|
| 646 |
+
layer_outputs = layer(
|
| 647 |
+
hidden_states,
|
| 648 |
+
position_embeddings,
|
| 649 |
+
attention_mask,
|
| 650 |
+
encoder_hidden_states,
|
| 651 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 652 |
+
past_key_values=past_key_values,
|
| 653 |
+
cache_position=cache_position,
|
| 654 |
+
**kwargs,
|
| 655 |
+
)
|
| 656 |
+
hidden_states = layer_outputs[0]
|
| 657 |
+
|
| 658 |
+
if output_attentions:
|
| 659 |
+
all_self_attns = all_self_attns + (layer_outputs[1],)
|
| 660 |
+
|
| 661 |
+
if encoder_hidden_states is not None:
|
| 662 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 663 |
+
|
| 664 |
+
hidden_states = self.norm(hidden_states)
|
| 665 |
+
|
| 666 |
+
if output_hidden_states:
|
| 667 |
+
all_hidden_states += (hidden_states,)
|
| 668 |
+
|
| 669 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 670 |
+
last_hidden_state=hidden_states,
|
| 671 |
+
past_key_values=past_key_values,
|
| 672 |
+
hidden_states=all_hidden_states,
|
| 673 |
+
attentions=all_self_attns,
|
| 674 |
+
cross_attentions=all_cross_attentions,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
|
| 678 |
+
def _update_cross_attn_mask(
|
| 679 |
+
self,
|
| 680 |
+
encoder_hidden_states: Union[torch.Tensor, None],
|
| 681 |
+
encoder_attention_mask: Union[torch.Tensor, None],
|
| 682 |
+
input_shape: torch.Size,
|
| 683 |
+
inputs_embeds: torch.Tensor,
|
| 684 |
+
):
|
| 685 |
+
# expand encoder attention mask
|
| 686 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
| 687 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 688 |
+
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
| 689 |
+
elif self.config._attn_implementation == "sdpa":
|
| 690 |
+
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
| 691 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
| 692 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 693 |
+
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 694 |
+
encoder_attention_mask,
|
| 695 |
+
inputs_embeds.dtype,
|
| 696 |
+
tgt_len=input_shape[-1],
|
| 697 |
+
)
|
| 698 |
+
elif self.config._attn_implementation == "flex_attention":
|
| 699 |
+
if isinstance(encoder_attention_mask, torch.Tensor):
|
| 700 |
+
encoder_attention_mask = make_flex_block_causal_mask(
|
| 701 |
+
encoder_attention_mask,
|
| 702 |
+
query_length=input_shape[-1],
|
| 703 |
+
is_causal=False,
|
| 704 |
+
)
|
| 705 |
+
else:
|
| 706 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 707 |
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
| 708 |
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
return encoder_attention_mask
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
@auto_docstring(
|
| 715 |
+
custom_intro="""
|
| 716 |
+
The bare Dia model outputting raw hidden-states without any specific head on top.
|
| 717 |
+
"""
|
| 718 |
+
)
|
| 719 |
+
class DiaModel(DiaPreTrainedModel):
|
| 720 |
+
def __init__(self, config: DiaConfig):
|
| 721 |
+
super().__init__(config)
|
| 722 |
+
self.config = config
|
| 723 |
+
self.encoder = DiaEncoder(config.encoder_config)
|
| 724 |
+
self.decoder = DiaDecoder(config.decoder_config)
|
| 725 |
+
self.post_init()
|
| 726 |
+
|
| 727 |
+
def get_encoder(self):
|
| 728 |
+
return self.encoder
|
| 729 |
+
|
| 730 |
+
@auto_docstring
|
| 731 |
+
@can_return_tuple
|
| 732 |
+
def forward(
|
| 733 |
+
self,
|
| 734 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 735 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 736 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 737 |
+
decoder_position_ids: Optional[torch.LongTensor] = None,
|
| 738 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
| 739 |
+
encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
|
| 740 |
+
past_key_values: Optional[EncoderDecoderCache] = None,
|
| 741 |
+
use_cache: Optional[bool] = None,
|
| 742 |
+
output_attentions: Optional[bool] = None,
|
| 743 |
+
output_hidden_states: Optional[bool] = None,
|
| 744 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 745 |
+
**kwargs,
|
| 746 |
+
) -> Union[tuple, Seq2SeqModelOutput]:
|
| 747 |
+
r"""
|
| 748 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
|
| 749 |
+
or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
|
| 750 |
+
1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
|
| 751 |
+
the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
|
| 752 |
+
tened audio logits which are used to calculate the loss.
|
| 753 |
+
|
| 754 |
+
2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
|
| 755 |
+
Dia to calculate embeddings and subsequent steps more efficiently.
|
| 756 |
+
|
| 757 |
+
If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
|
| 758 |
+
`(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
|
| 759 |
+
[`DiaProcessor.__call__`] for more details.
|
| 760 |
+
|
| 761 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
| 762 |
+
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
|
| 763 |
+
Indices of positions of each input sequence tokens in the position embeddings.
|
| 764 |
+
Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
|
| 765 |
+
|
| 766 |
+
[What are position IDs?](../glossary#position-ids)
|
| 767 |
+
"""
|
| 768 |
+
|
| 769 |
+
if input_ids is None and encoder_outputs is None:
|
| 770 |
+
raise ValueError(
|
| 771 |
+
"You should either provide text ids or the cached text encodings. Neither has been found."
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 775 |
+
output_hidden_states = (
|
| 776 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 777 |
+
)
|
| 778 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 779 |
+
|
| 780 |
+
if self.is_gradient_checkpointing and self.training:
|
| 781 |
+
if use_cache:
|
| 782 |
+
logger.warning_once(
|
| 783 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 784 |
+
)
|
| 785 |
+
use_cache = False
|
| 786 |
+
|
| 787 |
+
if use_cache and past_key_values is None:
|
| 788 |
+
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
| 789 |
+
|
| 790 |
+
if encoder_outputs is None:
|
| 791 |
+
encoder_outputs = self.encoder(
|
| 792 |
+
input_ids=input_ids,
|
| 793 |
+
attention_mask=attention_mask,
|
| 794 |
+
output_attentions=output_attentions,
|
| 795 |
+
output_hidden_states=output_hidden_states,
|
| 796 |
+
**kwargs,
|
| 797 |
+
)
|
| 798 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
| 799 |
+
elif not isinstance(encoder_outputs, BaseModelOutput):
|
| 800 |
+
encoder_outputs = BaseModelOutput(
|
| 801 |
+
last_hidden_state=encoder_outputs[0],
|
| 802 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
| 803 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
# On default we initialize the decoder with bos tokens if nothing has been provided
|
| 807 |
+
bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
|
| 808 |
+
if decoder_input_ids is None:
|
| 809 |
+
decoder_input_ids = torch.full(
|
| 810 |
+
size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
|
| 811 |
+
)
|
| 812 |
+
# Ensure 3D
|
| 813 |
+
if decoder_input_ids.ndim == 2:
|
| 814 |
+
decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
|
| 815 |
+
|
| 816 |
+
decoder_outputs = self.decoder(
|
| 817 |
+
input_ids=decoder_input_ids,
|
| 818 |
+
position_ids=decoder_position_ids,
|
| 819 |
+
attention_mask=decoder_attention_mask,
|
| 820 |
+
encoder_hidden_states=encoder_outputs[0],
|
| 821 |
+
encoder_attention_mask=attention_mask,
|
| 822 |
+
past_key_values=past_key_values,
|
| 823 |
+
output_attentions=output_attentions,
|
| 824 |
+
output_hidden_states=output_hidden_states,
|
| 825 |
+
use_cache=use_cache,
|
| 826 |
+
cache_position=cache_position,
|
| 827 |
+
**kwargs,
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
return Seq2SeqModelOutput(
|
| 831 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 832 |
+
past_key_values=decoder_outputs.past_key_values,
|
| 833 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 834 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 835 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 836 |
+
encoder_last_hidden_state=encoder_outputs[0],
|
| 837 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 838 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
@auto_docstring(
|
| 843 |
+
custom_intro="""
|
| 844 |
+
The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
|
| 845 |
+
"""
|
| 846 |
+
)
|
| 847 |
+
class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
|
| 848 |
+
base_model_prefix = "model"
|
| 849 |
+
|
| 850 |
+
def __init__(self, config: DiaConfig):
|
| 851 |
+
super().__init__(config)
|
| 852 |
+
self.config = config
|
| 853 |
+
self.model = DiaModel(config)
|
| 854 |
+
|
| 855 |
+
self.num_channels = config.decoder_config.num_channels
|
| 856 |
+
self.vocab_size = config.decoder_config.vocab_size
|
| 857 |
+
self.logits_dense = nn.Linear(
|
| 858 |
+
config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
|
| 859 |
+
)
|
| 860 |
+
self.loss_type = "ForMaskedLM"
|
| 861 |
+
|
| 862 |
+
# Initialize weights and apply final processing
|
| 863 |
+
self.post_init()
|
| 864 |
+
|
| 865 |
+
def get_encoder(self):
|
| 866 |
+
return self.model.get_encoder()
|
| 867 |
+
|
| 868 |
+
def get_decoder(self):
|
| 869 |
+
return self.model.get_decoder()
|
| 870 |
+
|
| 871 |
+
@auto_docstring
|
| 872 |
+
@can_return_tuple
|
| 873 |
+
def forward(
|
| 874 |
+
self,
|
| 875 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 876 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 877 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 878 |
+
decoder_position_ids: Optional[torch.LongTensor] = None,
|
| 879 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
| 880 |
+
encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
|
| 881 |
+
past_key_values: Optional[EncoderDecoderCache] = None,
|
| 882 |
+
use_cache: Optional[bool] = None,
|
| 883 |
+
output_attentions: Optional[bool] = None,
|
| 884 |
+
output_hidden_states: Optional[bool] = None,
|
| 885 |
+
labels: Optional[torch.LongTensor] = None,
|
| 886 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 887 |
+
**kwargs,
|
| 888 |
+
) -> Union[tuple, Seq2SeqLMOutput]:
|
| 889 |
+
r"""
|
| 890 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
|
| 891 |
+
or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
|
| 892 |
+
1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
|
| 893 |
+
the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
|
| 894 |
+
tened audio logits which are used to calculate the loss.
|
| 895 |
+
|
| 896 |
+
2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
|
| 897 |
+
Dia to calculate embeddings and subsequent steps more efficiently.
|
| 898 |
+
|
| 899 |
+
If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
|
| 900 |
+
`(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
|
| 901 |
+
[`DiaProcessor.__call__`] for more details.
|
| 902 |
+
|
| 903 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
| 904 |
+
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
|
| 905 |
+
Indices of positions of each input sequence tokens in the position embeddings.
|
| 906 |
+
Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
|
| 907 |
+
|
| 908 |
+
[What are position IDs?](../glossary#position-ids)
|
| 909 |
+
labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
|
| 910 |
+
Labels for computing the masked language modeling loss. Indices should either be in
|
| 911 |
+
`[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
|
| 912 |
+
are ignored (masked).
|
| 913 |
+
"""
|
| 914 |
+
|
| 915 |
+
outputs = self.model(
|
| 916 |
+
input_ids=input_ids,
|
| 917 |
+
attention_mask=attention_mask,
|
| 918 |
+
decoder_input_ids=decoder_input_ids,
|
| 919 |
+
decoder_position_ids=decoder_position_ids,
|
| 920 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 921 |
+
encoder_outputs=encoder_outputs,
|
| 922 |
+
past_key_values=past_key_values,
|
| 923 |
+
use_cache=use_cache,
|
| 924 |
+
output_attentions=output_attentions,
|
| 925 |
+
output_hidden_states=output_hidden_states,
|
| 926 |
+
cache_position=cache_position,
|
| 927 |
+
**kwargs,
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
last_hidden_state = outputs[0]
|
| 931 |
+
batch_size = last_hidden_state.shape[0]
|
| 932 |
+
# 3D <-> 2D makes it necessary to prioritize channel dim
|
| 933 |
+
audio_logits = (
|
| 934 |
+
self.logits_dense(last_hidden_state)
|
| 935 |
+
.view((batch_size, -1, self.num_channels, self.vocab_size))
|
| 936 |
+
.transpose(1, 2)
|
| 937 |
+
.contiguous()
|
| 938 |
+
.view(batch_size * self.num_channels, -1, self.vocab_size)
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
loss = None
|
| 942 |
+
if labels is not None:
|
| 943 |
+
loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
|
| 944 |
+
|
| 945 |
+
return Seq2SeqLMOutput(
|
| 946 |
+
loss=loss,
|
| 947 |
+
logits=audio_logits,
|
| 948 |
+
past_key_values=outputs.past_key_values,
|
| 949 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 950 |
+
decoder_attentions=outputs.decoder_attentions,
|
| 951 |
+
cross_attentions=outputs.cross_attentions,
|
| 952 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 953 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 954 |
+
encoder_attentions=outputs.encoder_attentions,
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]
|