Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/torchvision/_C.so +3 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/_optical_flow.py +490 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/country211.py +58 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/folder.py +337 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/gtsrb.py +103 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/inaturalist.py +242 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/lsun.py +168 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/sbu.py +110 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/svhn.py +130 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/widerface.py +197 -0
- .venv/lib/python3.11/site-packages/torchvision/io/__init__.py +76 -0
- .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_load_gpu_decoder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_video_opt.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/image.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video_reader.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/io/_load_gpu_decoder.py +8 -0
- .venv/lib/python3.11/site-packages/torchvision/io/_video_opt.py +513 -0
- .venv/lib/python3.11/site-packages/torchvision/io/image.py +436 -0
- .venv/lib/python3.11/site-packages/torchvision/io/video.py +438 -0
- .venv/lib/python3.11/site-packages/torchvision/io/video_reader.py +294 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__init__.py +7 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/anchor_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/backbone_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/faster_rcnn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/fcos.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/generalized_rcnn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/image_list.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/keypoint_rcnn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/mask_rcnn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/retinanet.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/roi_heads.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/rpn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssd.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssdlite.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/transform.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/_utils.py +540 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/anchor_utils.py +268 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/backbone_utils.py +244 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/faster_rcnn.py +846 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/fcos.py +775 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/generalized_rcnn.py +118 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/image_list.py +25 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/keypoint_rcnn.py +474 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/mask_rcnn.py +590 -0
- .venv/lib/python3.11/site-packages/torchvision/models/detection/retinanet.py +903 -0
.gitattributes
CHANGED
|
@@ -345,3 +345,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
|
|
| 345 |
.venv/lib/python3.11/site-packages/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 346 |
.venv/lib/python3.11/site-packages/multidict/_multidict.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 347 |
.venv/lib/python3.11/site-packages/torchvision/image.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 345 |
.venv/lib/python3.11/site-packages/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 346 |
.venv/lib/python3.11/site-packages/multidict/_multidict.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 347 |
.venv/lib/python3.11/site-packages/torchvision/image.so filter=lfs diff=lfs merge=lfs -text
|
| 348 |
+
.venv/lib/python3.11/site-packages/torchvision/_C.so filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/torchvision/_C.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb7e1b7570bd8fc14f9497793f89e188ccf161d7c14ca1f236e00368779ee609
|
| 3 |
+
size 7746688
|
.venv/lib/python3.11/site-packages/torchvision/datasets/_optical_flow.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import os
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from glob import glob
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from ..io.image import decode_png, read_file
|
| 13 |
+
from .utils import _read_pfm, verify_str_arg
|
| 14 |
+
from .vision import VisionDataset
|
| 15 |
+
|
| 16 |
+
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
|
| 17 |
+
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
__all__ = (
|
| 21 |
+
"KittiFlow",
|
| 22 |
+
"Sintel",
|
| 23 |
+
"FlyingThings3D",
|
| 24 |
+
"FlyingChairs",
|
| 25 |
+
"HD1K",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FlowDataset(ABC, VisionDataset):
|
| 30 |
+
# Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
|
| 31 |
+
# For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
|
| 32 |
+
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
|
| 33 |
+
_has_builtin_flow_mask = False
|
| 34 |
+
|
| 35 |
+
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
|
| 36 |
+
|
| 37 |
+
super().__init__(root=root)
|
| 38 |
+
self.transforms = transforms
|
| 39 |
+
|
| 40 |
+
self._flow_list: List[str] = []
|
| 41 |
+
self._image_list: List[List[str]] = []
|
| 42 |
+
|
| 43 |
+
def _read_img(self, file_name: str) -> Image.Image:
|
| 44 |
+
img = Image.open(file_name)
|
| 45 |
+
if img.mode != "RGB":
|
| 46 |
+
img = img.convert("RGB") # type: ignore[assignment]
|
| 47 |
+
return img
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def _read_flow(self, file_name: str):
|
| 51 |
+
# Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, index: int) -> Union[T1, T2]:
|
| 55 |
+
|
| 56 |
+
img1 = self._read_img(self._image_list[index][0])
|
| 57 |
+
img2 = self._read_img(self._image_list[index][1])
|
| 58 |
+
|
| 59 |
+
if self._flow_list: # it will be empty for some dataset when split="test"
|
| 60 |
+
flow = self._read_flow(self._flow_list[index])
|
| 61 |
+
if self._has_builtin_flow_mask:
|
| 62 |
+
flow, valid_flow_mask = flow
|
| 63 |
+
else:
|
| 64 |
+
valid_flow_mask = None
|
| 65 |
+
else:
|
| 66 |
+
flow = valid_flow_mask = None
|
| 67 |
+
|
| 68 |
+
if self.transforms is not None:
|
| 69 |
+
img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)
|
| 70 |
+
|
| 71 |
+
if self._has_builtin_flow_mask or valid_flow_mask is not None:
|
| 72 |
+
# The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
|
| 73 |
+
return img1, img2, flow, valid_flow_mask
|
| 74 |
+
else:
|
| 75 |
+
return img1, img2, flow
|
| 76 |
+
|
| 77 |
+
def __len__(self) -> int:
|
| 78 |
+
return len(self._image_list)
|
| 79 |
+
|
| 80 |
+
def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
|
| 81 |
+
return torch.utils.data.ConcatDataset([self] * v)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Sintel(FlowDataset):
|
| 85 |
+
"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
|
| 86 |
+
|
| 87 |
+
The dataset is expected to have the following structure: ::
|
| 88 |
+
|
| 89 |
+
root
|
| 90 |
+
Sintel
|
| 91 |
+
testing
|
| 92 |
+
clean
|
| 93 |
+
scene_1
|
| 94 |
+
scene_2
|
| 95 |
+
...
|
| 96 |
+
final
|
| 97 |
+
scene_1
|
| 98 |
+
scene_2
|
| 99 |
+
...
|
| 100 |
+
training
|
| 101 |
+
clean
|
| 102 |
+
scene_1
|
| 103 |
+
scene_2
|
| 104 |
+
...
|
| 105 |
+
final
|
| 106 |
+
scene_1
|
| 107 |
+
scene_2
|
| 108 |
+
...
|
| 109 |
+
flow
|
| 110 |
+
scene_1
|
| 111 |
+
scene_2
|
| 112 |
+
...
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
|
| 116 |
+
split (string, optional): The dataset split, either "train" (default) or "test"
|
| 117 |
+
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
|
| 118 |
+
details on the different passes.
|
| 119 |
+
transforms (callable, optional): A function/transform that takes in
|
| 120 |
+
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
|
| 121 |
+
``valid_flow_mask`` is expected for consistency with other datasets which
|
| 122 |
+
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
root: Union[str, Path],
|
| 128 |
+
split: str = "train",
|
| 129 |
+
pass_name: str = "clean",
|
| 130 |
+
transforms: Optional[Callable] = None,
|
| 131 |
+
) -> None:
|
| 132 |
+
super().__init__(root=root, transforms=transforms)
|
| 133 |
+
|
| 134 |
+
verify_str_arg(split, "split", valid_values=("train", "test"))
|
| 135 |
+
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
|
| 136 |
+
passes = ["clean", "final"] if pass_name == "both" else [pass_name]
|
| 137 |
+
|
| 138 |
+
root = Path(root) / "Sintel"
|
| 139 |
+
flow_root = root / "training" / "flow"
|
| 140 |
+
|
| 141 |
+
for pass_name in passes:
|
| 142 |
+
split_dir = "training" if split == "train" else split
|
| 143 |
+
image_root = root / split_dir / pass_name
|
| 144 |
+
for scene in os.listdir(image_root):
|
| 145 |
+
image_list = sorted(glob(str(image_root / scene / "*.png")))
|
| 146 |
+
for i in range(len(image_list) - 1):
|
| 147 |
+
self._image_list += [[image_list[i], image_list[i + 1]]]
|
| 148 |
+
|
| 149 |
+
if split == "train":
|
| 150 |
+
self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
|
| 151 |
+
|
| 152 |
+
def __getitem__(self, index: int) -> Union[T1, T2]:
|
| 153 |
+
"""Return example at given index.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
index(int): The index of the example to retrieve
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
tuple: A 3-tuple with ``(img1, img2, flow)``.
|
| 160 |
+
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
|
| 161 |
+
``flow`` is None if ``split="test"``.
|
| 162 |
+
If a valid flow mask is generated within the ``transforms`` parameter,
|
| 163 |
+
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
|
| 164 |
+
"""
|
| 165 |
+
return super().__getitem__(index)
|
| 166 |
+
|
| 167 |
+
def _read_flow(self, file_name: str) -> np.ndarray:
|
| 168 |
+
return _read_flo(file_name)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class KittiFlow(FlowDataset):
|
| 172 |
+
"""`KITTI <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow>`__ dataset for optical flow (2015).
|
| 173 |
+
|
| 174 |
+
The dataset is expected to have the following structure: ::
|
| 175 |
+
|
| 176 |
+
root
|
| 177 |
+
KittiFlow
|
| 178 |
+
testing
|
| 179 |
+
image_2
|
| 180 |
+
training
|
| 181 |
+
image_2
|
| 182 |
+
flow_occ
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
|
| 186 |
+
split (string, optional): The dataset split, either "train" (default) or "test"
|
| 187 |
+
transforms (callable, optional): A function/transform that takes in
|
| 188 |
+
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
_has_builtin_flow_mask = True
|
| 192 |
+
|
| 193 |
+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
|
| 194 |
+
super().__init__(root=root, transforms=transforms)
|
| 195 |
+
|
| 196 |
+
verify_str_arg(split, "split", valid_values=("train", "test"))
|
| 197 |
+
|
| 198 |
+
root = Path(root) / "KittiFlow" / (split + "ing")
|
| 199 |
+
images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
|
| 200 |
+
images2 = sorted(glob(str(root / "image_2" / "*_11.png")))
|
| 201 |
+
|
| 202 |
+
if not images1 or not images2:
|
| 203 |
+
raise FileNotFoundError(
|
| 204 |
+
"Could not find the Kitti flow images. Please make sure the directory structure is correct."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
for img1, img2 in zip(images1, images2):
|
| 208 |
+
self._image_list += [[img1, img2]]
|
| 209 |
+
|
| 210 |
+
if split == "train":
|
| 211 |
+
self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
|
| 212 |
+
|
| 213 |
+
def __getitem__(self, index: int) -> Union[T1, T2]:
|
| 214 |
+
"""Return example at given index.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
index(int): The index of the example to retrieve
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
|
| 221 |
+
where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
|
| 222 |
+
indicating which flow values are valid. The flow is a numpy array of
|
| 223 |
+
shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
|
| 224 |
+
``split="test"``.
|
| 225 |
+
"""
|
| 226 |
+
return super().__getitem__(index)
|
| 227 |
+
|
| 228 |
+
def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 229 |
+
return _read_16bits_png_with_flow_and_valid_mask(file_name)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class FlyingChairs(FlowDataset):
|
| 233 |
+
"""`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.
|
| 234 |
+
|
| 235 |
+
You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
|
| 236 |
+
|
| 237 |
+
The dataset is expected to have the following structure: ::
|
| 238 |
+
|
| 239 |
+
root
|
| 240 |
+
FlyingChairs
|
| 241 |
+
data
|
| 242 |
+
00001_flow.flo
|
| 243 |
+
00001_img1.ppm
|
| 244 |
+
00001_img2.ppm
|
| 245 |
+
...
|
| 246 |
+
FlyingChairs_train_val.txt
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
|
| 251 |
+
split (string, optional): The dataset split, either "train" (default) or "val"
|
| 252 |
+
transforms (callable, optional): A function/transform that takes in
|
| 253 |
+
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
|
| 254 |
+
``valid_flow_mask`` is expected for consistency with other datasets which
|
| 255 |
+
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
|
| 259 |
+
super().__init__(root=root, transforms=transforms)
|
| 260 |
+
|
| 261 |
+
verify_str_arg(split, "split", valid_values=("train", "val"))
|
| 262 |
+
|
| 263 |
+
root = Path(root) / "FlyingChairs"
|
| 264 |
+
images = sorted(glob(str(root / "data" / "*.ppm")))
|
| 265 |
+
flows = sorted(glob(str(root / "data" / "*.flo")))
|
| 266 |
+
|
| 267 |
+
split_file_name = "FlyingChairs_train_val.txt"
|
| 268 |
+
|
| 269 |
+
if not os.path.exists(root / split_file_name):
|
| 270 |
+
raise FileNotFoundError(
|
| 271 |
+
"The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
|
| 275 |
+
for i in range(len(flows)):
|
| 276 |
+
split_id = split_list[i]
|
| 277 |
+
if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
|
| 278 |
+
self._flow_list += [flows[i]]
|
| 279 |
+
self._image_list += [[images[2 * i], images[2 * i + 1]]]
|
| 280 |
+
|
| 281 |
+
def __getitem__(self, index: int) -> Union[T1, T2]:
|
| 282 |
+
"""Return example at given index.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
index(int): The index of the example to retrieve
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
tuple: A 3-tuple with ``(img1, img2, flow)``.
|
| 289 |
+
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
|
| 290 |
+
``flow`` is None if ``split="val"``.
|
| 291 |
+
If a valid flow mask is generated within the ``transforms`` parameter,
|
| 292 |
+
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
|
| 293 |
+
"""
|
| 294 |
+
return super().__getitem__(index)
|
| 295 |
+
|
| 296 |
+
def _read_flow(self, file_name: str) -> np.ndarray:
|
| 297 |
+
return _read_flo(file_name)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class FlyingThings3D(FlowDataset):
|
| 301 |
+
"""`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.
|
| 302 |
+
|
| 303 |
+
The dataset is expected to have the following structure: ::
|
| 304 |
+
|
| 305 |
+
root
|
| 306 |
+
FlyingThings3D
|
| 307 |
+
frames_cleanpass
|
| 308 |
+
TEST
|
| 309 |
+
TRAIN
|
| 310 |
+
frames_finalpass
|
| 311 |
+
TEST
|
| 312 |
+
TRAIN
|
| 313 |
+
optical_flow
|
| 314 |
+
TEST
|
| 315 |
+
TRAIN
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
|
| 319 |
+
split (string, optional): The dataset split, either "train" (default) or "test"
|
| 320 |
+
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
|
| 321 |
+
details on the different passes.
|
| 322 |
+
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
|
| 323 |
+
transforms (callable, optional): A function/transform that takes in
|
| 324 |
+
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
|
| 325 |
+
``valid_flow_mask`` is expected for consistency with other datasets which
|
| 326 |
+
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
def __init__(
|
| 330 |
+
self,
|
| 331 |
+
root: Union[str, Path],
|
| 332 |
+
split: str = "train",
|
| 333 |
+
pass_name: str = "clean",
|
| 334 |
+
camera: str = "left",
|
| 335 |
+
transforms: Optional[Callable] = None,
|
| 336 |
+
) -> None:
|
| 337 |
+
super().__init__(root=root, transforms=transforms)
|
| 338 |
+
|
| 339 |
+
verify_str_arg(split, "split", valid_values=("train", "test"))
|
| 340 |
+
split = split.upper()
|
| 341 |
+
|
| 342 |
+
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
|
| 343 |
+
passes = {
|
| 344 |
+
"clean": ["frames_cleanpass"],
|
| 345 |
+
"final": ["frames_finalpass"],
|
| 346 |
+
"both": ["frames_cleanpass", "frames_finalpass"],
|
| 347 |
+
}[pass_name]
|
| 348 |
+
|
| 349 |
+
verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
|
| 350 |
+
cameras = ["left", "right"] if camera == "both" else [camera]
|
| 351 |
+
|
| 352 |
+
root = Path(root) / "FlyingThings3D"
|
| 353 |
+
|
| 354 |
+
directions = ("into_future", "into_past")
|
| 355 |
+
for pass_name, camera, direction in itertools.product(passes, cameras, directions):
|
| 356 |
+
image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
|
| 357 |
+
image_dirs = sorted(Path(image_dir) / camera for image_dir in image_dirs)
|
| 358 |
+
|
| 359 |
+
flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
|
| 360 |
+
flow_dirs = sorted(Path(flow_dir) / direction / camera for flow_dir in flow_dirs)
|
| 361 |
+
|
| 362 |
+
if not image_dirs or not flow_dirs:
|
| 363 |
+
raise FileNotFoundError(
|
| 364 |
+
"Could not find the FlyingThings3D flow images. "
|
| 365 |
+
"Please make sure the directory structure is correct."
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
for image_dir, flow_dir in zip(image_dirs, flow_dirs):
|
| 369 |
+
images = sorted(glob(str(image_dir / "*.png")))
|
| 370 |
+
flows = sorted(glob(str(flow_dir / "*.pfm")))
|
| 371 |
+
for i in range(len(flows) - 1):
|
| 372 |
+
if direction == "into_future":
|
| 373 |
+
self._image_list += [[images[i], images[i + 1]]]
|
| 374 |
+
self._flow_list += [flows[i]]
|
| 375 |
+
elif direction == "into_past":
|
| 376 |
+
self._image_list += [[images[i + 1], images[i]]]
|
| 377 |
+
self._flow_list += [flows[i + 1]]
|
| 378 |
+
|
| 379 |
+
def __getitem__(self, index: int) -> Union[T1, T2]:
|
| 380 |
+
"""Return example at given index.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
index(int): The index of the example to retrieve
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
tuple: A 3-tuple with ``(img1, img2, flow)``.
|
| 387 |
+
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
|
| 388 |
+
``flow`` is None if ``split="test"``.
|
| 389 |
+
If a valid flow mask is generated within the ``transforms`` parameter,
|
| 390 |
+
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
|
| 391 |
+
"""
|
| 392 |
+
return super().__getitem__(index)
|
| 393 |
+
|
| 394 |
+
def _read_flow(self, file_name: str) -> np.ndarray:
|
| 395 |
+
return _read_pfm(file_name)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class HD1K(FlowDataset):
|
| 399 |
+
"""`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.
|
| 400 |
+
|
| 401 |
+
The dataset is expected to have the following structure: ::
|
| 402 |
+
|
| 403 |
+
root
|
| 404 |
+
hd1k
|
| 405 |
+
hd1k_challenge
|
| 406 |
+
image_2
|
| 407 |
+
hd1k_flow_gt
|
| 408 |
+
flow_occ
|
| 409 |
+
hd1k_input
|
| 410 |
+
image_2
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
|
| 414 |
+
split (string, optional): The dataset split, either "train" (default) or "test"
|
| 415 |
+
transforms (callable, optional): A function/transform that takes in
|
| 416 |
+
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
_has_builtin_flow_mask = True
|
| 420 |
+
|
| 421 |
+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
|
| 422 |
+
super().__init__(root=root, transforms=transforms)
|
| 423 |
+
|
| 424 |
+
verify_str_arg(split, "split", valid_values=("train", "test"))
|
| 425 |
+
|
| 426 |
+
root = Path(root) / "hd1k"
|
| 427 |
+
if split == "train":
|
| 428 |
+
# There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
|
| 429 |
+
for seq_idx in range(36):
|
| 430 |
+
flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
|
| 431 |
+
images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
|
| 432 |
+
for i in range(len(flows) - 1):
|
| 433 |
+
self._flow_list += [flows[i]]
|
| 434 |
+
self._image_list += [[images[i], images[i + 1]]]
|
| 435 |
+
else:
|
| 436 |
+
images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
|
| 437 |
+
images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
|
| 438 |
+
for image1, image2 in zip(images1, images2):
|
| 439 |
+
self._image_list += [[image1, image2]]
|
| 440 |
+
|
| 441 |
+
if not self._image_list:
|
| 442 |
+
raise FileNotFoundError(
|
| 443 |
+
"Could not find the HD1K images. Please make sure the directory structure is correct."
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 447 |
+
return _read_16bits_png_with_flow_and_valid_mask(file_name)
|
| 448 |
+
|
| 449 |
+
def __getitem__(self, index: int) -> Union[T1, T2]:
|
| 450 |
+
"""Return example at given index.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
index(int): The index of the example to retrieve
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
|
| 457 |
+
is a numpy boolean mask of shape (H, W)
|
| 458 |
+
indicating which flow values are valid. The flow is a numpy array of
|
| 459 |
+
shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
|
| 460 |
+
``split="test"``.
|
| 461 |
+
"""
|
| 462 |
+
return super().__getitem__(index)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def _read_flo(file_name: str) -> np.ndarray:
|
| 466 |
+
"""Read .flo file in Middlebury format"""
|
| 467 |
+
# Code adapted from:
|
| 468 |
+
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
| 469 |
+
# Everything needs to be in little Endian according to
|
| 470 |
+
# https://vision.middlebury.edu/flow/code/flow-code/README.txt
|
| 471 |
+
with open(file_name, "rb") as f:
|
| 472 |
+
magic = np.fromfile(f, "c", count=4).tobytes()
|
| 473 |
+
if magic != b"PIEH":
|
| 474 |
+
raise ValueError("Magic number incorrect. Invalid .flo file")
|
| 475 |
+
|
| 476 |
+
w = int(np.fromfile(f, "<i4", count=1))
|
| 477 |
+
h = int(np.fromfile(f, "<i4", count=1))
|
| 478 |
+
data = np.fromfile(f, "<f4", count=2 * w * h)
|
| 479 |
+
return data.reshape(h, w, 2).transpose(2, 0, 1)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 483 |
+
|
| 484 |
+
flow_and_valid = decode_png(read_file(file_name)).to(torch.float32)
|
| 485 |
+
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
|
| 486 |
+
flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive
|
| 487 |
+
valid_flow_mask = valid_flow_mask.bool()
|
| 488 |
+
|
| 489 |
+
# For consistency with other datasets, we convert to numpy
|
| 490 |
+
return flow.numpy(), valid_flow_mask.numpy()
|
.venv/lib/python3.11/site-packages/torchvision/datasets/country211.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Callable, Optional, Union
|
| 3 |
+
|
| 4 |
+
from .folder import ImageFolder
|
| 5 |
+
from .utils import download_and_extract_archive, verify_str_arg
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Country211(ImageFolder):
|
| 9 |
+
"""`The Country211 Data Set <https://github.com/openai/CLIP/blob/main/data/country211.md>`_ from OpenAI.
|
| 10 |
+
|
| 11 |
+
This dataset was built by filtering the images from the YFCC100m dataset
|
| 12 |
+
that have GPS coordinate corresponding to a ISO-3166 country code. The
|
| 13 |
+
dataset is balanced by sampling 150 train images, 50 validation images, and
|
| 14 |
+
100 test images for each country.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
root (str or ``pathlib.Path``): Root directory of the dataset.
|
| 18 |
+
split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
|
| 19 |
+
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
|
| 20 |
+
version. E.g, ``transforms.RandomCrop``.
|
| 21 |
+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
|
| 22 |
+
download (bool, optional): If True, downloads the dataset from the internet and puts it into
|
| 23 |
+
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
_URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
|
| 27 |
+
_MD5 = "84988d7644798601126c29e9877aab6a"
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
root: Union[str, Path],
|
| 32 |
+
split: str = "train",
|
| 33 |
+
transform: Optional[Callable] = None,
|
| 34 |
+
target_transform: Optional[Callable] = None,
|
| 35 |
+
download: bool = False,
|
| 36 |
+
) -> None:
|
| 37 |
+
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
|
| 38 |
+
|
| 39 |
+
root = Path(root).expanduser()
|
| 40 |
+
self.root = str(root)
|
| 41 |
+
self._base_folder = root / "country211"
|
| 42 |
+
|
| 43 |
+
if download:
|
| 44 |
+
self._download()
|
| 45 |
+
|
| 46 |
+
if not self._check_exists():
|
| 47 |
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
| 48 |
+
|
| 49 |
+
super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
|
| 50 |
+
self.root = str(root)
|
| 51 |
+
|
| 52 |
+
def _check_exists(self) -> bool:
|
| 53 |
+
return self._base_folder.exists() and self._base_folder.is_dir()
|
| 54 |
+
|
| 55 |
+
def _download(self) -> None:
|
| 56 |
+
if self._check_exists():
|
| 57 |
+
return
|
| 58 |
+
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/folder.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from .vision import VisionDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
|
| 12 |
+
"""Checks if a file is an allowed extension.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
filename (string): path to a file
|
| 16 |
+
extensions (tuple of strings): extensions to consider (lowercase)
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
bool: True if the filename ends with one of given extensions
|
| 20 |
+
"""
|
| 21 |
+
return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def is_image_file(filename: str) -> bool:
|
| 25 |
+
"""Checks if a file is an allowed image extension.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
filename (string): path to a file
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
bool: True if the filename ends with a known image extension
|
| 32 |
+
"""
|
| 33 |
+
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def find_classes(directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
|
| 37 |
+
"""Finds the class folders in a dataset.
|
| 38 |
+
|
| 39 |
+
See :class:`DatasetFolder` for details.
|
| 40 |
+
"""
|
| 41 |
+
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
|
| 42 |
+
if not classes:
|
| 43 |
+
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
|
| 44 |
+
|
| 45 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
| 46 |
+
return classes, class_to_idx
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def make_dataset(
|
| 50 |
+
directory: Union[str, Path],
|
| 51 |
+
class_to_idx: Optional[Dict[str, int]] = None,
|
| 52 |
+
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
|
| 53 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
| 54 |
+
allow_empty: bool = False,
|
| 55 |
+
) -> List[Tuple[str, int]]:
|
| 56 |
+
"""Generates a list of samples of a form (path_to_sample, class).
|
| 57 |
+
|
| 58 |
+
See :class:`DatasetFolder` for details.
|
| 59 |
+
|
| 60 |
+
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
|
| 61 |
+
by default.
|
| 62 |
+
"""
|
| 63 |
+
directory = os.path.expanduser(directory)
|
| 64 |
+
|
| 65 |
+
if class_to_idx is None:
|
| 66 |
+
_, class_to_idx = find_classes(directory)
|
| 67 |
+
elif not class_to_idx:
|
| 68 |
+
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
|
| 69 |
+
|
| 70 |
+
both_none = extensions is None and is_valid_file is None
|
| 71 |
+
both_something = extensions is not None and is_valid_file is not None
|
| 72 |
+
if both_none or both_something:
|
| 73 |
+
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
|
| 74 |
+
|
| 75 |
+
if extensions is not None:
|
| 76 |
+
|
| 77 |
+
def is_valid_file(x: str) -> bool:
|
| 78 |
+
return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
|
| 79 |
+
|
| 80 |
+
is_valid_file = cast(Callable[[str], bool], is_valid_file)
|
| 81 |
+
|
| 82 |
+
instances = []
|
| 83 |
+
available_classes = set()
|
| 84 |
+
for target_class in sorted(class_to_idx.keys()):
|
| 85 |
+
class_index = class_to_idx[target_class]
|
| 86 |
+
target_dir = os.path.join(directory, target_class)
|
| 87 |
+
if not os.path.isdir(target_dir):
|
| 88 |
+
continue
|
| 89 |
+
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
|
| 90 |
+
for fname in sorted(fnames):
|
| 91 |
+
path = os.path.join(root, fname)
|
| 92 |
+
if is_valid_file(path):
|
| 93 |
+
item = path, class_index
|
| 94 |
+
instances.append(item)
|
| 95 |
+
|
| 96 |
+
if target_class not in available_classes:
|
| 97 |
+
available_classes.add(target_class)
|
| 98 |
+
|
| 99 |
+
empty_classes = set(class_to_idx.keys()) - available_classes
|
| 100 |
+
if empty_classes and not allow_empty:
|
| 101 |
+
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
|
| 102 |
+
if extensions is not None:
|
| 103 |
+
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
|
| 104 |
+
raise FileNotFoundError(msg)
|
| 105 |
+
|
| 106 |
+
return instances
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class DatasetFolder(VisionDataset):
|
| 110 |
+
"""A generic data loader.
|
| 111 |
+
|
| 112 |
+
This default directory structure can be customized by overriding the
|
| 113 |
+
:meth:`find_classes` method.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
root (str or ``pathlib.Path``): Root directory path.
|
| 117 |
+
loader (callable): A function to load a sample given its path.
|
| 118 |
+
extensions (tuple[string]): A list of allowed extensions.
|
| 119 |
+
both extensions and is_valid_file should not be passed.
|
| 120 |
+
transform (callable, optional): A function/transform that takes in
|
| 121 |
+
a sample and returns a transformed version.
|
| 122 |
+
E.g, ``transforms.RandomCrop`` for images.
|
| 123 |
+
target_transform (callable, optional): A function/transform that takes
|
| 124 |
+
in the target and transforms it.
|
| 125 |
+
is_valid_file (callable, optional): A function that takes path of a file
|
| 126 |
+
and check if the file is a valid file (used to check of corrupt files)
|
| 127 |
+
both extensions and is_valid_file should not be passed.
|
| 128 |
+
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
|
| 129 |
+
An error is raised on empty folders if False (default).
|
| 130 |
+
|
| 131 |
+
Attributes:
|
| 132 |
+
classes (list): List of the class names sorted alphabetically.
|
| 133 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
| 134 |
+
samples (list): List of (sample path, class_index) tuples
|
| 135 |
+
targets (list): The class_index value for each image in the dataset
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
root: Union[str, Path],
|
| 141 |
+
loader: Callable[[str], Any],
|
| 142 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
| 143 |
+
transform: Optional[Callable] = None,
|
| 144 |
+
target_transform: Optional[Callable] = None,
|
| 145 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
| 146 |
+
allow_empty: bool = False,
|
| 147 |
+
) -> None:
|
| 148 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 149 |
+
classes, class_to_idx = self.find_classes(self.root)
|
| 150 |
+
samples = self.make_dataset(
|
| 151 |
+
self.root,
|
| 152 |
+
class_to_idx=class_to_idx,
|
| 153 |
+
extensions=extensions,
|
| 154 |
+
is_valid_file=is_valid_file,
|
| 155 |
+
allow_empty=allow_empty,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.loader = loader
|
| 159 |
+
self.extensions = extensions
|
| 160 |
+
|
| 161 |
+
self.classes = classes
|
| 162 |
+
self.class_to_idx = class_to_idx
|
| 163 |
+
self.samples = samples
|
| 164 |
+
self.targets = [s[1] for s in samples]
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def make_dataset(
|
| 168 |
+
directory: Union[str, Path],
|
| 169 |
+
class_to_idx: Dict[str, int],
|
| 170 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
| 171 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
| 172 |
+
allow_empty: bool = False,
|
| 173 |
+
) -> List[Tuple[str, int]]:
|
| 174 |
+
"""Generates a list of samples of a form (path_to_sample, class).
|
| 175 |
+
|
| 176 |
+
This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
directory (str): root dataset directory, corresponding to ``self.root``.
|
| 180 |
+
class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
|
| 181 |
+
extensions (optional): A list of allowed extensions.
|
| 182 |
+
Either extensions or is_valid_file should be passed. Defaults to None.
|
| 183 |
+
is_valid_file (optional): A function that takes path of a file
|
| 184 |
+
and checks if the file is a valid file
|
| 185 |
+
(used to check of corrupt files) both extensions and
|
| 186 |
+
is_valid_file should not be passed. Defaults to None.
|
| 187 |
+
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
|
| 188 |
+
An error is raised on empty folders if False (default).
|
| 189 |
+
|
| 190 |
+
Raises:
|
| 191 |
+
ValueError: In case ``class_to_idx`` is empty.
|
| 192 |
+
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
|
| 193 |
+
FileNotFoundError: In case no valid file was found for any class.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
|
| 197 |
+
"""
|
| 198 |
+
if class_to_idx is None:
|
| 199 |
+
# prevent potential bug since make_dataset() would use the class_to_idx logic of the
|
| 200 |
+
# find_classes() function, instead of using that of the find_classes() method, which
|
| 201 |
+
# is potentially overridden and thus could have a different logic.
|
| 202 |
+
raise ValueError("The class_to_idx parameter cannot be None.")
|
| 203 |
+
return make_dataset(
|
| 204 |
+
directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def find_classes(self, directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
|
| 208 |
+
"""Find the class folders in a dataset structured as follows::
|
| 209 |
+
|
| 210 |
+
directory/
|
| 211 |
+
├── class_x
|
| 212 |
+
│ ├── xxx.ext
|
| 213 |
+
│ ├── xxy.ext
|
| 214 |
+
│ └── ...
|
| 215 |
+
│ └── xxz.ext
|
| 216 |
+
└── class_y
|
| 217 |
+
├── 123.ext
|
| 218 |
+
├── nsdf3.ext
|
| 219 |
+
└── ...
|
| 220 |
+
└── asd932_.ext
|
| 221 |
+
|
| 222 |
+
This method can be overridden to only consider
|
| 223 |
+
a subset of classes, or to adapt to a different dataset directory structure.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
directory(str): Root directory path, corresponding to ``self.root``
|
| 227 |
+
|
| 228 |
+
Raises:
|
| 229 |
+
FileNotFoundError: If ``dir`` has no class folders.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
|
| 233 |
+
"""
|
| 234 |
+
return find_classes(directory)
|
| 235 |
+
|
| 236 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 237 |
+
"""
|
| 238 |
+
Args:
|
| 239 |
+
index (int): Index
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
tuple: (sample, target) where target is class_index of the target class.
|
| 243 |
+
"""
|
| 244 |
+
path, target = self.samples[index]
|
| 245 |
+
sample = self.loader(path)
|
| 246 |
+
if self.transform is not None:
|
| 247 |
+
sample = self.transform(sample)
|
| 248 |
+
if self.target_transform is not None:
|
| 249 |
+
target = self.target_transform(target)
|
| 250 |
+
|
| 251 |
+
return sample, target
|
| 252 |
+
|
| 253 |
+
def __len__(self) -> int:
|
| 254 |
+
return len(self.samples)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def pil_loader(path: str) -> Image.Image:
|
| 261 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
| 262 |
+
with open(path, "rb") as f:
|
| 263 |
+
img = Image.open(f)
|
| 264 |
+
return img.convert("RGB")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# TODO: specify the return type
|
| 268 |
+
def accimage_loader(path: str) -> Any:
|
| 269 |
+
import accimage
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
return accimage.Image(path)
|
| 273 |
+
except OSError:
|
| 274 |
+
# Potentially a decoding problem, fall back to PIL.Image
|
| 275 |
+
return pil_loader(path)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def default_loader(path: str) -> Any:
|
| 279 |
+
from torchvision import get_image_backend
|
| 280 |
+
|
| 281 |
+
if get_image_backend() == "accimage":
|
| 282 |
+
return accimage_loader(path)
|
| 283 |
+
else:
|
| 284 |
+
return pil_loader(path)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class ImageFolder(DatasetFolder):
|
| 288 |
+
"""A generic data loader where the images are arranged in this way by default: ::
|
| 289 |
+
|
| 290 |
+
root/dog/xxx.png
|
| 291 |
+
root/dog/xxy.png
|
| 292 |
+
root/dog/[...]/xxz.png
|
| 293 |
+
|
| 294 |
+
root/cat/123.png
|
| 295 |
+
root/cat/nsdf3.png
|
| 296 |
+
root/cat/[...]/asd932_.png
|
| 297 |
+
|
| 298 |
+
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
|
| 299 |
+
the same methods can be overridden to customize the dataset.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
root (str or ``pathlib.Path``): Root directory path.
|
| 303 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 304 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 305 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 306 |
+
target and transforms it.
|
| 307 |
+
loader (callable, optional): A function to load an image given its path.
|
| 308 |
+
is_valid_file (callable, optional): A function that takes path of an Image file
|
| 309 |
+
and check if the file is a valid file (used to check of corrupt files)
|
| 310 |
+
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
|
| 311 |
+
An error is raised on empty folders if False (default).
|
| 312 |
+
|
| 313 |
+
Attributes:
|
| 314 |
+
classes (list): List of the class names sorted alphabetically.
|
| 315 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
| 316 |
+
imgs (list): List of (image path, class_index) tuples
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
def __init__(
|
| 320 |
+
self,
|
| 321 |
+
root: Union[str, Path],
|
| 322 |
+
transform: Optional[Callable] = None,
|
| 323 |
+
target_transform: Optional[Callable] = None,
|
| 324 |
+
loader: Callable[[str], Any] = default_loader,
|
| 325 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
| 326 |
+
allow_empty: bool = False,
|
| 327 |
+
):
|
| 328 |
+
super().__init__(
|
| 329 |
+
root,
|
| 330 |
+
loader,
|
| 331 |
+
IMG_EXTENSIONS if is_valid_file is None else None,
|
| 332 |
+
transform=transform,
|
| 333 |
+
target_transform=target_transform,
|
| 334 |
+
is_valid_file=is_valid_file,
|
| 335 |
+
allow_empty=allow_empty,
|
| 336 |
+
)
|
| 337 |
+
self.imgs = self.samples
|
.venv/lib/python3.11/site-packages/torchvision/datasets/gtsrb.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import pathlib
|
| 3 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import PIL
|
| 6 |
+
|
| 7 |
+
from .folder import make_dataset
|
| 8 |
+
from .utils import download_and_extract_archive, verify_str_arg
|
| 9 |
+
from .vision import VisionDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GTSRB(VisionDataset):
|
| 13 |
+
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
root (str or ``pathlib.Path``): Root directory of the dataset.
|
| 17 |
+
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
|
| 18 |
+
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
|
| 19 |
+
version. E.g, ``transforms.RandomCrop``.
|
| 20 |
+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
|
| 21 |
+
download (bool, optional): If True, downloads the dataset from the internet and
|
| 22 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 23 |
+
downloaded again.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
root: Union[str, pathlib.Path],
|
| 29 |
+
split: str = "train",
|
| 30 |
+
transform: Optional[Callable] = None,
|
| 31 |
+
target_transform: Optional[Callable] = None,
|
| 32 |
+
download: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
|
| 35 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 36 |
+
|
| 37 |
+
self._split = verify_str_arg(split, "split", ("train", "test"))
|
| 38 |
+
self._base_folder = pathlib.Path(root) / "gtsrb"
|
| 39 |
+
self._target_folder = (
|
| 40 |
+
self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
if download:
|
| 44 |
+
self.download()
|
| 45 |
+
|
| 46 |
+
if not self._check_exists():
|
| 47 |
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
| 48 |
+
|
| 49 |
+
if self._split == "train":
|
| 50 |
+
samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
|
| 51 |
+
else:
|
| 52 |
+
with open(self._base_folder / "GT-final_test.csv") as csv_file:
|
| 53 |
+
samples = [
|
| 54 |
+
(str(self._target_folder / row["Filename"]), int(row["ClassId"]))
|
| 55 |
+
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
self._samples = samples
|
| 59 |
+
self.transform = transform
|
| 60 |
+
self.target_transform = target_transform
|
| 61 |
+
|
| 62 |
+
def __len__(self) -> int:
|
| 63 |
+
return len(self._samples)
|
| 64 |
+
|
| 65 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 66 |
+
|
| 67 |
+
path, target = self._samples[index]
|
| 68 |
+
sample = PIL.Image.open(path).convert("RGB")
|
| 69 |
+
|
| 70 |
+
if self.transform is not None:
|
| 71 |
+
sample = self.transform(sample)
|
| 72 |
+
|
| 73 |
+
if self.target_transform is not None:
|
| 74 |
+
target = self.target_transform(target)
|
| 75 |
+
|
| 76 |
+
return sample, target
|
| 77 |
+
|
| 78 |
+
def _check_exists(self) -> bool:
|
| 79 |
+
return self._target_folder.is_dir()
|
| 80 |
+
|
| 81 |
+
def download(self) -> None:
|
| 82 |
+
if self._check_exists():
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
|
| 86 |
+
|
| 87 |
+
if self._split == "train":
|
| 88 |
+
download_and_extract_archive(
|
| 89 |
+
f"{base_url}GTSRB-Training_fixed.zip",
|
| 90 |
+
download_root=str(self._base_folder),
|
| 91 |
+
md5="513f3c79a4c5141765e10e952eaa2478",
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
download_and_extract_archive(
|
| 95 |
+
f"{base_url}GTSRB_Final_Test_Images.zip",
|
| 96 |
+
download_root=str(self._base_folder),
|
| 97 |
+
md5="c7e4e6327067d32654124b0fe9e82185",
|
| 98 |
+
)
|
| 99 |
+
download_and_extract_archive(
|
| 100 |
+
f"{base_url}GTSRB_Final_Test_GT.zip",
|
| 101 |
+
download_root=str(self._base_folder),
|
| 102 |
+
md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
|
| 103 |
+
)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/inaturalist.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from .utils import download_and_extract_archive, verify_str_arg
|
| 9 |
+
from .vision import VisionDataset
|
| 10 |
+
|
| 11 |
+
CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
|
| 12 |
+
|
| 13 |
+
DATASET_URLS = {
|
| 14 |
+
"2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
|
| 15 |
+
"2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
|
| 16 |
+
"2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
|
| 17 |
+
"2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
|
| 18 |
+
"2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
|
| 19 |
+
"2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
DATASET_MD5 = {
|
| 23 |
+
"2017": "7c784ea5e424efaec655bd392f87301f",
|
| 24 |
+
"2018": "b1c6952ce38f31868cc50ea72d066cc3",
|
| 25 |
+
"2019": "c60a6e2962c9b8ccbd458d12c8582644",
|
| 26 |
+
"2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed",
|
| 27 |
+
"2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
|
| 28 |
+
"2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class INaturalist(VisionDataset):
|
| 33 |
+
"""`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
root (str or ``pathlib.Path``): Root directory of dataset where the image files are stored.
|
| 37 |
+
This class does not require/use annotation files.
|
| 38 |
+
version (string, optional): Which version of the dataset to download/use. One of
|
| 39 |
+
'2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
|
| 40 |
+
Default: `2021_train`.
|
| 41 |
+
target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
|
| 42 |
+
|
| 43 |
+
- ``full``: the full category (species)
|
| 44 |
+
- ``kingdom``: e.g. "Animalia"
|
| 45 |
+
- ``phylum``: e.g. "Arthropoda"
|
| 46 |
+
- ``class``: e.g. "Insecta"
|
| 47 |
+
- ``order``: e.g. "Coleoptera"
|
| 48 |
+
- ``family``: e.g. "Cleridae"
|
| 49 |
+
- ``genus``: e.g. "Trichodes"
|
| 50 |
+
|
| 51 |
+
for 2017-2019 versions, one of:
|
| 52 |
+
|
| 53 |
+
- ``full``: the full (numeric) category
|
| 54 |
+
- ``super``: the super category, e.g. "Amphibians"
|
| 55 |
+
|
| 56 |
+
Can also be a list to output a tuple with all specified target types.
|
| 57 |
+
Defaults to ``full``.
|
| 58 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 59 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 60 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 61 |
+
target and transforms it.
|
| 62 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
| 63 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 64 |
+
downloaded again.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
root: Union[str, Path],
|
| 70 |
+
version: str = "2021_train",
|
| 71 |
+
target_type: Union[List[str], str] = "full",
|
| 72 |
+
transform: Optional[Callable] = None,
|
| 73 |
+
target_transform: Optional[Callable] = None,
|
| 74 |
+
download: bool = False,
|
| 75 |
+
) -> None:
|
| 76 |
+
self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
|
| 77 |
+
|
| 78 |
+
super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
|
| 79 |
+
|
| 80 |
+
os.makedirs(root, exist_ok=True)
|
| 81 |
+
if download:
|
| 82 |
+
self.download()
|
| 83 |
+
|
| 84 |
+
if not self._check_integrity():
|
| 85 |
+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
|
| 86 |
+
|
| 87 |
+
self.all_categories: List[str] = []
|
| 88 |
+
|
| 89 |
+
# map: category type -> name of category -> index
|
| 90 |
+
self.categories_index: Dict[str, Dict[str, int]] = {}
|
| 91 |
+
|
| 92 |
+
# list indexed by category id, containing mapping from category type -> index
|
| 93 |
+
self.categories_map: List[Dict[str, int]] = []
|
| 94 |
+
|
| 95 |
+
if not isinstance(target_type, list):
|
| 96 |
+
target_type = [target_type]
|
| 97 |
+
if self.version[:4] == "2021":
|
| 98 |
+
self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
|
| 99 |
+
self._init_2021()
|
| 100 |
+
else:
|
| 101 |
+
self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
|
| 102 |
+
self._init_pre2021()
|
| 103 |
+
|
| 104 |
+
# index of all files: (full category id, filename)
|
| 105 |
+
self.index: List[Tuple[int, str]] = []
|
| 106 |
+
|
| 107 |
+
for dir_index, dir_name in enumerate(self.all_categories):
|
| 108 |
+
files = os.listdir(os.path.join(self.root, dir_name))
|
| 109 |
+
for fname in files:
|
| 110 |
+
self.index.append((dir_index, fname))
|
| 111 |
+
|
| 112 |
+
def _init_2021(self) -> None:
|
| 113 |
+
"""Initialize based on 2021 layout"""
|
| 114 |
+
|
| 115 |
+
self.all_categories = sorted(os.listdir(self.root))
|
| 116 |
+
|
| 117 |
+
# map: category type -> name of category -> index
|
| 118 |
+
self.categories_index = {k: {} for k in CATEGORIES_2021}
|
| 119 |
+
|
| 120 |
+
for dir_index, dir_name in enumerate(self.all_categories):
|
| 121 |
+
pieces = dir_name.split("_")
|
| 122 |
+
if len(pieces) != 8:
|
| 123 |
+
raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
|
| 124 |
+
if pieces[0] != f"{dir_index:05d}":
|
| 125 |
+
raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
|
| 126 |
+
cat_map = {}
|
| 127 |
+
for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
|
| 128 |
+
if name in self.categories_index[cat]:
|
| 129 |
+
cat_id = self.categories_index[cat][name]
|
| 130 |
+
else:
|
| 131 |
+
cat_id = len(self.categories_index[cat])
|
| 132 |
+
self.categories_index[cat][name] = cat_id
|
| 133 |
+
cat_map[cat] = cat_id
|
| 134 |
+
self.categories_map.append(cat_map)
|
| 135 |
+
|
| 136 |
+
def _init_pre2021(self) -> None:
|
| 137 |
+
"""Initialize based on 2017-2019 layout"""
|
| 138 |
+
|
| 139 |
+
# map: category type -> name of category -> index
|
| 140 |
+
self.categories_index = {"super": {}}
|
| 141 |
+
|
| 142 |
+
cat_index = 0
|
| 143 |
+
super_categories = sorted(os.listdir(self.root))
|
| 144 |
+
for sindex, scat in enumerate(super_categories):
|
| 145 |
+
self.categories_index["super"][scat] = sindex
|
| 146 |
+
subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
|
| 147 |
+
for subcat in subcategories:
|
| 148 |
+
if self.version == "2017":
|
| 149 |
+
# this version does not use ids as directory names
|
| 150 |
+
subcat_i = cat_index
|
| 151 |
+
cat_index += 1
|
| 152 |
+
else:
|
| 153 |
+
try:
|
| 154 |
+
subcat_i = int(subcat)
|
| 155 |
+
except ValueError:
|
| 156 |
+
raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
|
| 157 |
+
if subcat_i >= len(self.categories_map):
|
| 158 |
+
old_len = len(self.categories_map)
|
| 159 |
+
self.categories_map.extend([{}] * (subcat_i - old_len + 1))
|
| 160 |
+
self.all_categories.extend([""] * (subcat_i - old_len + 1))
|
| 161 |
+
if self.categories_map[subcat_i]:
|
| 162 |
+
raise RuntimeError(f"Duplicate category {subcat}")
|
| 163 |
+
self.categories_map[subcat_i] = {"super": sindex}
|
| 164 |
+
self.all_categories[subcat_i] = os.path.join(scat, subcat)
|
| 165 |
+
|
| 166 |
+
# validate the dictionary
|
| 167 |
+
for cindex, c in enumerate(self.categories_map):
|
| 168 |
+
if not c:
|
| 169 |
+
raise RuntimeError(f"Missing category {cindex}")
|
| 170 |
+
|
| 171 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 172 |
+
"""
|
| 173 |
+
Args:
|
| 174 |
+
index (int): Index
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
tuple: (image, target) where the type of target specified by target_type.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
cat_id, fname = self.index[index]
|
| 181 |
+
img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
|
| 182 |
+
|
| 183 |
+
target: Any = []
|
| 184 |
+
for t in self.target_type:
|
| 185 |
+
if t == "full":
|
| 186 |
+
target.append(cat_id)
|
| 187 |
+
else:
|
| 188 |
+
target.append(self.categories_map[cat_id][t])
|
| 189 |
+
target = tuple(target) if len(target) > 1 else target[0]
|
| 190 |
+
|
| 191 |
+
if self.transform is not None:
|
| 192 |
+
img = self.transform(img)
|
| 193 |
+
|
| 194 |
+
if self.target_transform is not None:
|
| 195 |
+
target = self.target_transform(target)
|
| 196 |
+
|
| 197 |
+
return img, target
|
| 198 |
+
|
| 199 |
+
def __len__(self) -> int:
|
| 200 |
+
return len(self.index)
|
| 201 |
+
|
| 202 |
+
def category_name(self, category_type: str, category_id: int) -> str:
|
| 203 |
+
"""
|
| 204 |
+
Args:
|
| 205 |
+
category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
|
| 206 |
+
category_id(int): an index (class id) from this category
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
the name of the category
|
| 210 |
+
"""
|
| 211 |
+
if category_type == "full":
|
| 212 |
+
return self.all_categories[category_id]
|
| 213 |
+
else:
|
| 214 |
+
if category_type not in self.categories_index:
|
| 215 |
+
raise ValueError(f"Invalid category type '{category_type}'")
|
| 216 |
+
else:
|
| 217 |
+
for name, id in self.categories_index[category_type].items():
|
| 218 |
+
if id == category_id:
|
| 219 |
+
return name
|
| 220 |
+
raise ValueError(f"Invalid category id {category_id} for {category_type}")
|
| 221 |
+
|
| 222 |
+
def _check_integrity(self) -> bool:
|
| 223 |
+
return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
|
| 224 |
+
|
| 225 |
+
def download(self) -> None:
|
| 226 |
+
if self._check_integrity():
|
| 227 |
+
raise RuntimeError(
|
| 228 |
+
f"The directory {self.root} already exists. "
|
| 229 |
+
f"If you want to re-download or re-extract the images, delete the directory."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
base_root = os.path.dirname(self.root)
|
| 233 |
+
|
| 234 |
+
download_and_extract_archive(
|
| 235 |
+
DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
|
| 239 |
+
if not os.path.exists(orig_dir_name):
|
| 240 |
+
raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
|
| 241 |
+
os.rename(orig_dir_name, self.root)
|
| 242 |
+
print(f"Dataset version '{self.version}' has been downloaded and prepared for use")
|
.venv/lib/python3.11/site-packages/torchvision/datasets/lsun.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os.path
|
| 3 |
+
import pickle
|
| 4 |
+
import string
|
| 5 |
+
from collections.abc import Iterable
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Callable, cast, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from .utils import iterable_to_str, verify_str_arg
|
| 12 |
+
from .vision import VisionDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LSUNClass(VisionDataset):
|
| 16 |
+
def __init__(
|
| 17 |
+
self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
|
| 18 |
+
) -> None:
|
| 19 |
+
import lmdb
|
| 20 |
+
|
| 21 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 22 |
+
|
| 23 |
+
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
|
| 24 |
+
with self.env.begin(write=False) as txn:
|
| 25 |
+
self.length = txn.stat()["entries"]
|
| 26 |
+
cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters)
|
| 27 |
+
if os.path.isfile(cache_file):
|
| 28 |
+
self.keys = pickle.load(open(cache_file, "rb"))
|
| 29 |
+
else:
|
| 30 |
+
with self.env.begin(write=False) as txn:
|
| 31 |
+
self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)]
|
| 32 |
+
pickle.dump(self.keys, open(cache_file, "wb"))
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 35 |
+
img, target = None, None
|
| 36 |
+
env = self.env
|
| 37 |
+
with env.begin(write=False) as txn:
|
| 38 |
+
imgbuf = txn.get(self.keys[index])
|
| 39 |
+
|
| 40 |
+
buf = io.BytesIO()
|
| 41 |
+
buf.write(imgbuf)
|
| 42 |
+
buf.seek(0)
|
| 43 |
+
img = Image.open(buf).convert("RGB")
|
| 44 |
+
|
| 45 |
+
if self.transform is not None:
|
| 46 |
+
img = self.transform(img)
|
| 47 |
+
|
| 48 |
+
if self.target_transform is not None:
|
| 49 |
+
target = self.target_transform(target)
|
| 50 |
+
|
| 51 |
+
return img, target
|
| 52 |
+
|
| 53 |
+
def __len__(self) -> int:
|
| 54 |
+
return self.length
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class LSUN(VisionDataset):
|
| 58 |
+
"""`LSUN <https://www.yf.io/p/lsun>`_ dataset.
|
| 59 |
+
|
| 60 |
+
You will need to install the ``lmdb`` package to use this dataset: run
|
| 61 |
+
``pip install lmdb``
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
root (str or ``pathlib.Path``): Root directory for the database files.
|
| 65 |
+
classes (string or list): One of {'train', 'val', 'test'} or a list of
|
| 66 |
+
categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
|
| 67 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 68 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 69 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 70 |
+
target and transforms it.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
root: Union[str, Path],
|
| 76 |
+
classes: Union[str, List[str]] = "train",
|
| 77 |
+
transform: Optional[Callable] = None,
|
| 78 |
+
target_transform: Optional[Callable] = None,
|
| 79 |
+
) -> None:
|
| 80 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 81 |
+
self.classes = self._verify_classes(classes)
|
| 82 |
+
|
| 83 |
+
# for each class, create an LSUNClassDataset
|
| 84 |
+
self.dbs = []
|
| 85 |
+
for c in self.classes:
|
| 86 |
+
self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform))
|
| 87 |
+
|
| 88 |
+
self.indices = []
|
| 89 |
+
count = 0
|
| 90 |
+
for db in self.dbs:
|
| 91 |
+
count += len(db)
|
| 92 |
+
self.indices.append(count)
|
| 93 |
+
|
| 94 |
+
self.length = count
|
| 95 |
+
|
| 96 |
+
def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]:
|
| 97 |
+
categories = [
|
| 98 |
+
"bedroom",
|
| 99 |
+
"bridge",
|
| 100 |
+
"church_outdoor",
|
| 101 |
+
"classroom",
|
| 102 |
+
"conference_room",
|
| 103 |
+
"dining_room",
|
| 104 |
+
"kitchen",
|
| 105 |
+
"living_room",
|
| 106 |
+
"restaurant",
|
| 107 |
+
"tower",
|
| 108 |
+
]
|
| 109 |
+
dset_opts = ["train", "val", "test"]
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
classes = cast(str, classes)
|
| 113 |
+
verify_str_arg(classes, "classes", dset_opts)
|
| 114 |
+
if classes == "test":
|
| 115 |
+
classes = [classes]
|
| 116 |
+
else:
|
| 117 |
+
classes = [c + "_" + classes for c in categories]
|
| 118 |
+
except ValueError:
|
| 119 |
+
if not isinstance(classes, Iterable):
|
| 120 |
+
msg = "Expected type str or Iterable for argument classes, but got type {}."
|
| 121 |
+
raise ValueError(msg.format(type(classes)))
|
| 122 |
+
|
| 123 |
+
classes = list(classes)
|
| 124 |
+
msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
|
| 125 |
+
for c in classes:
|
| 126 |
+
verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
|
| 127 |
+
c_short = c.split("_")
|
| 128 |
+
category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
|
| 129 |
+
|
| 130 |
+
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
|
| 131 |
+
msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
|
| 132 |
+
verify_str_arg(category, valid_values=categories, custom_msg=msg)
|
| 133 |
+
|
| 134 |
+
msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
|
| 135 |
+
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
|
| 136 |
+
|
| 137 |
+
return classes
|
| 138 |
+
|
| 139 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 140 |
+
"""
|
| 141 |
+
Args:
|
| 142 |
+
index (int): Index
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
tuple: Tuple (image, target) where target is the index of the target category.
|
| 146 |
+
"""
|
| 147 |
+
target = 0
|
| 148 |
+
sub = 0
|
| 149 |
+
for ind in self.indices:
|
| 150 |
+
if index < ind:
|
| 151 |
+
break
|
| 152 |
+
target += 1
|
| 153 |
+
sub = ind
|
| 154 |
+
|
| 155 |
+
db = self.dbs[target]
|
| 156 |
+
index = index - sub
|
| 157 |
+
|
| 158 |
+
if self.target_transform is not None:
|
| 159 |
+
target = self.target_transform(target)
|
| 160 |
+
|
| 161 |
+
img, _ = db[index]
|
| 162 |
+
return img, target
|
| 163 |
+
|
| 164 |
+
def __len__(self) -> int:
|
| 165 |
+
return self.length
|
| 166 |
+
|
| 167 |
+
def extra_repr(self) -> str:
|
| 168 |
+
return "Classes: {classes}".format(**self.__dict__)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/sbu.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
from .utils import check_integrity, download_and_extract_archive, download_url
|
| 8 |
+
from .vision import VisionDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SBU(VisionDataset):
|
| 12 |
+
"""`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
root (str or ``pathlib.Path``): Root directory of dataset where tarball
|
| 16 |
+
``SBUCaptionedPhotoDataset.tar.gz`` exists.
|
| 17 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 18 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 19 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 20 |
+
target and transforms it.
|
| 21 |
+
download (bool, optional): If True, downloads the dataset from the internet and
|
| 22 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 23 |
+
downloaded again.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
|
| 27 |
+
filename = "SBUCaptionedPhotoDataset.tar.gz"
|
| 28 |
+
md5_checksum = "9aec147b3488753cf758b4d493422285"
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
root: Union[str, Path],
|
| 33 |
+
transform: Optional[Callable] = None,
|
| 34 |
+
target_transform: Optional[Callable] = None,
|
| 35 |
+
download: bool = True,
|
| 36 |
+
) -> None:
|
| 37 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 38 |
+
|
| 39 |
+
if download:
|
| 40 |
+
self.download()
|
| 41 |
+
|
| 42 |
+
if not self._check_integrity():
|
| 43 |
+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
|
| 44 |
+
|
| 45 |
+
# Read the caption for each photo
|
| 46 |
+
self.photos = []
|
| 47 |
+
self.captions = []
|
| 48 |
+
|
| 49 |
+
file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
|
| 50 |
+
file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
|
| 51 |
+
|
| 52 |
+
for line1, line2 in zip(open(file1), open(file2)):
|
| 53 |
+
url = line1.rstrip()
|
| 54 |
+
photo = os.path.basename(url)
|
| 55 |
+
filename = os.path.join(self.root, "dataset", photo)
|
| 56 |
+
if os.path.exists(filename):
|
| 57 |
+
caption = line2.rstrip()
|
| 58 |
+
self.photos.append(photo)
|
| 59 |
+
self.captions.append(caption)
|
| 60 |
+
|
| 61 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 62 |
+
"""
|
| 63 |
+
Args:
|
| 64 |
+
index (int): Index
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
tuple: (image, target) where target is a caption for the photo.
|
| 68 |
+
"""
|
| 69 |
+
filename = os.path.join(self.root, "dataset", self.photos[index])
|
| 70 |
+
img = Image.open(filename).convert("RGB")
|
| 71 |
+
if self.transform is not None:
|
| 72 |
+
img = self.transform(img)
|
| 73 |
+
|
| 74 |
+
target = self.captions[index]
|
| 75 |
+
if self.target_transform is not None:
|
| 76 |
+
target = self.target_transform(target)
|
| 77 |
+
|
| 78 |
+
return img, target
|
| 79 |
+
|
| 80 |
+
def __len__(self) -> int:
|
| 81 |
+
"""The number of photos in the dataset."""
|
| 82 |
+
return len(self.photos)
|
| 83 |
+
|
| 84 |
+
def _check_integrity(self) -> bool:
|
| 85 |
+
"""Check the md5 checksum of the downloaded tarball."""
|
| 86 |
+
root = self.root
|
| 87 |
+
fpath = os.path.join(root, self.filename)
|
| 88 |
+
if not check_integrity(fpath, self.md5_checksum):
|
| 89 |
+
return False
|
| 90 |
+
return True
|
| 91 |
+
|
| 92 |
+
def download(self) -> None:
|
| 93 |
+
"""Download and extract the tarball, and download each individual photo."""
|
| 94 |
+
|
| 95 |
+
if self._check_integrity():
|
| 96 |
+
print("Files already downloaded and verified")
|
| 97 |
+
return
|
| 98 |
+
|
| 99 |
+
download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
|
| 100 |
+
|
| 101 |
+
# Download individual photos
|
| 102 |
+
with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
|
| 103 |
+
for line in fh:
|
| 104 |
+
url = line.rstrip()
|
| 105 |
+
try:
|
| 106 |
+
download_url(url, os.path.join(self.root, "dataset"))
|
| 107 |
+
except OSError:
|
| 108 |
+
# The images point to public images on Flickr.
|
| 109 |
+
# Note: Images might be removed by users at anytime.
|
| 110 |
+
pass
|
.venv/lib/python3.11/site-packages/torchvision/datasets/svhn.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from .utils import check_integrity, download_url, verify_str_arg
|
| 9 |
+
from .vision import VisionDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SVHN(VisionDataset):
|
| 13 |
+
"""`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
|
| 14 |
+
Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
|
| 15 |
+
we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
|
| 16 |
+
expect the class labels to be in the range `[0, C-1]`
|
| 17 |
+
|
| 18 |
+
.. warning::
|
| 19 |
+
|
| 20 |
+
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load data from `.mat` format.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
root (str or ``pathlib.Path``): Root directory of the dataset where the data is stored.
|
| 24 |
+
split (string): One of {'train', 'test', 'extra'}.
|
| 25 |
+
Accordingly dataset is selected. 'extra' is Extra training set.
|
| 26 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 27 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 28 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 29 |
+
target and transforms it.
|
| 30 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
| 31 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 32 |
+
downloaded again.
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
split_list = {
|
| 37 |
+
"train": [
|
| 38 |
+
"http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
|
| 39 |
+
"train_32x32.mat",
|
| 40 |
+
"e26dedcc434d2e4c54c9b2d4a06d8373",
|
| 41 |
+
],
|
| 42 |
+
"test": [
|
| 43 |
+
"http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
|
| 44 |
+
"test_32x32.mat",
|
| 45 |
+
"eb5a983be6a315427106f1b164d9cef3",
|
| 46 |
+
],
|
| 47 |
+
"extra": [
|
| 48 |
+
"http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
|
| 49 |
+
"extra_32x32.mat",
|
| 50 |
+
"a93ce644f1a588dc4d68dda5feec44a7",
|
| 51 |
+
],
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
root: Union[str, Path],
|
| 57 |
+
split: str = "train",
|
| 58 |
+
transform: Optional[Callable] = None,
|
| 59 |
+
target_transform: Optional[Callable] = None,
|
| 60 |
+
download: bool = False,
|
| 61 |
+
) -> None:
|
| 62 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 63 |
+
self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
|
| 64 |
+
self.url = self.split_list[split][0]
|
| 65 |
+
self.filename = self.split_list[split][1]
|
| 66 |
+
self.file_md5 = self.split_list[split][2]
|
| 67 |
+
|
| 68 |
+
if download:
|
| 69 |
+
self.download()
|
| 70 |
+
|
| 71 |
+
if not self._check_integrity():
|
| 72 |
+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
|
| 73 |
+
|
| 74 |
+
# import here rather than at top of file because this is
|
| 75 |
+
# an optional dependency for torchvision
|
| 76 |
+
import scipy.io as sio
|
| 77 |
+
|
| 78 |
+
# reading(loading) mat file as array
|
| 79 |
+
loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
|
| 80 |
+
|
| 81 |
+
self.data = loaded_mat["X"]
|
| 82 |
+
# loading from the .mat file gives an np.ndarray of type np.uint8
|
| 83 |
+
# converting to np.int64, so that we have a LongTensor after
|
| 84 |
+
# the conversion from the numpy array
|
| 85 |
+
# the squeeze is needed to obtain a 1D tensor
|
| 86 |
+
self.labels = loaded_mat["y"].astype(np.int64).squeeze()
|
| 87 |
+
|
| 88 |
+
# the svhn dataset assigns the class label "10" to the digit 0
|
| 89 |
+
# this makes it inconsistent with several loss functions
|
| 90 |
+
# which expect the class labels to be in the range [0, C-1]
|
| 91 |
+
np.place(self.labels, self.labels == 10, 0)
|
| 92 |
+
self.data = np.transpose(self.data, (3, 2, 0, 1))
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 95 |
+
"""
|
| 96 |
+
Args:
|
| 97 |
+
index (int): Index
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
tuple: (image, target) where target is index of the target class.
|
| 101 |
+
"""
|
| 102 |
+
img, target = self.data[index], int(self.labels[index])
|
| 103 |
+
|
| 104 |
+
# doing this so that it is consistent with all other datasets
|
| 105 |
+
# to return a PIL Image
|
| 106 |
+
img = Image.fromarray(np.transpose(img, (1, 2, 0)))
|
| 107 |
+
|
| 108 |
+
if self.transform is not None:
|
| 109 |
+
img = self.transform(img)
|
| 110 |
+
|
| 111 |
+
if self.target_transform is not None:
|
| 112 |
+
target = self.target_transform(target)
|
| 113 |
+
|
| 114 |
+
return img, target
|
| 115 |
+
|
| 116 |
+
def __len__(self) -> int:
|
| 117 |
+
return len(self.data)
|
| 118 |
+
|
| 119 |
+
def _check_integrity(self) -> bool:
|
| 120 |
+
root = self.root
|
| 121 |
+
md5 = self.split_list[self.split][2]
|
| 122 |
+
fpath = os.path.join(root, self.filename)
|
| 123 |
+
return check_integrity(fpath, md5)
|
| 124 |
+
|
| 125 |
+
def download(self) -> None:
|
| 126 |
+
md5 = self.split_list[self.split][2]
|
| 127 |
+
download_url(self.url, self.root, self.filename, md5)
|
| 128 |
+
|
| 129 |
+
def extra_repr(self) -> str:
|
| 130 |
+
return "Split: {split}".format(**self.__dict__)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/widerface.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from os.path import abspath, expanduser
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg
|
| 11 |
+
from .vision import VisionDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class WIDERFace(VisionDataset):
|
| 15 |
+
"""`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
root (str or ``pathlib.Path``): Root directory where images and annotations are downloaded to.
|
| 19 |
+
Expects the following folder structure if download=False:
|
| 20 |
+
|
| 21 |
+
.. code::
|
| 22 |
+
|
| 23 |
+
<root>
|
| 24 |
+
└── widerface
|
| 25 |
+
├── wider_face_split ('wider_face_split.zip' if compressed)
|
| 26 |
+
├── WIDER_train ('WIDER_train.zip' if compressed)
|
| 27 |
+
├── WIDER_val ('WIDER_val.zip' if compressed)
|
| 28 |
+
└── WIDER_test ('WIDER_test.zip' if compressed)
|
| 29 |
+
split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
|
| 30 |
+
Defaults to ``train``.
|
| 31 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 32 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 33 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 34 |
+
target and transforms it.
|
| 35 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
| 36 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 37 |
+
downloaded again.
|
| 38 |
+
|
| 39 |
+
.. warning::
|
| 40 |
+
|
| 41 |
+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
BASE_FOLDER = "widerface"
|
| 46 |
+
FILE_LIST = [
|
| 47 |
+
# File ID MD5 Hash Filename
|
| 48 |
+
("15hGDLhsx8bLgLcIRD5DhYt5iBxnjNF1M", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
|
| 49 |
+
("1GUCogbp16PMGa39thoMMeWxp7Rp5oM8Q", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
|
| 50 |
+
("1HIfDbVEWKmsYKJZm4lchTBDLW5N7dY5T", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
|
| 51 |
+
]
|
| 52 |
+
ANNOTATIONS_FILE = (
|
| 53 |
+
"http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip",
|
| 54 |
+
"0e3767bcf0e326556d407bf5bff5d27c",
|
| 55 |
+
"wider_face_split.zip",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
root: Union[str, Path],
|
| 61 |
+
split: str = "train",
|
| 62 |
+
transform: Optional[Callable] = None,
|
| 63 |
+
target_transform: Optional[Callable] = None,
|
| 64 |
+
download: bool = False,
|
| 65 |
+
) -> None:
|
| 66 |
+
super().__init__(
|
| 67 |
+
root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
|
| 68 |
+
)
|
| 69 |
+
# check arguments
|
| 70 |
+
self.split = verify_str_arg(split, "split", ("train", "val", "test"))
|
| 71 |
+
|
| 72 |
+
if download:
|
| 73 |
+
self.download()
|
| 74 |
+
|
| 75 |
+
if not self._check_integrity():
|
| 76 |
+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")
|
| 77 |
+
|
| 78 |
+
self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
|
| 79 |
+
if self.split in ("train", "val"):
|
| 80 |
+
self.parse_train_val_annotations_file()
|
| 81 |
+
else:
|
| 82 |
+
self.parse_test_annotations_file()
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
index (int): Index
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
tuple: (image, target) where target is a dict of annotations for all faces in the image.
|
| 91 |
+
target=None for the test split.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
# stay consistent with other datasets and return a PIL Image
|
| 95 |
+
img = Image.open(self.img_info[index]["img_path"]) # type: ignore[arg-type]
|
| 96 |
+
|
| 97 |
+
if self.transform is not None:
|
| 98 |
+
img = self.transform(img)
|
| 99 |
+
|
| 100 |
+
target = None if self.split == "test" else self.img_info[index]["annotations"]
|
| 101 |
+
if self.target_transform is not None:
|
| 102 |
+
target = self.target_transform(target)
|
| 103 |
+
|
| 104 |
+
return img, target
|
| 105 |
+
|
| 106 |
+
def __len__(self) -> int:
|
| 107 |
+
return len(self.img_info)
|
| 108 |
+
|
| 109 |
+
def extra_repr(self) -> str:
|
| 110 |
+
lines = ["Split: {split}"]
|
| 111 |
+
return "\n".join(lines).format(**self.__dict__)
|
| 112 |
+
|
| 113 |
+
def parse_train_val_annotations_file(self) -> None:
|
| 114 |
+
filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
|
| 115 |
+
filepath = os.path.join(self.root, "wider_face_split", filename)
|
| 116 |
+
|
| 117 |
+
with open(filepath) as f:
|
| 118 |
+
lines = f.readlines()
|
| 119 |
+
file_name_line, num_boxes_line, box_annotation_line = True, False, False
|
| 120 |
+
num_boxes, box_counter = 0, 0
|
| 121 |
+
labels = []
|
| 122 |
+
for line in lines:
|
| 123 |
+
line = line.rstrip()
|
| 124 |
+
if file_name_line:
|
| 125 |
+
img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
|
| 126 |
+
img_path = abspath(expanduser(img_path))
|
| 127 |
+
file_name_line = False
|
| 128 |
+
num_boxes_line = True
|
| 129 |
+
elif num_boxes_line:
|
| 130 |
+
num_boxes = int(line)
|
| 131 |
+
num_boxes_line = False
|
| 132 |
+
box_annotation_line = True
|
| 133 |
+
elif box_annotation_line:
|
| 134 |
+
box_counter += 1
|
| 135 |
+
line_split = line.split(" ")
|
| 136 |
+
line_values = [int(x) for x in line_split]
|
| 137 |
+
labels.append(line_values)
|
| 138 |
+
if box_counter >= num_boxes:
|
| 139 |
+
box_annotation_line = False
|
| 140 |
+
file_name_line = True
|
| 141 |
+
labels_tensor = torch.tensor(labels)
|
| 142 |
+
self.img_info.append(
|
| 143 |
+
{
|
| 144 |
+
"img_path": img_path,
|
| 145 |
+
"annotations": {
|
| 146 |
+
"bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height
|
| 147 |
+
"blur": labels_tensor[:, 4].clone(),
|
| 148 |
+
"expression": labels_tensor[:, 5].clone(),
|
| 149 |
+
"illumination": labels_tensor[:, 6].clone(),
|
| 150 |
+
"occlusion": labels_tensor[:, 7].clone(),
|
| 151 |
+
"pose": labels_tensor[:, 8].clone(),
|
| 152 |
+
"invalid": labels_tensor[:, 9].clone(),
|
| 153 |
+
},
|
| 154 |
+
}
|
| 155 |
+
)
|
| 156 |
+
box_counter = 0
|
| 157 |
+
labels.clear()
|
| 158 |
+
else:
|
| 159 |
+
raise RuntimeError(f"Error parsing annotation file {filepath}")
|
| 160 |
+
|
| 161 |
+
def parse_test_annotations_file(self) -> None:
|
| 162 |
+
filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
|
| 163 |
+
filepath = abspath(expanduser(filepath))
|
| 164 |
+
with open(filepath) as f:
|
| 165 |
+
lines = f.readlines()
|
| 166 |
+
for line in lines:
|
| 167 |
+
line = line.rstrip()
|
| 168 |
+
img_path = os.path.join(self.root, "WIDER_test", "images", line)
|
| 169 |
+
img_path = abspath(expanduser(img_path))
|
| 170 |
+
self.img_info.append({"img_path": img_path})
|
| 171 |
+
|
| 172 |
+
def _check_integrity(self) -> bool:
|
| 173 |
+
# Allow original archive to be deleted (zip). Only need the extracted images
|
| 174 |
+
all_files = self.FILE_LIST.copy()
|
| 175 |
+
all_files.append(self.ANNOTATIONS_FILE)
|
| 176 |
+
for (_, md5, filename) in all_files:
|
| 177 |
+
file, ext = os.path.splitext(filename)
|
| 178 |
+
extracted_dir = os.path.join(self.root, file)
|
| 179 |
+
if not os.path.exists(extracted_dir):
|
| 180 |
+
return False
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
def download(self) -> None:
|
| 184 |
+
if self._check_integrity():
|
| 185 |
+
print("Files already downloaded and verified")
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
+
# download and extract image data
|
| 189 |
+
for (file_id, md5, filename) in self.FILE_LIST:
|
| 190 |
+
download_file_from_google_drive(file_id, self.root, filename, md5)
|
| 191 |
+
filepath = os.path.join(self.root, filename)
|
| 192 |
+
extract_archive(filepath)
|
| 193 |
+
|
| 194 |
+
# download and extract annotation files
|
| 195 |
+
download_and_extract_archive(
|
| 196 |
+
url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
|
| 197 |
+
)
|
.venv/lib/python3.11/site-packages/torchvision/io/__init__.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Iterator
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ..utils import _log_api_usage_once
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
|
| 9 |
+
except ModuleNotFoundError:
|
| 10 |
+
_HAS_GPU_VIDEO_DECODER = False
|
| 11 |
+
|
| 12 |
+
from ._video_opt import (
|
| 13 |
+
_HAS_CPU_VIDEO_DECODER,
|
| 14 |
+
_HAS_VIDEO_OPT,
|
| 15 |
+
_probe_video_from_file,
|
| 16 |
+
_probe_video_from_memory,
|
| 17 |
+
_read_video_from_file,
|
| 18 |
+
_read_video_from_memory,
|
| 19 |
+
_read_video_timestamps_from_file,
|
| 20 |
+
_read_video_timestamps_from_memory,
|
| 21 |
+
Timebase,
|
| 22 |
+
VideoMetaData,
|
| 23 |
+
)
|
| 24 |
+
from .image import (
|
| 25 |
+
decode_gif,
|
| 26 |
+
decode_image,
|
| 27 |
+
decode_jpeg,
|
| 28 |
+
decode_png,
|
| 29 |
+
decode_webp,
|
| 30 |
+
encode_jpeg,
|
| 31 |
+
encode_png,
|
| 32 |
+
ImageReadMode,
|
| 33 |
+
read_file,
|
| 34 |
+
read_image,
|
| 35 |
+
write_file,
|
| 36 |
+
write_jpeg,
|
| 37 |
+
write_png,
|
| 38 |
+
)
|
| 39 |
+
from .video import read_video, read_video_timestamps, write_video
|
| 40 |
+
from .video_reader import VideoReader
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
__all__ = [
|
| 44 |
+
"write_video",
|
| 45 |
+
"read_video",
|
| 46 |
+
"read_video_timestamps",
|
| 47 |
+
"_read_video_from_file",
|
| 48 |
+
"_read_video_timestamps_from_file",
|
| 49 |
+
"_probe_video_from_file",
|
| 50 |
+
"_read_video_from_memory",
|
| 51 |
+
"_read_video_timestamps_from_memory",
|
| 52 |
+
"_probe_video_from_memory",
|
| 53 |
+
"_HAS_CPU_VIDEO_DECODER",
|
| 54 |
+
"_HAS_VIDEO_OPT",
|
| 55 |
+
"_HAS_GPU_VIDEO_DECODER",
|
| 56 |
+
"_read_video_clip_from_memory",
|
| 57 |
+
"_read_video_meta_data",
|
| 58 |
+
"VideoMetaData",
|
| 59 |
+
"Timebase",
|
| 60 |
+
"ImageReadMode",
|
| 61 |
+
"decode_image",
|
| 62 |
+
"decode_jpeg",
|
| 63 |
+
"decode_png",
|
| 64 |
+
"decode_heic",
|
| 65 |
+
"decode_webp",
|
| 66 |
+
"decode_gif",
|
| 67 |
+
"encode_jpeg",
|
| 68 |
+
"encode_png",
|
| 69 |
+
"read_file",
|
| 70 |
+
"read_image",
|
| 71 |
+
"write_file",
|
| 72 |
+
"write_jpeg",
|
| 73 |
+
"write_png",
|
| 74 |
+
"Video",
|
| 75 |
+
"VideoReader",
|
| 76 |
+
]
|
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_load_gpu_decoder.cpython-311.pyc
ADDED
|
Binary file (464 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_video_opt.cpython-311.pyc
ADDED
|
Binary file (23.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/image.cpython-311.pyc
ADDED
|
Binary file (24.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video.cpython-311.pyc
ADDED
|
Binary file (21.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video_reader.cpython-311.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/io/_load_gpu_decoder.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..extension import _load_library
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
_load_library("gpu_decoder")
|
| 6 |
+
_HAS_GPU_VIDEO_DECODER = True
|
| 7 |
+
except (ImportError, OSError):
|
| 8 |
+
_HAS_GPU_VIDEO_DECODER = False
|
.venv/lib/python3.11/site-packages/torchvision/io/_video_opt.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import warnings
|
| 3 |
+
from fractions import Fraction
|
| 4 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ..extension import _load_library
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
_load_library("video_reader")
|
| 13 |
+
_HAS_CPU_VIDEO_DECODER = True
|
| 14 |
+
except (ImportError, OSError):
|
| 15 |
+
_HAS_CPU_VIDEO_DECODER = False
|
| 16 |
+
|
| 17 |
+
_HAS_VIDEO_OPT = _HAS_CPU_VIDEO_DECODER # For BC
|
| 18 |
+
default_timebase = Fraction(0, 1)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# simple class for torch scripting
|
| 22 |
+
# the complex Fraction class from fractions module is not scriptable
|
| 23 |
+
class Timebase:
|
| 24 |
+
__annotations__ = {"numerator": int, "denominator": int}
|
| 25 |
+
__slots__ = ["numerator", "denominator"]
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
numerator: int,
|
| 30 |
+
denominator: int,
|
| 31 |
+
) -> None:
|
| 32 |
+
self.numerator = numerator
|
| 33 |
+
self.denominator = denominator
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class VideoMetaData:
|
| 37 |
+
__annotations__ = {
|
| 38 |
+
"has_video": bool,
|
| 39 |
+
"video_timebase": Timebase,
|
| 40 |
+
"video_duration": float,
|
| 41 |
+
"video_fps": float,
|
| 42 |
+
"has_audio": bool,
|
| 43 |
+
"audio_timebase": Timebase,
|
| 44 |
+
"audio_duration": float,
|
| 45 |
+
"audio_sample_rate": float,
|
| 46 |
+
}
|
| 47 |
+
__slots__ = [
|
| 48 |
+
"has_video",
|
| 49 |
+
"video_timebase",
|
| 50 |
+
"video_duration",
|
| 51 |
+
"video_fps",
|
| 52 |
+
"has_audio",
|
| 53 |
+
"audio_timebase",
|
| 54 |
+
"audio_duration",
|
| 55 |
+
"audio_sample_rate",
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
def __init__(self) -> None:
|
| 59 |
+
self.has_video = False
|
| 60 |
+
self.video_timebase = Timebase(0, 1)
|
| 61 |
+
self.video_duration = 0.0
|
| 62 |
+
self.video_fps = 0.0
|
| 63 |
+
self.has_audio = False
|
| 64 |
+
self.audio_timebase = Timebase(0, 1)
|
| 65 |
+
self.audio_duration = 0.0
|
| 66 |
+
self.audio_sample_rate = 0.0
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _validate_pts(pts_range: Tuple[int, int]) -> None:
|
| 70 |
+
|
| 71 |
+
if pts_range[0] > pts_range[1] > 0:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _fill_info(
|
| 78 |
+
vtimebase: torch.Tensor,
|
| 79 |
+
vfps: torch.Tensor,
|
| 80 |
+
vduration: torch.Tensor,
|
| 81 |
+
atimebase: torch.Tensor,
|
| 82 |
+
asample_rate: torch.Tensor,
|
| 83 |
+
aduration: torch.Tensor,
|
| 84 |
+
) -> VideoMetaData:
|
| 85 |
+
"""
|
| 86 |
+
Build update VideoMetaData struct with info about the video
|
| 87 |
+
"""
|
| 88 |
+
meta = VideoMetaData()
|
| 89 |
+
if vtimebase.numel() > 0:
|
| 90 |
+
meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
|
| 91 |
+
timebase = vtimebase[0].item() / float(vtimebase[1].item())
|
| 92 |
+
if vduration.numel() > 0:
|
| 93 |
+
meta.has_video = True
|
| 94 |
+
meta.video_duration = float(vduration.item()) * timebase
|
| 95 |
+
if vfps.numel() > 0:
|
| 96 |
+
meta.video_fps = float(vfps.item())
|
| 97 |
+
if atimebase.numel() > 0:
|
| 98 |
+
meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
|
| 99 |
+
timebase = atimebase[0].item() / float(atimebase[1].item())
|
| 100 |
+
if aduration.numel() > 0:
|
| 101 |
+
meta.has_audio = True
|
| 102 |
+
meta.audio_duration = float(aduration.item()) * timebase
|
| 103 |
+
if asample_rate.numel() > 0:
|
| 104 |
+
meta.audio_sample_rate = float(asample_rate.item())
|
| 105 |
+
|
| 106 |
+
return meta
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _align_audio_frames(
|
| 110 |
+
aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
|
| 111 |
+
) -> torch.Tensor:
|
| 112 |
+
start, end = aframe_pts[0], aframe_pts[-1]
|
| 113 |
+
num_samples = aframes.size(0)
|
| 114 |
+
step_per_aframe = float(end - start + 1) / float(num_samples)
|
| 115 |
+
s_idx = 0
|
| 116 |
+
e_idx = num_samples
|
| 117 |
+
if start < audio_pts_range[0]:
|
| 118 |
+
s_idx = int((audio_pts_range[0] - start) / step_per_aframe)
|
| 119 |
+
if audio_pts_range[1] != -1 and end > audio_pts_range[1]:
|
| 120 |
+
e_idx = int((audio_pts_range[1] - end) / step_per_aframe)
|
| 121 |
+
return aframes[s_idx:e_idx, :]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _read_video_from_file(
|
| 125 |
+
filename: str,
|
| 126 |
+
seek_frame_margin: float = 0.25,
|
| 127 |
+
read_video_stream: bool = True,
|
| 128 |
+
video_width: int = 0,
|
| 129 |
+
video_height: int = 0,
|
| 130 |
+
video_min_dimension: int = 0,
|
| 131 |
+
video_max_dimension: int = 0,
|
| 132 |
+
video_pts_range: Tuple[int, int] = (0, -1),
|
| 133 |
+
video_timebase: Fraction = default_timebase,
|
| 134 |
+
read_audio_stream: bool = True,
|
| 135 |
+
audio_samples: int = 0,
|
| 136 |
+
audio_channels: int = 0,
|
| 137 |
+
audio_pts_range: Tuple[int, int] = (0, -1),
|
| 138 |
+
audio_timebase: Fraction = default_timebase,
|
| 139 |
+
) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
|
| 140 |
+
"""
|
| 141 |
+
Reads a video from a file, returning both the video frames and the audio frames
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
filename (str): path to the video file
|
| 145 |
+
seek_frame_margin (double, optional): seeking frame in the stream is imprecise. Thus,
|
| 146 |
+
when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
|
| 147 |
+
read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
|
| 148 |
+
video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
|
| 149 |
+
the size of decoded frames:
|
| 150 |
+
|
| 151 |
+
- When video_width = 0, video_height = 0, video_min_dimension = 0,
|
| 152 |
+
and video_max_dimension = 0, keep the original frame resolution
|
| 153 |
+
- When video_width = 0, video_height = 0, video_min_dimension != 0,
|
| 154 |
+
and video_max_dimension = 0, keep the aspect ratio and resize the
|
| 155 |
+
frame so that shorter edge size is video_min_dimension
|
| 156 |
+
- When video_width = 0, video_height = 0, video_min_dimension = 0,
|
| 157 |
+
and video_max_dimension != 0, keep the aspect ratio and resize
|
| 158 |
+
the frame so that longer edge size is video_max_dimension
|
| 159 |
+
- When video_width = 0, video_height = 0, video_min_dimension != 0,
|
| 160 |
+
and video_max_dimension != 0, resize the frame so that shorter
|
| 161 |
+
edge size is video_min_dimension, and longer edge size is
|
| 162 |
+
video_max_dimension. The aspect ratio may not be preserved
|
| 163 |
+
- When video_width = 0, video_height != 0, video_min_dimension = 0,
|
| 164 |
+
and video_max_dimension = 0, keep the aspect ratio and resize
|
| 165 |
+
the frame so that frame video_height is $video_height
|
| 166 |
+
- When video_width != 0, video_height == 0, video_min_dimension = 0,
|
| 167 |
+
and video_max_dimension = 0, keep the aspect ratio and resize
|
| 168 |
+
the frame so that frame video_width is $video_width
|
| 169 |
+
- When video_width != 0, video_height != 0, video_min_dimension = 0,
|
| 170 |
+
and video_max_dimension = 0, resize the frame so that frame
|
| 171 |
+
video_width and video_height are set to $video_width and
|
| 172 |
+
$video_height, respectively
|
| 173 |
+
video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
|
| 174 |
+
video_timebase (Fraction, optional): a Fraction rational number which denotes timebase in video stream
|
| 175 |
+
read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
|
| 176 |
+
audio_samples (int, optional): audio sampling rate
|
| 177 |
+
audio_channels (int optional): audio channels
|
| 178 |
+
audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
|
| 179 |
+
audio_timebase (Fraction, optional): a Fraction rational number which denotes time base in audio stream
|
| 180 |
+
|
| 181 |
+
Returns
|
| 182 |
+
vframes (Tensor[T, H, W, C]): the `T` video frames
|
| 183 |
+
aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
|
| 184 |
+
`K` is the number of audio_channels
|
| 185 |
+
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float)
|
| 186 |
+
and audio_fps (int)
|
| 187 |
+
"""
|
| 188 |
+
_validate_pts(video_pts_range)
|
| 189 |
+
_validate_pts(audio_pts_range)
|
| 190 |
+
|
| 191 |
+
result = torch.ops.video_reader.read_video_from_file(
|
| 192 |
+
filename,
|
| 193 |
+
seek_frame_margin,
|
| 194 |
+
0, # getPtsOnly
|
| 195 |
+
read_video_stream,
|
| 196 |
+
video_width,
|
| 197 |
+
video_height,
|
| 198 |
+
video_min_dimension,
|
| 199 |
+
video_max_dimension,
|
| 200 |
+
video_pts_range[0],
|
| 201 |
+
video_pts_range[1],
|
| 202 |
+
video_timebase.numerator,
|
| 203 |
+
video_timebase.denominator,
|
| 204 |
+
read_audio_stream,
|
| 205 |
+
audio_samples,
|
| 206 |
+
audio_channels,
|
| 207 |
+
audio_pts_range[0],
|
| 208 |
+
audio_pts_range[1],
|
| 209 |
+
audio_timebase.numerator,
|
| 210 |
+
audio_timebase.denominator,
|
| 211 |
+
)
|
| 212 |
+
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
|
| 213 |
+
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
|
| 214 |
+
if aframes.numel() > 0:
|
| 215 |
+
# when audio stream is found
|
| 216 |
+
aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
|
| 217 |
+
return vframes, aframes, info
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
|
| 221 |
+
"""
|
| 222 |
+
Decode all video- and audio frames in the video. Only pts
|
| 223 |
+
(presentation timestamp) is returned. The actual frame pixel data is not
|
| 224 |
+
copied. Thus, it is much faster than read_video(...)
|
| 225 |
+
"""
|
| 226 |
+
result = torch.ops.video_reader.read_video_from_file(
|
| 227 |
+
filename,
|
| 228 |
+
0, # seek_frame_margin
|
| 229 |
+
1, # getPtsOnly
|
| 230 |
+
1, # read_video_stream
|
| 231 |
+
0, # video_width
|
| 232 |
+
0, # video_height
|
| 233 |
+
0, # video_min_dimension
|
| 234 |
+
0, # video_max_dimension
|
| 235 |
+
0, # video_start_pts
|
| 236 |
+
-1, # video_end_pts
|
| 237 |
+
0, # video_timebase_num
|
| 238 |
+
1, # video_timebase_den
|
| 239 |
+
1, # read_audio_stream
|
| 240 |
+
0, # audio_samples
|
| 241 |
+
0, # audio_channels
|
| 242 |
+
0, # audio_start_pts
|
| 243 |
+
-1, # audio_end_pts
|
| 244 |
+
0, # audio_timebase_num
|
| 245 |
+
1, # audio_timebase_den
|
| 246 |
+
)
|
| 247 |
+
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
|
| 248 |
+
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
|
| 249 |
+
|
| 250 |
+
vframe_pts = vframe_pts.numpy().tolist()
|
| 251 |
+
aframe_pts = aframe_pts.numpy().tolist()
|
| 252 |
+
return vframe_pts, aframe_pts, info
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _probe_video_from_file(filename: str) -> VideoMetaData:
|
| 256 |
+
"""
|
| 257 |
+
Probe a video file and return VideoMetaData with info about the video
|
| 258 |
+
"""
|
| 259 |
+
result = torch.ops.video_reader.probe_video_from_file(filename)
|
| 260 |
+
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
|
| 261 |
+
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
|
| 262 |
+
return info
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _read_video_from_memory(
|
| 266 |
+
video_data: torch.Tensor,
|
| 267 |
+
seek_frame_margin: float = 0.25,
|
| 268 |
+
read_video_stream: int = 1,
|
| 269 |
+
video_width: int = 0,
|
| 270 |
+
video_height: int = 0,
|
| 271 |
+
video_min_dimension: int = 0,
|
| 272 |
+
video_max_dimension: int = 0,
|
| 273 |
+
video_pts_range: Tuple[int, int] = (0, -1),
|
| 274 |
+
video_timebase_numerator: int = 0,
|
| 275 |
+
video_timebase_denominator: int = 1,
|
| 276 |
+
read_audio_stream: int = 1,
|
| 277 |
+
audio_samples: int = 0,
|
| 278 |
+
audio_channels: int = 0,
|
| 279 |
+
audio_pts_range: Tuple[int, int] = (0, -1),
|
| 280 |
+
audio_timebase_numerator: int = 0,
|
| 281 |
+
audio_timebase_denominator: int = 1,
|
| 282 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 283 |
+
"""
|
| 284 |
+
Reads a video from memory, returning both the video frames as the audio frames
|
| 285 |
+
This function is torchscriptable.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
video_data (data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes):
|
| 289 |
+
compressed video content stored in either 1) torch.Tensor 2) python bytes
|
| 290 |
+
seek_frame_margin (double, optional): seeking frame in the stream is imprecise.
|
| 291 |
+
Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
|
| 292 |
+
read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
|
| 293 |
+
video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
|
| 294 |
+
the size of decoded frames:
|
| 295 |
+
|
| 296 |
+
- When video_width = 0, video_height = 0, video_min_dimension = 0,
|
| 297 |
+
and video_max_dimension = 0, keep the original frame resolution
|
| 298 |
+
- When video_width = 0, video_height = 0, video_min_dimension != 0,
|
| 299 |
+
and video_max_dimension = 0, keep the aspect ratio and resize the
|
| 300 |
+
frame so that shorter edge size is video_min_dimension
|
| 301 |
+
- When video_width = 0, video_height = 0, video_min_dimension = 0,
|
| 302 |
+
and video_max_dimension != 0, keep the aspect ratio and resize
|
| 303 |
+
the frame so that longer edge size is video_max_dimension
|
| 304 |
+
- When video_width = 0, video_height = 0, video_min_dimension != 0,
|
| 305 |
+
and video_max_dimension != 0, resize the frame so that shorter
|
| 306 |
+
edge size is video_min_dimension, and longer edge size is
|
| 307 |
+
video_max_dimension. The aspect ratio may not be preserved
|
| 308 |
+
- When video_width = 0, video_height != 0, video_min_dimension = 0,
|
| 309 |
+
and video_max_dimension = 0, keep the aspect ratio and resize
|
| 310 |
+
the frame so that frame video_height is $video_height
|
| 311 |
+
- When video_width != 0, video_height == 0, video_min_dimension = 0,
|
| 312 |
+
and video_max_dimension = 0, keep the aspect ratio and resize
|
| 313 |
+
the frame so that frame video_width is $video_width
|
| 314 |
+
- When video_width != 0, video_height != 0, video_min_dimension = 0,
|
| 315 |
+
and video_max_dimension = 0, resize the frame so that frame
|
| 316 |
+
video_width and video_height are set to $video_width and
|
| 317 |
+
$video_height, respectively
|
| 318 |
+
video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
|
| 319 |
+
video_timebase_numerator / video_timebase_denominator (float, optional): a rational
|
| 320 |
+
number which denotes timebase in video stream
|
| 321 |
+
read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
|
| 322 |
+
audio_samples (int, optional): audio sampling rate
|
| 323 |
+
audio_channels (int optional): audio audio_channels
|
| 324 |
+
audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
|
| 325 |
+
audio_timebase_numerator / audio_timebase_denominator (float, optional):
|
| 326 |
+
a rational number which denotes time base in audio stream
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
vframes (Tensor[T, H, W, C]): the `T` video frames
|
| 330 |
+
aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
|
| 331 |
+
`K` is the number of channels
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
_validate_pts(video_pts_range)
|
| 335 |
+
_validate_pts(audio_pts_range)
|
| 336 |
+
|
| 337 |
+
if not isinstance(video_data, torch.Tensor):
|
| 338 |
+
with warnings.catch_warnings():
|
| 339 |
+
# Ignore the warning because we actually don't modify the buffer in this function
|
| 340 |
+
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
| 341 |
+
video_data = torch.frombuffer(video_data, dtype=torch.uint8)
|
| 342 |
+
|
| 343 |
+
result = torch.ops.video_reader.read_video_from_memory(
|
| 344 |
+
video_data,
|
| 345 |
+
seek_frame_margin,
|
| 346 |
+
0, # getPtsOnly
|
| 347 |
+
read_video_stream,
|
| 348 |
+
video_width,
|
| 349 |
+
video_height,
|
| 350 |
+
video_min_dimension,
|
| 351 |
+
video_max_dimension,
|
| 352 |
+
video_pts_range[0],
|
| 353 |
+
video_pts_range[1],
|
| 354 |
+
video_timebase_numerator,
|
| 355 |
+
video_timebase_denominator,
|
| 356 |
+
read_audio_stream,
|
| 357 |
+
audio_samples,
|
| 358 |
+
audio_channels,
|
| 359 |
+
audio_pts_range[0],
|
| 360 |
+
audio_pts_range[1],
|
| 361 |
+
audio_timebase_numerator,
|
| 362 |
+
audio_timebase_denominator,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
|
| 366 |
+
|
| 367 |
+
if aframes.numel() > 0:
|
| 368 |
+
# when audio stream is found
|
| 369 |
+
aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
|
| 370 |
+
|
| 371 |
+
return vframes, aframes
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _read_video_timestamps_from_memory(
|
| 375 |
+
video_data: torch.Tensor,
|
| 376 |
+
) -> Tuple[List[int], List[int], VideoMetaData]:
|
| 377 |
+
"""
|
| 378 |
+
Decode all frames in the video. Only pts (presentation timestamp) is returned.
|
| 379 |
+
The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
|
| 380 |
+
is much faster than read_video(...)
|
| 381 |
+
"""
|
| 382 |
+
if not isinstance(video_data, torch.Tensor):
|
| 383 |
+
with warnings.catch_warnings():
|
| 384 |
+
# Ignore the warning because we actually don't modify the buffer in this function
|
| 385 |
+
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
| 386 |
+
video_data = torch.frombuffer(video_data, dtype=torch.uint8)
|
| 387 |
+
result = torch.ops.video_reader.read_video_from_memory(
|
| 388 |
+
video_data,
|
| 389 |
+
0, # seek_frame_margin
|
| 390 |
+
1, # getPtsOnly
|
| 391 |
+
1, # read_video_stream
|
| 392 |
+
0, # video_width
|
| 393 |
+
0, # video_height
|
| 394 |
+
0, # video_min_dimension
|
| 395 |
+
0, # video_max_dimension
|
| 396 |
+
0, # video_start_pts
|
| 397 |
+
-1, # video_end_pts
|
| 398 |
+
0, # video_timebase_num
|
| 399 |
+
1, # video_timebase_den
|
| 400 |
+
1, # read_audio_stream
|
| 401 |
+
0, # audio_samples
|
| 402 |
+
0, # audio_channels
|
| 403 |
+
0, # audio_start_pts
|
| 404 |
+
-1, # audio_end_pts
|
| 405 |
+
0, # audio_timebase_num
|
| 406 |
+
1, # audio_timebase_den
|
| 407 |
+
)
|
| 408 |
+
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
|
| 409 |
+
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
|
| 410 |
+
|
| 411 |
+
vframe_pts = vframe_pts.numpy().tolist()
|
| 412 |
+
aframe_pts = aframe_pts.numpy().tolist()
|
| 413 |
+
return vframe_pts, aframe_pts, info
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def _probe_video_from_memory(
|
| 417 |
+
video_data: torch.Tensor,
|
| 418 |
+
) -> VideoMetaData:
|
| 419 |
+
"""
|
| 420 |
+
Probe a video in memory and return VideoMetaData with info about the video
|
| 421 |
+
This function is torchscriptable
|
| 422 |
+
"""
|
| 423 |
+
if not isinstance(video_data, torch.Tensor):
|
| 424 |
+
with warnings.catch_warnings():
|
| 425 |
+
# Ignore the warning because we actually don't modify the buffer in this function
|
| 426 |
+
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
| 427 |
+
video_data = torch.frombuffer(video_data, dtype=torch.uint8)
|
| 428 |
+
result = torch.ops.video_reader.probe_video_from_memory(video_data)
|
| 429 |
+
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
|
| 430 |
+
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
|
| 431 |
+
return info
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _read_video(
|
| 435 |
+
filename: str,
|
| 436 |
+
start_pts: Union[float, Fraction] = 0,
|
| 437 |
+
end_pts: Optional[Union[float, Fraction]] = None,
|
| 438 |
+
pts_unit: str = "pts",
|
| 439 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
|
| 440 |
+
if end_pts is None:
|
| 441 |
+
end_pts = float("inf")
|
| 442 |
+
|
| 443 |
+
if pts_unit == "pts":
|
| 444 |
+
warnings.warn(
|
| 445 |
+
"The pts_unit 'pts' gives wrong results and will be removed in a "
|
| 446 |
+
+ "follow-up version. Please use pts_unit 'sec'."
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
info = _probe_video_from_file(filename)
|
| 450 |
+
|
| 451 |
+
has_video = info.has_video
|
| 452 |
+
has_audio = info.has_audio
|
| 453 |
+
|
| 454 |
+
def get_pts(time_base):
|
| 455 |
+
start_offset = start_pts
|
| 456 |
+
end_offset = end_pts
|
| 457 |
+
if pts_unit == "sec":
|
| 458 |
+
start_offset = int(math.floor(start_pts * (1 / time_base)))
|
| 459 |
+
if end_offset != float("inf"):
|
| 460 |
+
end_offset = int(math.ceil(end_pts * (1 / time_base)))
|
| 461 |
+
if end_offset == float("inf"):
|
| 462 |
+
end_offset = -1
|
| 463 |
+
return start_offset, end_offset
|
| 464 |
+
|
| 465 |
+
video_pts_range = (0, -1)
|
| 466 |
+
video_timebase = default_timebase
|
| 467 |
+
if has_video:
|
| 468 |
+
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
|
| 469 |
+
video_pts_range = get_pts(video_timebase)
|
| 470 |
+
|
| 471 |
+
audio_pts_range = (0, -1)
|
| 472 |
+
audio_timebase = default_timebase
|
| 473 |
+
if has_audio:
|
| 474 |
+
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
|
| 475 |
+
audio_pts_range = get_pts(audio_timebase)
|
| 476 |
+
|
| 477 |
+
vframes, aframes, info = _read_video_from_file(
|
| 478 |
+
filename,
|
| 479 |
+
read_video_stream=True,
|
| 480 |
+
video_pts_range=video_pts_range,
|
| 481 |
+
video_timebase=video_timebase,
|
| 482 |
+
read_audio_stream=True,
|
| 483 |
+
audio_pts_range=audio_pts_range,
|
| 484 |
+
audio_timebase=audio_timebase,
|
| 485 |
+
)
|
| 486 |
+
_info = {}
|
| 487 |
+
if has_video:
|
| 488 |
+
_info["video_fps"] = info.video_fps
|
| 489 |
+
if has_audio:
|
| 490 |
+
_info["audio_fps"] = info.audio_sample_rate
|
| 491 |
+
|
| 492 |
+
return vframes, aframes, _info
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def _read_video_timestamps(
|
| 496 |
+
filename: str, pts_unit: str = "pts"
|
| 497 |
+
) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
|
| 498 |
+
if pts_unit == "pts":
|
| 499 |
+
warnings.warn(
|
| 500 |
+
"The pts_unit 'pts' gives wrong results and will be removed in a "
|
| 501 |
+
+ "follow-up version. Please use pts_unit 'sec'."
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
pts: Union[List[int], List[Fraction]]
|
| 505 |
+
pts, _, info = _read_video_timestamps_from_file(filename)
|
| 506 |
+
|
| 507 |
+
if pts_unit == "sec":
|
| 508 |
+
video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
|
| 509 |
+
pts = [x * video_time_base for x in pts]
|
| 510 |
+
|
| 511 |
+
video_fps = info.video_fps if info.has_video else None
|
| 512 |
+
|
| 513 |
+
return pts, video_fps
|
.venv/lib/python3.11/site-packages/torchvision/io/image.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
from warnings import warn
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ..extension import _load_library
|
| 8 |
+
from ..utils import _log_api_usage_once
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
_load_library("image")
|
| 13 |
+
except (ImportError, OSError) as e:
|
| 14 |
+
warn(
|
| 15 |
+
f"Failed to load image Python extension: '{e}'"
|
| 16 |
+
f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
|
| 17 |
+
f"Otherwise, there might be something wrong with your environment. "
|
| 18 |
+
f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ImageReadMode(Enum):
|
| 23 |
+
"""Allow automatic conversion to RGB, RGBA, etc while decoding.
|
| 24 |
+
|
| 25 |
+
.. note::
|
| 26 |
+
|
| 27 |
+
You don't need to use this struct, you can just pass strings to all
|
| 28 |
+
``mode`` parameters, e.g. ``mode="RGB"``.
|
| 29 |
+
|
| 30 |
+
The different available modes are the following.
|
| 31 |
+
|
| 32 |
+
- UNCHANGED: loads the image as-is
|
| 33 |
+
- RGB: converts to RGB
|
| 34 |
+
- RGBA: converts to RGB with transparency (also aliased as RGB_ALPHA)
|
| 35 |
+
- GRAY: converts to grayscale
|
| 36 |
+
- GRAY_ALPHA: converts to grayscale with transparency
|
| 37 |
+
|
| 38 |
+
.. note::
|
| 39 |
+
|
| 40 |
+
Some decoders won't support all possible values, e.g. GRAY and
|
| 41 |
+
GRAY_ALPHA are only supported for PNG and JPEG images.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
UNCHANGED = 0
|
| 45 |
+
GRAY = 1
|
| 46 |
+
GRAY_ALPHA = 2
|
| 47 |
+
RGB = 3
|
| 48 |
+
RGB_ALPHA = 4
|
| 49 |
+
RGBA = RGB_ALPHA # Alias for convenience
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def read_file(path: str) -> torch.Tensor:
|
| 53 |
+
"""
|
| 54 |
+
Return the bytes contents of a file as a uint8 1D Tensor.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
path (str or ``pathlib.Path``): the path to the file to be read
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
data (Tensor)
|
| 61 |
+
"""
|
| 62 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 63 |
+
_log_api_usage_once(read_file)
|
| 64 |
+
data = torch.ops.image.read_file(str(path))
|
| 65 |
+
return data
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def write_file(filename: str, data: torch.Tensor) -> None:
|
| 69 |
+
"""
|
| 70 |
+
Write the content of an uint8 1D tensor to a file.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
filename (str or ``pathlib.Path``): the path to the file to be written
|
| 74 |
+
data (Tensor): the contents to be written to the output file
|
| 75 |
+
"""
|
| 76 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 77 |
+
_log_api_usage_once(write_file)
|
| 78 |
+
torch.ops.image.write_file(str(filename), data)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def decode_png(
|
| 82 |
+
input: torch.Tensor,
|
| 83 |
+
mode: ImageReadMode = ImageReadMode.UNCHANGED,
|
| 84 |
+
apply_exif_orientation: bool = False,
|
| 85 |
+
) -> torch.Tensor:
|
| 86 |
+
"""
|
| 87 |
+
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
|
| 88 |
+
|
| 89 |
+
The values of the output tensor are in uint8 in [0, 255] for most cases. If
|
| 90 |
+
the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
|
| 91 |
+
(supported from torchvision ``0.21``). Since uint16 support is limited in
|
| 92 |
+
pytorch, we recommend calling
|
| 93 |
+
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
|
| 94 |
+
after this function to convert the decoded image into a uint8 or float
|
| 95 |
+
tensor.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
input (Tensor[1]): a one dimensional uint8 tensor containing
|
| 99 |
+
the raw bytes of the PNG image.
|
| 100 |
+
mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
|
| 101 |
+
Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
|
| 102 |
+
for available modes.
|
| 103 |
+
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
|
| 104 |
+
Default: False.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
output (Tensor[image_channels, image_height, image_width])
|
| 108 |
+
"""
|
| 109 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 110 |
+
_log_api_usage_once(decode_png)
|
| 111 |
+
if isinstance(mode, str):
|
| 112 |
+
mode = ImageReadMode[mode.upper()]
|
| 113 |
+
output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
|
| 114 |
+
return output
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
|
| 118 |
+
"""
|
| 119 |
+
Takes an input tensor in CHW layout and returns a buffer with the contents
|
| 120 |
+
of its corresponding PNG file.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
input (Tensor[channels, image_height, image_width]): int8 image tensor of
|
| 124 |
+
``c`` channels, where ``c`` must 3 or 1.
|
| 125 |
+
compression_level (int): Compression factor for the resulting file, it must be a number
|
| 126 |
+
between 0 and 9. Default: 6
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
|
| 130 |
+
PNG file.
|
| 131 |
+
"""
|
| 132 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 133 |
+
_log_api_usage_once(encode_png)
|
| 134 |
+
output = torch.ops.image.encode_png(input, compression_level)
|
| 135 |
+
return output
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
|
| 139 |
+
"""
|
| 140 |
+
Takes an input tensor in CHW layout (or HW in the case of grayscale images)
|
| 141 |
+
and saves it in a PNG file.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
input (Tensor[channels, image_height, image_width]): int8 image tensor of
|
| 145 |
+
``c`` channels, where ``c`` must be 1 or 3.
|
| 146 |
+
filename (str or ``pathlib.Path``): Path to save the image.
|
| 147 |
+
compression_level (int): Compression factor for the resulting file, it must be a number
|
| 148 |
+
between 0 and 9. Default: 6
|
| 149 |
+
"""
|
| 150 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 151 |
+
_log_api_usage_once(write_png)
|
| 152 |
+
output = encode_png(input, compression_level)
|
| 153 |
+
write_file(filename, output)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def decode_jpeg(
|
| 157 |
+
input: Union[torch.Tensor, List[torch.Tensor]],
|
| 158 |
+
mode: ImageReadMode = ImageReadMode.UNCHANGED,
|
| 159 |
+
device: Union[str, torch.device] = "cpu",
|
| 160 |
+
apply_exif_orientation: bool = False,
|
| 161 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 162 |
+
"""Decode JPEG image(s) into 3D RGB or grayscale Tensor(s), on CPU or CUDA.
|
| 163 |
+
|
| 164 |
+
The values of the output tensor are uint8 between 0 and 255.
|
| 165 |
+
|
| 166 |
+
.. note::
|
| 167 |
+
When using a CUDA device, passing a list of tensors is more efficient than repeated individual calls to ``decode_jpeg``.
|
| 168 |
+
When using CPU the performance is equivalent.
|
| 169 |
+
The CUDA version of this function has explicitly been designed with thread-safety in mind.
|
| 170 |
+
This function does not return partial results in case of an error.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
|
| 174 |
+
the raw bytes of the JPEG image. The tensor(s) must be on CPU,
|
| 175 |
+
regardless of the ``device`` parameter.
|
| 176 |
+
mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
|
| 177 |
+
Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
|
| 178 |
+
for available modes.
|
| 179 |
+
device (str or torch.device): The device on which the decoded image will
|
| 180 |
+
be stored. If a cuda device is specified, the image will be decoded
|
| 181 |
+
with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
|
| 182 |
+
supported for CUDA version >= 10.1
|
| 183 |
+
|
| 184 |
+
.. betastatus:: device parameter
|
| 185 |
+
|
| 186 |
+
.. warning::
|
| 187 |
+
There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
|
| 188 |
+
Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
|
| 189 |
+
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
|
| 190 |
+
Default: False. Only implemented for JPEG format on CPU.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
output (Tensor[image_channels, image_height, image_width] or list[Tensor[image_channels, image_height, image_width]]):
|
| 194 |
+
The values of the output tensor(s) are uint8 between 0 and 255.
|
| 195 |
+
``output.device`` will be set to the specified ``device``
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
"""
|
| 199 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 200 |
+
_log_api_usage_once(decode_jpeg)
|
| 201 |
+
if isinstance(device, str):
|
| 202 |
+
device = torch.device(device)
|
| 203 |
+
if isinstance(mode, str):
|
| 204 |
+
mode = ImageReadMode[mode.upper()]
|
| 205 |
+
|
| 206 |
+
if isinstance(input, list):
|
| 207 |
+
if len(input) == 0:
|
| 208 |
+
raise ValueError("Input list must contain at least one element")
|
| 209 |
+
if not all(isinstance(t, torch.Tensor) for t in input):
|
| 210 |
+
raise ValueError("All elements of the input list must be tensors.")
|
| 211 |
+
if not all(t.device.type == "cpu" for t in input):
|
| 212 |
+
raise ValueError("Input list must contain tensors on CPU.")
|
| 213 |
+
if device.type == "cuda":
|
| 214 |
+
return torch.ops.image.decode_jpegs_cuda(input, mode.value, device)
|
| 215 |
+
else:
|
| 216 |
+
return [torch.ops.image.decode_jpeg(img, mode.value, apply_exif_orientation) for img in input]
|
| 217 |
+
|
| 218 |
+
else: # input is tensor
|
| 219 |
+
if input.device.type != "cpu":
|
| 220 |
+
raise ValueError("Input tensor must be a CPU tensor")
|
| 221 |
+
if device.type == "cuda":
|
| 222 |
+
return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0]
|
| 223 |
+
else:
|
| 224 |
+
return torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def encode_jpeg(
|
| 228 |
+
input: Union[torch.Tensor, List[torch.Tensor]], quality: int = 75
|
| 229 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 230 |
+
"""Encode RGB tensor(s) into raw encoded jpeg bytes, on CPU or CUDA.
|
| 231 |
+
|
| 232 |
+
.. note::
|
| 233 |
+
Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``.
|
| 234 |
+
For CPU tensors the performance is equivalent.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]):
|
| 238 |
+
(list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3
|
| 239 |
+
quality (int): Quality of the resulting JPEG file(s). Must be a number between
|
| 240 |
+
1 and 100. Default: 75
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file.
|
| 244 |
+
"""
|
| 245 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 246 |
+
_log_api_usage_once(encode_jpeg)
|
| 247 |
+
if quality < 1 or quality > 100:
|
| 248 |
+
raise ValueError("Image quality should be a positive number between 1 and 100")
|
| 249 |
+
if isinstance(input, list):
|
| 250 |
+
if not input:
|
| 251 |
+
raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
|
| 252 |
+
if input[0].device.type == "cuda":
|
| 253 |
+
return torch.ops.image.encode_jpegs_cuda(input, quality)
|
| 254 |
+
else:
|
| 255 |
+
return [torch.ops.image.encode_jpeg(image, quality) for image in input]
|
| 256 |
+
else: # single input tensor
|
| 257 |
+
if input.device.type == "cuda":
|
| 258 |
+
return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
|
| 259 |
+
else:
|
| 260 |
+
return torch.ops.image.encode_jpeg(input, quality)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
|
| 264 |
+
"""
|
| 265 |
+
Takes an input tensor in CHW layout and saves it in a JPEG file.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
|
| 269 |
+
channels, where ``c`` must be 1 or 3.
|
| 270 |
+
filename (str or ``pathlib.Path``): Path to save the image.
|
| 271 |
+
quality (int): Quality of the resulting JPEG file, it must be a number
|
| 272 |
+
between 1 and 100. Default: 75
|
| 273 |
+
"""
|
| 274 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 275 |
+
_log_api_usage_once(write_jpeg)
|
| 276 |
+
output = encode_jpeg(input, quality)
|
| 277 |
+
assert isinstance(output, torch.Tensor) # Needed for torchscript
|
| 278 |
+
write_file(filename, output)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def decode_image(
|
| 282 |
+
input: Union[torch.Tensor, str],
|
| 283 |
+
mode: ImageReadMode = ImageReadMode.UNCHANGED,
|
| 284 |
+
apply_exif_orientation: bool = False,
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
"""Decode an image into a uint8 tensor, from a path or from raw encoded bytes.
|
| 287 |
+
|
| 288 |
+
Currently supported image formats are jpeg, png, gif and webp.
|
| 289 |
+
|
| 290 |
+
The values of the output tensor are in uint8 in [0, 255] for most cases.
|
| 291 |
+
|
| 292 |
+
If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
|
| 293 |
+
(supported from torchvision ``0.21``). Since uint16 support is limited in
|
| 294 |
+
pytorch, we recommend calling
|
| 295 |
+
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
|
| 296 |
+
after this function to convert the decoded image into a uint8 or float
|
| 297 |
+
tensor.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
input (Tensor or str or ``pathlib.Path``): The image to decode. If a
|
| 301 |
+
tensor is passed, it must be one dimensional uint8 tensor containing
|
| 302 |
+
the raw bytes of the image. Otherwise, this must be a path to the image file.
|
| 303 |
+
mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
|
| 304 |
+
Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
|
| 305 |
+
for available modes.
|
| 306 |
+
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
|
| 307 |
+
Only applies to JPEG and PNG images. Default: False.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
output (Tensor[image_channels, image_height, image_width])
|
| 311 |
+
"""
|
| 312 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 313 |
+
_log_api_usage_once(decode_image)
|
| 314 |
+
if not isinstance(input, torch.Tensor):
|
| 315 |
+
input = read_file(str(input))
|
| 316 |
+
if isinstance(mode, str):
|
| 317 |
+
mode = ImageReadMode[mode.upper()]
|
| 318 |
+
output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
|
| 319 |
+
return output
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def read_image(
|
| 323 |
+
path: str,
|
| 324 |
+
mode: ImageReadMode = ImageReadMode.UNCHANGED,
|
| 325 |
+
apply_exif_orientation: bool = False,
|
| 326 |
+
) -> torch.Tensor:
|
| 327 |
+
"""[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
|
| 328 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 329 |
+
_log_api_usage_once(read_image)
|
| 330 |
+
data = read_file(path)
|
| 331 |
+
return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def decode_gif(input: torch.Tensor) -> torch.Tensor:
|
| 335 |
+
"""
|
| 336 |
+
Decode a GIF image into a 3 or 4 dimensional RGB Tensor.
|
| 337 |
+
|
| 338 |
+
The values of the output tensor are uint8 between 0 and 255.
|
| 339 |
+
The output tensor has shape ``(C, H, W)`` if there is only one image in the
|
| 340 |
+
GIF, and ``(N, C, H, W)`` if there are ``N`` images.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
|
| 344 |
+
the raw bytes of the GIF image.
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
output (Tensor[image_channels, image_height, image_width] or Tensor[num_images, image_channels, image_height, image_width])
|
| 348 |
+
"""
|
| 349 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 350 |
+
_log_api_usage_once(decode_gif)
|
| 351 |
+
return torch.ops.image.decode_gif(input)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def decode_webp(
|
| 355 |
+
input: torch.Tensor,
|
| 356 |
+
mode: ImageReadMode = ImageReadMode.UNCHANGED,
|
| 357 |
+
) -> torch.Tensor:
|
| 358 |
+
"""
|
| 359 |
+
Decode a WEBP image into a 3 dimensional RGB[A] Tensor.
|
| 360 |
+
|
| 361 |
+
The values of the output tensor are uint8 between 0 and 255.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
|
| 365 |
+
the raw bytes of the WEBP image.
|
| 366 |
+
mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
|
| 367 |
+
Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
|
| 368 |
+
for available modes.
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
Decoded image (Tensor[image_channels, image_height, image_width])
|
| 372 |
+
"""
|
| 373 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 374 |
+
_log_api_usage_once(decode_webp)
|
| 375 |
+
if isinstance(mode, str):
|
| 376 |
+
mode = ImageReadMode[mode.upper()]
|
| 377 |
+
return torch.ops.image.decode_webp(input, mode.value)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def _decode_avif(
|
| 381 |
+
input: torch.Tensor,
|
| 382 |
+
mode: ImageReadMode = ImageReadMode.UNCHANGED,
|
| 383 |
+
) -> torch.Tensor:
|
| 384 |
+
"""
|
| 385 |
+
Decode an AVIF image into a 3 dimensional RGB[A] Tensor.
|
| 386 |
+
|
| 387 |
+
The values of the output tensor are in uint8 in [0, 255] for most images. If
|
| 388 |
+
the image has a bit-depth of more than 8, then the output tensor is uint16
|
| 389 |
+
in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
|
| 390 |
+
calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
|
| 391 |
+
``scale=True`` after this function to convert the decoded image into a uint8
|
| 392 |
+
or float tensor.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
|
| 396 |
+
the raw bytes of the AVIF image.
|
| 397 |
+
mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
|
| 398 |
+
Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
|
| 399 |
+
for available modes.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
Decoded image (Tensor[image_channels, image_height, image_width])
|
| 403 |
+
"""
|
| 404 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 405 |
+
_log_api_usage_once(_decode_avif)
|
| 406 |
+
if isinstance(mode, str):
|
| 407 |
+
mode = ImageReadMode[mode.upper()]
|
| 408 |
+
return torch.ops.image.decode_avif(input, mode.value)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
|
| 412 |
+
"""
|
| 413 |
+
Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
|
| 414 |
+
|
| 415 |
+
The values of the output tensor are in uint8 in [0, 255] for most images. If
|
| 416 |
+
the image has a bit-depth of more than 8, then the output tensor is uint16
|
| 417 |
+
in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
|
| 418 |
+
calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
|
| 419 |
+
``scale=True`` after this function to convert the decoded image into a uint8
|
| 420 |
+
or float tensor.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
|
| 424 |
+
the raw bytes of the HEIC image.
|
| 425 |
+
mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
|
| 426 |
+
Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
|
| 427 |
+
for available modes.
|
| 428 |
+
|
| 429 |
+
Returns:
|
| 430 |
+
Decoded image (Tensor[image_channels, image_height, image_width])
|
| 431 |
+
"""
|
| 432 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 433 |
+
_log_api_usage_once(_decode_heic)
|
| 434 |
+
if isinstance(mode, str):
|
| 435 |
+
mode = ImageReadMode[mode.upper()]
|
| 436 |
+
return torch.ops.image.decode_heic(input, mode.value)
|
.venv/lib/python3.11/site-packages/torchvision/io/video.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import warnings
|
| 6 |
+
from fractions import Fraction
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from ..utils import _log_api_usage_once
|
| 13 |
+
from . import _video_opt
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import av
|
| 17 |
+
|
| 18 |
+
av.logging.set_level(av.logging.ERROR)
|
| 19 |
+
if not hasattr(av.video.frame.VideoFrame, "pict_type"):
|
| 20 |
+
av = ImportError(
|
| 21 |
+
"""\
|
| 22 |
+
Your version of PyAV is too old for the necessary video operations in torchvision.
|
| 23 |
+
If you are on Python 3.5, you will have to build from source (the conda-forge
|
| 24 |
+
packages are not up-to-date). See
|
| 25 |
+
https://github.com/mikeboers/PyAV#installation for instructions on how to
|
| 26 |
+
install PyAV on your system.
|
| 27 |
+
"""
|
| 28 |
+
)
|
| 29 |
+
except ImportError:
|
| 30 |
+
av = ImportError(
|
| 31 |
+
"""\
|
| 32 |
+
PyAV is not installed, and is necessary for the video operations in torchvision.
|
| 33 |
+
See https://github.com/mikeboers/PyAV#installation for instructions on how to
|
| 34 |
+
install PyAV on your system.
|
| 35 |
+
"""
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _check_av_available() -> None:
|
| 40 |
+
if isinstance(av, Exception):
|
| 41 |
+
raise av
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _av_available() -> bool:
|
| 45 |
+
return not isinstance(av, Exception)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# PyAV has some reference cycles
|
| 49 |
+
_CALLED_TIMES = 0
|
| 50 |
+
_GC_COLLECTION_INTERVAL = 10
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def write_video(
|
| 54 |
+
filename: str,
|
| 55 |
+
video_array: torch.Tensor,
|
| 56 |
+
fps: float,
|
| 57 |
+
video_codec: str = "libx264",
|
| 58 |
+
options: Optional[Dict[str, Any]] = None,
|
| 59 |
+
audio_array: Optional[torch.Tensor] = None,
|
| 60 |
+
audio_fps: Optional[float] = None,
|
| 61 |
+
audio_codec: Optional[str] = None,
|
| 62 |
+
audio_options: Optional[Dict[str, Any]] = None,
|
| 63 |
+
) -> None:
|
| 64 |
+
"""
|
| 65 |
+
Writes a 4d tensor in [T, H, W, C] format in a video file
|
| 66 |
+
|
| 67 |
+
.. warning::
|
| 68 |
+
|
| 69 |
+
In the near future, we intend to centralize PyTorch's video decoding
|
| 70 |
+
capabilities within the `torchcodec
|
| 71 |
+
<https://github.com/pytorch/torchcodec>`_ project. We encourage you to
|
| 72 |
+
try it out and share your feedback, as the torchvision video decoders
|
| 73 |
+
will eventually be deprecated.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
filename (str): path where the video will be saved
|
| 77 |
+
video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
|
| 78 |
+
as a uint8 tensor in [T, H, W, C] format
|
| 79 |
+
fps (Number): video frames per second
|
| 80 |
+
video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc.
|
| 81 |
+
options (Dict): dictionary containing options to be passed into the PyAV video stream
|
| 82 |
+
audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels
|
| 83 |
+
and N is the number of samples
|
| 84 |
+
audio_fps (Number): audio sample rate, typically 44100 or 48000
|
| 85 |
+
audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc.
|
| 86 |
+
audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream
|
| 87 |
+
"""
|
| 88 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 89 |
+
_log_api_usage_once(write_video)
|
| 90 |
+
_check_av_available()
|
| 91 |
+
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True)
|
| 92 |
+
|
| 93 |
+
# PyAV does not support floating point numbers with decimal point
|
| 94 |
+
# and will throw OverflowException in case this is not the case
|
| 95 |
+
if isinstance(fps, float):
|
| 96 |
+
fps = np.round(fps)
|
| 97 |
+
|
| 98 |
+
with av.open(filename, mode="w") as container:
|
| 99 |
+
stream = container.add_stream(video_codec, rate=fps)
|
| 100 |
+
stream.width = video_array.shape[2]
|
| 101 |
+
stream.height = video_array.shape[1]
|
| 102 |
+
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
|
| 103 |
+
stream.options = options or {}
|
| 104 |
+
|
| 105 |
+
if audio_array is not None:
|
| 106 |
+
audio_format_dtypes = {
|
| 107 |
+
"dbl": "<f8",
|
| 108 |
+
"dblp": "<f8",
|
| 109 |
+
"flt": "<f4",
|
| 110 |
+
"fltp": "<f4",
|
| 111 |
+
"s16": "<i2",
|
| 112 |
+
"s16p": "<i2",
|
| 113 |
+
"s32": "<i4",
|
| 114 |
+
"s32p": "<i4",
|
| 115 |
+
"u8": "u1",
|
| 116 |
+
"u8p": "u1",
|
| 117 |
+
}
|
| 118 |
+
a_stream = container.add_stream(audio_codec, rate=audio_fps)
|
| 119 |
+
a_stream.options = audio_options or {}
|
| 120 |
+
|
| 121 |
+
num_channels = audio_array.shape[0]
|
| 122 |
+
audio_layout = "stereo" if num_channels > 1 else "mono"
|
| 123 |
+
audio_sample_fmt = container.streams.audio[0].format.name
|
| 124 |
+
|
| 125 |
+
format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
|
| 126 |
+
audio_array = torch.as_tensor(audio_array).numpy(force=True).astype(format_dtype)
|
| 127 |
+
|
| 128 |
+
frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
|
| 129 |
+
|
| 130 |
+
frame.sample_rate = audio_fps
|
| 131 |
+
|
| 132 |
+
for packet in a_stream.encode(frame):
|
| 133 |
+
container.mux(packet)
|
| 134 |
+
|
| 135 |
+
for packet in a_stream.encode():
|
| 136 |
+
container.mux(packet)
|
| 137 |
+
|
| 138 |
+
for img in video_array:
|
| 139 |
+
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
|
| 140 |
+
frame.pict_type = "NONE"
|
| 141 |
+
for packet in stream.encode(frame):
|
| 142 |
+
container.mux(packet)
|
| 143 |
+
|
| 144 |
+
# Flush stream
|
| 145 |
+
for packet in stream.encode():
|
| 146 |
+
container.mux(packet)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _read_from_stream(
|
| 150 |
+
container: "av.container.Container",
|
| 151 |
+
start_offset: float,
|
| 152 |
+
end_offset: float,
|
| 153 |
+
pts_unit: str,
|
| 154 |
+
stream: "av.stream.Stream",
|
| 155 |
+
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
|
| 156 |
+
) -> List["av.frame.Frame"]:
|
| 157 |
+
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
|
| 158 |
+
_CALLED_TIMES += 1
|
| 159 |
+
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
|
| 160 |
+
gc.collect()
|
| 161 |
+
|
| 162 |
+
if pts_unit == "sec":
|
| 163 |
+
# TODO: we should change all of this from ground up to simply take
|
| 164 |
+
# sec and convert to MS in C++
|
| 165 |
+
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
|
| 166 |
+
if end_offset != float("inf"):
|
| 167 |
+
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
|
| 168 |
+
else:
|
| 169 |
+
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
|
| 170 |
+
|
| 171 |
+
frames = {}
|
| 172 |
+
should_buffer = True
|
| 173 |
+
max_buffer_size = 5
|
| 174 |
+
if stream.type == "video":
|
| 175 |
+
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
|
| 176 |
+
# so need to buffer some extra frames to sort everything
|
| 177 |
+
# properly
|
| 178 |
+
extradata = stream.codec_context.extradata
|
| 179 |
+
# overly complicated way of finding if `divx_packed` is set, following
|
| 180 |
+
# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
|
| 181 |
+
if extradata and b"DivX" in extradata:
|
| 182 |
+
# can't use regex directly because of some weird characters sometimes...
|
| 183 |
+
pos = extradata.find(b"DivX")
|
| 184 |
+
d = extradata[pos:]
|
| 185 |
+
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
|
| 186 |
+
if o is None:
|
| 187 |
+
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
|
| 188 |
+
if o is not None:
|
| 189 |
+
should_buffer = o.group(3) == b"p"
|
| 190 |
+
seek_offset = start_offset
|
| 191 |
+
# some files don't seek to the right location, so better be safe here
|
| 192 |
+
seek_offset = max(seek_offset - 1, 0)
|
| 193 |
+
if should_buffer:
|
| 194 |
+
# FIXME this is kind of a hack, but we will jump to the previous keyframe
|
| 195 |
+
# so this will be safe
|
| 196 |
+
seek_offset = max(seek_offset - max_buffer_size, 0)
|
| 197 |
+
try:
|
| 198 |
+
# TODO check if stream needs to always be the video stream here or not
|
| 199 |
+
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
|
| 200 |
+
except av.AVError:
|
| 201 |
+
# TODO add some warnings in this case
|
| 202 |
+
# print("Corrupted file?", container.name)
|
| 203 |
+
return []
|
| 204 |
+
buffer_count = 0
|
| 205 |
+
try:
|
| 206 |
+
for _idx, frame in enumerate(container.decode(**stream_name)):
|
| 207 |
+
frames[frame.pts] = frame
|
| 208 |
+
if frame.pts >= end_offset:
|
| 209 |
+
if should_buffer and buffer_count < max_buffer_size:
|
| 210 |
+
buffer_count += 1
|
| 211 |
+
continue
|
| 212 |
+
break
|
| 213 |
+
except av.AVError:
|
| 214 |
+
# TODO add a warning
|
| 215 |
+
pass
|
| 216 |
+
# ensure that the results are sorted wrt the pts
|
| 217 |
+
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
|
| 218 |
+
if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
|
| 219 |
+
# if there is no frame that exactly matches the pts of start_offset
|
| 220 |
+
# add the last frame smaller than start_offset, to guarantee that
|
| 221 |
+
# we will have all the necessary data. This is most useful for audio
|
| 222 |
+
preceding_frames = [i for i in frames if i < start_offset]
|
| 223 |
+
if len(preceding_frames) > 0:
|
| 224 |
+
first_frame_pts = max(preceding_frames)
|
| 225 |
+
result.insert(0, frames[first_frame_pts])
|
| 226 |
+
return result
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _align_audio_frames(
|
| 230 |
+
aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
|
| 231 |
+
) -> torch.Tensor:
|
| 232 |
+
start, end = audio_frames[0].pts, audio_frames[-1].pts
|
| 233 |
+
total_aframes = aframes.shape[1]
|
| 234 |
+
step_per_aframe = (end - start + 1) / total_aframes
|
| 235 |
+
s_idx = 0
|
| 236 |
+
e_idx = total_aframes
|
| 237 |
+
if start < ref_start:
|
| 238 |
+
s_idx = int((ref_start - start) / step_per_aframe)
|
| 239 |
+
if end > ref_end:
|
| 240 |
+
e_idx = int((ref_end - end) / step_per_aframe)
|
| 241 |
+
return aframes[:, s_idx:e_idx]
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def read_video(
|
| 245 |
+
filename: str,
|
| 246 |
+
start_pts: Union[float, Fraction] = 0,
|
| 247 |
+
end_pts: Optional[Union[float, Fraction]] = None,
|
| 248 |
+
pts_unit: str = "pts",
|
| 249 |
+
output_format: str = "THWC",
|
| 250 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
| 251 |
+
"""
|
| 252 |
+
Reads a video from a file, returning both the video frames and the audio frames
|
| 253 |
+
|
| 254 |
+
.. warning::
|
| 255 |
+
|
| 256 |
+
In the near future, we intend to centralize PyTorch's video decoding
|
| 257 |
+
capabilities within the `torchcodec
|
| 258 |
+
<https://github.com/pytorch/torchcodec>`_ project. We encourage you to
|
| 259 |
+
try it out and share your feedback, as the torchvision video decoders
|
| 260 |
+
will eventually be deprecated.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
filename (str): path to the video file. If using the pyav backend, this can be whatever ``av.open`` accepts.
|
| 264 |
+
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
| 265 |
+
The start presentation time of the video
|
| 266 |
+
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
| 267 |
+
The end presentation time
|
| 268 |
+
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
|
| 269 |
+
either 'pts' or 'sec'. Defaults to 'pts'.
|
| 270 |
+
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
|
| 274 |
+
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
|
| 275 |
+
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
|
| 276 |
+
"""
|
| 277 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 278 |
+
_log_api_usage_once(read_video)
|
| 279 |
+
|
| 280 |
+
output_format = output_format.upper()
|
| 281 |
+
if output_format not in ("THWC", "TCHW"):
|
| 282 |
+
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
|
| 283 |
+
|
| 284 |
+
from torchvision import get_video_backend
|
| 285 |
+
|
| 286 |
+
if get_video_backend() != "pyav":
|
| 287 |
+
if not os.path.exists(filename):
|
| 288 |
+
raise RuntimeError(f"File not found: {filename}")
|
| 289 |
+
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
|
| 290 |
+
else:
|
| 291 |
+
_check_av_available()
|
| 292 |
+
|
| 293 |
+
if end_pts is None:
|
| 294 |
+
end_pts = float("inf")
|
| 295 |
+
|
| 296 |
+
if end_pts < start_pts:
|
| 297 |
+
raise ValueError(
|
| 298 |
+
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
info = {}
|
| 302 |
+
video_frames = []
|
| 303 |
+
audio_frames = []
|
| 304 |
+
audio_timebase = _video_opt.default_timebase
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
with av.open(filename, metadata_errors="ignore") as container:
|
| 308 |
+
if container.streams.audio:
|
| 309 |
+
audio_timebase = container.streams.audio[0].time_base
|
| 310 |
+
if container.streams.video:
|
| 311 |
+
video_frames = _read_from_stream(
|
| 312 |
+
container,
|
| 313 |
+
start_pts,
|
| 314 |
+
end_pts,
|
| 315 |
+
pts_unit,
|
| 316 |
+
container.streams.video[0],
|
| 317 |
+
{"video": 0},
|
| 318 |
+
)
|
| 319 |
+
video_fps = container.streams.video[0].average_rate
|
| 320 |
+
# guard against potentially corrupted files
|
| 321 |
+
if video_fps is not None:
|
| 322 |
+
info["video_fps"] = float(video_fps)
|
| 323 |
+
|
| 324 |
+
if container.streams.audio:
|
| 325 |
+
audio_frames = _read_from_stream(
|
| 326 |
+
container,
|
| 327 |
+
start_pts,
|
| 328 |
+
end_pts,
|
| 329 |
+
pts_unit,
|
| 330 |
+
container.streams.audio[0],
|
| 331 |
+
{"audio": 0},
|
| 332 |
+
)
|
| 333 |
+
info["audio_fps"] = container.streams.audio[0].rate
|
| 334 |
+
|
| 335 |
+
except av.AVError:
|
| 336 |
+
# TODO raise a warning?
|
| 337 |
+
pass
|
| 338 |
+
|
| 339 |
+
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
|
| 340 |
+
aframes_list = [frame.to_ndarray() for frame in audio_frames]
|
| 341 |
+
|
| 342 |
+
if vframes_list:
|
| 343 |
+
vframes = torch.as_tensor(np.stack(vframes_list))
|
| 344 |
+
else:
|
| 345 |
+
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
|
| 346 |
+
|
| 347 |
+
if aframes_list:
|
| 348 |
+
aframes = np.concatenate(aframes_list, 1)
|
| 349 |
+
aframes = torch.as_tensor(aframes)
|
| 350 |
+
if pts_unit == "sec":
|
| 351 |
+
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
|
| 352 |
+
if end_pts != float("inf"):
|
| 353 |
+
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
|
| 354 |
+
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
|
| 355 |
+
else:
|
| 356 |
+
aframes = torch.empty((1, 0), dtype=torch.float32)
|
| 357 |
+
|
| 358 |
+
if output_format == "TCHW":
|
| 359 |
+
# [T,H,W,C] --> [T,C,H,W]
|
| 360 |
+
vframes = vframes.permute(0, 3, 1, 2)
|
| 361 |
+
|
| 362 |
+
return vframes, aframes, info
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
|
| 366 |
+
extradata = container.streams[0].codec_context.extradata
|
| 367 |
+
if extradata is None:
|
| 368 |
+
return False
|
| 369 |
+
if b"Lavc" in extradata:
|
| 370 |
+
return True
|
| 371 |
+
return False
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
|
| 375 |
+
if _can_read_timestamps_from_packets(container):
|
| 376 |
+
# fast path
|
| 377 |
+
return [x.pts for x in container.demux(video=0) if x.pts is not None]
|
| 378 |
+
else:
|
| 379 |
+
return [x.pts for x in container.decode(video=0) if x.pts is not None]
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
|
| 383 |
+
"""
|
| 384 |
+
List the video frames timestamps.
|
| 385 |
+
|
| 386 |
+
.. warning::
|
| 387 |
+
|
| 388 |
+
In the near future, we intend to centralize PyTorch's video decoding
|
| 389 |
+
capabilities within the `torchcodec
|
| 390 |
+
<https://github.com/pytorch/torchcodec>`_ project. We encourage you to
|
| 391 |
+
try it out and share your feedback, as the torchvision video decoders
|
| 392 |
+
will eventually be deprecated.
|
| 393 |
+
|
| 394 |
+
Note that the function decodes the whole video frame-by-frame.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
filename (str): path to the video file
|
| 398 |
+
pts_unit (str, optional): unit in which timestamp values will be returned
|
| 399 |
+
either 'pts' or 'sec'. Defaults to 'pts'.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'):
|
| 403 |
+
presentation timestamps for each one of the frames in the video.
|
| 404 |
+
video_fps (float, optional): the frame rate for the video
|
| 405 |
+
|
| 406 |
+
"""
|
| 407 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 408 |
+
_log_api_usage_once(read_video_timestamps)
|
| 409 |
+
from torchvision import get_video_backend
|
| 410 |
+
|
| 411 |
+
if get_video_backend() != "pyav":
|
| 412 |
+
return _video_opt._read_video_timestamps(filename, pts_unit)
|
| 413 |
+
|
| 414 |
+
_check_av_available()
|
| 415 |
+
|
| 416 |
+
video_fps = None
|
| 417 |
+
pts = []
|
| 418 |
+
|
| 419 |
+
try:
|
| 420 |
+
with av.open(filename, metadata_errors="ignore") as container:
|
| 421 |
+
if container.streams.video:
|
| 422 |
+
video_stream = container.streams.video[0]
|
| 423 |
+
video_time_base = video_stream.time_base
|
| 424 |
+
try:
|
| 425 |
+
pts = _decode_video_timestamps(container)
|
| 426 |
+
except av.AVError:
|
| 427 |
+
warnings.warn(f"Failed decoding frames for file {filename}")
|
| 428 |
+
video_fps = float(video_stream.average_rate)
|
| 429 |
+
except av.AVError as e:
|
| 430 |
+
msg = f"Failed to open container for {filename}; Caught error: {e}"
|
| 431 |
+
warnings.warn(msg, RuntimeWarning)
|
| 432 |
+
|
| 433 |
+
pts.sort()
|
| 434 |
+
|
| 435 |
+
if pts_unit == "sec":
|
| 436 |
+
pts = [x * video_time_base for x in pts]
|
| 437 |
+
|
| 438 |
+
return pts, video_fps
|
.venv/lib/python3.11/site-packages/torchvision/io/video_reader.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
from typing import Any, Dict, Iterator
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ..utils import _log_api_usage_once
|
| 9 |
+
|
| 10 |
+
from ._video_opt import _HAS_CPU_VIDEO_DECODER
|
| 11 |
+
|
| 12 |
+
if _HAS_CPU_VIDEO_DECODER:
|
| 13 |
+
|
| 14 |
+
def _has_video_opt() -> bool:
|
| 15 |
+
return True
|
| 16 |
+
|
| 17 |
+
else:
|
| 18 |
+
|
| 19 |
+
def _has_video_opt() -> bool:
|
| 20 |
+
return False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
import av
|
| 25 |
+
|
| 26 |
+
av.logging.set_level(av.logging.ERROR)
|
| 27 |
+
if not hasattr(av.video.frame.VideoFrame, "pict_type"):
|
| 28 |
+
av = ImportError(
|
| 29 |
+
"""\
|
| 30 |
+
Your version of PyAV is too old for the necessary video operations in torchvision.
|
| 31 |
+
If you are on Python 3.5, you will have to build from source (the conda-forge
|
| 32 |
+
packages are not up-to-date). See
|
| 33 |
+
https://github.com/mikeboers/PyAV#installation for instructions on how to
|
| 34 |
+
install PyAV on your system.
|
| 35 |
+
"""
|
| 36 |
+
)
|
| 37 |
+
except ImportError:
|
| 38 |
+
av = ImportError(
|
| 39 |
+
"""\
|
| 40 |
+
PyAV is not installed, and is necessary for the video operations in torchvision.
|
| 41 |
+
See https://github.com/mikeboers/PyAV#installation for instructions on how to
|
| 42 |
+
install PyAV on your system.
|
| 43 |
+
"""
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class VideoReader:
|
| 48 |
+
"""
|
| 49 |
+
Fine-grained video-reading API.
|
| 50 |
+
Supports frame-by-frame reading of various streams from a single video
|
| 51 |
+
container. Much like previous video_reader API it supports the following
|
| 52 |
+
backends: video_reader, pyav, and cuda.
|
| 53 |
+
Backends can be set via `torchvision.set_video_backend` function.
|
| 54 |
+
|
| 55 |
+
.. warning::
|
| 56 |
+
|
| 57 |
+
In the near future, we intend to centralize PyTorch's video decoding
|
| 58 |
+
capabilities within the `torchcodec
|
| 59 |
+
<https://github.com/pytorch/torchcodec>`_ project. We encourage you to
|
| 60 |
+
try it out and share your feedback, as the torchvision video decoders
|
| 61 |
+
will eventually be deprecated.
|
| 62 |
+
|
| 63 |
+
.. betastatus:: VideoReader class
|
| 64 |
+
|
| 65 |
+
Example:
|
| 66 |
+
The following examples creates a :mod:`VideoReader` object, seeks into 2s
|
| 67 |
+
point, and returns a single frame::
|
| 68 |
+
|
| 69 |
+
import torchvision
|
| 70 |
+
video_path = "path_to_a_test_video"
|
| 71 |
+
reader = torchvision.io.VideoReader(video_path, "video")
|
| 72 |
+
reader.seek(2.0)
|
| 73 |
+
frame = next(reader)
|
| 74 |
+
|
| 75 |
+
:mod:`VideoReader` implements the iterable API, which makes it suitable to
|
| 76 |
+
using it in conjunction with :mod:`itertools` for more advanced reading.
|
| 77 |
+
As such, we can use a :mod:`VideoReader` instance inside for loops::
|
| 78 |
+
|
| 79 |
+
reader.seek(2)
|
| 80 |
+
for frame in reader:
|
| 81 |
+
frames.append(frame['data'])
|
| 82 |
+
# additionally, `seek` implements a fluent API, so we can do
|
| 83 |
+
for frame in reader.seek(2):
|
| 84 |
+
frames.append(frame['data'])
|
| 85 |
+
|
| 86 |
+
With :mod:`itertools`, we can read all frames between 2 and 5 seconds with the
|
| 87 |
+
following code::
|
| 88 |
+
|
| 89 |
+
for frame in itertools.takewhile(lambda x: x['pts'] <= 5, reader.seek(2)):
|
| 90 |
+
frames.append(frame['data'])
|
| 91 |
+
|
| 92 |
+
and similarly, reading 10 frames after the 2s timestamp can be achieved
|
| 93 |
+
as follows::
|
| 94 |
+
|
| 95 |
+
for frame in itertools.islice(reader.seek(2), 10):
|
| 96 |
+
frames.append(frame['data'])
|
| 97 |
+
|
| 98 |
+
.. note::
|
| 99 |
+
|
| 100 |
+
Each stream descriptor consists of two parts: stream type (e.g. 'video') and
|
| 101 |
+
a unique stream id (which are determined by the video encoding).
|
| 102 |
+
In this way, if the video container contains multiple
|
| 103 |
+
streams of the same type, users can access the one they want.
|
| 104 |
+
If only stream type is passed, the decoder auto-detects first stream of that type.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
src (string, bytes object, or tensor): The media source.
|
| 108 |
+
If string-type, it must be a file path supported by FFMPEG.
|
| 109 |
+
If bytes, should be an in-memory representation of a file supported by FFMPEG.
|
| 110 |
+
If Tensor, it is interpreted internally as byte buffer.
|
| 111 |
+
It must be one-dimensional, of type ``torch.uint8``.
|
| 112 |
+
|
| 113 |
+
stream (string, optional): descriptor of the required stream, followed by the stream id,
|
| 114 |
+
in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
|
| 115 |
+
Currently available options include ``['video', 'audio']``
|
| 116 |
+
|
| 117 |
+
num_threads (int, optional): number of threads used by the codec to decode video.
|
| 118 |
+
Default value (0) enables multithreading with codec-dependent heuristic. The performance
|
| 119 |
+
will depend on the version of FFMPEG codecs supported.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
src: str,
|
| 125 |
+
stream: str = "video",
|
| 126 |
+
num_threads: int = 0,
|
| 127 |
+
) -> None:
|
| 128 |
+
_log_api_usage_once(self)
|
| 129 |
+
from .. import get_video_backend
|
| 130 |
+
|
| 131 |
+
self.backend = get_video_backend()
|
| 132 |
+
if isinstance(src, str):
|
| 133 |
+
if not src:
|
| 134 |
+
raise ValueError("src cannot be empty")
|
| 135 |
+
elif isinstance(src, bytes):
|
| 136 |
+
if self.backend in ["cuda"]:
|
| 137 |
+
raise RuntimeError(
|
| 138 |
+
"VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
|
| 139 |
+
)
|
| 140 |
+
elif self.backend == "pyav":
|
| 141 |
+
src = io.BytesIO(src)
|
| 142 |
+
else:
|
| 143 |
+
with warnings.catch_warnings():
|
| 144 |
+
# Ignore the warning because we actually don't modify the buffer in this function
|
| 145 |
+
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
| 146 |
+
src = torch.frombuffer(src, dtype=torch.uint8)
|
| 147 |
+
elif isinstance(src, torch.Tensor):
|
| 148 |
+
if self.backend in ["cuda", "pyav"]:
|
| 149 |
+
raise RuntimeError(
|
| 150 |
+
"VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"src must be either string, Tensor or bytes object. Got {type(src)}")
|
| 154 |
+
|
| 155 |
+
if self.backend == "cuda":
|
| 156 |
+
device = torch.device("cuda")
|
| 157 |
+
self._c = torch.classes.torchvision.GPUDecoder(src, device)
|
| 158 |
+
|
| 159 |
+
elif self.backend == "video_reader":
|
| 160 |
+
if isinstance(src, str):
|
| 161 |
+
self._c = torch.classes.torchvision.Video(src, stream, num_threads)
|
| 162 |
+
elif isinstance(src, torch.Tensor):
|
| 163 |
+
self._c = torch.classes.torchvision.Video("", "", 0)
|
| 164 |
+
self._c.init_from_memory(src, stream, num_threads)
|
| 165 |
+
|
| 166 |
+
elif self.backend == "pyav":
|
| 167 |
+
self.container = av.open(src, metadata_errors="ignore")
|
| 168 |
+
# TODO: load metadata
|
| 169 |
+
stream_type = stream.split(":")[0]
|
| 170 |
+
stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
|
| 171 |
+
self.pyav_stream = {stream_type: stream_id}
|
| 172 |
+
self._c = self.container.decode(**self.pyav_stream)
|
| 173 |
+
|
| 174 |
+
# TODO: add extradata exception
|
| 175 |
+
|
| 176 |
+
else:
|
| 177 |
+
raise RuntimeError("Unknown video backend: {}".format(self.backend))
|
| 178 |
+
|
| 179 |
+
def __next__(self) -> Dict[str, Any]:
|
| 180 |
+
"""Decodes and returns the next frame of the current stream.
|
| 181 |
+
Frames are encoded as a dict with mandatory
|
| 182 |
+
data and pts fields, where data is a tensor, and pts is a
|
| 183 |
+
presentation timestamp of the frame expressed in seconds
|
| 184 |
+
as a float.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
(dict): a dictionary and containing decoded frame (``data``)
|
| 188 |
+
and corresponding timestamp (``pts``) in seconds
|
| 189 |
+
|
| 190 |
+
"""
|
| 191 |
+
if self.backend == "cuda":
|
| 192 |
+
frame = self._c.next()
|
| 193 |
+
if frame.numel() == 0:
|
| 194 |
+
raise StopIteration
|
| 195 |
+
return {"data": frame, "pts": None}
|
| 196 |
+
elif self.backend == "video_reader":
|
| 197 |
+
frame, pts = self._c.next()
|
| 198 |
+
else:
|
| 199 |
+
try:
|
| 200 |
+
frame = next(self._c)
|
| 201 |
+
pts = float(frame.pts * frame.time_base)
|
| 202 |
+
if "video" in self.pyav_stream:
|
| 203 |
+
frame = torch.as_tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1)
|
| 204 |
+
elif "audio" in self.pyav_stream:
|
| 205 |
+
frame = torch.as_tensor(frame.to_ndarray()).permute(1, 0)
|
| 206 |
+
else:
|
| 207 |
+
frame = None
|
| 208 |
+
except av.error.EOFError:
|
| 209 |
+
raise StopIteration
|
| 210 |
+
|
| 211 |
+
if frame.numel() == 0:
|
| 212 |
+
raise StopIteration
|
| 213 |
+
|
| 214 |
+
return {"data": frame, "pts": pts}
|
| 215 |
+
|
| 216 |
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
| 217 |
+
return self
|
| 218 |
+
|
| 219 |
+
def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
|
| 220 |
+
"""Seek within current stream.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
time_s (float): seek time in seconds
|
| 224 |
+
keyframes_only (bool): allow to seek only to keyframes
|
| 225 |
+
|
| 226 |
+
.. note::
|
| 227 |
+
Current implementation is the so-called precise seek. This
|
| 228 |
+
means following seek, call to :mod:`next()` will return the
|
| 229 |
+
frame with the exact timestamp if it exists or
|
| 230 |
+
the first frame with timestamp larger than ``time_s``.
|
| 231 |
+
"""
|
| 232 |
+
if self.backend in ["cuda", "video_reader"]:
|
| 233 |
+
self._c.seek(time_s, keyframes_only)
|
| 234 |
+
else:
|
| 235 |
+
# handle special case as pyav doesn't catch it
|
| 236 |
+
if time_s < 0:
|
| 237 |
+
time_s = 0
|
| 238 |
+
temp_str = self.container.streams.get(**self.pyav_stream)[0]
|
| 239 |
+
offset = int(round(time_s / temp_str.time_base))
|
| 240 |
+
if not keyframes_only:
|
| 241 |
+
warnings.warn("Accurate seek is not implemented for pyav backend")
|
| 242 |
+
self.container.seek(offset, backward=True, any_frame=False, stream=temp_str)
|
| 243 |
+
self._c = self.container.decode(**self.pyav_stream)
|
| 244 |
+
return self
|
| 245 |
+
|
| 246 |
+
def get_metadata(self) -> Dict[str, Any]:
|
| 247 |
+
"""Returns video metadata
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
(dict): dictionary containing duration and frame rate for every stream
|
| 251 |
+
"""
|
| 252 |
+
if self.backend == "pyav":
|
| 253 |
+
metadata = {} # type: Dict[str, Any]
|
| 254 |
+
for stream in self.container.streams:
|
| 255 |
+
if stream.type not in metadata:
|
| 256 |
+
if stream.type == "video":
|
| 257 |
+
rate_n = "fps"
|
| 258 |
+
else:
|
| 259 |
+
rate_n = "framerate"
|
| 260 |
+
metadata[stream.type] = {rate_n: [], "duration": []}
|
| 261 |
+
|
| 262 |
+
rate = getattr(stream, "average_rate", None) or stream.sample_rate
|
| 263 |
+
|
| 264 |
+
metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base))
|
| 265 |
+
metadata[stream.type][rate_n].append(float(rate))
|
| 266 |
+
return metadata
|
| 267 |
+
return self._c.get_metadata()
|
| 268 |
+
|
| 269 |
+
def set_current_stream(self, stream: str) -> bool:
|
| 270 |
+
"""Set current stream.
|
| 271 |
+
Explicitly define the stream we are operating on.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
stream (string): descriptor of the required stream. Defaults to ``"video:0"``
|
| 275 |
+
Currently available stream types include ``['video', 'audio']``.
|
| 276 |
+
Each descriptor consists of two parts: stream type (e.g. 'video') and
|
| 277 |
+
a unique stream id (which are determined by video encoding).
|
| 278 |
+
In this way, if the video container contains multiple
|
| 279 |
+
streams of the same type, users can access the one they want.
|
| 280 |
+
If only stream type is passed, the decoder auto-detects first stream
|
| 281 |
+
of that type and returns it.
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
(bool): True on success, False otherwise
|
| 285 |
+
"""
|
| 286 |
+
if self.backend == "cuda":
|
| 287 |
+
warnings.warn("GPU decoding only works with video stream.")
|
| 288 |
+
if self.backend == "pyav":
|
| 289 |
+
stream_type = stream.split(":")[0]
|
| 290 |
+
stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
|
| 291 |
+
self.pyav_stream = {stream_type: stream_id}
|
| 292 |
+
self._c = self.container.decode(**self.pyav_stream)
|
| 293 |
+
return True
|
| 294 |
+
return self._c.set_current_stream(stream)
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .faster_rcnn import *
|
| 2 |
+
from .fcos import *
|
| 3 |
+
from .keypoint_rcnn import *
|
| 4 |
+
from .mask_rcnn import *
|
| 5 |
+
from .retinanet import *
|
| 6 |
+
from .ssd import *
|
| 7 |
+
from .ssdlite import *
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (414 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/_utils.cpython-311.pyc
ADDED
|
Binary file (28.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/anchor_utils.cpython-311.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/backbone_utils.cpython-311.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/faster_rcnn.cpython-311.pyc
ADDED
|
Binary file (39.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/fcos.cpython-311.pyc
ADDED
|
Binary file (42.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/generalized_rcnn.cpython-311.pyc
ADDED
|
Binary file (6.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/image_list.cpython-311.pyc
ADDED
|
Binary file (1.65 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/keypoint_rcnn.cpython-311.pyc
ADDED
|
Binary file (23.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/mask_rcnn.cpython-311.pyc
ADDED
|
Binary file (28 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/retinanet.cpython-311.pyc
ADDED
|
Binary file (42.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/roi_heads.cpython-311.pyc
ADDED
|
Binary file (45.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/rpn.cpython-311.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssd.cpython-311.pyc
ADDED
|
Binary file (37.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssdlite.cpython-311.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/transform.cpython-311.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/_utils.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from typing import Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn, Tensor
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BalancedPositiveNegativeSampler:
|
| 12 |
+
"""
|
| 13 |
+
This class samples batches, ensuring that they contain a fixed proportion of positives
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
|
| 17 |
+
"""
|
| 18 |
+
Args:
|
| 19 |
+
batch_size_per_image (int): number of elements to be selected per image
|
| 20 |
+
positive_fraction (float): percentage of positive elements per batch
|
| 21 |
+
"""
|
| 22 |
+
self.batch_size_per_image = batch_size_per_image
|
| 23 |
+
self.positive_fraction = positive_fraction
|
| 24 |
+
|
| 25 |
+
def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
|
| 26 |
+
"""
|
| 27 |
+
Args:
|
| 28 |
+
matched_idxs: list of tensors containing -1, 0 or positive values.
|
| 29 |
+
Each tensor corresponds to a specific image.
|
| 30 |
+
-1 values are ignored, 0 are considered as negatives and > 0 as
|
| 31 |
+
positives.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
pos_idx (list[tensor])
|
| 35 |
+
neg_idx (list[tensor])
|
| 36 |
+
|
| 37 |
+
Returns two lists of binary masks for each image.
|
| 38 |
+
The first list contains the positive elements that were selected,
|
| 39 |
+
and the second list the negative example.
|
| 40 |
+
"""
|
| 41 |
+
pos_idx = []
|
| 42 |
+
neg_idx = []
|
| 43 |
+
for matched_idxs_per_image in matched_idxs:
|
| 44 |
+
positive = torch.where(matched_idxs_per_image >= 1)[0]
|
| 45 |
+
negative = torch.where(matched_idxs_per_image == 0)[0]
|
| 46 |
+
|
| 47 |
+
num_pos = int(self.batch_size_per_image * self.positive_fraction)
|
| 48 |
+
# protect against not enough positive examples
|
| 49 |
+
num_pos = min(positive.numel(), num_pos)
|
| 50 |
+
num_neg = self.batch_size_per_image - num_pos
|
| 51 |
+
# protect against not enough negative examples
|
| 52 |
+
num_neg = min(negative.numel(), num_neg)
|
| 53 |
+
|
| 54 |
+
# randomly select positive and negative examples
|
| 55 |
+
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
|
| 56 |
+
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
|
| 57 |
+
|
| 58 |
+
pos_idx_per_image = positive[perm1]
|
| 59 |
+
neg_idx_per_image = negative[perm2]
|
| 60 |
+
|
| 61 |
+
# create binary mask from indices
|
| 62 |
+
pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
|
| 63 |
+
neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
|
| 64 |
+
|
| 65 |
+
pos_idx_per_image_mask[pos_idx_per_image] = 1
|
| 66 |
+
neg_idx_per_image_mask[neg_idx_per_image] = 1
|
| 67 |
+
|
| 68 |
+
pos_idx.append(pos_idx_per_image_mask)
|
| 69 |
+
neg_idx.append(neg_idx_per_image_mask)
|
| 70 |
+
|
| 71 |
+
return pos_idx, neg_idx
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@torch.jit._script_if_tracing
|
| 75 |
+
def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
|
| 76 |
+
"""
|
| 77 |
+
Encode a set of proposals with respect to some
|
| 78 |
+
reference boxes
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
reference_boxes (Tensor): reference boxes
|
| 82 |
+
proposals (Tensor): boxes to be encoded
|
| 83 |
+
weights (Tensor[4]): the weights for ``(x, y, w, h)``
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# perform some unpacking to make it JIT-fusion friendly
|
| 87 |
+
wx = weights[0]
|
| 88 |
+
wy = weights[1]
|
| 89 |
+
ww = weights[2]
|
| 90 |
+
wh = weights[3]
|
| 91 |
+
|
| 92 |
+
proposals_x1 = proposals[:, 0].unsqueeze(1)
|
| 93 |
+
proposals_y1 = proposals[:, 1].unsqueeze(1)
|
| 94 |
+
proposals_x2 = proposals[:, 2].unsqueeze(1)
|
| 95 |
+
proposals_y2 = proposals[:, 3].unsqueeze(1)
|
| 96 |
+
|
| 97 |
+
reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
|
| 98 |
+
reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
|
| 99 |
+
reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
|
| 100 |
+
reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
|
| 101 |
+
|
| 102 |
+
# implementation starts here
|
| 103 |
+
ex_widths = proposals_x2 - proposals_x1
|
| 104 |
+
ex_heights = proposals_y2 - proposals_y1
|
| 105 |
+
ex_ctr_x = proposals_x1 + 0.5 * ex_widths
|
| 106 |
+
ex_ctr_y = proposals_y1 + 0.5 * ex_heights
|
| 107 |
+
|
| 108 |
+
gt_widths = reference_boxes_x2 - reference_boxes_x1
|
| 109 |
+
gt_heights = reference_boxes_y2 - reference_boxes_y1
|
| 110 |
+
gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
|
| 111 |
+
gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
|
| 112 |
+
|
| 113 |
+
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
|
| 114 |
+
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
|
| 115 |
+
targets_dw = ww * torch.log(gt_widths / ex_widths)
|
| 116 |
+
targets_dh = wh * torch.log(gt_heights / ex_heights)
|
| 117 |
+
|
| 118 |
+
targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
|
| 119 |
+
return targets
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class BoxCoder:
|
| 123 |
+
"""
|
| 124 |
+
This class encodes and decodes a set of bounding boxes into
|
| 125 |
+
the representation used for training the regressors.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
|
| 130 |
+
) -> None:
|
| 131 |
+
"""
|
| 132 |
+
Args:
|
| 133 |
+
weights (4-element tuple)
|
| 134 |
+
bbox_xform_clip (float)
|
| 135 |
+
"""
|
| 136 |
+
self.weights = weights
|
| 137 |
+
self.bbox_xform_clip = bbox_xform_clip
|
| 138 |
+
|
| 139 |
+
def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
|
| 140 |
+
boxes_per_image = [len(b) for b in reference_boxes]
|
| 141 |
+
reference_boxes = torch.cat(reference_boxes, dim=0)
|
| 142 |
+
proposals = torch.cat(proposals, dim=0)
|
| 143 |
+
targets = self.encode_single(reference_boxes, proposals)
|
| 144 |
+
return targets.split(boxes_per_image, 0)
|
| 145 |
+
|
| 146 |
+
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
|
| 147 |
+
"""
|
| 148 |
+
Encode a set of proposals with respect to some
|
| 149 |
+
reference boxes
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
reference_boxes (Tensor): reference boxes
|
| 153 |
+
proposals (Tensor): boxes to be encoded
|
| 154 |
+
"""
|
| 155 |
+
dtype = reference_boxes.dtype
|
| 156 |
+
device = reference_boxes.device
|
| 157 |
+
weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
|
| 158 |
+
targets = encode_boxes(reference_boxes, proposals, weights)
|
| 159 |
+
|
| 160 |
+
return targets
|
| 161 |
+
|
| 162 |
+
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
|
| 163 |
+
torch._assert(
|
| 164 |
+
isinstance(boxes, (list, tuple)),
|
| 165 |
+
"This function expects boxes of type list or tuple.",
|
| 166 |
+
)
|
| 167 |
+
torch._assert(
|
| 168 |
+
isinstance(rel_codes, torch.Tensor),
|
| 169 |
+
"This function expects rel_codes of type torch.Tensor.",
|
| 170 |
+
)
|
| 171 |
+
boxes_per_image = [b.size(0) for b in boxes]
|
| 172 |
+
concat_boxes = torch.cat(boxes, dim=0)
|
| 173 |
+
box_sum = 0
|
| 174 |
+
for val in boxes_per_image:
|
| 175 |
+
box_sum += val
|
| 176 |
+
if box_sum > 0:
|
| 177 |
+
rel_codes = rel_codes.reshape(box_sum, -1)
|
| 178 |
+
pred_boxes = self.decode_single(rel_codes, concat_boxes)
|
| 179 |
+
if box_sum > 0:
|
| 180 |
+
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
|
| 181 |
+
return pred_boxes
|
| 182 |
+
|
| 183 |
+
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
|
| 184 |
+
"""
|
| 185 |
+
From a set of original boxes and encoded relative box offsets,
|
| 186 |
+
get the decoded boxes.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
rel_codes (Tensor): encoded boxes
|
| 190 |
+
boxes (Tensor): reference boxes.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
boxes = boxes.to(rel_codes.dtype)
|
| 194 |
+
|
| 195 |
+
widths = boxes[:, 2] - boxes[:, 0]
|
| 196 |
+
heights = boxes[:, 3] - boxes[:, 1]
|
| 197 |
+
ctr_x = boxes[:, 0] + 0.5 * widths
|
| 198 |
+
ctr_y = boxes[:, 1] + 0.5 * heights
|
| 199 |
+
|
| 200 |
+
wx, wy, ww, wh = self.weights
|
| 201 |
+
dx = rel_codes[:, 0::4] / wx
|
| 202 |
+
dy = rel_codes[:, 1::4] / wy
|
| 203 |
+
dw = rel_codes[:, 2::4] / ww
|
| 204 |
+
dh = rel_codes[:, 3::4] / wh
|
| 205 |
+
|
| 206 |
+
# Prevent sending too large values into torch.exp()
|
| 207 |
+
dw = torch.clamp(dw, max=self.bbox_xform_clip)
|
| 208 |
+
dh = torch.clamp(dh, max=self.bbox_xform_clip)
|
| 209 |
+
|
| 210 |
+
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
|
| 211 |
+
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
|
| 212 |
+
pred_w = torch.exp(dw) * widths[:, None]
|
| 213 |
+
pred_h = torch.exp(dh) * heights[:, None]
|
| 214 |
+
|
| 215 |
+
# Distance from center to box's corner.
|
| 216 |
+
c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
|
| 217 |
+
c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
|
| 218 |
+
|
| 219 |
+
pred_boxes1 = pred_ctr_x - c_to_c_w
|
| 220 |
+
pred_boxes2 = pred_ctr_y - c_to_c_h
|
| 221 |
+
pred_boxes3 = pred_ctr_x + c_to_c_w
|
| 222 |
+
pred_boxes4 = pred_ctr_y + c_to_c_h
|
| 223 |
+
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
|
| 224 |
+
return pred_boxes
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class BoxLinearCoder:
|
| 228 |
+
"""
|
| 229 |
+
The linear box-to-box transform defined in FCOS. The transformation is parameterized
|
| 230 |
+
by the distance from the center of (square) src box to 4 edges of the target box.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
def __init__(self, normalize_by_size: bool = True) -> None:
|
| 234 |
+
"""
|
| 235 |
+
Args:
|
| 236 |
+
normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
|
| 237 |
+
"""
|
| 238 |
+
self.normalize_by_size = normalize_by_size
|
| 239 |
+
|
| 240 |
+
def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
|
| 241 |
+
"""
|
| 242 |
+
Encode a set of proposals with respect to some reference boxes
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
reference_boxes (Tensor): reference boxes
|
| 246 |
+
proposals (Tensor): boxes to be encoded
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Tensor: the encoded relative box offsets that can be used to
|
| 250 |
+
decode the boxes.
|
| 251 |
+
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
# get the center of reference_boxes
|
| 255 |
+
reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
|
| 256 |
+
reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
|
| 257 |
+
|
| 258 |
+
# get box regression transformation deltas
|
| 259 |
+
target_l = reference_boxes_ctr_x - proposals[..., 0]
|
| 260 |
+
target_t = reference_boxes_ctr_y - proposals[..., 1]
|
| 261 |
+
target_r = proposals[..., 2] - reference_boxes_ctr_x
|
| 262 |
+
target_b = proposals[..., 3] - reference_boxes_ctr_y
|
| 263 |
+
|
| 264 |
+
targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
|
| 265 |
+
|
| 266 |
+
if self.normalize_by_size:
|
| 267 |
+
reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
|
| 268 |
+
reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
|
| 269 |
+
reference_boxes_size = torch.stack(
|
| 270 |
+
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
|
| 271 |
+
)
|
| 272 |
+
targets = targets / reference_boxes_size
|
| 273 |
+
return targets
|
| 274 |
+
|
| 275 |
+
def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
|
| 276 |
+
|
| 277 |
+
"""
|
| 278 |
+
From a set of original boxes and encoded relative box offsets,
|
| 279 |
+
get the decoded boxes.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
rel_codes (Tensor): encoded boxes
|
| 283 |
+
boxes (Tensor): reference boxes.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
Tensor: the predicted boxes with the encoded relative box offsets.
|
| 287 |
+
|
| 288 |
+
.. note::
|
| 289 |
+
This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
|
| 290 |
+
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
boxes = boxes.to(dtype=rel_codes.dtype)
|
| 294 |
+
|
| 295 |
+
ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
|
| 296 |
+
ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
|
| 297 |
+
|
| 298 |
+
if self.normalize_by_size:
|
| 299 |
+
boxes_w = boxes[..., 2] - boxes[..., 0]
|
| 300 |
+
boxes_h = boxes[..., 3] - boxes[..., 1]
|
| 301 |
+
|
| 302 |
+
list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
|
| 303 |
+
rel_codes = rel_codes * list_box_size
|
| 304 |
+
|
| 305 |
+
pred_boxes1 = ctr_x - rel_codes[..., 0]
|
| 306 |
+
pred_boxes2 = ctr_y - rel_codes[..., 1]
|
| 307 |
+
pred_boxes3 = ctr_x + rel_codes[..., 2]
|
| 308 |
+
pred_boxes4 = ctr_y + rel_codes[..., 3]
|
| 309 |
+
|
| 310 |
+
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
|
| 311 |
+
return pred_boxes
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class Matcher:
|
| 315 |
+
"""
|
| 316 |
+
This class assigns to each predicted "element" (e.g., a box) a ground-truth
|
| 317 |
+
element. Each predicted element will have exactly zero or one matches; each
|
| 318 |
+
ground-truth element may be assigned to zero or more predicted elements.
|
| 319 |
+
|
| 320 |
+
Matching is based on the MxN match_quality_matrix, that characterizes how well
|
| 321 |
+
each (ground-truth, predicted)-pair match. For example, if the elements are
|
| 322 |
+
boxes, the matrix may contain box IoU overlap values.
|
| 323 |
+
|
| 324 |
+
The matcher returns a tensor of size N containing the index of the ground-truth
|
| 325 |
+
element m that matches to prediction n. If there is no match, a negative value
|
| 326 |
+
is returned.
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
BELOW_LOW_THRESHOLD = -1
|
| 330 |
+
BETWEEN_THRESHOLDS = -2
|
| 331 |
+
|
| 332 |
+
__annotations__ = {
|
| 333 |
+
"BELOW_LOW_THRESHOLD": int,
|
| 334 |
+
"BETWEEN_THRESHOLDS": int,
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
|
| 338 |
+
"""
|
| 339 |
+
Args:
|
| 340 |
+
high_threshold (float): quality values greater than or equal to
|
| 341 |
+
this value are candidate matches.
|
| 342 |
+
low_threshold (float): a lower quality threshold used to stratify
|
| 343 |
+
matches into three levels:
|
| 344 |
+
1) matches >= high_threshold
|
| 345 |
+
2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
|
| 346 |
+
3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
|
| 347 |
+
allow_low_quality_matches (bool): if True, produce additional matches
|
| 348 |
+
for predictions that have only low-quality match candidates. See
|
| 349 |
+
set_low_quality_matches_ for more details.
|
| 350 |
+
"""
|
| 351 |
+
self.BELOW_LOW_THRESHOLD = -1
|
| 352 |
+
self.BETWEEN_THRESHOLDS = -2
|
| 353 |
+
torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
|
| 354 |
+
self.high_threshold = high_threshold
|
| 355 |
+
self.low_threshold = low_threshold
|
| 356 |
+
self.allow_low_quality_matches = allow_low_quality_matches
|
| 357 |
+
|
| 358 |
+
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
|
| 359 |
+
"""
|
| 360 |
+
Args:
|
| 361 |
+
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
|
| 362 |
+
pairwise quality between M ground-truth elements and N predicted elements.
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
|
| 366 |
+
[0, M - 1] or a negative value indicating that prediction i could not
|
| 367 |
+
be matched.
|
| 368 |
+
"""
|
| 369 |
+
if match_quality_matrix.numel() == 0:
|
| 370 |
+
# empty targets or proposals not supported during training
|
| 371 |
+
if match_quality_matrix.shape[0] == 0:
|
| 372 |
+
raise ValueError("No ground-truth boxes available for one of the images during training")
|
| 373 |
+
else:
|
| 374 |
+
raise ValueError("No proposal boxes available for one of the images during training")
|
| 375 |
+
|
| 376 |
+
# match_quality_matrix is M (gt) x N (predicted)
|
| 377 |
+
# Max over gt elements (dim 0) to find best gt candidate for each prediction
|
| 378 |
+
matched_vals, matches = match_quality_matrix.max(dim=0)
|
| 379 |
+
if self.allow_low_quality_matches:
|
| 380 |
+
all_matches = matches.clone()
|
| 381 |
+
else:
|
| 382 |
+
all_matches = None # type: ignore[assignment]
|
| 383 |
+
|
| 384 |
+
# Assign candidate matches with low quality to negative (unassigned) values
|
| 385 |
+
below_low_threshold = matched_vals < self.low_threshold
|
| 386 |
+
between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
|
| 387 |
+
matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
|
| 388 |
+
matches[between_thresholds] = self.BETWEEN_THRESHOLDS
|
| 389 |
+
|
| 390 |
+
if self.allow_low_quality_matches:
|
| 391 |
+
if all_matches is None:
|
| 392 |
+
torch._assert(False, "all_matches should not be None")
|
| 393 |
+
else:
|
| 394 |
+
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
|
| 395 |
+
|
| 396 |
+
return matches
|
| 397 |
+
|
| 398 |
+
def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
|
| 399 |
+
"""
|
| 400 |
+
Produce additional matches for predictions that have only low-quality matches.
|
| 401 |
+
Specifically, for each ground-truth find the set of predictions that have
|
| 402 |
+
maximum overlap with it (including ties); for each prediction in that set, if
|
| 403 |
+
it is unmatched, then match it to the ground-truth with which it has the highest
|
| 404 |
+
quality value.
|
| 405 |
+
"""
|
| 406 |
+
# For each gt, find the prediction with which it has the highest quality
|
| 407 |
+
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
|
| 408 |
+
# Find the highest quality match available, even if it is low, including ties
|
| 409 |
+
gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
|
| 410 |
+
# Example gt_pred_pairs_of_highest_quality:
|
| 411 |
+
# (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
|
| 412 |
+
# tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
|
| 413 |
+
# Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
|
| 414 |
+
# Note how gt items 1, 2, 3, and 5 each have two ties
|
| 415 |
+
|
| 416 |
+
pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
|
| 417 |
+
matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class SSDMatcher(Matcher):
|
| 421 |
+
def __init__(self, threshold: float) -> None:
|
| 422 |
+
super().__init__(threshold, threshold, allow_low_quality_matches=False)
|
| 423 |
+
|
| 424 |
+
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
|
| 425 |
+
matches = super().__call__(match_quality_matrix)
|
| 426 |
+
|
| 427 |
+
# For each gt, find the prediction with which it has the highest quality
|
| 428 |
+
_, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
|
| 429 |
+
matches[highest_quality_pred_foreach_gt] = torch.arange(
|
| 430 |
+
highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
return matches
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def overwrite_eps(model: nn.Module, eps: float) -> None:
|
| 437 |
+
"""
|
| 438 |
+
This method overwrites the default eps values of all the
|
| 439 |
+
FrozenBatchNorm2d layers of the model with the provided value.
|
| 440 |
+
This is necessary to address the BC-breaking change introduced
|
| 441 |
+
by the bug-fix at pytorch/vision#2933. The overwrite is applied
|
| 442 |
+
only when the pretrained weights are loaded to maintain compatibility
|
| 443 |
+
with previous versions.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
model (nn.Module): The model on which we perform the overwrite.
|
| 447 |
+
eps (float): The new value of eps.
|
| 448 |
+
"""
|
| 449 |
+
for module in model.modules():
|
| 450 |
+
if isinstance(module, FrozenBatchNorm2d):
|
| 451 |
+
module.eps = eps
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
|
| 455 |
+
"""
|
| 456 |
+
This method retrieves the number of output channels of a specific model.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
model (nn.Module): The model for which we estimate the out_channels.
|
| 460 |
+
It should return a single Tensor or an OrderedDict[Tensor].
|
| 461 |
+
size (Tuple[int, int]): The size (wxh) of the input.
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
out_channels (List[int]): A list of the output channels of the model.
|
| 465 |
+
"""
|
| 466 |
+
in_training = model.training
|
| 467 |
+
model.eval()
|
| 468 |
+
|
| 469 |
+
with torch.no_grad():
|
| 470 |
+
# Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
|
| 471 |
+
device = next(model.parameters()).device
|
| 472 |
+
tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
|
| 473 |
+
features = model(tmp_img)
|
| 474 |
+
if isinstance(features, torch.Tensor):
|
| 475 |
+
features = OrderedDict([("0", features)])
|
| 476 |
+
out_channels = [x.size(1) for x in features.values()]
|
| 477 |
+
|
| 478 |
+
if in_training:
|
| 479 |
+
model.train()
|
| 480 |
+
|
| 481 |
+
return out_channels
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
@torch.jit.unused
|
| 485 |
+
def _fake_cast_onnx(v: Tensor) -> int:
|
| 486 |
+
return v # type: ignore[return-value]
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
|
| 490 |
+
"""
|
| 491 |
+
ONNX spec requires the k-value to be less than or equal to the number of inputs along
|
| 492 |
+
provided dim. Certain models use the number of elements along a particular axis instead of K
|
| 493 |
+
if K exceeds the number of elements along that axis. Previously, python's min() function was
|
| 494 |
+
used to determine whether to use the provided k-value or the specified dim axis value.
|
| 495 |
+
|
| 496 |
+
However, in cases where the model is being exported in tracing mode, python min() is
|
| 497 |
+
static causing the model to be traced incorrectly and eventually fail at the topk node.
|
| 498 |
+
In order to avoid this situation, in tracing mode, torch.min() is used instead.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
input (Tensor): The original input tensor.
|
| 502 |
+
orig_kval (int): The provided k-value.
|
| 503 |
+
axis(int): Axis along which we retrieve the input size.
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
min_kval (int): Appropriately selected k-value.
|
| 507 |
+
"""
|
| 508 |
+
if not torch.jit.is_tracing():
|
| 509 |
+
return min(orig_kval, input.size(axis))
|
| 510 |
+
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
|
| 511 |
+
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
|
| 512 |
+
return _fake_cast_onnx(min_kval)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def _box_loss(
|
| 516 |
+
type: str,
|
| 517 |
+
box_coder: BoxCoder,
|
| 518 |
+
anchors_per_image: Tensor,
|
| 519 |
+
matched_gt_boxes_per_image: Tensor,
|
| 520 |
+
bbox_regression_per_image: Tensor,
|
| 521 |
+
cnf: Optional[Dict[str, float]] = None,
|
| 522 |
+
) -> Tensor:
|
| 523 |
+
torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
|
| 524 |
+
|
| 525 |
+
if type == "l1":
|
| 526 |
+
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
|
| 527 |
+
return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
|
| 528 |
+
elif type == "smooth_l1":
|
| 529 |
+
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
|
| 530 |
+
beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
|
| 531 |
+
return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
|
| 532 |
+
else:
|
| 533 |
+
bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
|
| 534 |
+
eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
|
| 535 |
+
if type == "ciou":
|
| 536 |
+
return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
|
| 537 |
+
if type == "diou":
|
| 538 |
+
return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
|
| 539 |
+
# otherwise giou
|
| 540 |
+
return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/anchor_utils.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, Tensor
|
| 6 |
+
|
| 7 |
+
from .image_list import ImageList
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AnchorGenerator(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Module that generates anchors for a set of feature maps and
|
| 13 |
+
image sizes.
|
| 14 |
+
|
| 15 |
+
The module support computing anchors at multiple sizes and aspect ratios
|
| 16 |
+
per feature map. This module assumes aspect ratio = height / width for
|
| 17 |
+
each anchor.
|
| 18 |
+
|
| 19 |
+
sizes and aspect_ratios should have the same number of elements, and it should
|
| 20 |
+
correspond to the number of feature maps.
|
| 21 |
+
|
| 22 |
+
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
|
| 23 |
+
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
|
| 24 |
+
per spatial location for feature map i.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
sizes (Tuple[Tuple[int]]):
|
| 28 |
+
aspect_ratios (Tuple[Tuple[float]]):
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
__annotations__ = {
|
| 32 |
+
"cell_anchors": List[torch.Tensor],
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
sizes=((128, 256, 512),),
|
| 38 |
+
aspect_ratios=((0.5, 1.0, 2.0),),
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
if not isinstance(sizes[0], (list, tuple)):
|
| 43 |
+
# TODO change this
|
| 44 |
+
sizes = tuple((s,) for s in sizes)
|
| 45 |
+
if not isinstance(aspect_ratios[0], (list, tuple)):
|
| 46 |
+
aspect_ratios = (aspect_ratios,) * len(sizes)
|
| 47 |
+
|
| 48 |
+
self.sizes = sizes
|
| 49 |
+
self.aspect_ratios = aspect_ratios
|
| 50 |
+
self.cell_anchors = [
|
| 51 |
+
self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios)
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
# TODO: https://github.com/pytorch/pytorch/issues/26792
|
| 55 |
+
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
|
| 56 |
+
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
|
| 57 |
+
# This method assumes aspect ratio = height / width for an anchor.
|
| 58 |
+
def generate_anchors(
|
| 59 |
+
self,
|
| 60 |
+
scales: List[int],
|
| 61 |
+
aspect_ratios: List[float],
|
| 62 |
+
dtype: torch.dtype = torch.float32,
|
| 63 |
+
device: torch.device = torch.device("cpu"),
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
scales = torch.as_tensor(scales, dtype=dtype, device=device)
|
| 66 |
+
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
|
| 67 |
+
h_ratios = torch.sqrt(aspect_ratios)
|
| 68 |
+
w_ratios = 1 / h_ratios
|
| 69 |
+
|
| 70 |
+
ws = (w_ratios[:, None] * scales[None, :]).view(-1)
|
| 71 |
+
hs = (h_ratios[:, None] * scales[None, :]).view(-1)
|
| 72 |
+
|
| 73 |
+
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
|
| 74 |
+
return base_anchors.round()
|
| 75 |
+
|
| 76 |
+
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
|
| 77 |
+
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
|
| 78 |
+
|
| 79 |
+
def num_anchors_per_location(self) -> List[int]:
|
| 80 |
+
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
|
| 81 |
+
|
| 82 |
+
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
|
| 83 |
+
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
|
| 84 |
+
def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
|
| 85 |
+
anchors = []
|
| 86 |
+
cell_anchors = self.cell_anchors
|
| 87 |
+
torch._assert(cell_anchors is not None, "cell_anchors should not be None")
|
| 88 |
+
torch._assert(
|
| 89 |
+
len(grid_sizes) == len(strides) == len(cell_anchors),
|
| 90 |
+
"Anchors should be Tuple[Tuple[int]] because each feature "
|
| 91 |
+
"map could potentially have different sizes and aspect ratios. "
|
| 92 |
+
"There needs to be a match between the number of "
|
| 93 |
+
"feature maps passed and the number of sizes / aspect ratios specified.",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
|
| 97 |
+
grid_height, grid_width = size
|
| 98 |
+
stride_height, stride_width = stride
|
| 99 |
+
device = base_anchors.device
|
| 100 |
+
|
| 101 |
+
# For output anchor, compute [x_center, y_center, x_center, y_center]
|
| 102 |
+
shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
|
| 103 |
+
shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
|
| 104 |
+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
|
| 105 |
+
shift_x = shift_x.reshape(-1)
|
| 106 |
+
shift_y = shift_y.reshape(-1)
|
| 107 |
+
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
|
| 108 |
+
|
| 109 |
+
# For every (base anchor, output anchor) pair,
|
| 110 |
+
# offset each zero-centered base anchor by the center of the output anchor.
|
| 111 |
+
anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
|
| 112 |
+
|
| 113 |
+
return anchors
|
| 114 |
+
|
| 115 |
+
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
|
| 116 |
+
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
|
| 117 |
+
image_size = image_list.tensors.shape[-2:]
|
| 118 |
+
dtype, device = feature_maps[0].dtype, feature_maps[0].device
|
| 119 |
+
strides = [
|
| 120 |
+
[
|
| 121 |
+
torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
|
| 122 |
+
torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
|
| 123 |
+
]
|
| 124 |
+
for g in grid_sizes
|
| 125 |
+
]
|
| 126 |
+
self.set_cell_anchors(dtype, device)
|
| 127 |
+
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
|
| 128 |
+
anchors: List[List[torch.Tensor]] = []
|
| 129 |
+
for _ in range(len(image_list.image_sizes)):
|
| 130 |
+
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
|
| 131 |
+
anchors.append(anchors_in_image)
|
| 132 |
+
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
|
| 133 |
+
return anchors
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DefaultBoxGenerator(nn.Module):
|
| 137 |
+
"""
|
| 138 |
+
This module generates the default boxes of SSD for a set of feature maps and image sizes.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
|
| 142 |
+
min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
|
| 143 |
+
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
|
| 144 |
+
max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}` of the default boxes used in the estimation
|
| 145 |
+
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
|
| 146 |
+
scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
|
| 147 |
+
the ``min_ratio`` and ``max_ratio`` parameters.
|
| 148 |
+
steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided
|
| 149 |
+
it will be estimated from the data.
|
| 150 |
+
clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
|
| 151 |
+
is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
aspect_ratios: List[List[int]],
|
| 157 |
+
min_ratio: float = 0.15,
|
| 158 |
+
max_ratio: float = 0.9,
|
| 159 |
+
scales: Optional[List[float]] = None,
|
| 160 |
+
steps: Optional[List[int]] = None,
|
| 161 |
+
clip: bool = True,
|
| 162 |
+
):
|
| 163 |
+
super().__init__()
|
| 164 |
+
if steps is not None and len(aspect_ratios) != len(steps):
|
| 165 |
+
raise ValueError("aspect_ratios and steps should have the same length")
|
| 166 |
+
self.aspect_ratios = aspect_ratios
|
| 167 |
+
self.steps = steps
|
| 168 |
+
self.clip = clip
|
| 169 |
+
num_outputs = len(aspect_ratios)
|
| 170 |
+
|
| 171 |
+
# Estimation of default boxes scales
|
| 172 |
+
if scales is None:
|
| 173 |
+
if num_outputs > 1:
|
| 174 |
+
range_ratio = max_ratio - min_ratio
|
| 175 |
+
self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
|
| 176 |
+
self.scales.append(1.0)
|
| 177 |
+
else:
|
| 178 |
+
self.scales = [min_ratio, max_ratio]
|
| 179 |
+
else:
|
| 180 |
+
self.scales = scales
|
| 181 |
+
|
| 182 |
+
self._wh_pairs = self._generate_wh_pairs(num_outputs)
|
| 183 |
+
|
| 184 |
+
def _generate_wh_pairs(
|
| 185 |
+
self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu")
|
| 186 |
+
) -> List[Tensor]:
|
| 187 |
+
_wh_pairs: List[Tensor] = []
|
| 188 |
+
for k in range(num_outputs):
|
| 189 |
+
# Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
|
| 190 |
+
s_k = self.scales[k]
|
| 191 |
+
s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
|
| 192 |
+
wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
|
| 193 |
+
|
| 194 |
+
# Adding 2 pairs for each aspect ratio of the feature map k
|
| 195 |
+
for ar in self.aspect_ratios[k]:
|
| 196 |
+
sq_ar = math.sqrt(ar)
|
| 197 |
+
w = self.scales[k] * sq_ar
|
| 198 |
+
h = self.scales[k] / sq_ar
|
| 199 |
+
wh_pairs.extend([[w, h], [h, w]])
|
| 200 |
+
|
| 201 |
+
_wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
|
| 202 |
+
return _wh_pairs
|
| 203 |
+
|
| 204 |
+
def num_anchors_per_location(self) -> List[int]:
|
| 205 |
+
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
|
| 206 |
+
return [2 + 2 * len(r) for r in self.aspect_ratios]
|
| 207 |
+
|
| 208 |
+
# Default Boxes calculation based on page 6 of SSD paper
|
| 209 |
+
def _grid_default_boxes(
|
| 210 |
+
self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32
|
| 211 |
+
) -> Tensor:
|
| 212 |
+
default_boxes = []
|
| 213 |
+
for k, f_k in enumerate(grid_sizes):
|
| 214 |
+
# Now add the default boxes for each width-height pair
|
| 215 |
+
if self.steps is not None:
|
| 216 |
+
x_f_k = image_size[1] / self.steps[k]
|
| 217 |
+
y_f_k = image_size[0] / self.steps[k]
|
| 218 |
+
else:
|
| 219 |
+
y_f_k, x_f_k = f_k
|
| 220 |
+
|
| 221 |
+
shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
|
| 222 |
+
shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
|
| 223 |
+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
|
| 224 |
+
shift_x = shift_x.reshape(-1)
|
| 225 |
+
shift_y = shift_y.reshape(-1)
|
| 226 |
+
|
| 227 |
+
shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
|
| 228 |
+
# Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
|
| 229 |
+
_wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
|
| 230 |
+
wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)
|
| 231 |
+
|
| 232 |
+
default_box = torch.cat((shifts, wh_pairs), dim=1)
|
| 233 |
+
|
| 234 |
+
default_boxes.append(default_box)
|
| 235 |
+
|
| 236 |
+
return torch.cat(default_boxes, dim=0)
|
| 237 |
+
|
| 238 |
+
def __repr__(self) -> str:
|
| 239 |
+
s = (
|
| 240 |
+
f"{self.__class__.__name__}("
|
| 241 |
+
f"aspect_ratios={self.aspect_ratios}"
|
| 242 |
+
f", clip={self.clip}"
|
| 243 |
+
f", scales={self.scales}"
|
| 244 |
+
f", steps={self.steps}"
|
| 245 |
+
")"
|
| 246 |
+
)
|
| 247 |
+
return s
|
| 248 |
+
|
| 249 |
+
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
|
| 250 |
+
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
|
| 251 |
+
image_size = image_list.tensors.shape[-2:]
|
| 252 |
+
dtype, device = feature_maps[0].dtype, feature_maps[0].device
|
| 253 |
+
default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
|
| 254 |
+
default_boxes = default_boxes.to(device)
|
| 255 |
+
|
| 256 |
+
dboxes = []
|
| 257 |
+
x_y_size = torch.tensor([image_size[1], image_size[0]], device=default_boxes.device)
|
| 258 |
+
for _ in image_list.image_sizes:
|
| 259 |
+
dboxes_in_image = default_boxes
|
| 260 |
+
dboxes_in_image = torch.cat(
|
| 261 |
+
[
|
| 262 |
+
(dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
|
| 263 |
+
(dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
|
| 264 |
+
],
|
| 265 |
+
-1,
|
| 266 |
+
)
|
| 267 |
+
dboxes.append(dboxes_in_image)
|
| 268 |
+
return dboxes
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/backbone_utils.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
from torch import nn, Tensor
|
| 5 |
+
from torchvision.ops import misc as misc_nn_ops
|
| 6 |
+
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
|
| 7 |
+
|
| 8 |
+
from .. import mobilenet, resnet
|
| 9 |
+
from .._api import _get_enum_from_fn, WeightsEnum
|
| 10 |
+
from .._utils import handle_legacy_interface, IntermediateLayerGetter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BackboneWithFPN(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Adds a FPN on top of a model.
|
| 16 |
+
Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
|
| 17 |
+
extract a submodel that returns the feature maps specified in return_layers.
|
| 18 |
+
The same limitations of IntermediateLayerGetter apply here.
|
| 19 |
+
Args:
|
| 20 |
+
backbone (nn.Module)
|
| 21 |
+
return_layers (Dict[name, new_name]): a dict containing the names
|
| 22 |
+
of the modules for which the activations will be returned as
|
| 23 |
+
the key of the dict, and the value of the dict is the name
|
| 24 |
+
of the returned activation (which the user can specify).
|
| 25 |
+
in_channels_list (List[int]): number of channels for each feature map
|
| 26 |
+
that is returned, in the order they are present in the OrderedDict
|
| 27 |
+
out_channels (int): number of channels in the FPN.
|
| 28 |
+
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
|
| 29 |
+
Attributes:
|
| 30 |
+
out_channels (int): the number of channels in the FPN
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
backbone: nn.Module,
|
| 36 |
+
return_layers: Dict[str, str],
|
| 37 |
+
in_channels_list: List[int],
|
| 38 |
+
out_channels: int,
|
| 39 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 40 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 41 |
+
) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
if extra_blocks is None:
|
| 45 |
+
extra_blocks = LastLevelMaxPool()
|
| 46 |
+
|
| 47 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
| 48 |
+
self.fpn = FeaturePyramidNetwork(
|
| 49 |
+
in_channels_list=in_channels_list,
|
| 50 |
+
out_channels=out_channels,
|
| 51 |
+
extra_blocks=extra_blocks,
|
| 52 |
+
norm_layer=norm_layer,
|
| 53 |
+
)
|
| 54 |
+
self.out_channels = out_channels
|
| 55 |
+
|
| 56 |
+
def forward(self, x: Tensor) -> Dict[str, Tensor]:
|
| 57 |
+
x = self.body(x)
|
| 58 |
+
x = self.fpn(x)
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@handle_legacy_interface(
|
| 63 |
+
weights=(
|
| 64 |
+
"pretrained",
|
| 65 |
+
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
|
| 66 |
+
),
|
| 67 |
+
)
|
| 68 |
+
def resnet_fpn_backbone(
|
| 69 |
+
*,
|
| 70 |
+
backbone_name: str,
|
| 71 |
+
weights: Optional[WeightsEnum],
|
| 72 |
+
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
|
| 73 |
+
trainable_layers: int = 3,
|
| 74 |
+
returned_layers: Optional[List[int]] = None,
|
| 75 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 76 |
+
) -> BackboneWithFPN:
|
| 77 |
+
"""
|
| 78 |
+
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
|
| 79 |
+
|
| 80 |
+
Examples::
|
| 81 |
+
|
| 82 |
+
>>> import torch
|
| 83 |
+
>>> from torchvision.models import ResNet50_Weights
|
| 84 |
+
>>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
|
| 85 |
+
>>> backbone = resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
|
| 86 |
+
>>> # get some dummy image
|
| 87 |
+
>>> x = torch.rand(1,3,64,64)
|
| 88 |
+
>>> # compute the output
|
| 89 |
+
>>> output = backbone(x)
|
| 90 |
+
>>> print([(k, v.shape) for k, v in output.items()])
|
| 91 |
+
>>> # returns
|
| 92 |
+
>>> [('0', torch.Size([1, 256, 16, 16])),
|
| 93 |
+
>>> ('1', torch.Size([1, 256, 8, 8])),
|
| 94 |
+
>>> ('2', torch.Size([1, 256, 4, 4])),
|
| 95 |
+
>>> ('3', torch.Size([1, 256, 2, 2])),
|
| 96 |
+
>>> ('pool', torch.Size([1, 256, 1, 1]))]
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
|
| 100 |
+
'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
|
| 101 |
+
weights (WeightsEnum, optional): The pretrained weights for the model
|
| 102 |
+
norm_layer (callable): it is recommended to use the default value. For details visit:
|
| 103 |
+
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
|
| 104 |
+
trainable_layers (int): number of trainable (not frozen) layers starting from final block.
|
| 105 |
+
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
|
| 106 |
+
returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
|
| 107 |
+
By default, all layers are returned.
|
| 108 |
+
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
|
| 109 |
+
be performed. It is expected to take the fpn features, the original
|
| 110 |
+
features and the names of the original features as input, and returns
|
| 111 |
+
a new list of feature maps and their corresponding names. By
|
| 112 |
+
default, a ``LastLevelMaxPool`` is used.
|
| 113 |
+
"""
|
| 114 |
+
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
|
| 115 |
+
return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _resnet_fpn_extractor(
|
| 119 |
+
backbone: resnet.ResNet,
|
| 120 |
+
trainable_layers: int,
|
| 121 |
+
returned_layers: Optional[List[int]] = None,
|
| 122 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 123 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 124 |
+
) -> BackboneWithFPN:
|
| 125 |
+
|
| 126 |
+
# select layers that won't be frozen
|
| 127 |
+
if trainable_layers < 0 or trainable_layers > 5:
|
| 128 |
+
raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
|
| 129 |
+
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
|
| 130 |
+
if trainable_layers == 5:
|
| 131 |
+
layers_to_train.append("bn1")
|
| 132 |
+
for name, parameter in backbone.named_parameters():
|
| 133 |
+
if all([not name.startswith(layer) for layer in layers_to_train]):
|
| 134 |
+
parameter.requires_grad_(False)
|
| 135 |
+
|
| 136 |
+
if extra_blocks is None:
|
| 137 |
+
extra_blocks = LastLevelMaxPool()
|
| 138 |
+
|
| 139 |
+
if returned_layers is None:
|
| 140 |
+
returned_layers = [1, 2, 3, 4]
|
| 141 |
+
if min(returned_layers) <= 0 or max(returned_layers) >= 5:
|
| 142 |
+
raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
|
| 143 |
+
return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
|
| 144 |
+
|
| 145 |
+
in_channels_stage2 = backbone.inplanes // 8
|
| 146 |
+
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
|
| 147 |
+
out_channels = 256
|
| 148 |
+
return BackboneWithFPN(
|
| 149 |
+
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _validate_trainable_layers(
|
| 154 |
+
is_trained: bool,
|
| 155 |
+
trainable_backbone_layers: Optional[int],
|
| 156 |
+
max_value: int,
|
| 157 |
+
default_value: int,
|
| 158 |
+
) -> int:
|
| 159 |
+
# don't freeze any layers if pretrained model or backbone is not used
|
| 160 |
+
if not is_trained:
|
| 161 |
+
if trainable_backbone_layers is not None:
|
| 162 |
+
warnings.warn(
|
| 163 |
+
"Changing trainable_backbone_layers has no effect if "
|
| 164 |
+
"neither pretrained nor pretrained_backbone have been set to True, "
|
| 165 |
+
f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
|
| 166 |
+
)
|
| 167 |
+
trainable_backbone_layers = max_value
|
| 168 |
+
|
| 169 |
+
# by default freeze first blocks
|
| 170 |
+
if trainable_backbone_layers is None:
|
| 171 |
+
trainable_backbone_layers = default_value
|
| 172 |
+
if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
|
| 175 |
+
)
|
| 176 |
+
return trainable_backbone_layers
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@handle_legacy_interface(
|
| 180 |
+
weights=(
|
| 181 |
+
"pretrained",
|
| 182 |
+
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
|
| 183 |
+
),
|
| 184 |
+
)
|
| 185 |
+
def mobilenet_backbone(
|
| 186 |
+
*,
|
| 187 |
+
backbone_name: str,
|
| 188 |
+
weights: Optional[WeightsEnum],
|
| 189 |
+
fpn: bool,
|
| 190 |
+
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
|
| 191 |
+
trainable_layers: int = 2,
|
| 192 |
+
returned_layers: Optional[List[int]] = None,
|
| 193 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 194 |
+
) -> nn.Module:
|
| 195 |
+
backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
|
| 196 |
+
return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _mobilenet_extractor(
|
| 200 |
+
backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
|
| 201 |
+
fpn: bool,
|
| 202 |
+
trainable_layers: int,
|
| 203 |
+
returned_layers: Optional[List[int]] = None,
|
| 204 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 205 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 206 |
+
) -> nn.Module:
|
| 207 |
+
backbone = backbone.features
|
| 208 |
+
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
|
| 209 |
+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
|
| 210 |
+
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
|
| 211 |
+
num_stages = len(stage_indices)
|
| 212 |
+
|
| 213 |
+
# find the index of the layer from which we won't freeze
|
| 214 |
+
if trainable_layers < 0 or trainable_layers > num_stages:
|
| 215 |
+
raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
|
| 216 |
+
freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
|
| 217 |
+
|
| 218 |
+
for b in backbone[:freeze_before]:
|
| 219 |
+
for parameter in b.parameters():
|
| 220 |
+
parameter.requires_grad_(False)
|
| 221 |
+
|
| 222 |
+
out_channels = 256
|
| 223 |
+
if fpn:
|
| 224 |
+
if extra_blocks is None:
|
| 225 |
+
extra_blocks = LastLevelMaxPool()
|
| 226 |
+
|
| 227 |
+
if returned_layers is None:
|
| 228 |
+
returned_layers = [num_stages - 2, num_stages - 1]
|
| 229 |
+
if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
|
| 230 |
+
raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
|
| 231 |
+
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
|
| 232 |
+
|
| 233 |
+
in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
|
| 234 |
+
return BackboneWithFPN(
|
| 235 |
+
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
m = nn.Sequential(
|
| 239 |
+
backbone,
|
| 240 |
+
# depthwise linear combination of channels to reduce their size
|
| 241 |
+
nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
|
| 242 |
+
)
|
| 243 |
+
m.out_channels = out_channels # type: ignore[assignment]
|
| 244 |
+
return m
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/faster_rcnn.py
ADDED
|
@@ -0,0 +1,846 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torchvision.ops import MultiScaleRoIAlign
|
| 7 |
+
|
| 8 |
+
from ...ops import misc as misc_nn_ops
|
| 9 |
+
from ...transforms._presets import ObjectDetection
|
| 10 |
+
from .._api import register_model, Weights, WeightsEnum
|
| 11 |
+
from .._meta import _COCO_CATEGORIES
|
| 12 |
+
from .._utils import _ovewrite_value_param, handle_legacy_interface
|
| 13 |
+
from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
|
| 14 |
+
from ..resnet import resnet50, ResNet50_Weights
|
| 15 |
+
from ._utils import overwrite_eps
|
| 16 |
+
from .anchor_utils import AnchorGenerator
|
| 17 |
+
from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
|
| 18 |
+
from .generalized_rcnn import GeneralizedRCNN
|
| 19 |
+
from .roi_heads import RoIHeads
|
| 20 |
+
from .rpn import RegionProposalNetwork, RPNHead
|
| 21 |
+
from .transform import GeneralizedRCNNTransform
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"FasterRCNN",
|
| 26 |
+
"FasterRCNN_ResNet50_FPN_Weights",
|
| 27 |
+
"FasterRCNN_ResNet50_FPN_V2_Weights",
|
| 28 |
+
"FasterRCNN_MobileNet_V3_Large_FPN_Weights",
|
| 29 |
+
"FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
|
| 30 |
+
"fasterrcnn_resnet50_fpn",
|
| 31 |
+
"fasterrcnn_resnet50_fpn_v2",
|
| 32 |
+
"fasterrcnn_mobilenet_v3_large_fpn",
|
| 33 |
+
"fasterrcnn_mobilenet_v3_large_320_fpn",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _default_anchorgen():
|
| 38 |
+
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
|
| 39 |
+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
|
| 40 |
+
return AnchorGenerator(anchor_sizes, aspect_ratios)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class FasterRCNN(GeneralizedRCNN):
|
| 44 |
+
"""
|
| 45 |
+
Implements Faster R-CNN.
|
| 46 |
+
|
| 47 |
+
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
|
| 48 |
+
image, and should be in 0-1 range. Different images can have different sizes.
|
| 49 |
+
|
| 50 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 51 |
+
|
| 52 |
+
During training, the model expects both the input tensors and targets (list of dictionary),
|
| 53 |
+
containing:
|
| 54 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 55 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 56 |
+
- labels (Int64Tensor[N]): the class label for each ground-truth box
|
| 57 |
+
|
| 58 |
+
The model returns a Dict[Tensor] during training, containing the classification and regression
|
| 59 |
+
losses for both the RPN and the R-CNN.
|
| 60 |
+
|
| 61 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 62 |
+
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
|
| 63 |
+
follows:
|
| 64 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 65 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 66 |
+
- labels (Int64Tensor[N]): the predicted labels for each image
|
| 67 |
+
- scores (Tensor[N]): the scores or each prediction
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
backbone (nn.Module): the network used to compute the features for the model.
|
| 71 |
+
It should contain an out_channels attribute, which indicates the number of output
|
| 72 |
+
channels that each feature map has (and it should be the same for all feature maps).
|
| 73 |
+
The backbone should return a single Tensor or and OrderedDict[Tensor].
|
| 74 |
+
num_classes (int): number of output classes of the model (including the background).
|
| 75 |
+
If box_predictor is specified, num_classes should be None.
|
| 76 |
+
min_size (int): Images are rescaled before feeding them to the backbone:
|
| 77 |
+
we attempt to preserve the aspect ratio and scale the shorter edge
|
| 78 |
+
to ``min_size``. If the resulting longer edge exceeds ``max_size``,
|
| 79 |
+
then downscale so that the longer edge does not exceed ``max_size``.
|
| 80 |
+
This may result in the shorter edge beeing lower than ``min_size``.
|
| 81 |
+
max_size (int): See ``min_size``.
|
| 82 |
+
image_mean (Tuple[float, float, float]): mean values used for input normalization.
|
| 83 |
+
They are generally the mean values of the dataset on which the backbone has been trained
|
| 84 |
+
on
|
| 85 |
+
image_std (Tuple[float, float, float]): std values used for input normalization.
|
| 86 |
+
They are generally the std values of the dataset on which the backbone has been trained on
|
| 87 |
+
rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
|
| 88 |
+
maps.
|
| 89 |
+
rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
|
| 90 |
+
rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
|
| 91 |
+
rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
|
| 92 |
+
rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
|
| 93 |
+
rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
|
| 94 |
+
rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
|
| 95 |
+
rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
|
| 96 |
+
considered as positive during training of the RPN.
|
| 97 |
+
rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
|
| 98 |
+
considered as negative during training of the RPN.
|
| 99 |
+
rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
|
| 100 |
+
for computing the loss
|
| 101 |
+
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
|
| 102 |
+
of the RPN
|
| 103 |
+
rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
|
| 104 |
+
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
|
| 105 |
+
the locations indicated by the bounding boxes
|
| 106 |
+
box_head (nn.Module): module that takes the cropped feature maps as input
|
| 107 |
+
box_predictor (nn.Module): module that takes the output of box_head and returns the
|
| 108 |
+
classification logits and box regression deltas.
|
| 109 |
+
box_score_thresh (float): during inference, only return proposals with a classification score
|
| 110 |
+
greater than box_score_thresh
|
| 111 |
+
box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
|
| 112 |
+
box_detections_per_img (int): maximum number of detections per image, for all classes.
|
| 113 |
+
box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
|
| 114 |
+
considered as positive during training of the classification head
|
| 115 |
+
box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
|
| 116 |
+
considered as negative during training of the classification head
|
| 117 |
+
box_batch_size_per_image (int): number of proposals that are sampled during training of the
|
| 118 |
+
classification head
|
| 119 |
+
box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
|
| 120 |
+
of the classification head
|
| 121 |
+
bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
|
| 122 |
+
bounding boxes
|
| 123 |
+
|
| 124 |
+
Example::
|
| 125 |
+
|
| 126 |
+
>>> import torch
|
| 127 |
+
>>> import torchvision
|
| 128 |
+
>>> from torchvision.models.detection import FasterRCNN
|
| 129 |
+
>>> from torchvision.models.detection.rpn import AnchorGenerator
|
| 130 |
+
>>> # load a pre-trained model for classification and return
|
| 131 |
+
>>> # only the features
|
| 132 |
+
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
|
| 133 |
+
>>> # FasterRCNN needs to know the number of
|
| 134 |
+
>>> # output channels in a backbone. For mobilenet_v2, it's 1280,
|
| 135 |
+
>>> # so we need to add it here
|
| 136 |
+
>>> backbone.out_channels = 1280
|
| 137 |
+
>>>
|
| 138 |
+
>>> # let's make the RPN generate 5 x 3 anchors per spatial
|
| 139 |
+
>>> # location, with 5 different sizes and 3 different aspect
|
| 140 |
+
>>> # ratios. We have a Tuple[Tuple[int]] because each feature
|
| 141 |
+
>>> # map could potentially have different sizes and
|
| 142 |
+
>>> # aspect ratios
|
| 143 |
+
>>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
|
| 144 |
+
>>> aspect_ratios=((0.5, 1.0, 2.0),))
|
| 145 |
+
>>>
|
| 146 |
+
>>> # let's define what are the feature maps that we will
|
| 147 |
+
>>> # use to perform the region of interest cropping, as well as
|
| 148 |
+
>>> # the size of the crop after rescaling.
|
| 149 |
+
>>> # if your backbone returns a Tensor, featmap_names is expected to
|
| 150 |
+
>>> # be ['0']. More generally, the backbone should return an
|
| 151 |
+
>>> # OrderedDict[Tensor], and in featmap_names you can choose which
|
| 152 |
+
>>> # feature maps to use.
|
| 153 |
+
>>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
|
| 154 |
+
>>> output_size=7,
|
| 155 |
+
>>> sampling_ratio=2)
|
| 156 |
+
>>>
|
| 157 |
+
>>> # put the pieces together inside a FasterRCNN model
|
| 158 |
+
>>> model = FasterRCNN(backbone,
|
| 159 |
+
>>> num_classes=2,
|
| 160 |
+
>>> rpn_anchor_generator=anchor_generator,
|
| 161 |
+
>>> box_roi_pool=roi_pooler)
|
| 162 |
+
>>> model.eval()
|
| 163 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 164 |
+
>>> predictions = model(x)
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
backbone,
|
| 170 |
+
num_classes=None,
|
| 171 |
+
# transform parameters
|
| 172 |
+
min_size=800,
|
| 173 |
+
max_size=1333,
|
| 174 |
+
image_mean=None,
|
| 175 |
+
image_std=None,
|
| 176 |
+
# RPN parameters
|
| 177 |
+
rpn_anchor_generator=None,
|
| 178 |
+
rpn_head=None,
|
| 179 |
+
rpn_pre_nms_top_n_train=2000,
|
| 180 |
+
rpn_pre_nms_top_n_test=1000,
|
| 181 |
+
rpn_post_nms_top_n_train=2000,
|
| 182 |
+
rpn_post_nms_top_n_test=1000,
|
| 183 |
+
rpn_nms_thresh=0.7,
|
| 184 |
+
rpn_fg_iou_thresh=0.7,
|
| 185 |
+
rpn_bg_iou_thresh=0.3,
|
| 186 |
+
rpn_batch_size_per_image=256,
|
| 187 |
+
rpn_positive_fraction=0.5,
|
| 188 |
+
rpn_score_thresh=0.0,
|
| 189 |
+
# Box parameters
|
| 190 |
+
box_roi_pool=None,
|
| 191 |
+
box_head=None,
|
| 192 |
+
box_predictor=None,
|
| 193 |
+
box_score_thresh=0.05,
|
| 194 |
+
box_nms_thresh=0.5,
|
| 195 |
+
box_detections_per_img=100,
|
| 196 |
+
box_fg_iou_thresh=0.5,
|
| 197 |
+
box_bg_iou_thresh=0.5,
|
| 198 |
+
box_batch_size_per_image=512,
|
| 199 |
+
box_positive_fraction=0.25,
|
| 200 |
+
bbox_reg_weights=None,
|
| 201 |
+
**kwargs,
|
| 202 |
+
):
|
| 203 |
+
|
| 204 |
+
if not hasattr(backbone, "out_channels"):
|
| 205 |
+
raise ValueError(
|
| 206 |
+
"backbone should contain an attribute out_channels "
|
| 207 |
+
"specifying the number of output channels (assumed to be the "
|
| 208 |
+
"same for all the levels)"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
|
| 212 |
+
raise TypeError(
|
| 213 |
+
f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
|
| 214 |
+
)
|
| 215 |
+
if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
|
| 216 |
+
raise TypeError(
|
| 217 |
+
f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
if num_classes is not None:
|
| 221 |
+
if box_predictor is not None:
|
| 222 |
+
raise ValueError("num_classes should be None when box_predictor is specified")
|
| 223 |
+
else:
|
| 224 |
+
if box_predictor is None:
|
| 225 |
+
raise ValueError("num_classes should not be None when box_predictor is not specified")
|
| 226 |
+
|
| 227 |
+
out_channels = backbone.out_channels
|
| 228 |
+
|
| 229 |
+
if rpn_anchor_generator is None:
|
| 230 |
+
rpn_anchor_generator = _default_anchorgen()
|
| 231 |
+
if rpn_head is None:
|
| 232 |
+
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
|
| 233 |
+
|
| 234 |
+
rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
|
| 235 |
+
rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
|
| 236 |
+
|
| 237 |
+
rpn = RegionProposalNetwork(
|
| 238 |
+
rpn_anchor_generator,
|
| 239 |
+
rpn_head,
|
| 240 |
+
rpn_fg_iou_thresh,
|
| 241 |
+
rpn_bg_iou_thresh,
|
| 242 |
+
rpn_batch_size_per_image,
|
| 243 |
+
rpn_positive_fraction,
|
| 244 |
+
rpn_pre_nms_top_n,
|
| 245 |
+
rpn_post_nms_top_n,
|
| 246 |
+
rpn_nms_thresh,
|
| 247 |
+
score_thresh=rpn_score_thresh,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if box_roi_pool is None:
|
| 251 |
+
box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
|
| 252 |
+
|
| 253 |
+
if box_head is None:
|
| 254 |
+
resolution = box_roi_pool.output_size[0]
|
| 255 |
+
representation_size = 1024
|
| 256 |
+
box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
|
| 257 |
+
|
| 258 |
+
if box_predictor is None:
|
| 259 |
+
representation_size = 1024
|
| 260 |
+
box_predictor = FastRCNNPredictor(representation_size, num_classes)
|
| 261 |
+
|
| 262 |
+
roi_heads = RoIHeads(
|
| 263 |
+
# Box
|
| 264 |
+
box_roi_pool,
|
| 265 |
+
box_head,
|
| 266 |
+
box_predictor,
|
| 267 |
+
box_fg_iou_thresh,
|
| 268 |
+
box_bg_iou_thresh,
|
| 269 |
+
box_batch_size_per_image,
|
| 270 |
+
box_positive_fraction,
|
| 271 |
+
bbox_reg_weights,
|
| 272 |
+
box_score_thresh,
|
| 273 |
+
box_nms_thresh,
|
| 274 |
+
box_detections_per_img,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if image_mean is None:
|
| 278 |
+
image_mean = [0.485, 0.456, 0.406]
|
| 279 |
+
if image_std is None:
|
| 280 |
+
image_std = [0.229, 0.224, 0.225]
|
| 281 |
+
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
|
| 282 |
+
|
| 283 |
+
super().__init__(backbone, rpn, roi_heads, transform)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class TwoMLPHead(nn.Module):
|
| 287 |
+
"""
|
| 288 |
+
Standard heads for FPN-based models
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
in_channels (int): number of input channels
|
| 292 |
+
representation_size (int): size of the intermediate representation
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__(self, in_channels, representation_size):
|
| 296 |
+
super().__init__()
|
| 297 |
+
|
| 298 |
+
self.fc6 = nn.Linear(in_channels, representation_size)
|
| 299 |
+
self.fc7 = nn.Linear(representation_size, representation_size)
|
| 300 |
+
|
| 301 |
+
def forward(self, x):
|
| 302 |
+
x = x.flatten(start_dim=1)
|
| 303 |
+
|
| 304 |
+
x = F.relu(self.fc6(x))
|
| 305 |
+
x = F.relu(self.fc7(x))
|
| 306 |
+
|
| 307 |
+
return x
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class FastRCNNConvFCHead(nn.Sequential):
|
| 311 |
+
def __init__(
|
| 312 |
+
self,
|
| 313 |
+
input_size: Tuple[int, int, int],
|
| 314 |
+
conv_layers: List[int],
|
| 315 |
+
fc_layers: List[int],
|
| 316 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 317 |
+
):
|
| 318 |
+
"""
|
| 319 |
+
Args:
|
| 320 |
+
input_size (Tuple[int, int, int]): the input size in CHW format.
|
| 321 |
+
conv_layers (list): feature dimensions of each Convolution layer
|
| 322 |
+
fc_layers (list): feature dimensions of each FCN layer
|
| 323 |
+
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
|
| 324 |
+
"""
|
| 325 |
+
in_channels, in_height, in_width = input_size
|
| 326 |
+
|
| 327 |
+
blocks = []
|
| 328 |
+
previous_channels = in_channels
|
| 329 |
+
for current_channels in conv_layers:
|
| 330 |
+
blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
|
| 331 |
+
previous_channels = current_channels
|
| 332 |
+
blocks.append(nn.Flatten())
|
| 333 |
+
previous_channels = previous_channels * in_height * in_width
|
| 334 |
+
for current_channels in fc_layers:
|
| 335 |
+
blocks.append(nn.Linear(previous_channels, current_channels))
|
| 336 |
+
blocks.append(nn.ReLU(inplace=True))
|
| 337 |
+
previous_channels = current_channels
|
| 338 |
+
|
| 339 |
+
super().__init__(*blocks)
|
| 340 |
+
for layer in self.modules():
|
| 341 |
+
if isinstance(layer, nn.Conv2d):
|
| 342 |
+
nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
|
| 343 |
+
if layer.bias is not None:
|
| 344 |
+
nn.init.zeros_(layer.bias)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class FastRCNNPredictor(nn.Module):
|
| 348 |
+
"""
|
| 349 |
+
Standard classification + bounding box regression layers
|
| 350 |
+
for Fast R-CNN.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
in_channels (int): number of input channels
|
| 354 |
+
num_classes (int): number of output classes (including background)
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
def __init__(self, in_channels, num_classes):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.cls_score = nn.Linear(in_channels, num_classes)
|
| 360 |
+
self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
|
| 361 |
+
|
| 362 |
+
def forward(self, x):
|
| 363 |
+
if x.dim() == 4:
|
| 364 |
+
torch._assert(
|
| 365 |
+
list(x.shape[2:]) == [1, 1],
|
| 366 |
+
f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
|
| 367 |
+
)
|
| 368 |
+
x = x.flatten(start_dim=1)
|
| 369 |
+
scores = self.cls_score(x)
|
| 370 |
+
bbox_deltas = self.bbox_pred(x)
|
| 371 |
+
|
| 372 |
+
return scores, bbox_deltas
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
_COMMON_META = {
|
| 376 |
+
"categories": _COCO_CATEGORIES,
|
| 377 |
+
"min_size": (1, 1),
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
|
| 382 |
+
COCO_V1 = Weights(
|
| 383 |
+
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
|
| 384 |
+
transforms=ObjectDetection,
|
| 385 |
+
meta={
|
| 386 |
+
**_COMMON_META,
|
| 387 |
+
"num_params": 41755286,
|
| 388 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
|
| 389 |
+
"_metrics": {
|
| 390 |
+
"COCO-val2017": {
|
| 391 |
+
"box_map": 37.0,
|
| 392 |
+
}
|
| 393 |
+
},
|
| 394 |
+
"_ops": 134.38,
|
| 395 |
+
"_file_size": 159.743,
|
| 396 |
+
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
|
| 397 |
+
},
|
| 398 |
+
)
|
| 399 |
+
DEFAULT = COCO_V1
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
|
| 403 |
+
COCO_V1 = Weights(
|
| 404 |
+
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
|
| 405 |
+
transforms=ObjectDetection,
|
| 406 |
+
meta={
|
| 407 |
+
**_COMMON_META,
|
| 408 |
+
"num_params": 43712278,
|
| 409 |
+
"recipe": "https://github.com/pytorch/vision/pull/5763",
|
| 410 |
+
"_metrics": {
|
| 411 |
+
"COCO-val2017": {
|
| 412 |
+
"box_map": 46.7,
|
| 413 |
+
}
|
| 414 |
+
},
|
| 415 |
+
"_ops": 280.371,
|
| 416 |
+
"_file_size": 167.104,
|
| 417 |
+
"_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
|
| 418 |
+
},
|
| 419 |
+
)
|
| 420 |
+
DEFAULT = COCO_V1
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
|
| 424 |
+
COCO_V1 = Weights(
|
| 425 |
+
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
|
| 426 |
+
transforms=ObjectDetection,
|
| 427 |
+
meta={
|
| 428 |
+
**_COMMON_META,
|
| 429 |
+
"num_params": 19386354,
|
| 430 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
|
| 431 |
+
"_metrics": {
|
| 432 |
+
"COCO-val2017": {
|
| 433 |
+
"box_map": 32.8,
|
| 434 |
+
}
|
| 435 |
+
},
|
| 436 |
+
"_ops": 4.494,
|
| 437 |
+
"_file_size": 74.239,
|
| 438 |
+
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
|
| 439 |
+
},
|
| 440 |
+
)
|
| 441 |
+
DEFAULT = COCO_V1
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
|
| 445 |
+
COCO_V1 = Weights(
|
| 446 |
+
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
|
| 447 |
+
transforms=ObjectDetection,
|
| 448 |
+
meta={
|
| 449 |
+
**_COMMON_META,
|
| 450 |
+
"num_params": 19386354,
|
| 451 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
|
| 452 |
+
"_metrics": {
|
| 453 |
+
"COCO-val2017": {
|
| 454 |
+
"box_map": 22.8,
|
| 455 |
+
}
|
| 456 |
+
},
|
| 457 |
+
"_ops": 0.719,
|
| 458 |
+
"_file_size": 74.239,
|
| 459 |
+
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
|
| 460 |
+
},
|
| 461 |
+
)
|
| 462 |
+
DEFAULT = COCO_V1
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
@register_model()
|
| 466 |
+
@handle_legacy_interface(
|
| 467 |
+
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
|
| 468 |
+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
| 469 |
+
)
|
| 470 |
+
def fasterrcnn_resnet50_fpn(
|
| 471 |
+
*,
|
| 472 |
+
weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
|
| 473 |
+
progress: bool = True,
|
| 474 |
+
num_classes: Optional[int] = None,
|
| 475 |
+
weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
|
| 476 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 477 |
+
**kwargs: Any,
|
| 478 |
+
) -> FasterRCNN:
|
| 479 |
+
"""
|
| 480 |
+
Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
|
| 481 |
+
Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
|
| 482 |
+
paper.
|
| 483 |
+
|
| 484 |
+
.. betastatus:: detection module
|
| 485 |
+
|
| 486 |
+
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
|
| 487 |
+
image, and should be in ``0-1`` range. Different images can have different sizes.
|
| 488 |
+
|
| 489 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 490 |
+
|
| 491 |
+
During training, the model expects both the input tensors and a targets (list of dictionary),
|
| 492 |
+
containing:
|
| 493 |
+
|
| 494 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 495 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 496 |
+
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
|
| 497 |
+
|
| 498 |
+
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
|
| 499 |
+
losses for both the RPN and the R-CNN.
|
| 500 |
+
|
| 501 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 502 |
+
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
|
| 503 |
+
follows, where ``N`` is the number of detections:
|
| 504 |
+
|
| 505 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 506 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 507 |
+
- labels (``Int64Tensor[N]``): the predicted labels for each detection
|
| 508 |
+
- scores (``Tensor[N]``): the scores of each detection
|
| 509 |
+
|
| 510 |
+
For more details on the output, you may refer to :ref:`instance_seg_output`.
|
| 511 |
+
|
| 512 |
+
Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
|
| 513 |
+
|
| 514 |
+
Example::
|
| 515 |
+
|
| 516 |
+
>>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
|
| 517 |
+
>>> # For training
|
| 518 |
+
>>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
|
| 519 |
+
>>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
|
| 520 |
+
>>> labels = torch.randint(1, 91, (4, 11))
|
| 521 |
+
>>> images = list(image for image in images)
|
| 522 |
+
>>> targets = []
|
| 523 |
+
>>> for i in range(len(images)):
|
| 524 |
+
>>> d = {}
|
| 525 |
+
>>> d['boxes'] = boxes[i]
|
| 526 |
+
>>> d['labels'] = labels[i]
|
| 527 |
+
>>> targets.append(d)
|
| 528 |
+
>>> output = model(images, targets)
|
| 529 |
+
>>> # For inference
|
| 530 |
+
>>> model.eval()
|
| 531 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 532 |
+
>>> predictions = model(x)
|
| 533 |
+
>>>
|
| 534 |
+
>>> # optionally, if you want to export the model to ONNX:
|
| 535 |
+
>>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
|
| 539 |
+
pretrained weights to use. See
|
| 540 |
+
:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
|
| 541 |
+
more details, and possible values. By default, no pre-trained
|
| 542 |
+
weights are used.
|
| 543 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 544 |
+
download to stderr. Default is True.
|
| 545 |
+
num_classes (int, optional): number of output classes of the model (including the background)
|
| 546 |
+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
|
| 547 |
+
pretrained weights for the backbone.
|
| 548 |
+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
|
| 549 |
+
final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
|
| 550 |
+
trainable. If ``None`` is passed (the default) this value is set to 3.
|
| 551 |
+
**kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
|
| 552 |
+
base class. Please refer to the `source code
|
| 553 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
|
| 554 |
+
for more details about this class.
|
| 555 |
+
|
| 556 |
+
.. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
|
| 557 |
+
:members:
|
| 558 |
+
"""
|
| 559 |
+
weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
|
| 560 |
+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
| 561 |
+
|
| 562 |
+
if weights is not None:
|
| 563 |
+
weights_backbone = None
|
| 564 |
+
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
| 565 |
+
elif num_classes is None:
|
| 566 |
+
num_classes = 91
|
| 567 |
+
|
| 568 |
+
is_trained = weights is not None or weights_backbone is not None
|
| 569 |
+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
| 570 |
+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
| 571 |
+
|
| 572 |
+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
| 573 |
+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
| 574 |
+
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
|
| 575 |
+
|
| 576 |
+
if weights is not None:
|
| 577 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 578 |
+
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
|
| 579 |
+
overwrite_eps(model, 0.0)
|
| 580 |
+
|
| 581 |
+
return model
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
@register_model()
|
| 585 |
+
@handle_legacy_interface(
|
| 586 |
+
weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
|
| 587 |
+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
| 588 |
+
)
|
| 589 |
+
def fasterrcnn_resnet50_fpn_v2(
|
| 590 |
+
*,
|
| 591 |
+
weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
|
| 592 |
+
progress: bool = True,
|
| 593 |
+
num_classes: Optional[int] = None,
|
| 594 |
+
weights_backbone: Optional[ResNet50_Weights] = None,
|
| 595 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 596 |
+
**kwargs: Any,
|
| 597 |
+
) -> FasterRCNN:
|
| 598 |
+
"""
|
| 599 |
+
Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
|
| 600 |
+
Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
|
| 601 |
+
|
| 602 |
+
.. betastatus:: detection module
|
| 603 |
+
|
| 604 |
+
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
|
| 605 |
+
:func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
|
| 606 |
+
details.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
|
| 610 |
+
pretrained weights to use. See
|
| 611 |
+
:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
|
| 612 |
+
more details, and possible values. By default, no pre-trained
|
| 613 |
+
weights are used.
|
| 614 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 615 |
+
download to stderr. Default is True.
|
| 616 |
+
num_classes (int, optional): number of output classes of the model (including the background)
|
| 617 |
+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
|
| 618 |
+
pretrained weights for the backbone.
|
| 619 |
+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
|
| 620 |
+
final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
|
| 621 |
+
trainable. If ``None`` is passed (the default) this value is set to 3.
|
| 622 |
+
**kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
|
| 623 |
+
base class. Please refer to the `source code
|
| 624 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
|
| 625 |
+
for more details about this class.
|
| 626 |
+
|
| 627 |
+
.. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
|
| 628 |
+
:members:
|
| 629 |
+
"""
|
| 630 |
+
weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
|
| 631 |
+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
| 632 |
+
|
| 633 |
+
if weights is not None:
|
| 634 |
+
weights_backbone = None
|
| 635 |
+
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
| 636 |
+
elif num_classes is None:
|
| 637 |
+
num_classes = 91
|
| 638 |
+
|
| 639 |
+
is_trained = weights is not None or weights_backbone is not None
|
| 640 |
+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
| 641 |
+
|
| 642 |
+
backbone = resnet50(weights=weights_backbone, progress=progress)
|
| 643 |
+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
|
| 644 |
+
rpn_anchor_generator = _default_anchorgen()
|
| 645 |
+
rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
|
| 646 |
+
box_head = FastRCNNConvFCHead(
|
| 647 |
+
(backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
|
| 648 |
+
)
|
| 649 |
+
model = FasterRCNN(
|
| 650 |
+
backbone,
|
| 651 |
+
num_classes=num_classes,
|
| 652 |
+
rpn_anchor_generator=rpn_anchor_generator,
|
| 653 |
+
rpn_head=rpn_head,
|
| 654 |
+
box_head=box_head,
|
| 655 |
+
**kwargs,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
if weights is not None:
|
| 659 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 660 |
+
|
| 661 |
+
return model
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def _fasterrcnn_mobilenet_v3_large_fpn(
|
| 665 |
+
*,
|
| 666 |
+
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
|
| 667 |
+
progress: bool,
|
| 668 |
+
num_classes: Optional[int],
|
| 669 |
+
weights_backbone: Optional[MobileNet_V3_Large_Weights],
|
| 670 |
+
trainable_backbone_layers: Optional[int],
|
| 671 |
+
**kwargs: Any,
|
| 672 |
+
) -> FasterRCNN:
|
| 673 |
+
if weights is not None:
|
| 674 |
+
weights_backbone = None
|
| 675 |
+
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
| 676 |
+
elif num_classes is None:
|
| 677 |
+
num_classes = 91
|
| 678 |
+
|
| 679 |
+
is_trained = weights is not None or weights_backbone is not None
|
| 680 |
+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
|
| 681 |
+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
| 682 |
+
|
| 683 |
+
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
| 684 |
+
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
|
| 685 |
+
anchor_sizes = (
|
| 686 |
+
(
|
| 687 |
+
32,
|
| 688 |
+
64,
|
| 689 |
+
128,
|
| 690 |
+
256,
|
| 691 |
+
512,
|
| 692 |
+
),
|
| 693 |
+
) * 3
|
| 694 |
+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
|
| 695 |
+
model = FasterRCNN(
|
| 696 |
+
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
if weights is not None:
|
| 700 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 701 |
+
|
| 702 |
+
return model
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
@register_model()
|
| 706 |
+
@handle_legacy_interface(
|
| 707 |
+
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
|
| 708 |
+
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
|
| 709 |
+
)
|
| 710 |
+
def fasterrcnn_mobilenet_v3_large_320_fpn(
|
| 711 |
+
*,
|
| 712 |
+
weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
|
| 713 |
+
progress: bool = True,
|
| 714 |
+
num_classes: Optional[int] = None,
|
| 715 |
+
weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
|
| 716 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 717 |
+
**kwargs: Any,
|
| 718 |
+
) -> FasterRCNN:
|
| 719 |
+
"""
|
| 720 |
+
Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
|
| 721 |
+
|
| 722 |
+
.. betastatus:: detection module
|
| 723 |
+
|
| 724 |
+
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
|
| 725 |
+
:func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
|
| 726 |
+
details.
|
| 727 |
+
|
| 728 |
+
Example::
|
| 729 |
+
|
| 730 |
+
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
|
| 731 |
+
>>> model.eval()
|
| 732 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 733 |
+
>>> predictions = model(x)
|
| 734 |
+
|
| 735 |
+
Args:
|
| 736 |
+
weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
|
| 737 |
+
pretrained weights to use. See
|
| 738 |
+
:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
|
| 739 |
+
more details, and possible values. By default, no pre-trained
|
| 740 |
+
weights are used.
|
| 741 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 742 |
+
download to stderr. Default is True.
|
| 743 |
+
num_classes (int, optional): number of output classes of the model (including the background)
|
| 744 |
+
weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
|
| 745 |
+
pretrained weights for the backbone.
|
| 746 |
+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
|
| 747 |
+
final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
|
| 748 |
+
trainable. If ``None`` is passed (the default) this value is set to 3.
|
| 749 |
+
**kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
|
| 750 |
+
base class. Please refer to the `source code
|
| 751 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
|
| 752 |
+
for more details about this class.
|
| 753 |
+
|
| 754 |
+
.. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
|
| 755 |
+
:members:
|
| 756 |
+
"""
|
| 757 |
+
weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
|
| 758 |
+
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
|
| 759 |
+
|
| 760 |
+
defaults = {
|
| 761 |
+
"min_size": 320,
|
| 762 |
+
"max_size": 640,
|
| 763 |
+
"rpn_pre_nms_top_n_test": 150,
|
| 764 |
+
"rpn_post_nms_top_n_test": 150,
|
| 765 |
+
"rpn_score_thresh": 0.05,
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
kwargs = {**defaults, **kwargs}
|
| 769 |
+
return _fasterrcnn_mobilenet_v3_large_fpn(
|
| 770 |
+
weights=weights,
|
| 771 |
+
progress=progress,
|
| 772 |
+
num_classes=num_classes,
|
| 773 |
+
weights_backbone=weights_backbone,
|
| 774 |
+
trainable_backbone_layers=trainable_backbone_layers,
|
| 775 |
+
**kwargs,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
@register_model()
|
| 780 |
+
@handle_legacy_interface(
|
| 781 |
+
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
|
| 782 |
+
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
|
| 783 |
+
)
|
| 784 |
+
def fasterrcnn_mobilenet_v3_large_fpn(
|
| 785 |
+
*,
|
| 786 |
+
weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
|
| 787 |
+
progress: bool = True,
|
| 788 |
+
num_classes: Optional[int] = None,
|
| 789 |
+
weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
|
| 790 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 791 |
+
**kwargs: Any,
|
| 792 |
+
) -> FasterRCNN:
|
| 793 |
+
"""
|
| 794 |
+
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
|
| 795 |
+
|
| 796 |
+
.. betastatus:: detection module
|
| 797 |
+
|
| 798 |
+
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
|
| 799 |
+
:func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
|
| 800 |
+
details.
|
| 801 |
+
|
| 802 |
+
Example::
|
| 803 |
+
|
| 804 |
+
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
|
| 805 |
+
>>> model.eval()
|
| 806 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 807 |
+
>>> predictions = model(x)
|
| 808 |
+
|
| 809 |
+
Args:
|
| 810 |
+
weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
|
| 811 |
+
pretrained weights to use. See
|
| 812 |
+
:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
|
| 813 |
+
more details, and possible values. By default, no pre-trained
|
| 814 |
+
weights are used.
|
| 815 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 816 |
+
download to stderr. Default is True.
|
| 817 |
+
num_classes (int, optional): number of output classes of the model (including the background)
|
| 818 |
+
weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
|
| 819 |
+
pretrained weights for the backbone.
|
| 820 |
+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
|
| 821 |
+
final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
|
| 822 |
+
trainable. If ``None`` is passed (the default) this value is set to 3.
|
| 823 |
+
**kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
|
| 824 |
+
base class. Please refer to the `source code
|
| 825 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
|
| 826 |
+
for more details about this class.
|
| 827 |
+
|
| 828 |
+
.. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
|
| 829 |
+
:members:
|
| 830 |
+
"""
|
| 831 |
+
weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
|
| 832 |
+
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
|
| 833 |
+
|
| 834 |
+
defaults = {
|
| 835 |
+
"rpn_score_thresh": 0.05,
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
kwargs = {**defaults, **kwargs}
|
| 839 |
+
return _fasterrcnn_mobilenet_v3_large_fpn(
|
| 840 |
+
weights=weights,
|
| 841 |
+
progress=progress,
|
| 842 |
+
num_classes=num_classes,
|
| 843 |
+
weights_backbone=weights_backbone,
|
| 844 |
+
trainable_backbone_layers=trainable_backbone_layers,
|
| 845 |
+
**kwargs,
|
| 846 |
+
)
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/fcos.py
ADDED
|
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import warnings
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn, Tensor
|
| 9 |
+
|
| 10 |
+
from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_ops, sigmoid_focal_loss
|
| 11 |
+
from ...ops.feature_pyramid_network import LastLevelP6P7
|
| 12 |
+
from ...transforms._presets import ObjectDetection
|
| 13 |
+
from ...utils import _log_api_usage_once
|
| 14 |
+
from .._api import register_model, Weights, WeightsEnum
|
| 15 |
+
from .._meta import _COCO_CATEGORIES
|
| 16 |
+
from .._utils import _ovewrite_value_param, handle_legacy_interface
|
| 17 |
+
from ..resnet import resnet50, ResNet50_Weights
|
| 18 |
+
from . import _utils as det_utils
|
| 19 |
+
from .anchor_utils import AnchorGenerator
|
| 20 |
+
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
| 21 |
+
from .transform import GeneralizedRCNNTransform
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"FCOS",
|
| 26 |
+
"FCOS_ResNet50_FPN_Weights",
|
| 27 |
+
"fcos_resnet50_fpn",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class FCOSHead(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
A regression and classification head for use in FCOS.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
in_channels (int): number of channels of the input feature
|
| 37 |
+
num_anchors (int): number of anchors to be predicted
|
| 38 |
+
num_classes (int): number of classes to be predicted
|
| 39 |
+
num_convs (Optional[int]): number of conv layer of head. Default: 4.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
__annotations__ = {
|
| 43 |
+
"box_coder": det_utils.BoxLinearCoder,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
|
| 49 |
+
self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs)
|
| 50 |
+
self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs)
|
| 51 |
+
|
| 52 |
+
def compute_loss(
|
| 53 |
+
self,
|
| 54 |
+
targets: List[Dict[str, Tensor]],
|
| 55 |
+
head_outputs: Dict[str, Tensor],
|
| 56 |
+
anchors: List[Tensor],
|
| 57 |
+
matched_idxs: List[Tensor],
|
| 58 |
+
) -> Dict[str, Tensor]:
|
| 59 |
+
|
| 60 |
+
cls_logits = head_outputs["cls_logits"] # [N, HWA, C]
|
| 61 |
+
bbox_regression = head_outputs["bbox_regression"] # [N, HWA, 4]
|
| 62 |
+
bbox_ctrness = head_outputs["bbox_ctrness"] # [N, HWA, 1]
|
| 63 |
+
|
| 64 |
+
all_gt_classes_targets = []
|
| 65 |
+
all_gt_boxes_targets = []
|
| 66 |
+
for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
|
| 67 |
+
if len(targets_per_image["labels"]) == 0:
|
| 68 |
+
gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
|
| 69 |
+
gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
|
| 70 |
+
else:
|
| 71 |
+
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
|
| 72 |
+
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
|
| 73 |
+
gt_classes_targets[matched_idxs_per_image < 0] = -1 # background
|
| 74 |
+
all_gt_classes_targets.append(gt_classes_targets)
|
| 75 |
+
all_gt_boxes_targets.append(gt_boxes_targets)
|
| 76 |
+
|
| 77 |
+
# List[Tensor] to Tensor conversion of `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors`
|
| 78 |
+
all_gt_boxes_targets, all_gt_classes_targets, anchors = (
|
| 79 |
+
torch.stack(all_gt_boxes_targets),
|
| 80 |
+
torch.stack(all_gt_classes_targets),
|
| 81 |
+
torch.stack(anchors),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# compute foregroud
|
| 85 |
+
foregroud_mask = all_gt_classes_targets >= 0
|
| 86 |
+
num_foreground = foregroud_mask.sum().item()
|
| 87 |
+
|
| 88 |
+
# classification loss
|
| 89 |
+
gt_classes_targets = torch.zeros_like(cls_logits)
|
| 90 |
+
gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
|
| 91 |
+
loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
|
| 92 |
+
|
| 93 |
+
# amp issue: pred_boxes need to convert float
|
| 94 |
+
pred_boxes = self.box_coder.decode(bbox_regression, anchors)
|
| 95 |
+
|
| 96 |
+
# regression loss: GIoU loss
|
| 97 |
+
loss_bbox_reg = generalized_box_iou_loss(
|
| 98 |
+
pred_boxes[foregroud_mask],
|
| 99 |
+
all_gt_boxes_targets[foregroud_mask],
|
| 100 |
+
reduction="sum",
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# ctrness loss
|
| 104 |
+
|
| 105 |
+
bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets)
|
| 106 |
+
|
| 107 |
+
if len(bbox_reg_targets) == 0:
|
| 108 |
+
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
|
| 109 |
+
else:
|
| 110 |
+
left_right = bbox_reg_targets[:, :, [0, 2]]
|
| 111 |
+
top_bottom = bbox_reg_targets[:, :, [1, 3]]
|
| 112 |
+
gt_ctrness_targets = torch.sqrt(
|
| 113 |
+
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
|
| 114 |
+
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
|
| 115 |
+
)
|
| 116 |
+
pred_centerness = bbox_ctrness.squeeze(dim=2)
|
| 117 |
+
loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
|
| 118 |
+
pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
"classification": loss_cls / max(1, num_foreground),
|
| 123 |
+
"bbox_regression": loss_bbox_reg / max(1, num_foreground),
|
| 124 |
+
"bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground),
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
|
| 128 |
+
cls_logits = self.classification_head(x)
|
| 129 |
+
bbox_regression, bbox_ctrness = self.regression_head(x)
|
| 130 |
+
return {
|
| 131 |
+
"cls_logits": cls_logits,
|
| 132 |
+
"bbox_regression": bbox_regression,
|
| 133 |
+
"bbox_ctrness": bbox_ctrness,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class FCOSClassificationHead(nn.Module):
|
| 138 |
+
"""
|
| 139 |
+
A classification head for use in FCOS.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
in_channels (int): number of channels of the input feature.
|
| 143 |
+
num_anchors (int): number of anchors to be predicted.
|
| 144 |
+
num_classes (int): number of classes to be predicted.
|
| 145 |
+
num_convs (Optional[int]): number of conv layer. Default: 4.
|
| 146 |
+
prior_probability (Optional[float]): probability of prior. Default: 0.01.
|
| 147 |
+
norm_layer: Module specifying the normalization layer to use.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
in_channels: int,
|
| 153 |
+
num_anchors: int,
|
| 154 |
+
num_classes: int,
|
| 155 |
+
num_convs: int = 4,
|
| 156 |
+
prior_probability: float = 0.01,
|
| 157 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 158 |
+
) -> None:
|
| 159 |
+
super().__init__()
|
| 160 |
+
|
| 161 |
+
self.num_classes = num_classes
|
| 162 |
+
self.num_anchors = num_anchors
|
| 163 |
+
|
| 164 |
+
if norm_layer is None:
|
| 165 |
+
norm_layer = partial(nn.GroupNorm, 32)
|
| 166 |
+
|
| 167 |
+
conv = []
|
| 168 |
+
for _ in range(num_convs):
|
| 169 |
+
conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
|
| 170 |
+
conv.append(norm_layer(in_channels))
|
| 171 |
+
conv.append(nn.ReLU())
|
| 172 |
+
self.conv = nn.Sequential(*conv)
|
| 173 |
+
|
| 174 |
+
for layer in self.conv.children():
|
| 175 |
+
if isinstance(layer, nn.Conv2d):
|
| 176 |
+
torch.nn.init.normal_(layer.weight, std=0.01)
|
| 177 |
+
torch.nn.init.constant_(layer.bias, 0)
|
| 178 |
+
|
| 179 |
+
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
|
| 180 |
+
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
|
| 181 |
+
torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
|
| 182 |
+
|
| 183 |
+
def forward(self, x: List[Tensor]) -> Tensor:
|
| 184 |
+
all_cls_logits = []
|
| 185 |
+
|
| 186 |
+
for features in x:
|
| 187 |
+
cls_logits = self.conv(features)
|
| 188 |
+
cls_logits = self.cls_logits(cls_logits)
|
| 189 |
+
|
| 190 |
+
# Permute classification output from (N, A * K, H, W) to (N, HWA, K).
|
| 191 |
+
N, _, H, W = cls_logits.shape
|
| 192 |
+
cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
|
| 193 |
+
cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
|
| 194 |
+
cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
|
| 195 |
+
|
| 196 |
+
all_cls_logits.append(cls_logits)
|
| 197 |
+
|
| 198 |
+
return torch.cat(all_cls_logits, dim=1)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class FCOSRegressionHead(nn.Module):
|
| 202 |
+
"""
|
| 203 |
+
A regression head for use in FCOS, which combines regression branch and center-ness branch.
|
| 204 |
+
This can obtain better performance.
|
| 205 |
+
|
| 206 |
+
Reference: `FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
in_channels (int): number of channels of the input feature
|
| 210 |
+
num_anchors (int): number of anchors to be predicted
|
| 211 |
+
num_convs (Optional[int]): number of conv layer. Default: 4.
|
| 212 |
+
norm_layer: Module specifying the normalization layer to use.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
in_channels: int,
|
| 218 |
+
num_anchors: int,
|
| 219 |
+
num_convs: int = 4,
|
| 220 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 221 |
+
):
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
if norm_layer is None:
|
| 225 |
+
norm_layer = partial(nn.GroupNorm, 32)
|
| 226 |
+
|
| 227 |
+
conv = []
|
| 228 |
+
for _ in range(num_convs):
|
| 229 |
+
conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
|
| 230 |
+
conv.append(norm_layer(in_channels))
|
| 231 |
+
conv.append(nn.ReLU())
|
| 232 |
+
self.conv = nn.Sequential(*conv)
|
| 233 |
+
|
| 234 |
+
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
|
| 235 |
+
self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1)
|
| 236 |
+
for layer in [self.bbox_reg, self.bbox_ctrness]:
|
| 237 |
+
torch.nn.init.normal_(layer.weight, std=0.01)
|
| 238 |
+
torch.nn.init.zeros_(layer.bias)
|
| 239 |
+
|
| 240 |
+
for layer in self.conv.children():
|
| 241 |
+
if isinstance(layer, nn.Conv2d):
|
| 242 |
+
torch.nn.init.normal_(layer.weight, std=0.01)
|
| 243 |
+
torch.nn.init.zeros_(layer.bias)
|
| 244 |
+
|
| 245 |
+
def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]:
|
| 246 |
+
all_bbox_regression = []
|
| 247 |
+
all_bbox_ctrness = []
|
| 248 |
+
|
| 249 |
+
for features in x:
|
| 250 |
+
bbox_feature = self.conv(features)
|
| 251 |
+
bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature))
|
| 252 |
+
bbox_ctrness = self.bbox_ctrness(bbox_feature)
|
| 253 |
+
|
| 254 |
+
# permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
|
| 255 |
+
N, _, H, W = bbox_regression.shape
|
| 256 |
+
bbox_regression = bbox_regression.view(N, -1, 4, H, W)
|
| 257 |
+
bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
|
| 258 |
+
bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
|
| 259 |
+
all_bbox_regression.append(bbox_regression)
|
| 260 |
+
|
| 261 |
+
# permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1).
|
| 262 |
+
bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W)
|
| 263 |
+
bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2)
|
| 264 |
+
bbox_ctrness = bbox_ctrness.reshape(N, -1, 1)
|
| 265 |
+
all_bbox_ctrness.append(bbox_ctrness)
|
| 266 |
+
|
| 267 |
+
return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class FCOS(nn.Module):
|
| 271 |
+
"""
|
| 272 |
+
Implements FCOS.
|
| 273 |
+
|
| 274 |
+
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
|
| 275 |
+
image, and should be in 0-1 range. Different images can have different sizes.
|
| 276 |
+
|
| 277 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 278 |
+
|
| 279 |
+
During training, the model expects both the input tensors and targets (list of dictionary),
|
| 280 |
+
containing:
|
| 281 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 282 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 283 |
+
- labels (Int64Tensor[N]): the class label for each ground-truth box
|
| 284 |
+
|
| 285 |
+
The model returns a Dict[Tensor] during training, containing the classification, regression
|
| 286 |
+
and centerness losses.
|
| 287 |
+
|
| 288 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 289 |
+
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
|
| 290 |
+
follows:
|
| 291 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 292 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 293 |
+
- labels (Int64Tensor[N]): the predicted labels for each image
|
| 294 |
+
- scores (Tensor[N]): the scores for each prediction
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
backbone (nn.Module): the network used to compute the features for the model.
|
| 298 |
+
It should contain an out_channels attribute, which indicates the number of output
|
| 299 |
+
channels that each feature map has (and it should be the same for all feature maps).
|
| 300 |
+
The backbone should return a single Tensor or an OrderedDict[Tensor].
|
| 301 |
+
num_classes (int): number of output classes of the model (including the background).
|
| 302 |
+
min_size (int): Images are rescaled before feeding them to the backbone:
|
| 303 |
+
we attempt to preserve the aspect ratio and scale the shorter edge
|
| 304 |
+
to ``min_size``. If the resulting longer edge exceeds ``max_size``,
|
| 305 |
+
then downscale so that the longer edge does not exceed ``max_size``.
|
| 306 |
+
This may result in the shorter edge beeing lower than ``min_size``.
|
| 307 |
+
max_size (int): See ``min_size``.
|
| 308 |
+
image_mean (Tuple[float, float, float]): mean values used for input normalization.
|
| 309 |
+
They are generally the mean values of the dataset on which the backbone has been trained
|
| 310 |
+
on
|
| 311 |
+
image_std (Tuple[float, float, float]): std values used for input normalization.
|
| 312 |
+
They are generally the std values of the dataset on which the backbone has been trained on
|
| 313 |
+
anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
|
| 314 |
+
maps. For FCOS, only set one anchor for per position of each level, the width and height equal to
|
| 315 |
+
the stride of feature map, and set aspect ratio = 1.0, so the center of anchor is equivalent to the point
|
| 316 |
+
in FCOS paper.
|
| 317 |
+
head (nn.Module): Module run on top of the feature pyramid.
|
| 318 |
+
Defaults to a module containing a classification and regression module.
|
| 319 |
+
center_sampling_radius (int): radius of the "center" of a groundtruth box,
|
| 320 |
+
within which all anchor points are labeled positive.
|
| 321 |
+
score_thresh (float): Score threshold used for postprocessing the detections.
|
| 322 |
+
nms_thresh (float): NMS threshold used for postprocessing the detections.
|
| 323 |
+
detections_per_img (int): Number of best detections to keep after NMS.
|
| 324 |
+
topk_candidates (int): Number of best detections to keep before NMS.
|
| 325 |
+
|
| 326 |
+
Example:
|
| 327 |
+
|
| 328 |
+
>>> import torch
|
| 329 |
+
>>> import torchvision
|
| 330 |
+
>>> from torchvision.models.detection import FCOS
|
| 331 |
+
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
|
| 332 |
+
>>> # load a pre-trained model for classification and return
|
| 333 |
+
>>> # only the features
|
| 334 |
+
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
|
| 335 |
+
>>> # FCOS needs to know the number of
|
| 336 |
+
>>> # output channels in a backbone. For mobilenet_v2, it's 1280,
|
| 337 |
+
>>> # so we need to add it here
|
| 338 |
+
>>> backbone.out_channels = 1280
|
| 339 |
+
>>>
|
| 340 |
+
>>> # let's make the network generate 5 x 3 anchors per spatial
|
| 341 |
+
>>> # location, with 5 different sizes and 3 different aspect
|
| 342 |
+
>>> # ratios. We have a Tuple[Tuple[int]] because each feature
|
| 343 |
+
>>> # map could potentially have different sizes and
|
| 344 |
+
>>> # aspect ratios
|
| 345 |
+
>>> anchor_generator = AnchorGenerator(
|
| 346 |
+
>>> sizes=((8,), (16,), (32,), (64,), (128,)),
|
| 347 |
+
>>> aspect_ratios=((1.0,),)
|
| 348 |
+
>>> )
|
| 349 |
+
>>>
|
| 350 |
+
>>> # put the pieces together inside a FCOS model
|
| 351 |
+
>>> model = FCOS(
|
| 352 |
+
>>> backbone,
|
| 353 |
+
>>> num_classes=80,
|
| 354 |
+
>>> anchor_generator=anchor_generator,
|
| 355 |
+
>>> )
|
| 356 |
+
>>> model.eval()
|
| 357 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 358 |
+
>>> predictions = model(x)
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
__annotations__ = {
|
| 362 |
+
"box_coder": det_utils.BoxLinearCoder,
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
def __init__(
|
| 366 |
+
self,
|
| 367 |
+
backbone: nn.Module,
|
| 368 |
+
num_classes: int,
|
| 369 |
+
# transform parameters
|
| 370 |
+
min_size: int = 800,
|
| 371 |
+
max_size: int = 1333,
|
| 372 |
+
image_mean: Optional[List[float]] = None,
|
| 373 |
+
image_std: Optional[List[float]] = None,
|
| 374 |
+
# Anchor parameters
|
| 375 |
+
anchor_generator: Optional[AnchorGenerator] = None,
|
| 376 |
+
head: Optional[nn.Module] = None,
|
| 377 |
+
center_sampling_radius: float = 1.5,
|
| 378 |
+
score_thresh: float = 0.2,
|
| 379 |
+
nms_thresh: float = 0.6,
|
| 380 |
+
detections_per_img: int = 100,
|
| 381 |
+
topk_candidates: int = 1000,
|
| 382 |
+
**kwargs,
|
| 383 |
+
):
|
| 384 |
+
super().__init__()
|
| 385 |
+
_log_api_usage_once(self)
|
| 386 |
+
|
| 387 |
+
if not hasattr(backbone, "out_channels"):
|
| 388 |
+
raise ValueError(
|
| 389 |
+
"backbone should contain an attribute out_channels "
|
| 390 |
+
"specifying the number of output channels (assumed to be the "
|
| 391 |
+
"same for all the levels)"
|
| 392 |
+
)
|
| 393 |
+
self.backbone = backbone
|
| 394 |
+
|
| 395 |
+
if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
|
| 396 |
+
raise TypeError(
|
| 397 |
+
f"anchor_generator should be of type AnchorGenerator or None, instead got {type(anchor_generator)}"
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
if anchor_generator is None:
|
| 401 |
+
anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map
|
| 402 |
+
aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor
|
| 403 |
+
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
|
| 404 |
+
self.anchor_generator = anchor_generator
|
| 405 |
+
if self.anchor_generator.num_anchors_per_location()[0] != 1:
|
| 406 |
+
raise ValueError(
|
| 407 |
+
f"anchor_generator.num_anchors_per_location()[0] should be 1 instead of {anchor_generator.num_anchors_per_location()[0]}"
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if head is None:
|
| 411 |
+
head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
|
| 412 |
+
self.head = head
|
| 413 |
+
|
| 414 |
+
self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
|
| 415 |
+
|
| 416 |
+
if image_mean is None:
|
| 417 |
+
image_mean = [0.485, 0.456, 0.406]
|
| 418 |
+
if image_std is None:
|
| 419 |
+
image_std = [0.229, 0.224, 0.225]
|
| 420 |
+
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
|
| 421 |
+
|
| 422 |
+
self.center_sampling_radius = center_sampling_radius
|
| 423 |
+
self.score_thresh = score_thresh
|
| 424 |
+
self.nms_thresh = nms_thresh
|
| 425 |
+
self.detections_per_img = detections_per_img
|
| 426 |
+
self.topk_candidates = topk_candidates
|
| 427 |
+
|
| 428 |
+
# used only on torchscript mode
|
| 429 |
+
self._has_warned = False
|
| 430 |
+
|
| 431 |
+
@torch.jit.unused
|
| 432 |
+
def eager_outputs(
|
| 433 |
+
self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
|
| 434 |
+
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
|
| 435 |
+
if self.training:
|
| 436 |
+
return losses
|
| 437 |
+
|
| 438 |
+
return detections
|
| 439 |
+
|
| 440 |
+
def compute_loss(
|
| 441 |
+
self,
|
| 442 |
+
targets: List[Dict[str, Tensor]],
|
| 443 |
+
head_outputs: Dict[str, Tensor],
|
| 444 |
+
anchors: List[Tensor],
|
| 445 |
+
num_anchors_per_level: List[int],
|
| 446 |
+
) -> Dict[str, Tensor]:
|
| 447 |
+
matched_idxs = []
|
| 448 |
+
for anchors_per_image, targets_per_image in zip(anchors, targets):
|
| 449 |
+
if targets_per_image["boxes"].numel() == 0:
|
| 450 |
+
matched_idxs.append(
|
| 451 |
+
torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
|
| 452 |
+
)
|
| 453 |
+
continue
|
| 454 |
+
|
| 455 |
+
gt_boxes = targets_per_image["boxes"]
|
| 456 |
+
gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2
|
| 457 |
+
anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N
|
| 458 |
+
anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0]
|
| 459 |
+
# center sampling: anchor point must be close enough to gt center.
|
| 460 |
+
pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
|
| 461 |
+
dim=2
|
| 462 |
+
).values < self.center_sampling_radius * anchor_sizes[:, None]
|
| 463 |
+
# compute pairwise distance between N points and M boxes
|
| 464 |
+
x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
|
| 465 |
+
x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
|
| 466 |
+
pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
|
| 467 |
+
|
| 468 |
+
# anchor point must be inside gt
|
| 469 |
+
pairwise_match &= pairwise_dist.min(dim=2).values > 0
|
| 470 |
+
|
| 471 |
+
# each anchor is only responsible for certain scale range.
|
| 472 |
+
lower_bound = anchor_sizes * 4
|
| 473 |
+
lower_bound[: num_anchors_per_level[0]] = 0
|
| 474 |
+
upper_bound = anchor_sizes * 8
|
| 475 |
+
upper_bound[-num_anchors_per_level[-1] :] = float("inf")
|
| 476 |
+
pairwise_dist = pairwise_dist.max(dim=2).values
|
| 477 |
+
pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None])
|
| 478 |
+
|
| 479 |
+
# match the GT box with minimum area, if there are multiple GT matches
|
| 480 |
+
gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
|
| 481 |
+
pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
|
| 482 |
+
min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
|
| 483 |
+
matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1
|
| 484 |
+
|
| 485 |
+
matched_idxs.append(matched_idx)
|
| 486 |
+
|
| 487 |
+
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
|
| 488 |
+
|
| 489 |
+
def postprocess_detections(
|
| 490 |
+
self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]]
|
| 491 |
+
) -> List[Dict[str, Tensor]]:
|
| 492 |
+
class_logits = head_outputs["cls_logits"]
|
| 493 |
+
box_regression = head_outputs["bbox_regression"]
|
| 494 |
+
box_ctrness = head_outputs["bbox_ctrness"]
|
| 495 |
+
|
| 496 |
+
num_images = len(image_shapes)
|
| 497 |
+
|
| 498 |
+
detections: List[Dict[str, Tensor]] = []
|
| 499 |
+
|
| 500 |
+
for index in range(num_images):
|
| 501 |
+
box_regression_per_image = [br[index] for br in box_regression]
|
| 502 |
+
logits_per_image = [cl[index] for cl in class_logits]
|
| 503 |
+
box_ctrness_per_image = [bc[index] for bc in box_ctrness]
|
| 504 |
+
anchors_per_image, image_shape = anchors[index], image_shapes[index]
|
| 505 |
+
|
| 506 |
+
image_boxes = []
|
| 507 |
+
image_scores = []
|
| 508 |
+
image_labels = []
|
| 509 |
+
|
| 510 |
+
for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip(
|
| 511 |
+
box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image
|
| 512 |
+
):
|
| 513 |
+
num_classes = logits_per_level.shape[-1]
|
| 514 |
+
|
| 515 |
+
# remove low scoring boxes
|
| 516 |
+
scores_per_level = torch.sqrt(
|
| 517 |
+
torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)
|
| 518 |
+
).flatten()
|
| 519 |
+
keep_idxs = scores_per_level > self.score_thresh
|
| 520 |
+
scores_per_level = scores_per_level[keep_idxs]
|
| 521 |
+
topk_idxs = torch.where(keep_idxs)[0]
|
| 522 |
+
|
| 523 |
+
# keep only topk scoring predictions
|
| 524 |
+
num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
|
| 525 |
+
scores_per_level, idxs = scores_per_level.topk(num_topk)
|
| 526 |
+
topk_idxs = topk_idxs[idxs]
|
| 527 |
+
|
| 528 |
+
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
|
| 529 |
+
labels_per_level = topk_idxs % num_classes
|
| 530 |
+
|
| 531 |
+
boxes_per_level = self.box_coder.decode(
|
| 532 |
+
box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
|
| 533 |
+
)
|
| 534 |
+
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
|
| 535 |
+
|
| 536 |
+
image_boxes.append(boxes_per_level)
|
| 537 |
+
image_scores.append(scores_per_level)
|
| 538 |
+
image_labels.append(labels_per_level)
|
| 539 |
+
|
| 540 |
+
image_boxes = torch.cat(image_boxes, dim=0)
|
| 541 |
+
image_scores = torch.cat(image_scores, dim=0)
|
| 542 |
+
image_labels = torch.cat(image_labels, dim=0)
|
| 543 |
+
|
| 544 |
+
# non-maximum suppression
|
| 545 |
+
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
| 546 |
+
keep = keep[: self.detections_per_img]
|
| 547 |
+
|
| 548 |
+
detections.append(
|
| 549 |
+
{
|
| 550 |
+
"boxes": image_boxes[keep],
|
| 551 |
+
"scores": image_scores[keep],
|
| 552 |
+
"labels": image_labels[keep],
|
| 553 |
+
}
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
return detections
|
| 557 |
+
|
| 558 |
+
def forward(
|
| 559 |
+
self,
|
| 560 |
+
images: List[Tensor],
|
| 561 |
+
targets: Optional[List[Dict[str, Tensor]]] = None,
|
| 562 |
+
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
|
| 563 |
+
"""
|
| 564 |
+
Args:
|
| 565 |
+
images (list[Tensor]): images to be processed
|
| 566 |
+
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
|
| 567 |
+
|
| 568 |
+
Returns:
|
| 569 |
+
result (list[BoxList] or dict[Tensor]): the output from the model.
|
| 570 |
+
During training, it returns a dict[Tensor] which contains the losses.
|
| 571 |
+
During testing, it returns list[BoxList] contains additional fields
|
| 572 |
+
like `scores`, `labels` and `mask` (for Mask R-CNN models).
|
| 573 |
+
"""
|
| 574 |
+
if self.training:
|
| 575 |
+
|
| 576 |
+
if targets is None:
|
| 577 |
+
torch._assert(False, "targets should not be none when in training mode")
|
| 578 |
+
else:
|
| 579 |
+
for target in targets:
|
| 580 |
+
boxes = target["boxes"]
|
| 581 |
+
torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
|
| 582 |
+
torch._assert(
|
| 583 |
+
len(boxes.shape) == 2 and boxes.shape[-1] == 4,
|
| 584 |
+
f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
original_image_sizes: List[Tuple[int, int]] = []
|
| 588 |
+
for img in images:
|
| 589 |
+
val = img.shape[-2:]
|
| 590 |
+
torch._assert(
|
| 591 |
+
len(val) == 2,
|
| 592 |
+
f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
|
| 593 |
+
)
|
| 594 |
+
original_image_sizes.append((val[0], val[1]))
|
| 595 |
+
|
| 596 |
+
# transform the input
|
| 597 |
+
images, targets = self.transform(images, targets)
|
| 598 |
+
|
| 599 |
+
# Check for degenerate boxes
|
| 600 |
+
if targets is not None:
|
| 601 |
+
for target_idx, target in enumerate(targets):
|
| 602 |
+
boxes = target["boxes"]
|
| 603 |
+
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
|
| 604 |
+
if degenerate_boxes.any():
|
| 605 |
+
# print the first degenerate box
|
| 606 |
+
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
|
| 607 |
+
degen_bb: List[float] = boxes[bb_idx].tolist()
|
| 608 |
+
torch._assert(
|
| 609 |
+
False,
|
| 610 |
+
f"All bounding boxes should have positive height and width. Found invalid box {degen_bb} for target at index {target_idx}.",
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# get the features from the backbone
|
| 614 |
+
features = self.backbone(images.tensors)
|
| 615 |
+
if isinstance(features, torch.Tensor):
|
| 616 |
+
features = OrderedDict([("0", features)])
|
| 617 |
+
|
| 618 |
+
features = list(features.values())
|
| 619 |
+
|
| 620 |
+
# compute the fcos heads outputs using the features
|
| 621 |
+
head_outputs = self.head(features)
|
| 622 |
+
|
| 623 |
+
# create the set of anchors
|
| 624 |
+
anchors = self.anchor_generator(images, features)
|
| 625 |
+
# recover level sizes
|
| 626 |
+
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
|
| 627 |
+
|
| 628 |
+
losses = {}
|
| 629 |
+
detections: List[Dict[str, Tensor]] = []
|
| 630 |
+
if self.training:
|
| 631 |
+
if targets is None:
|
| 632 |
+
torch._assert(False, "targets should not be none when in training mode")
|
| 633 |
+
else:
|
| 634 |
+
# compute the losses
|
| 635 |
+
losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level)
|
| 636 |
+
else:
|
| 637 |
+
# split outputs per level
|
| 638 |
+
split_head_outputs: Dict[str, List[Tensor]] = {}
|
| 639 |
+
for k in head_outputs:
|
| 640 |
+
split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
|
| 641 |
+
split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
|
| 642 |
+
|
| 643 |
+
# compute the detections
|
| 644 |
+
detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
|
| 645 |
+
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
|
| 646 |
+
|
| 647 |
+
if torch.jit.is_scripting():
|
| 648 |
+
if not self._has_warned:
|
| 649 |
+
warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting")
|
| 650 |
+
self._has_warned = True
|
| 651 |
+
return losses, detections
|
| 652 |
+
return self.eager_outputs(losses, detections)
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
class FCOS_ResNet50_FPN_Weights(WeightsEnum):
|
| 656 |
+
COCO_V1 = Weights(
|
| 657 |
+
url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
|
| 658 |
+
transforms=ObjectDetection,
|
| 659 |
+
meta={
|
| 660 |
+
"num_params": 32269600,
|
| 661 |
+
"categories": _COCO_CATEGORIES,
|
| 662 |
+
"min_size": (1, 1),
|
| 663 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn",
|
| 664 |
+
"_metrics": {
|
| 665 |
+
"COCO-val2017": {
|
| 666 |
+
"box_map": 39.2,
|
| 667 |
+
}
|
| 668 |
+
},
|
| 669 |
+
"_ops": 128.207,
|
| 670 |
+
"_file_size": 123.608,
|
| 671 |
+
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
|
| 672 |
+
},
|
| 673 |
+
)
|
| 674 |
+
DEFAULT = COCO_V1
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
@register_model()
|
| 678 |
+
@handle_legacy_interface(
|
| 679 |
+
weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
|
| 680 |
+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
| 681 |
+
)
|
| 682 |
+
def fcos_resnet50_fpn(
|
| 683 |
+
*,
|
| 684 |
+
weights: Optional[FCOS_ResNet50_FPN_Weights] = None,
|
| 685 |
+
progress: bool = True,
|
| 686 |
+
num_classes: Optional[int] = None,
|
| 687 |
+
weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
|
| 688 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 689 |
+
**kwargs: Any,
|
| 690 |
+
) -> FCOS:
|
| 691 |
+
"""
|
| 692 |
+
Constructs a FCOS model with a ResNet-50-FPN backbone.
|
| 693 |
+
|
| 694 |
+
.. betastatus:: detection module
|
| 695 |
+
|
| 696 |
+
Reference: `FCOS: Fully Convolutional One-Stage Object Detection <https://arxiv.org/abs/1904.01355>`_.
|
| 697 |
+
`FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.
|
| 698 |
+
|
| 699 |
+
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
|
| 700 |
+
image, and should be in ``0-1`` range. Different images can have different sizes.
|
| 701 |
+
|
| 702 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 703 |
+
|
| 704 |
+
During training, the model expects both the input tensors and targets (list of dictionary),
|
| 705 |
+
containing:
|
| 706 |
+
|
| 707 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 708 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 709 |
+
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
|
| 710 |
+
|
| 711 |
+
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
|
| 712 |
+
losses.
|
| 713 |
+
|
| 714 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 715 |
+
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
|
| 716 |
+
follows, where ``N`` is the number of detections:
|
| 717 |
+
|
| 718 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 719 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 720 |
+
- labels (``Int64Tensor[N]``): the predicted labels for each detection
|
| 721 |
+
- scores (``Tensor[N]``): the scores of each detection
|
| 722 |
+
|
| 723 |
+
For more details on the output, you may refer to :ref:`instance_seg_output`.
|
| 724 |
+
|
| 725 |
+
Example:
|
| 726 |
+
|
| 727 |
+
>>> model = torchvision.models.detection.fcos_resnet50_fpn(weights=FCOS_ResNet50_FPN_Weights.DEFAULT)
|
| 728 |
+
>>> model.eval()
|
| 729 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 730 |
+
>>> predictions = model(x)
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
weights (:class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`, optional): The
|
| 734 |
+
pretrained weights to use. See
|
| 735 |
+
:class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`
|
| 736 |
+
below for more details, and possible values. By default, no
|
| 737 |
+
pre-trained weights are used.
|
| 738 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 739 |
+
num_classes (int, optional): number of output classes of the model (including the background)
|
| 740 |
+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
|
| 741 |
+
the backbone.
|
| 742 |
+
trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting
|
| 743 |
+
from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
|
| 744 |
+
trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
|
| 745 |
+
**kwargs: parameters passed to the ``torchvision.models.detection.FCOS``
|
| 746 |
+
base class. Please refer to the `source code
|
| 747 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/fcos.py>`_
|
| 748 |
+
for more details about this class.
|
| 749 |
+
|
| 750 |
+
.. autoclass:: torchvision.models.detection.FCOS_ResNet50_FPN_Weights
|
| 751 |
+
:members:
|
| 752 |
+
"""
|
| 753 |
+
weights = FCOS_ResNet50_FPN_Weights.verify(weights)
|
| 754 |
+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
| 755 |
+
|
| 756 |
+
if weights is not None:
|
| 757 |
+
weights_backbone = None
|
| 758 |
+
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
| 759 |
+
elif num_classes is None:
|
| 760 |
+
num_classes = 91
|
| 761 |
+
|
| 762 |
+
is_trained = weights is not None or weights_backbone is not None
|
| 763 |
+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
| 764 |
+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
| 765 |
+
|
| 766 |
+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
| 767 |
+
backbone = _resnet_fpn_extractor(
|
| 768 |
+
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
|
| 769 |
+
)
|
| 770 |
+
model = FCOS(backbone, num_classes, **kwargs)
|
| 771 |
+
|
| 772 |
+
if weights is not None:
|
| 773 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 774 |
+
|
| 775 |
+
return model
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/generalized_rcnn.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Implements the Generalized R-CNN framework
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import warnings
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn, Tensor
|
| 11 |
+
|
| 12 |
+
from ...utils import _log_api_usage_once
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GeneralizedRCNN(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Main class for Generalized R-CNN.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
backbone (nn.Module):
|
| 21 |
+
rpn (nn.Module):
|
| 22 |
+
roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
|
| 23 |
+
detections / masks from it.
|
| 24 |
+
transform (nn.Module): performs the data transformation from the inputs to feed into
|
| 25 |
+
the model
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
_log_api_usage_once(self)
|
| 31 |
+
self.transform = transform
|
| 32 |
+
self.backbone = backbone
|
| 33 |
+
self.rpn = rpn
|
| 34 |
+
self.roi_heads = roi_heads
|
| 35 |
+
# used only on torchscript mode
|
| 36 |
+
self._has_warned = False
|
| 37 |
+
|
| 38 |
+
@torch.jit.unused
|
| 39 |
+
def eager_outputs(self, losses, detections):
|
| 40 |
+
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
|
| 41 |
+
if self.training:
|
| 42 |
+
return losses
|
| 43 |
+
|
| 44 |
+
return detections
|
| 45 |
+
|
| 46 |
+
def forward(self, images, targets=None):
|
| 47 |
+
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
images (list[Tensor]): images to be processed
|
| 51 |
+
targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
result (list[BoxList] or dict[Tensor]): the output from the model.
|
| 55 |
+
During training, it returns a dict[Tensor] which contains the losses.
|
| 56 |
+
During testing, it returns list[BoxList] contains additional fields
|
| 57 |
+
like `scores`, `labels` and `mask` (for Mask R-CNN models).
|
| 58 |
+
|
| 59 |
+
"""
|
| 60 |
+
if self.training:
|
| 61 |
+
if targets is None:
|
| 62 |
+
torch._assert(False, "targets should not be none when in training mode")
|
| 63 |
+
else:
|
| 64 |
+
for target in targets:
|
| 65 |
+
boxes = target["boxes"]
|
| 66 |
+
if isinstance(boxes, torch.Tensor):
|
| 67 |
+
torch._assert(
|
| 68 |
+
len(boxes.shape) == 2 and boxes.shape[-1] == 4,
|
| 69 |
+
f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
|
| 73 |
+
|
| 74 |
+
original_image_sizes: List[Tuple[int, int]] = []
|
| 75 |
+
for img in images:
|
| 76 |
+
val = img.shape[-2:]
|
| 77 |
+
torch._assert(
|
| 78 |
+
len(val) == 2,
|
| 79 |
+
f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
|
| 80 |
+
)
|
| 81 |
+
original_image_sizes.append((val[0], val[1]))
|
| 82 |
+
|
| 83 |
+
images, targets = self.transform(images, targets)
|
| 84 |
+
|
| 85 |
+
# Check for degenerate boxes
|
| 86 |
+
# TODO: Move this to a function
|
| 87 |
+
if targets is not None:
|
| 88 |
+
for target_idx, target in enumerate(targets):
|
| 89 |
+
boxes = target["boxes"]
|
| 90 |
+
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
|
| 91 |
+
if degenerate_boxes.any():
|
| 92 |
+
# print the first degenerate box
|
| 93 |
+
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
|
| 94 |
+
degen_bb: List[float] = boxes[bb_idx].tolist()
|
| 95 |
+
torch._assert(
|
| 96 |
+
False,
|
| 97 |
+
"All bounding boxes should have positive height and width."
|
| 98 |
+
f" Found invalid box {degen_bb} for target at index {target_idx}.",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
features = self.backbone(images.tensors)
|
| 102 |
+
if isinstance(features, torch.Tensor):
|
| 103 |
+
features = OrderedDict([("0", features)])
|
| 104 |
+
proposals, proposal_losses = self.rpn(images, features, targets)
|
| 105 |
+
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
|
| 106 |
+
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
|
| 107 |
+
|
| 108 |
+
losses = {}
|
| 109 |
+
losses.update(detector_losses)
|
| 110 |
+
losses.update(proposal_losses)
|
| 111 |
+
|
| 112 |
+
if torch.jit.is_scripting():
|
| 113 |
+
if not self._has_warned:
|
| 114 |
+
warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
|
| 115 |
+
self._has_warned = True
|
| 116 |
+
return losses, detections
|
| 117 |
+
else:
|
| 118 |
+
return self.eager_outputs(losses, detections)
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/image_list.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ImageList:
|
| 8 |
+
"""
|
| 9 |
+
Structure that holds a list of images (of possibly
|
| 10 |
+
varying sizes) as a single tensor.
|
| 11 |
+
This works by padding the images to the same size,
|
| 12 |
+
and storing in a field the original sizes of each image
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
tensors (tensor): Tensor containing images.
|
| 16 |
+
image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
|
| 20 |
+
self.tensors = tensors
|
| 21 |
+
self.image_sizes = image_sizes
|
| 22 |
+
|
| 23 |
+
def to(self, device: torch.device) -> "ImageList":
|
| 24 |
+
cast_tensor = self.tensors.to(device)
|
| 25 |
+
return ImageList(cast_tensor, self.image_sizes)
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/keypoint_rcnn.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torchvision.ops import MultiScaleRoIAlign
|
| 6 |
+
|
| 7 |
+
from ...ops import misc as misc_nn_ops
|
| 8 |
+
from ...transforms._presets import ObjectDetection
|
| 9 |
+
from .._api import register_model, Weights, WeightsEnum
|
| 10 |
+
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
|
| 11 |
+
from .._utils import _ovewrite_value_param, handle_legacy_interface
|
| 12 |
+
from ..resnet import resnet50, ResNet50_Weights
|
| 13 |
+
from ._utils import overwrite_eps
|
| 14 |
+
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
| 15 |
+
from .faster_rcnn import FasterRCNN
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"KeypointRCNN",
|
| 20 |
+
"KeypointRCNN_ResNet50_FPN_Weights",
|
| 21 |
+
"keypointrcnn_resnet50_fpn",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class KeypointRCNN(FasterRCNN):
|
| 26 |
+
"""
|
| 27 |
+
Implements Keypoint R-CNN.
|
| 28 |
+
|
| 29 |
+
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
|
| 30 |
+
image, and should be in 0-1 range. Different images can have different sizes.
|
| 31 |
+
|
| 32 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 33 |
+
|
| 34 |
+
During training, the model expects both the input tensors and targets (list of dictionary),
|
| 35 |
+
containing:
|
| 36 |
+
|
| 37 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 38 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 39 |
+
- labels (Int64Tensor[N]): the class label for each ground-truth box
|
| 40 |
+
- keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
|
| 41 |
+
format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
|
| 42 |
+
|
| 43 |
+
The model returns a Dict[Tensor] during training, containing the classification and regression
|
| 44 |
+
losses for both the RPN and the R-CNN, and the keypoint loss.
|
| 45 |
+
|
| 46 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 47 |
+
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
|
| 48 |
+
follows:
|
| 49 |
+
|
| 50 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 51 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 52 |
+
- labels (Int64Tensor[N]): the predicted labels for each image
|
| 53 |
+
- scores (Tensor[N]): the scores or each prediction
|
| 54 |
+
- keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
backbone (nn.Module): the network used to compute the features for the model.
|
| 58 |
+
It should contain an out_channels attribute, which indicates the number of output
|
| 59 |
+
channels that each feature map has (and it should be the same for all feature maps).
|
| 60 |
+
The backbone should return a single Tensor or and OrderedDict[Tensor].
|
| 61 |
+
num_classes (int): number of output classes of the model (including the background).
|
| 62 |
+
If box_predictor is specified, num_classes should be None.
|
| 63 |
+
min_size (int): Images are rescaled before feeding them to the backbone:
|
| 64 |
+
we attempt to preserve the aspect ratio and scale the shorter edge
|
| 65 |
+
to ``min_size``. If the resulting longer edge exceeds ``max_size``,
|
| 66 |
+
then downscale so that the longer edge does not exceed ``max_size``.
|
| 67 |
+
This may result in the shorter edge beeing lower than ``min_size``.
|
| 68 |
+
max_size (int): See ``min_size``.
|
| 69 |
+
image_mean (Tuple[float, float, float]): mean values used for input normalization.
|
| 70 |
+
They are generally the mean values of the dataset on which the backbone has been trained
|
| 71 |
+
on
|
| 72 |
+
image_std (Tuple[float, float, float]): std values used for input normalization.
|
| 73 |
+
They are generally the std values of the dataset on which the backbone has been trained on
|
| 74 |
+
rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
|
| 75 |
+
maps.
|
| 76 |
+
rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
|
| 77 |
+
rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
|
| 78 |
+
rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
|
| 79 |
+
rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
|
| 80 |
+
rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
|
| 81 |
+
rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
|
| 82 |
+
rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
|
| 83 |
+
considered as positive during training of the RPN.
|
| 84 |
+
rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
|
| 85 |
+
considered as negative during training of the RPN.
|
| 86 |
+
rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
|
| 87 |
+
for computing the loss
|
| 88 |
+
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
|
| 89 |
+
of the RPN
|
| 90 |
+
rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
|
| 91 |
+
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
|
| 92 |
+
the locations indicated by the bounding boxes
|
| 93 |
+
box_head (nn.Module): module that takes the cropped feature maps as input
|
| 94 |
+
box_predictor (nn.Module): module that takes the output of box_head and returns the
|
| 95 |
+
classification logits and box regression deltas.
|
| 96 |
+
box_score_thresh (float): during inference, only return proposals with a classification score
|
| 97 |
+
greater than box_score_thresh
|
| 98 |
+
box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
|
| 99 |
+
box_detections_per_img (int): maximum number of detections per image, for all classes.
|
| 100 |
+
box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
|
| 101 |
+
considered as positive during training of the classification head
|
| 102 |
+
box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
|
| 103 |
+
considered as negative during training of the classification head
|
| 104 |
+
box_batch_size_per_image (int): number of proposals that are sampled during training of the
|
| 105 |
+
classification head
|
| 106 |
+
box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
|
| 107 |
+
of the classification head
|
| 108 |
+
bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
|
| 109 |
+
bounding boxes
|
| 110 |
+
keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
|
| 111 |
+
the locations indicated by the bounding boxes, which will be used for the keypoint head.
|
| 112 |
+
keypoint_head (nn.Module): module that takes the cropped feature maps as input
|
| 113 |
+
keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
|
| 114 |
+
heatmap logits
|
| 115 |
+
|
| 116 |
+
Example::
|
| 117 |
+
|
| 118 |
+
>>> import torch
|
| 119 |
+
>>> import torchvision
|
| 120 |
+
>>> from torchvision.models.detection import KeypointRCNN
|
| 121 |
+
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
|
| 122 |
+
>>>
|
| 123 |
+
>>> # load a pre-trained model for classification and return
|
| 124 |
+
>>> # only the features
|
| 125 |
+
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
|
| 126 |
+
>>> # KeypointRCNN needs to know the number of
|
| 127 |
+
>>> # output channels in a backbone. For mobilenet_v2, it's 1280,
|
| 128 |
+
>>> # so we need to add it here
|
| 129 |
+
>>> backbone.out_channels = 1280
|
| 130 |
+
>>>
|
| 131 |
+
>>> # let's make the RPN generate 5 x 3 anchors per spatial
|
| 132 |
+
>>> # location, with 5 different sizes and 3 different aspect
|
| 133 |
+
>>> # ratios. We have a Tuple[Tuple[int]] because each feature
|
| 134 |
+
>>> # map could potentially have different sizes and
|
| 135 |
+
>>> # aspect ratios
|
| 136 |
+
>>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
|
| 137 |
+
>>> aspect_ratios=((0.5, 1.0, 2.0),))
|
| 138 |
+
>>>
|
| 139 |
+
>>> # let's define what are the feature maps that we will
|
| 140 |
+
>>> # use to perform the region of interest cropping, as well as
|
| 141 |
+
>>> # the size of the crop after rescaling.
|
| 142 |
+
>>> # if your backbone returns a Tensor, featmap_names is expected to
|
| 143 |
+
>>> # be ['0']. More generally, the backbone should return an
|
| 144 |
+
>>> # OrderedDict[Tensor], and in featmap_names you can choose which
|
| 145 |
+
>>> # feature maps to use.
|
| 146 |
+
>>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
|
| 147 |
+
>>> output_size=7,
|
| 148 |
+
>>> sampling_ratio=2)
|
| 149 |
+
>>>
|
| 150 |
+
>>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
|
| 151 |
+
>>> output_size=14,
|
| 152 |
+
>>> sampling_ratio=2)
|
| 153 |
+
>>> # put the pieces together inside a KeypointRCNN model
|
| 154 |
+
>>> model = KeypointRCNN(backbone,
|
| 155 |
+
>>> num_classes=2,
|
| 156 |
+
>>> rpn_anchor_generator=anchor_generator,
|
| 157 |
+
>>> box_roi_pool=roi_pooler,
|
| 158 |
+
>>> keypoint_roi_pool=keypoint_roi_pooler)
|
| 159 |
+
>>> model.eval()
|
| 160 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 161 |
+
>>> predictions = model(x)
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
backbone,
|
| 167 |
+
num_classes=None,
|
| 168 |
+
# transform parameters
|
| 169 |
+
min_size=None,
|
| 170 |
+
max_size=1333,
|
| 171 |
+
image_mean=None,
|
| 172 |
+
image_std=None,
|
| 173 |
+
# RPN parameters
|
| 174 |
+
rpn_anchor_generator=None,
|
| 175 |
+
rpn_head=None,
|
| 176 |
+
rpn_pre_nms_top_n_train=2000,
|
| 177 |
+
rpn_pre_nms_top_n_test=1000,
|
| 178 |
+
rpn_post_nms_top_n_train=2000,
|
| 179 |
+
rpn_post_nms_top_n_test=1000,
|
| 180 |
+
rpn_nms_thresh=0.7,
|
| 181 |
+
rpn_fg_iou_thresh=0.7,
|
| 182 |
+
rpn_bg_iou_thresh=0.3,
|
| 183 |
+
rpn_batch_size_per_image=256,
|
| 184 |
+
rpn_positive_fraction=0.5,
|
| 185 |
+
rpn_score_thresh=0.0,
|
| 186 |
+
# Box parameters
|
| 187 |
+
box_roi_pool=None,
|
| 188 |
+
box_head=None,
|
| 189 |
+
box_predictor=None,
|
| 190 |
+
box_score_thresh=0.05,
|
| 191 |
+
box_nms_thresh=0.5,
|
| 192 |
+
box_detections_per_img=100,
|
| 193 |
+
box_fg_iou_thresh=0.5,
|
| 194 |
+
box_bg_iou_thresh=0.5,
|
| 195 |
+
box_batch_size_per_image=512,
|
| 196 |
+
box_positive_fraction=0.25,
|
| 197 |
+
bbox_reg_weights=None,
|
| 198 |
+
# keypoint parameters
|
| 199 |
+
keypoint_roi_pool=None,
|
| 200 |
+
keypoint_head=None,
|
| 201 |
+
keypoint_predictor=None,
|
| 202 |
+
num_keypoints=None,
|
| 203 |
+
**kwargs,
|
| 204 |
+
):
|
| 205 |
+
|
| 206 |
+
if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
|
| 207 |
+
raise TypeError(
|
| 208 |
+
"keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
|
| 209 |
+
)
|
| 210 |
+
if min_size is None:
|
| 211 |
+
min_size = (640, 672, 704, 736, 768, 800)
|
| 212 |
+
|
| 213 |
+
if num_keypoints is not None:
|
| 214 |
+
if keypoint_predictor is not None:
|
| 215 |
+
raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
|
| 216 |
+
else:
|
| 217 |
+
num_keypoints = 17
|
| 218 |
+
|
| 219 |
+
out_channels = backbone.out_channels
|
| 220 |
+
|
| 221 |
+
if keypoint_roi_pool is None:
|
| 222 |
+
keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
|
| 223 |
+
|
| 224 |
+
if keypoint_head is None:
|
| 225 |
+
keypoint_layers = tuple(512 for _ in range(8))
|
| 226 |
+
keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers)
|
| 227 |
+
|
| 228 |
+
if keypoint_predictor is None:
|
| 229 |
+
keypoint_dim_reduced = 512 # == keypoint_layers[-1]
|
| 230 |
+
keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
|
| 231 |
+
|
| 232 |
+
super().__init__(
|
| 233 |
+
backbone,
|
| 234 |
+
num_classes,
|
| 235 |
+
# transform parameters
|
| 236 |
+
min_size,
|
| 237 |
+
max_size,
|
| 238 |
+
image_mean,
|
| 239 |
+
image_std,
|
| 240 |
+
# RPN-specific parameters
|
| 241 |
+
rpn_anchor_generator,
|
| 242 |
+
rpn_head,
|
| 243 |
+
rpn_pre_nms_top_n_train,
|
| 244 |
+
rpn_pre_nms_top_n_test,
|
| 245 |
+
rpn_post_nms_top_n_train,
|
| 246 |
+
rpn_post_nms_top_n_test,
|
| 247 |
+
rpn_nms_thresh,
|
| 248 |
+
rpn_fg_iou_thresh,
|
| 249 |
+
rpn_bg_iou_thresh,
|
| 250 |
+
rpn_batch_size_per_image,
|
| 251 |
+
rpn_positive_fraction,
|
| 252 |
+
rpn_score_thresh,
|
| 253 |
+
# Box parameters
|
| 254 |
+
box_roi_pool,
|
| 255 |
+
box_head,
|
| 256 |
+
box_predictor,
|
| 257 |
+
box_score_thresh,
|
| 258 |
+
box_nms_thresh,
|
| 259 |
+
box_detections_per_img,
|
| 260 |
+
box_fg_iou_thresh,
|
| 261 |
+
box_bg_iou_thresh,
|
| 262 |
+
box_batch_size_per_image,
|
| 263 |
+
box_positive_fraction,
|
| 264 |
+
bbox_reg_weights,
|
| 265 |
+
**kwargs,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
|
| 269 |
+
self.roi_heads.keypoint_head = keypoint_head
|
| 270 |
+
self.roi_heads.keypoint_predictor = keypoint_predictor
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class KeypointRCNNHeads(nn.Sequential):
|
| 274 |
+
def __init__(self, in_channels, layers):
|
| 275 |
+
d = []
|
| 276 |
+
next_feature = in_channels
|
| 277 |
+
for out_channels in layers:
|
| 278 |
+
d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
|
| 279 |
+
d.append(nn.ReLU(inplace=True))
|
| 280 |
+
next_feature = out_channels
|
| 281 |
+
super().__init__(*d)
|
| 282 |
+
for m in self.children():
|
| 283 |
+
if isinstance(m, nn.Conv2d):
|
| 284 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 285 |
+
nn.init.constant_(m.bias, 0)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class KeypointRCNNPredictor(nn.Module):
|
| 289 |
+
def __init__(self, in_channels, num_keypoints):
|
| 290 |
+
super().__init__()
|
| 291 |
+
input_features = in_channels
|
| 292 |
+
deconv_kernel = 4
|
| 293 |
+
self.kps_score_lowres = nn.ConvTranspose2d(
|
| 294 |
+
input_features,
|
| 295 |
+
num_keypoints,
|
| 296 |
+
deconv_kernel,
|
| 297 |
+
stride=2,
|
| 298 |
+
padding=deconv_kernel // 2 - 1,
|
| 299 |
+
)
|
| 300 |
+
nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
|
| 301 |
+
nn.init.constant_(self.kps_score_lowres.bias, 0)
|
| 302 |
+
self.up_scale = 2
|
| 303 |
+
self.out_channels = num_keypoints
|
| 304 |
+
|
| 305 |
+
def forward(self, x):
|
| 306 |
+
x = self.kps_score_lowres(x)
|
| 307 |
+
return torch.nn.functional.interpolate(
|
| 308 |
+
x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
_COMMON_META = {
|
| 313 |
+
"categories": _COCO_PERSON_CATEGORIES,
|
| 314 |
+
"keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
|
| 315 |
+
"min_size": (1, 1),
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
|
| 320 |
+
COCO_LEGACY = Weights(
|
| 321 |
+
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
|
| 322 |
+
transforms=ObjectDetection,
|
| 323 |
+
meta={
|
| 324 |
+
**_COMMON_META,
|
| 325 |
+
"num_params": 59137258,
|
| 326 |
+
"recipe": "https://github.com/pytorch/vision/issues/1606",
|
| 327 |
+
"_metrics": {
|
| 328 |
+
"COCO-val2017": {
|
| 329 |
+
"box_map": 50.6,
|
| 330 |
+
"kp_map": 61.1,
|
| 331 |
+
}
|
| 332 |
+
},
|
| 333 |
+
"_ops": 133.924,
|
| 334 |
+
"_file_size": 226.054,
|
| 335 |
+
"_docs": """
|
| 336 |
+
These weights were produced by following a similar training recipe as on the paper but use a checkpoint
|
| 337 |
+
from an early epoch.
|
| 338 |
+
""",
|
| 339 |
+
},
|
| 340 |
+
)
|
| 341 |
+
COCO_V1 = Weights(
|
| 342 |
+
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
|
| 343 |
+
transforms=ObjectDetection,
|
| 344 |
+
meta={
|
| 345 |
+
**_COMMON_META,
|
| 346 |
+
"num_params": 59137258,
|
| 347 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
|
| 348 |
+
"_metrics": {
|
| 349 |
+
"COCO-val2017": {
|
| 350 |
+
"box_map": 54.6,
|
| 351 |
+
"kp_map": 65.0,
|
| 352 |
+
}
|
| 353 |
+
},
|
| 354 |
+
"_ops": 137.42,
|
| 355 |
+
"_file_size": 226.054,
|
| 356 |
+
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
|
| 357 |
+
},
|
| 358 |
+
)
|
| 359 |
+
DEFAULT = COCO_V1
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
@register_model()
|
| 363 |
+
@handle_legacy_interface(
|
| 364 |
+
weights=(
|
| 365 |
+
"pretrained",
|
| 366 |
+
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
|
| 367 |
+
if kwargs["pretrained"] == "legacy"
|
| 368 |
+
else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
|
| 369 |
+
),
|
| 370 |
+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
| 371 |
+
)
|
| 372 |
+
def keypointrcnn_resnet50_fpn(
|
| 373 |
+
*,
|
| 374 |
+
weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
|
| 375 |
+
progress: bool = True,
|
| 376 |
+
num_classes: Optional[int] = None,
|
| 377 |
+
num_keypoints: Optional[int] = None,
|
| 378 |
+
weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
|
| 379 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 380 |
+
**kwargs: Any,
|
| 381 |
+
) -> KeypointRCNN:
|
| 382 |
+
"""
|
| 383 |
+
Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
|
| 384 |
+
|
| 385 |
+
.. betastatus:: detection module
|
| 386 |
+
|
| 387 |
+
Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
|
| 388 |
+
|
| 389 |
+
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
|
| 390 |
+
image, and should be in ``0-1`` range. Different images can have different sizes.
|
| 391 |
+
|
| 392 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 393 |
+
|
| 394 |
+
During training, the model expects both the input tensors and targets (list of dictionary),
|
| 395 |
+
containing:
|
| 396 |
+
|
| 397 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 398 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 399 |
+
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
|
| 400 |
+
- keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
|
| 401 |
+
format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
|
| 402 |
+
|
| 403 |
+
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
|
| 404 |
+
losses for both the RPN and the R-CNN, and the keypoint loss.
|
| 405 |
+
|
| 406 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 407 |
+
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
|
| 408 |
+
follows, where ``N`` is the number of detected instances:
|
| 409 |
+
|
| 410 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 411 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 412 |
+
- labels (``Int64Tensor[N]``): the predicted labels for each instance
|
| 413 |
+
- scores (``Tensor[N]``): the scores or each instance
|
| 414 |
+
- keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
|
| 415 |
+
|
| 416 |
+
For more details on the output, you may refer to :ref:`instance_seg_output`.
|
| 417 |
+
|
| 418 |
+
Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
|
| 419 |
+
|
| 420 |
+
Example::
|
| 421 |
+
|
| 422 |
+
>>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
|
| 423 |
+
>>> model.eval()
|
| 424 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 425 |
+
>>> predictions = model(x)
|
| 426 |
+
>>>
|
| 427 |
+
>>> # optionally, if you want to export the model to ONNX:
|
| 428 |
+
>>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
|
| 432 |
+
pretrained weights to use. See
|
| 433 |
+
:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
|
| 434 |
+
below for more details, and possible values. By default, no
|
| 435 |
+
pre-trained weights are used.
|
| 436 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 437 |
+
num_classes (int, optional): number of output classes of the model (including the background)
|
| 438 |
+
num_keypoints (int, optional): number of keypoints
|
| 439 |
+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
|
| 440 |
+
pretrained weights for the backbone.
|
| 441 |
+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
|
| 442 |
+
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
|
| 443 |
+
passed (the default) this value is set to 3.
|
| 444 |
+
|
| 445 |
+
.. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
|
| 446 |
+
:members:
|
| 447 |
+
"""
|
| 448 |
+
weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
|
| 449 |
+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
| 450 |
+
|
| 451 |
+
if weights is not None:
|
| 452 |
+
weights_backbone = None
|
| 453 |
+
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
| 454 |
+
num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
|
| 455 |
+
else:
|
| 456 |
+
if num_classes is None:
|
| 457 |
+
num_classes = 2
|
| 458 |
+
if num_keypoints is None:
|
| 459 |
+
num_keypoints = 17
|
| 460 |
+
|
| 461 |
+
is_trained = weights is not None or weights_backbone is not None
|
| 462 |
+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
| 463 |
+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
| 464 |
+
|
| 465 |
+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
| 466 |
+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
| 467 |
+
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
|
| 468 |
+
|
| 469 |
+
if weights is not None:
|
| 470 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 471 |
+
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
|
| 472 |
+
overwrite_eps(model, 0.0)
|
| 473 |
+
|
| 474 |
+
return model
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/mask_rcnn.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Any, Callable, Optional
|
| 3 |
+
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torchvision.ops import MultiScaleRoIAlign
|
| 6 |
+
|
| 7 |
+
from ...ops import misc as misc_nn_ops
|
| 8 |
+
from ...transforms._presets import ObjectDetection
|
| 9 |
+
from .._api import register_model, Weights, WeightsEnum
|
| 10 |
+
from .._meta import _COCO_CATEGORIES
|
| 11 |
+
from .._utils import _ovewrite_value_param, handle_legacy_interface
|
| 12 |
+
from ..resnet import resnet50, ResNet50_Weights
|
| 13 |
+
from ._utils import overwrite_eps
|
| 14 |
+
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
| 15 |
+
from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"MaskRCNN",
|
| 20 |
+
"MaskRCNN_ResNet50_FPN_Weights",
|
| 21 |
+
"MaskRCNN_ResNet50_FPN_V2_Weights",
|
| 22 |
+
"maskrcnn_resnet50_fpn",
|
| 23 |
+
"maskrcnn_resnet50_fpn_v2",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MaskRCNN(FasterRCNN):
|
| 28 |
+
"""
|
| 29 |
+
Implements Mask R-CNN.
|
| 30 |
+
|
| 31 |
+
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
|
| 32 |
+
image, and should be in 0-1 range. Different images can have different sizes.
|
| 33 |
+
|
| 34 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 35 |
+
|
| 36 |
+
During training, the model expects both the input tensors and targets (list of dictionary),
|
| 37 |
+
containing:
|
| 38 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 39 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 40 |
+
- labels (Int64Tensor[N]): the class label for each ground-truth box
|
| 41 |
+
- masks (UInt8Tensor[N, H, W]): the segmentation binary masks for each instance
|
| 42 |
+
|
| 43 |
+
The model returns a Dict[Tensor] during training, containing the classification and regression
|
| 44 |
+
losses for both the RPN and the R-CNN, and the mask loss.
|
| 45 |
+
|
| 46 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 47 |
+
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
|
| 48 |
+
follows:
|
| 49 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 50 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 51 |
+
- labels (Int64Tensor[N]): the predicted labels for each image
|
| 52 |
+
- scores (Tensor[N]): the scores or each prediction
|
| 53 |
+
- masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to
|
| 54 |
+
obtain the final segmentation masks, the soft masks can be thresholded, generally
|
| 55 |
+
with a value of 0.5 (mask >= 0.5)
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
backbone (nn.Module): the network used to compute the features for the model.
|
| 59 |
+
It should contain an out_channels attribute, which indicates the number of output
|
| 60 |
+
channels that each feature map has (and it should be the same for all feature maps).
|
| 61 |
+
The backbone should return a single Tensor or and OrderedDict[Tensor].
|
| 62 |
+
num_classes (int): number of output classes of the model (including the background).
|
| 63 |
+
If box_predictor is specified, num_classes should be None.
|
| 64 |
+
min_size (int): Images are rescaled before feeding them to the backbone:
|
| 65 |
+
we attempt to preserve the aspect ratio and scale the shorter edge
|
| 66 |
+
to ``min_size``. If the resulting longer edge exceeds ``max_size``,
|
| 67 |
+
then downscale so that the longer edge does not exceed ``max_size``.
|
| 68 |
+
This may result in the shorter edge beeing lower than ``min_size``.
|
| 69 |
+
max_size (int): See ``min_size``.
|
| 70 |
+
image_mean (Tuple[float, float, float]): mean values used for input normalization.
|
| 71 |
+
They are generally the mean values of the dataset on which the backbone has been trained
|
| 72 |
+
on
|
| 73 |
+
image_std (Tuple[float, float, float]): std values used for input normalization.
|
| 74 |
+
They are generally the std values of the dataset on which the backbone has been trained on
|
| 75 |
+
rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
|
| 76 |
+
maps.
|
| 77 |
+
rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
|
| 78 |
+
rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
|
| 79 |
+
rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
|
| 80 |
+
rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
|
| 81 |
+
rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
|
| 82 |
+
rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
|
| 83 |
+
rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
|
| 84 |
+
considered as positive during training of the RPN.
|
| 85 |
+
rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
|
| 86 |
+
considered as negative during training of the RPN.
|
| 87 |
+
rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
|
| 88 |
+
for computing the loss
|
| 89 |
+
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
|
| 90 |
+
of the RPN
|
| 91 |
+
rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
|
| 92 |
+
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
|
| 93 |
+
the locations indicated by the bounding boxes
|
| 94 |
+
box_head (nn.Module): module that takes the cropped feature maps as input
|
| 95 |
+
box_predictor (nn.Module): module that takes the output of box_head and returns the
|
| 96 |
+
classification logits and box regression deltas.
|
| 97 |
+
box_score_thresh (float): during inference, only return proposals with a classification score
|
| 98 |
+
greater than box_score_thresh
|
| 99 |
+
box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
|
| 100 |
+
box_detections_per_img (int): maximum number of detections per image, for all classes.
|
| 101 |
+
box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
|
| 102 |
+
considered as positive during training of the classification head
|
| 103 |
+
box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
|
| 104 |
+
considered as negative during training of the classification head
|
| 105 |
+
box_batch_size_per_image (int): number of proposals that are sampled during training of the
|
| 106 |
+
classification head
|
| 107 |
+
box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
|
| 108 |
+
of the classification head
|
| 109 |
+
bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
|
| 110 |
+
bounding boxes
|
| 111 |
+
mask_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
|
| 112 |
+
the locations indicated by the bounding boxes, which will be used for the mask head.
|
| 113 |
+
mask_head (nn.Module): module that takes the cropped feature maps as input
|
| 114 |
+
mask_predictor (nn.Module): module that takes the output of the mask_head and returns the
|
| 115 |
+
segmentation mask logits
|
| 116 |
+
|
| 117 |
+
Example::
|
| 118 |
+
|
| 119 |
+
>>> import torch
|
| 120 |
+
>>> import torchvision
|
| 121 |
+
>>> from torchvision.models.detection import MaskRCNN
|
| 122 |
+
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
|
| 123 |
+
>>>
|
| 124 |
+
>>> # load a pre-trained model for classification and return
|
| 125 |
+
>>> # only the features
|
| 126 |
+
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
|
| 127 |
+
>>> # MaskRCNN needs to know the number of
|
| 128 |
+
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
|
| 129 |
+
>>> # so we need to add it here,
|
| 130 |
+
>>> backbone.out_channels = 1280
|
| 131 |
+
>>>
|
| 132 |
+
>>> # let's make the RPN generate 5 x 3 anchors per spatial
|
| 133 |
+
>>> # location, with 5 different sizes and 3 different aspect
|
| 134 |
+
>>> # ratios. We have a Tuple[Tuple[int]] because each feature
|
| 135 |
+
>>> # map could potentially have different sizes and
|
| 136 |
+
>>> # aspect ratios
|
| 137 |
+
>>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
|
| 138 |
+
>>> aspect_ratios=((0.5, 1.0, 2.0),))
|
| 139 |
+
>>>
|
| 140 |
+
>>> # let's define what are the feature maps that we will
|
| 141 |
+
>>> # use to perform the region of interest cropping, as well as
|
| 142 |
+
>>> # the size of the crop after rescaling.
|
| 143 |
+
>>> # if your backbone returns a Tensor, featmap_names is expected to
|
| 144 |
+
>>> # be ['0']. More generally, the backbone should return an
|
| 145 |
+
>>> # OrderedDict[Tensor], and in featmap_names you can choose which
|
| 146 |
+
>>> # feature maps to use.
|
| 147 |
+
>>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
|
| 148 |
+
>>> output_size=7,
|
| 149 |
+
>>> sampling_ratio=2)
|
| 150 |
+
>>>
|
| 151 |
+
>>> mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
|
| 152 |
+
>>> output_size=14,
|
| 153 |
+
>>> sampling_ratio=2)
|
| 154 |
+
>>> # put the pieces together inside a MaskRCNN model
|
| 155 |
+
>>> model = MaskRCNN(backbone,
|
| 156 |
+
>>> num_classes=2,
|
| 157 |
+
>>> rpn_anchor_generator=anchor_generator,
|
| 158 |
+
>>> box_roi_pool=roi_pooler,
|
| 159 |
+
>>> mask_roi_pool=mask_roi_pooler)
|
| 160 |
+
>>> model.eval()
|
| 161 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 162 |
+
>>> predictions = model(x)
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
backbone,
|
| 168 |
+
num_classes=None,
|
| 169 |
+
# transform parameters
|
| 170 |
+
min_size=800,
|
| 171 |
+
max_size=1333,
|
| 172 |
+
image_mean=None,
|
| 173 |
+
image_std=None,
|
| 174 |
+
# RPN parameters
|
| 175 |
+
rpn_anchor_generator=None,
|
| 176 |
+
rpn_head=None,
|
| 177 |
+
rpn_pre_nms_top_n_train=2000,
|
| 178 |
+
rpn_pre_nms_top_n_test=1000,
|
| 179 |
+
rpn_post_nms_top_n_train=2000,
|
| 180 |
+
rpn_post_nms_top_n_test=1000,
|
| 181 |
+
rpn_nms_thresh=0.7,
|
| 182 |
+
rpn_fg_iou_thresh=0.7,
|
| 183 |
+
rpn_bg_iou_thresh=0.3,
|
| 184 |
+
rpn_batch_size_per_image=256,
|
| 185 |
+
rpn_positive_fraction=0.5,
|
| 186 |
+
rpn_score_thresh=0.0,
|
| 187 |
+
# Box parameters
|
| 188 |
+
box_roi_pool=None,
|
| 189 |
+
box_head=None,
|
| 190 |
+
box_predictor=None,
|
| 191 |
+
box_score_thresh=0.05,
|
| 192 |
+
box_nms_thresh=0.5,
|
| 193 |
+
box_detections_per_img=100,
|
| 194 |
+
box_fg_iou_thresh=0.5,
|
| 195 |
+
box_bg_iou_thresh=0.5,
|
| 196 |
+
box_batch_size_per_image=512,
|
| 197 |
+
box_positive_fraction=0.25,
|
| 198 |
+
bbox_reg_weights=None,
|
| 199 |
+
# Mask parameters
|
| 200 |
+
mask_roi_pool=None,
|
| 201 |
+
mask_head=None,
|
| 202 |
+
mask_predictor=None,
|
| 203 |
+
**kwargs,
|
| 204 |
+
):
|
| 205 |
+
|
| 206 |
+
if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))):
|
| 207 |
+
raise TypeError(
|
| 208 |
+
f"mask_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(mask_roi_pool)}"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if num_classes is not None:
|
| 212 |
+
if mask_predictor is not None:
|
| 213 |
+
raise ValueError("num_classes should be None when mask_predictor is specified")
|
| 214 |
+
|
| 215 |
+
out_channels = backbone.out_channels
|
| 216 |
+
|
| 217 |
+
if mask_roi_pool is None:
|
| 218 |
+
mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
|
| 219 |
+
|
| 220 |
+
if mask_head is None:
|
| 221 |
+
mask_layers = (256, 256, 256, 256)
|
| 222 |
+
mask_dilation = 1
|
| 223 |
+
mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)
|
| 224 |
+
|
| 225 |
+
if mask_predictor is None:
|
| 226 |
+
mask_predictor_in_channels = 256 # == mask_layers[-1]
|
| 227 |
+
mask_dim_reduced = 256
|
| 228 |
+
mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes)
|
| 229 |
+
|
| 230 |
+
super().__init__(
|
| 231 |
+
backbone,
|
| 232 |
+
num_classes,
|
| 233 |
+
# transform parameters
|
| 234 |
+
min_size,
|
| 235 |
+
max_size,
|
| 236 |
+
image_mean,
|
| 237 |
+
image_std,
|
| 238 |
+
# RPN-specific parameters
|
| 239 |
+
rpn_anchor_generator,
|
| 240 |
+
rpn_head,
|
| 241 |
+
rpn_pre_nms_top_n_train,
|
| 242 |
+
rpn_pre_nms_top_n_test,
|
| 243 |
+
rpn_post_nms_top_n_train,
|
| 244 |
+
rpn_post_nms_top_n_test,
|
| 245 |
+
rpn_nms_thresh,
|
| 246 |
+
rpn_fg_iou_thresh,
|
| 247 |
+
rpn_bg_iou_thresh,
|
| 248 |
+
rpn_batch_size_per_image,
|
| 249 |
+
rpn_positive_fraction,
|
| 250 |
+
rpn_score_thresh,
|
| 251 |
+
# Box parameters
|
| 252 |
+
box_roi_pool,
|
| 253 |
+
box_head,
|
| 254 |
+
box_predictor,
|
| 255 |
+
box_score_thresh,
|
| 256 |
+
box_nms_thresh,
|
| 257 |
+
box_detections_per_img,
|
| 258 |
+
box_fg_iou_thresh,
|
| 259 |
+
box_bg_iou_thresh,
|
| 260 |
+
box_batch_size_per_image,
|
| 261 |
+
box_positive_fraction,
|
| 262 |
+
bbox_reg_weights,
|
| 263 |
+
**kwargs,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
self.roi_heads.mask_roi_pool = mask_roi_pool
|
| 267 |
+
self.roi_heads.mask_head = mask_head
|
| 268 |
+
self.roi_heads.mask_predictor = mask_predictor
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class MaskRCNNHeads(nn.Sequential):
|
| 272 |
+
_version = 2
|
| 273 |
+
|
| 274 |
+
def __init__(self, in_channels, layers, dilation, norm_layer: Optional[Callable[..., nn.Module]] = None):
|
| 275 |
+
"""
|
| 276 |
+
Args:
|
| 277 |
+
in_channels (int): number of input channels
|
| 278 |
+
layers (list): feature dimensions of each FCN layer
|
| 279 |
+
dilation (int): dilation rate of kernel
|
| 280 |
+
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
|
| 281 |
+
"""
|
| 282 |
+
blocks = []
|
| 283 |
+
next_feature = in_channels
|
| 284 |
+
for layer_features in layers:
|
| 285 |
+
blocks.append(
|
| 286 |
+
misc_nn_ops.Conv2dNormActivation(
|
| 287 |
+
next_feature,
|
| 288 |
+
layer_features,
|
| 289 |
+
kernel_size=3,
|
| 290 |
+
stride=1,
|
| 291 |
+
padding=dilation,
|
| 292 |
+
dilation=dilation,
|
| 293 |
+
norm_layer=norm_layer,
|
| 294 |
+
)
|
| 295 |
+
)
|
| 296 |
+
next_feature = layer_features
|
| 297 |
+
|
| 298 |
+
super().__init__(*blocks)
|
| 299 |
+
for layer in self.modules():
|
| 300 |
+
if isinstance(layer, nn.Conv2d):
|
| 301 |
+
nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
|
| 302 |
+
if layer.bias is not None:
|
| 303 |
+
nn.init.zeros_(layer.bias)
|
| 304 |
+
|
| 305 |
+
def _load_from_state_dict(
|
| 306 |
+
self,
|
| 307 |
+
state_dict,
|
| 308 |
+
prefix,
|
| 309 |
+
local_metadata,
|
| 310 |
+
strict,
|
| 311 |
+
missing_keys,
|
| 312 |
+
unexpected_keys,
|
| 313 |
+
error_msgs,
|
| 314 |
+
):
|
| 315 |
+
version = local_metadata.get("version", None)
|
| 316 |
+
|
| 317 |
+
if version is None or version < 2:
|
| 318 |
+
num_blocks = len(self)
|
| 319 |
+
for i in range(num_blocks):
|
| 320 |
+
for type in ["weight", "bias"]:
|
| 321 |
+
old_key = f"{prefix}mask_fcn{i+1}.{type}"
|
| 322 |
+
new_key = f"{prefix}{i}.0.{type}"
|
| 323 |
+
if old_key in state_dict:
|
| 324 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 325 |
+
|
| 326 |
+
super()._load_from_state_dict(
|
| 327 |
+
state_dict,
|
| 328 |
+
prefix,
|
| 329 |
+
local_metadata,
|
| 330 |
+
strict,
|
| 331 |
+
missing_keys,
|
| 332 |
+
unexpected_keys,
|
| 333 |
+
error_msgs,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class MaskRCNNPredictor(nn.Sequential):
|
| 338 |
+
def __init__(self, in_channels, dim_reduced, num_classes):
|
| 339 |
+
super().__init__(
|
| 340 |
+
OrderedDict(
|
| 341 |
+
[
|
| 342 |
+
("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
|
| 343 |
+
("relu", nn.ReLU(inplace=True)),
|
| 344 |
+
("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
|
| 345 |
+
]
|
| 346 |
+
)
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
for name, param in self.named_parameters():
|
| 350 |
+
if "weight" in name:
|
| 351 |
+
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
|
| 352 |
+
# elif "bias" in name:
|
| 353 |
+
# nn.init.constant_(param, 0)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
_COMMON_META = {
|
| 357 |
+
"categories": _COCO_CATEGORIES,
|
| 358 |
+
"min_size": (1, 1),
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
|
| 363 |
+
COCO_V1 = Weights(
|
| 364 |
+
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
|
| 365 |
+
transforms=ObjectDetection,
|
| 366 |
+
meta={
|
| 367 |
+
**_COMMON_META,
|
| 368 |
+
"num_params": 44401393,
|
| 369 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
|
| 370 |
+
"_metrics": {
|
| 371 |
+
"COCO-val2017": {
|
| 372 |
+
"box_map": 37.9,
|
| 373 |
+
"mask_map": 34.6,
|
| 374 |
+
}
|
| 375 |
+
},
|
| 376 |
+
"_ops": 134.38,
|
| 377 |
+
"_file_size": 169.84,
|
| 378 |
+
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
|
| 379 |
+
},
|
| 380 |
+
)
|
| 381 |
+
DEFAULT = COCO_V1
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
|
| 385 |
+
COCO_V1 = Weights(
|
| 386 |
+
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth",
|
| 387 |
+
transforms=ObjectDetection,
|
| 388 |
+
meta={
|
| 389 |
+
**_COMMON_META,
|
| 390 |
+
"num_params": 46359409,
|
| 391 |
+
"recipe": "https://github.com/pytorch/vision/pull/5773",
|
| 392 |
+
"_metrics": {
|
| 393 |
+
"COCO-val2017": {
|
| 394 |
+
"box_map": 47.4,
|
| 395 |
+
"mask_map": 41.8,
|
| 396 |
+
}
|
| 397 |
+
},
|
| 398 |
+
"_ops": 333.577,
|
| 399 |
+
"_file_size": 177.219,
|
| 400 |
+
"_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
|
| 401 |
+
},
|
| 402 |
+
)
|
| 403 |
+
DEFAULT = COCO_V1
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
@register_model()
|
| 407 |
+
@handle_legacy_interface(
|
| 408 |
+
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
|
| 409 |
+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
| 410 |
+
)
|
| 411 |
+
def maskrcnn_resnet50_fpn(
|
| 412 |
+
*,
|
| 413 |
+
weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None,
|
| 414 |
+
progress: bool = True,
|
| 415 |
+
num_classes: Optional[int] = None,
|
| 416 |
+
weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
|
| 417 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 418 |
+
**kwargs: Any,
|
| 419 |
+
) -> MaskRCNN:
|
| 420 |
+
"""Mask R-CNN model with a ResNet-50-FPN backbone from the `Mask R-CNN
|
| 421 |
+
<https://arxiv.org/abs/1703.06870>`_ paper.
|
| 422 |
+
|
| 423 |
+
.. betastatus:: detection module
|
| 424 |
+
|
| 425 |
+
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
|
| 426 |
+
image, and should be in ``0-1`` range. Different images can have different sizes.
|
| 427 |
+
|
| 428 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 429 |
+
|
| 430 |
+
During training, the model expects both the input tensors and targets (list of dictionary),
|
| 431 |
+
containing:
|
| 432 |
+
|
| 433 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 434 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 435 |
+
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
|
| 436 |
+
- masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance
|
| 437 |
+
|
| 438 |
+
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
|
| 439 |
+
losses for both the RPN and the R-CNN, and the mask loss.
|
| 440 |
+
|
| 441 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 442 |
+
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
|
| 443 |
+
follows, where ``N`` is the number of detected instances:
|
| 444 |
+
|
| 445 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 446 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 447 |
+
- labels (``Int64Tensor[N]``): the predicted labels for each instance
|
| 448 |
+
- scores (``Tensor[N]``): the scores or each instance
|
| 449 |
+
- masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
|
| 450 |
+
obtain the final segmentation masks, the soft masks can be thresholded, generally
|
| 451 |
+
with a value of 0.5 (``mask >= 0.5``)
|
| 452 |
+
|
| 453 |
+
For more details on the output and on how to plot the masks, you may refer to :ref:`instance_seg_output`.
|
| 454 |
+
|
| 455 |
+
Mask R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
|
| 456 |
+
|
| 457 |
+
Example::
|
| 458 |
+
|
| 459 |
+
>>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
|
| 460 |
+
>>> model.eval()
|
| 461 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 462 |
+
>>> predictions = model(x)
|
| 463 |
+
>>>
|
| 464 |
+
>>> # optionally, if you want to export the model to ONNX:
|
| 465 |
+
>>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11)
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights`, optional): The
|
| 469 |
+
pretrained weights to use. See
|
| 470 |
+
:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights` below for
|
| 471 |
+
more details, and possible values. By default, no pre-trained
|
| 472 |
+
weights are used.
|
| 473 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 474 |
+
download to stderr. Default is True.
|
| 475 |
+
num_classes (int, optional): number of output classes of the model (including the background)
|
| 476 |
+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
|
| 477 |
+
pretrained weights for the backbone.
|
| 478 |
+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
|
| 479 |
+
final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
|
| 480 |
+
trainable. If ``None`` is passed (the default) this value is set to 3.
|
| 481 |
+
**kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
|
| 482 |
+
base class. Please refer to the `source code
|
| 483 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_
|
| 484 |
+
for more details about this class.
|
| 485 |
+
|
| 486 |
+
.. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights
|
| 487 |
+
:members:
|
| 488 |
+
"""
|
| 489 |
+
weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights)
|
| 490 |
+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
| 491 |
+
|
| 492 |
+
if weights is not None:
|
| 493 |
+
weights_backbone = None
|
| 494 |
+
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
| 495 |
+
elif num_classes is None:
|
| 496 |
+
num_classes = 91
|
| 497 |
+
|
| 498 |
+
is_trained = weights is not None or weights_backbone is not None
|
| 499 |
+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
| 500 |
+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
| 501 |
+
|
| 502 |
+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
| 503 |
+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
| 504 |
+
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
|
| 505 |
+
|
| 506 |
+
if weights is not None:
|
| 507 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 508 |
+
if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
|
| 509 |
+
overwrite_eps(model, 0.0)
|
| 510 |
+
|
| 511 |
+
return model
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
@register_model()
|
| 515 |
+
@handle_legacy_interface(
|
| 516 |
+
weights=("pretrained", MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
|
| 517 |
+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
| 518 |
+
)
|
| 519 |
+
def maskrcnn_resnet50_fpn_v2(
|
| 520 |
+
*,
|
| 521 |
+
weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None,
|
| 522 |
+
progress: bool = True,
|
| 523 |
+
num_classes: Optional[int] = None,
|
| 524 |
+
weights_backbone: Optional[ResNet50_Weights] = None,
|
| 525 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 526 |
+
**kwargs: Any,
|
| 527 |
+
) -> MaskRCNN:
|
| 528 |
+
"""Improved Mask R-CNN model with a ResNet-50-FPN backbone from the `Benchmarking Detection Transfer
|
| 529 |
+
Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`_ paper.
|
| 530 |
+
|
| 531 |
+
.. betastatus:: detection module
|
| 532 |
+
|
| 533 |
+
:func:`~torchvision.models.detection.maskrcnn_resnet50_fpn` for more details.
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights`, optional): The
|
| 537 |
+
pretrained weights to use. See
|
| 538 |
+
:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights` below for
|
| 539 |
+
more details, and possible values. By default, no pre-trained
|
| 540 |
+
weights are used.
|
| 541 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 542 |
+
download to stderr. Default is True.
|
| 543 |
+
num_classes (int, optional): number of output classes of the model (including the background)
|
| 544 |
+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
|
| 545 |
+
pretrained weights for the backbone.
|
| 546 |
+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
|
| 547 |
+
final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
|
| 548 |
+
trainable. If ``None`` is passed (the default) this value is set to 3.
|
| 549 |
+
**kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
|
| 550 |
+
base class. Please refer to the `source code
|
| 551 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_
|
| 552 |
+
for more details about this class.
|
| 553 |
+
|
| 554 |
+
.. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights
|
| 555 |
+
:members:
|
| 556 |
+
"""
|
| 557 |
+
weights = MaskRCNN_ResNet50_FPN_V2_Weights.verify(weights)
|
| 558 |
+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
| 559 |
+
|
| 560 |
+
if weights is not None:
|
| 561 |
+
weights_backbone = None
|
| 562 |
+
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
| 563 |
+
elif num_classes is None:
|
| 564 |
+
num_classes = 91
|
| 565 |
+
|
| 566 |
+
is_trained = weights is not None or weights_backbone is not None
|
| 567 |
+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
| 568 |
+
|
| 569 |
+
backbone = resnet50(weights=weights_backbone, progress=progress)
|
| 570 |
+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
|
| 571 |
+
rpn_anchor_generator = _default_anchorgen()
|
| 572 |
+
rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
|
| 573 |
+
box_head = FastRCNNConvFCHead(
|
| 574 |
+
(backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
|
| 575 |
+
)
|
| 576 |
+
mask_head = MaskRCNNHeads(backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)
|
| 577 |
+
model = MaskRCNN(
|
| 578 |
+
backbone,
|
| 579 |
+
num_classes=num_classes,
|
| 580 |
+
rpn_anchor_generator=rpn_anchor_generator,
|
| 581 |
+
rpn_head=rpn_head,
|
| 582 |
+
box_head=box_head,
|
| 583 |
+
mask_head=mask_head,
|
| 584 |
+
**kwargs,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
if weights is not None:
|
| 588 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 589 |
+
|
| 590 |
+
return model
|
.venv/lib/python3.11/site-packages/torchvision/models/detection/retinanet.py
ADDED
|
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import warnings
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn, Tensor
|
| 9 |
+
|
| 10 |
+
from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss
|
| 11 |
+
from ...ops.feature_pyramid_network import LastLevelP6P7
|
| 12 |
+
from ...transforms._presets import ObjectDetection
|
| 13 |
+
from ...utils import _log_api_usage_once
|
| 14 |
+
from .._api import register_model, Weights, WeightsEnum
|
| 15 |
+
from .._meta import _COCO_CATEGORIES
|
| 16 |
+
from .._utils import _ovewrite_value_param, handle_legacy_interface
|
| 17 |
+
from ..resnet import resnet50, ResNet50_Weights
|
| 18 |
+
from . import _utils as det_utils
|
| 19 |
+
from ._utils import _box_loss, overwrite_eps
|
| 20 |
+
from .anchor_utils import AnchorGenerator
|
| 21 |
+
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
| 22 |
+
from .transform import GeneralizedRCNNTransform
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
"RetinaNet",
|
| 27 |
+
"RetinaNet_ResNet50_FPN_Weights",
|
| 28 |
+
"RetinaNet_ResNet50_FPN_V2_Weights",
|
| 29 |
+
"retinanet_resnet50_fpn",
|
| 30 |
+
"retinanet_resnet50_fpn_v2",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _sum(x: List[Tensor]) -> Tensor:
|
| 35 |
+
res = x[0]
|
| 36 |
+
for i in x[1:]:
|
| 37 |
+
res = res + i
|
| 38 |
+
return res
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _v1_to_v2_weights(state_dict, prefix):
|
| 42 |
+
for i in range(4):
|
| 43 |
+
for type in ["weight", "bias"]:
|
| 44 |
+
old_key = f"{prefix}conv.{2*i}.{type}"
|
| 45 |
+
new_key = f"{prefix}conv.{i}.0.{type}"
|
| 46 |
+
if old_key in state_dict:
|
| 47 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _default_anchorgen():
|
| 51 |
+
anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
|
| 52 |
+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
|
| 53 |
+
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
|
| 54 |
+
return anchor_generator
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RetinaNetHead(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
A regression and classification head for use in RetinaNet.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
in_channels (int): number of channels of the input feature
|
| 63 |
+
num_anchors (int): number of anchors to be predicted
|
| 64 |
+
num_classes (int): number of classes to be predicted
|
| 65 |
+
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.classification_head = RetinaNetClassificationHead(
|
| 71 |
+
in_channels, num_anchors, num_classes, norm_layer=norm_layer
|
| 72 |
+
)
|
| 73 |
+
self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
|
| 74 |
+
|
| 75 |
+
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
|
| 76 |
+
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
|
| 77 |
+
return {
|
| 78 |
+
"classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
|
| 79 |
+
"bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
# type: (List[Tensor]) -> Dict[str, Tensor]
|
| 84 |
+
return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class RetinaNetClassificationHead(nn.Module):
|
| 88 |
+
"""
|
| 89 |
+
A classification head for use in RetinaNet.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
in_channels (int): number of channels of the input feature
|
| 93 |
+
num_anchors (int): number of anchors to be predicted
|
| 94 |
+
num_classes (int): number of classes to be predicted
|
| 95 |
+
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
_version = 2
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
in_channels,
|
| 103 |
+
num_anchors,
|
| 104 |
+
num_classes,
|
| 105 |
+
prior_probability=0.01,
|
| 106 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
|
| 110 |
+
conv = []
|
| 111 |
+
for _ in range(4):
|
| 112 |
+
conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
|
| 113 |
+
self.conv = nn.Sequential(*conv)
|
| 114 |
+
|
| 115 |
+
for layer in self.conv.modules():
|
| 116 |
+
if isinstance(layer, nn.Conv2d):
|
| 117 |
+
torch.nn.init.normal_(layer.weight, std=0.01)
|
| 118 |
+
if layer.bias is not None:
|
| 119 |
+
torch.nn.init.constant_(layer.bias, 0)
|
| 120 |
+
|
| 121 |
+
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
|
| 122 |
+
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
|
| 123 |
+
torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
|
| 124 |
+
|
| 125 |
+
self.num_classes = num_classes
|
| 126 |
+
self.num_anchors = num_anchors
|
| 127 |
+
|
| 128 |
+
# This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
|
| 129 |
+
# TorchScript doesn't support class attributes.
|
| 130 |
+
# https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
|
| 131 |
+
self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
|
| 132 |
+
|
| 133 |
+
def _load_from_state_dict(
|
| 134 |
+
self,
|
| 135 |
+
state_dict,
|
| 136 |
+
prefix,
|
| 137 |
+
local_metadata,
|
| 138 |
+
strict,
|
| 139 |
+
missing_keys,
|
| 140 |
+
unexpected_keys,
|
| 141 |
+
error_msgs,
|
| 142 |
+
):
|
| 143 |
+
version = local_metadata.get("version", None)
|
| 144 |
+
|
| 145 |
+
if version is None or version < 2:
|
| 146 |
+
_v1_to_v2_weights(state_dict, prefix)
|
| 147 |
+
|
| 148 |
+
super()._load_from_state_dict(
|
| 149 |
+
state_dict,
|
| 150 |
+
prefix,
|
| 151 |
+
local_metadata,
|
| 152 |
+
strict,
|
| 153 |
+
missing_keys,
|
| 154 |
+
unexpected_keys,
|
| 155 |
+
error_msgs,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def compute_loss(self, targets, head_outputs, matched_idxs):
|
| 159 |
+
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
|
| 160 |
+
losses = []
|
| 161 |
+
|
| 162 |
+
cls_logits = head_outputs["cls_logits"]
|
| 163 |
+
|
| 164 |
+
for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
|
| 165 |
+
# determine only the foreground
|
| 166 |
+
foreground_idxs_per_image = matched_idxs_per_image >= 0
|
| 167 |
+
num_foreground = foreground_idxs_per_image.sum()
|
| 168 |
+
|
| 169 |
+
# create the target classification
|
| 170 |
+
gt_classes_target = torch.zeros_like(cls_logits_per_image)
|
| 171 |
+
gt_classes_target[
|
| 172 |
+
foreground_idxs_per_image,
|
| 173 |
+
targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
|
| 174 |
+
] = 1.0
|
| 175 |
+
|
| 176 |
+
# find indices for which anchors should be ignored
|
| 177 |
+
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
|
| 178 |
+
|
| 179 |
+
# compute the classification loss
|
| 180 |
+
losses.append(
|
| 181 |
+
sigmoid_focal_loss(
|
| 182 |
+
cls_logits_per_image[valid_idxs_per_image],
|
| 183 |
+
gt_classes_target[valid_idxs_per_image],
|
| 184 |
+
reduction="sum",
|
| 185 |
+
)
|
| 186 |
+
/ max(1, num_foreground)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return _sum(losses) / len(targets)
|
| 190 |
+
|
| 191 |
+
def forward(self, x):
|
| 192 |
+
# type: (List[Tensor]) -> Tensor
|
| 193 |
+
all_cls_logits = []
|
| 194 |
+
|
| 195 |
+
for features in x:
|
| 196 |
+
cls_logits = self.conv(features)
|
| 197 |
+
cls_logits = self.cls_logits(cls_logits)
|
| 198 |
+
|
| 199 |
+
# Permute classification output from (N, A * K, H, W) to (N, HWA, K).
|
| 200 |
+
N, _, H, W = cls_logits.shape
|
| 201 |
+
cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
|
| 202 |
+
cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
|
| 203 |
+
cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
|
| 204 |
+
|
| 205 |
+
all_cls_logits.append(cls_logits)
|
| 206 |
+
|
| 207 |
+
return torch.cat(all_cls_logits, dim=1)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class RetinaNetRegressionHead(nn.Module):
|
| 211 |
+
"""
|
| 212 |
+
A regression head for use in RetinaNet.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
in_channels (int): number of channels of the input feature
|
| 216 |
+
num_anchors (int): number of anchors to be predicted
|
| 217 |
+
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
_version = 2
|
| 221 |
+
|
| 222 |
+
__annotations__ = {
|
| 223 |
+
"box_coder": det_utils.BoxCoder,
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
|
| 227 |
+
super().__init__()
|
| 228 |
+
|
| 229 |
+
conv = []
|
| 230 |
+
for _ in range(4):
|
| 231 |
+
conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
|
| 232 |
+
self.conv = nn.Sequential(*conv)
|
| 233 |
+
|
| 234 |
+
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
|
| 235 |
+
torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
|
| 236 |
+
torch.nn.init.zeros_(self.bbox_reg.bias)
|
| 237 |
+
|
| 238 |
+
for layer in self.conv.modules():
|
| 239 |
+
if isinstance(layer, nn.Conv2d):
|
| 240 |
+
torch.nn.init.normal_(layer.weight, std=0.01)
|
| 241 |
+
if layer.bias is not None:
|
| 242 |
+
torch.nn.init.zeros_(layer.bias)
|
| 243 |
+
|
| 244 |
+
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
| 245 |
+
self._loss_type = "l1"
|
| 246 |
+
|
| 247 |
+
def _load_from_state_dict(
|
| 248 |
+
self,
|
| 249 |
+
state_dict,
|
| 250 |
+
prefix,
|
| 251 |
+
local_metadata,
|
| 252 |
+
strict,
|
| 253 |
+
missing_keys,
|
| 254 |
+
unexpected_keys,
|
| 255 |
+
error_msgs,
|
| 256 |
+
):
|
| 257 |
+
version = local_metadata.get("version", None)
|
| 258 |
+
|
| 259 |
+
if version is None or version < 2:
|
| 260 |
+
_v1_to_v2_weights(state_dict, prefix)
|
| 261 |
+
|
| 262 |
+
super()._load_from_state_dict(
|
| 263 |
+
state_dict,
|
| 264 |
+
prefix,
|
| 265 |
+
local_metadata,
|
| 266 |
+
strict,
|
| 267 |
+
missing_keys,
|
| 268 |
+
unexpected_keys,
|
| 269 |
+
error_msgs,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
|
| 273 |
+
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
|
| 274 |
+
losses = []
|
| 275 |
+
|
| 276 |
+
bbox_regression = head_outputs["bbox_regression"]
|
| 277 |
+
|
| 278 |
+
for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
|
| 279 |
+
targets, bbox_regression, anchors, matched_idxs
|
| 280 |
+
):
|
| 281 |
+
# determine only the foreground indices, ignore the rest
|
| 282 |
+
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
|
| 283 |
+
num_foreground = foreground_idxs_per_image.numel()
|
| 284 |
+
|
| 285 |
+
# select only the foreground boxes
|
| 286 |
+
matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
|
| 287 |
+
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
|
| 288 |
+
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
|
| 289 |
+
|
| 290 |
+
# compute the loss
|
| 291 |
+
losses.append(
|
| 292 |
+
_box_loss(
|
| 293 |
+
self._loss_type,
|
| 294 |
+
self.box_coder,
|
| 295 |
+
anchors_per_image,
|
| 296 |
+
matched_gt_boxes_per_image,
|
| 297 |
+
bbox_regression_per_image,
|
| 298 |
+
)
|
| 299 |
+
/ max(1, num_foreground)
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
return _sum(losses) / max(1, len(targets))
|
| 303 |
+
|
| 304 |
+
def forward(self, x):
|
| 305 |
+
# type: (List[Tensor]) -> Tensor
|
| 306 |
+
all_bbox_regression = []
|
| 307 |
+
|
| 308 |
+
for features in x:
|
| 309 |
+
bbox_regression = self.conv(features)
|
| 310 |
+
bbox_regression = self.bbox_reg(bbox_regression)
|
| 311 |
+
|
| 312 |
+
# Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
|
| 313 |
+
N, _, H, W = bbox_regression.shape
|
| 314 |
+
bbox_regression = bbox_regression.view(N, -1, 4, H, W)
|
| 315 |
+
bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
|
| 316 |
+
bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
|
| 317 |
+
|
| 318 |
+
all_bbox_regression.append(bbox_regression)
|
| 319 |
+
|
| 320 |
+
return torch.cat(all_bbox_regression, dim=1)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class RetinaNet(nn.Module):
|
| 324 |
+
"""
|
| 325 |
+
Implements RetinaNet.
|
| 326 |
+
|
| 327 |
+
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
|
| 328 |
+
image, and should be in 0-1 range. Different images can have different sizes.
|
| 329 |
+
|
| 330 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 331 |
+
|
| 332 |
+
During training, the model expects both the input tensors and targets (list of dictionary),
|
| 333 |
+
containing:
|
| 334 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 335 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 336 |
+
- labels (Int64Tensor[N]): the class label for each ground-truth box
|
| 337 |
+
|
| 338 |
+
The model returns a Dict[Tensor] during training, containing the classification and regression
|
| 339 |
+
losses.
|
| 340 |
+
|
| 341 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 342 |
+
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
|
| 343 |
+
follows:
|
| 344 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 345 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 346 |
+
- labels (Int64Tensor[N]): the predicted labels for each image
|
| 347 |
+
- scores (Tensor[N]): the scores for each prediction
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
backbone (nn.Module): the network used to compute the features for the model.
|
| 351 |
+
It should contain an out_channels attribute, which indicates the number of output
|
| 352 |
+
channels that each feature map has (and it should be the same for all feature maps).
|
| 353 |
+
The backbone should return a single Tensor or an OrderedDict[Tensor].
|
| 354 |
+
num_classes (int): number of output classes of the model (including the background).
|
| 355 |
+
min_size (int): Images are rescaled before feeding them to the backbone:
|
| 356 |
+
we attempt to preserve the aspect ratio and scale the shorter edge
|
| 357 |
+
to ``min_size``. If the resulting longer edge exceeds ``max_size``,
|
| 358 |
+
then downscale so that the longer edge does not exceed ``max_size``.
|
| 359 |
+
This may result in the shorter edge beeing lower than ``min_size``.
|
| 360 |
+
max_size (int): See ``min_size``.
|
| 361 |
+
image_mean (Tuple[float, float, float]): mean values used for input normalization.
|
| 362 |
+
They are generally the mean values of the dataset on which the backbone has been trained
|
| 363 |
+
on
|
| 364 |
+
image_std (Tuple[float, float, float]): std values used for input normalization.
|
| 365 |
+
They are generally the std values of the dataset on which the backbone has been trained on
|
| 366 |
+
anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
|
| 367 |
+
maps.
|
| 368 |
+
head (nn.Module): Module run on top of the feature pyramid.
|
| 369 |
+
Defaults to a module containing a classification and regression module.
|
| 370 |
+
score_thresh (float): Score threshold used for postprocessing the detections.
|
| 371 |
+
nms_thresh (float): NMS threshold used for postprocessing the detections.
|
| 372 |
+
detections_per_img (int): Number of best detections to keep after NMS.
|
| 373 |
+
fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
|
| 374 |
+
considered as positive during training.
|
| 375 |
+
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
|
| 376 |
+
considered as negative during training.
|
| 377 |
+
topk_candidates (int): Number of best detections to keep before NMS.
|
| 378 |
+
|
| 379 |
+
Example:
|
| 380 |
+
|
| 381 |
+
>>> import torch
|
| 382 |
+
>>> import torchvision
|
| 383 |
+
>>> from torchvision.models.detection import RetinaNet
|
| 384 |
+
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
|
| 385 |
+
>>> # load a pre-trained model for classification and return
|
| 386 |
+
>>> # only the features
|
| 387 |
+
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
|
| 388 |
+
>>> # RetinaNet needs to know the number of
|
| 389 |
+
>>> # output channels in a backbone. For mobilenet_v2, it's 1280,
|
| 390 |
+
>>> # so we need to add it here
|
| 391 |
+
>>> backbone.out_channels = 1280
|
| 392 |
+
>>>
|
| 393 |
+
>>> # let's make the network generate 5 x 3 anchors per spatial
|
| 394 |
+
>>> # location, with 5 different sizes and 3 different aspect
|
| 395 |
+
>>> # ratios. We have a Tuple[Tuple[int]] because each feature
|
| 396 |
+
>>> # map could potentially have different sizes and
|
| 397 |
+
>>> # aspect ratios
|
| 398 |
+
>>> anchor_generator = AnchorGenerator(
|
| 399 |
+
>>> sizes=((32, 64, 128, 256, 512),),
|
| 400 |
+
>>> aspect_ratios=((0.5, 1.0, 2.0),)
|
| 401 |
+
>>> )
|
| 402 |
+
>>>
|
| 403 |
+
>>> # put the pieces together inside a RetinaNet model
|
| 404 |
+
>>> model = RetinaNet(backbone,
|
| 405 |
+
>>> num_classes=2,
|
| 406 |
+
>>> anchor_generator=anchor_generator)
|
| 407 |
+
>>> model.eval()
|
| 408 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 409 |
+
>>> predictions = model(x)
|
| 410 |
+
"""
|
| 411 |
+
|
| 412 |
+
__annotations__ = {
|
| 413 |
+
"box_coder": det_utils.BoxCoder,
|
| 414 |
+
"proposal_matcher": det_utils.Matcher,
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
def __init__(
|
| 418 |
+
self,
|
| 419 |
+
backbone,
|
| 420 |
+
num_classes,
|
| 421 |
+
# transform parameters
|
| 422 |
+
min_size=800,
|
| 423 |
+
max_size=1333,
|
| 424 |
+
image_mean=None,
|
| 425 |
+
image_std=None,
|
| 426 |
+
# Anchor parameters
|
| 427 |
+
anchor_generator=None,
|
| 428 |
+
head=None,
|
| 429 |
+
proposal_matcher=None,
|
| 430 |
+
score_thresh=0.05,
|
| 431 |
+
nms_thresh=0.5,
|
| 432 |
+
detections_per_img=300,
|
| 433 |
+
fg_iou_thresh=0.5,
|
| 434 |
+
bg_iou_thresh=0.4,
|
| 435 |
+
topk_candidates=1000,
|
| 436 |
+
**kwargs,
|
| 437 |
+
):
|
| 438 |
+
super().__init__()
|
| 439 |
+
_log_api_usage_once(self)
|
| 440 |
+
|
| 441 |
+
if not hasattr(backbone, "out_channels"):
|
| 442 |
+
raise ValueError(
|
| 443 |
+
"backbone should contain an attribute out_channels "
|
| 444 |
+
"specifying the number of output channels (assumed to be the "
|
| 445 |
+
"same for all the levels)"
|
| 446 |
+
)
|
| 447 |
+
self.backbone = backbone
|
| 448 |
+
|
| 449 |
+
if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
|
| 450 |
+
raise TypeError(
|
| 451 |
+
f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}"
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
if anchor_generator is None:
|
| 455 |
+
anchor_generator = _default_anchorgen()
|
| 456 |
+
self.anchor_generator = anchor_generator
|
| 457 |
+
|
| 458 |
+
if head is None:
|
| 459 |
+
head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
|
| 460 |
+
self.head = head
|
| 461 |
+
|
| 462 |
+
if proposal_matcher is None:
|
| 463 |
+
proposal_matcher = det_utils.Matcher(
|
| 464 |
+
fg_iou_thresh,
|
| 465 |
+
bg_iou_thresh,
|
| 466 |
+
allow_low_quality_matches=True,
|
| 467 |
+
)
|
| 468 |
+
self.proposal_matcher = proposal_matcher
|
| 469 |
+
|
| 470 |
+
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
| 471 |
+
|
| 472 |
+
if image_mean is None:
|
| 473 |
+
image_mean = [0.485, 0.456, 0.406]
|
| 474 |
+
if image_std is None:
|
| 475 |
+
image_std = [0.229, 0.224, 0.225]
|
| 476 |
+
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
|
| 477 |
+
|
| 478 |
+
self.score_thresh = score_thresh
|
| 479 |
+
self.nms_thresh = nms_thresh
|
| 480 |
+
self.detections_per_img = detections_per_img
|
| 481 |
+
self.topk_candidates = topk_candidates
|
| 482 |
+
|
| 483 |
+
# used only on torchscript mode
|
| 484 |
+
self._has_warned = False
|
| 485 |
+
|
| 486 |
+
@torch.jit.unused
|
| 487 |
+
def eager_outputs(self, losses, detections):
|
| 488 |
+
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
|
| 489 |
+
if self.training:
|
| 490 |
+
return losses
|
| 491 |
+
|
| 492 |
+
return detections
|
| 493 |
+
|
| 494 |
+
def compute_loss(self, targets, head_outputs, anchors):
|
| 495 |
+
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
|
| 496 |
+
matched_idxs = []
|
| 497 |
+
for anchors_per_image, targets_per_image in zip(anchors, targets):
|
| 498 |
+
if targets_per_image["boxes"].numel() == 0:
|
| 499 |
+
matched_idxs.append(
|
| 500 |
+
torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
|
| 501 |
+
)
|
| 502 |
+
continue
|
| 503 |
+
|
| 504 |
+
match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
|
| 505 |
+
matched_idxs.append(self.proposal_matcher(match_quality_matrix))
|
| 506 |
+
|
| 507 |
+
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
|
| 508 |
+
|
| 509 |
+
def postprocess_detections(self, head_outputs, anchors, image_shapes):
|
| 510 |
+
# type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
|
| 511 |
+
class_logits = head_outputs["cls_logits"]
|
| 512 |
+
box_regression = head_outputs["bbox_regression"]
|
| 513 |
+
|
| 514 |
+
num_images = len(image_shapes)
|
| 515 |
+
|
| 516 |
+
detections: List[Dict[str, Tensor]] = []
|
| 517 |
+
|
| 518 |
+
for index in range(num_images):
|
| 519 |
+
box_regression_per_image = [br[index] for br in box_regression]
|
| 520 |
+
logits_per_image = [cl[index] for cl in class_logits]
|
| 521 |
+
anchors_per_image, image_shape = anchors[index], image_shapes[index]
|
| 522 |
+
|
| 523 |
+
image_boxes = []
|
| 524 |
+
image_scores = []
|
| 525 |
+
image_labels = []
|
| 526 |
+
|
| 527 |
+
for box_regression_per_level, logits_per_level, anchors_per_level in zip(
|
| 528 |
+
box_regression_per_image, logits_per_image, anchors_per_image
|
| 529 |
+
):
|
| 530 |
+
num_classes = logits_per_level.shape[-1]
|
| 531 |
+
|
| 532 |
+
# remove low scoring boxes
|
| 533 |
+
scores_per_level = torch.sigmoid(logits_per_level).flatten()
|
| 534 |
+
keep_idxs = scores_per_level > self.score_thresh
|
| 535 |
+
scores_per_level = scores_per_level[keep_idxs]
|
| 536 |
+
topk_idxs = torch.where(keep_idxs)[0]
|
| 537 |
+
|
| 538 |
+
# keep only topk scoring predictions
|
| 539 |
+
num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
|
| 540 |
+
scores_per_level, idxs = scores_per_level.topk(num_topk)
|
| 541 |
+
topk_idxs = topk_idxs[idxs]
|
| 542 |
+
|
| 543 |
+
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
|
| 544 |
+
labels_per_level = topk_idxs % num_classes
|
| 545 |
+
|
| 546 |
+
boxes_per_level = self.box_coder.decode_single(
|
| 547 |
+
box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
|
| 548 |
+
)
|
| 549 |
+
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
|
| 550 |
+
|
| 551 |
+
image_boxes.append(boxes_per_level)
|
| 552 |
+
image_scores.append(scores_per_level)
|
| 553 |
+
image_labels.append(labels_per_level)
|
| 554 |
+
|
| 555 |
+
image_boxes = torch.cat(image_boxes, dim=0)
|
| 556 |
+
image_scores = torch.cat(image_scores, dim=0)
|
| 557 |
+
image_labels = torch.cat(image_labels, dim=0)
|
| 558 |
+
|
| 559 |
+
# non-maximum suppression
|
| 560 |
+
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
| 561 |
+
keep = keep[: self.detections_per_img]
|
| 562 |
+
|
| 563 |
+
detections.append(
|
| 564 |
+
{
|
| 565 |
+
"boxes": image_boxes[keep],
|
| 566 |
+
"scores": image_scores[keep],
|
| 567 |
+
"labels": image_labels[keep],
|
| 568 |
+
}
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
return detections
|
| 572 |
+
|
| 573 |
+
def forward(self, images, targets=None):
|
| 574 |
+
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
|
| 575 |
+
"""
|
| 576 |
+
Args:
|
| 577 |
+
images (list[Tensor]): images to be processed
|
| 578 |
+
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
|
| 579 |
+
|
| 580 |
+
Returns:
|
| 581 |
+
result (list[BoxList] or dict[Tensor]): the output from the model.
|
| 582 |
+
During training, it returns a dict[Tensor] which contains the losses.
|
| 583 |
+
During testing, it returns list[BoxList] contains additional fields
|
| 584 |
+
like `scores`, `labels` and `mask` (for Mask R-CNN models).
|
| 585 |
+
|
| 586 |
+
"""
|
| 587 |
+
if self.training:
|
| 588 |
+
if targets is None:
|
| 589 |
+
torch._assert(False, "targets should not be none when in training mode")
|
| 590 |
+
else:
|
| 591 |
+
for target in targets:
|
| 592 |
+
boxes = target["boxes"]
|
| 593 |
+
torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
|
| 594 |
+
torch._assert(
|
| 595 |
+
len(boxes.shape) == 2 and boxes.shape[-1] == 4,
|
| 596 |
+
"Expected target boxes to be a tensor of shape [N, 4].",
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
# get the original image sizes
|
| 600 |
+
original_image_sizes: List[Tuple[int, int]] = []
|
| 601 |
+
for img in images:
|
| 602 |
+
val = img.shape[-2:]
|
| 603 |
+
torch._assert(
|
| 604 |
+
len(val) == 2,
|
| 605 |
+
f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
|
| 606 |
+
)
|
| 607 |
+
original_image_sizes.append((val[0], val[1]))
|
| 608 |
+
|
| 609 |
+
# transform the input
|
| 610 |
+
images, targets = self.transform(images, targets)
|
| 611 |
+
|
| 612 |
+
# Check for degenerate boxes
|
| 613 |
+
# TODO: Move this to a function
|
| 614 |
+
if targets is not None:
|
| 615 |
+
for target_idx, target in enumerate(targets):
|
| 616 |
+
boxes = target["boxes"]
|
| 617 |
+
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
|
| 618 |
+
if degenerate_boxes.any():
|
| 619 |
+
# print the first degenerate box
|
| 620 |
+
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
|
| 621 |
+
degen_bb: List[float] = boxes[bb_idx].tolist()
|
| 622 |
+
torch._assert(
|
| 623 |
+
False,
|
| 624 |
+
"All bounding boxes should have positive height and width."
|
| 625 |
+
f" Found invalid box {degen_bb} for target at index {target_idx}.",
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
# get the features from the backbone
|
| 629 |
+
features = self.backbone(images.tensors)
|
| 630 |
+
if isinstance(features, torch.Tensor):
|
| 631 |
+
features = OrderedDict([("0", features)])
|
| 632 |
+
|
| 633 |
+
# TODO: Do we want a list or a dict?
|
| 634 |
+
features = list(features.values())
|
| 635 |
+
|
| 636 |
+
# compute the retinanet heads outputs using the features
|
| 637 |
+
head_outputs = self.head(features)
|
| 638 |
+
|
| 639 |
+
# create the set of anchors
|
| 640 |
+
anchors = self.anchor_generator(images, features)
|
| 641 |
+
|
| 642 |
+
losses = {}
|
| 643 |
+
detections: List[Dict[str, Tensor]] = []
|
| 644 |
+
if self.training:
|
| 645 |
+
if targets is None:
|
| 646 |
+
torch._assert(False, "targets should not be none when in training mode")
|
| 647 |
+
else:
|
| 648 |
+
# compute the losses
|
| 649 |
+
losses = self.compute_loss(targets, head_outputs, anchors)
|
| 650 |
+
else:
|
| 651 |
+
# recover level sizes
|
| 652 |
+
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
|
| 653 |
+
HW = 0
|
| 654 |
+
for v in num_anchors_per_level:
|
| 655 |
+
HW += v
|
| 656 |
+
HWA = head_outputs["cls_logits"].size(1)
|
| 657 |
+
A = HWA // HW
|
| 658 |
+
num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
|
| 659 |
+
|
| 660 |
+
# split outputs per level
|
| 661 |
+
split_head_outputs: Dict[str, List[Tensor]] = {}
|
| 662 |
+
for k in head_outputs:
|
| 663 |
+
split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
|
| 664 |
+
split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
|
| 665 |
+
|
| 666 |
+
# compute the detections
|
| 667 |
+
detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
|
| 668 |
+
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
|
| 669 |
+
|
| 670 |
+
if torch.jit.is_scripting():
|
| 671 |
+
if not self._has_warned:
|
| 672 |
+
warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
|
| 673 |
+
self._has_warned = True
|
| 674 |
+
return losses, detections
|
| 675 |
+
return self.eager_outputs(losses, detections)
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
_COMMON_META = {
|
| 679 |
+
"categories": _COCO_CATEGORIES,
|
| 680 |
+
"min_size": (1, 1),
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
|
| 685 |
+
COCO_V1 = Weights(
|
| 686 |
+
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
|
| 687 |
+
transforms=ObjectDetection,
|
| 688 |
+
meta={
|
| 689 |
+
**_COMMON_META,
|
| 690 |
+
"num_params": 34014999,
|
| 691 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
|
| 692 |
+
"_metrics": {
|
| 693 |
+
"COCO-val2017": {
|
| 694 |
+
"box_map": 36.4,
|
| 695 |
+
}
|
| 696 |
+
},
|
| 697 |
+
"_ops": 151.54,
|
| 698 |
+
"_file_size": 130.267,
|
| 699 |
+
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
|
| 700 |
+
},
|
| 701 |
+
)
|
| 702 |
+
DEFAULT = COCO_V1
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
|
| 706 |
+
COCO_V1 = Weights(
|
| 707 |
+
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
|
| 708 |
+
transforms=ObjectDetection,
|
| 709 |
+
meta={
|
| 710 |
+
**_COMMON_META,
|
| 711 |
+
"num_params": 38198935,
|
| 712 |
+
"recipe": "https://github.com/pytorch/vision/pull/5756",
|
| 713 |
+
"_metrics": {
|
| 714 |
+
"COCO-val2017": {
|
| 715 |
+
"box_map": 41.5,
|
| 716 |
+
}
|
| 717 |
+
},
|
| 718 |
+
"_ops": 152.238,
|
| 719 |
+
"_file_size": 146.037,
|
| 720 |
+
"_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
|
| 721 |
+
},
|
| 722 |
+
)
|
| 723 |
+
DEFAULT = COCO_V1
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
@register_model()
|
| 727 |
+
@handle_legacy_interface(
|
| 728 |
+
weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
|
| 729 |
+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
| 730 |
+
)
|
| 731 |
+
def retinanet_resnet50_fpn(
|
| 732 |
+
*,
|
| 733 |
+
weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
|
| 734 |
+
progress: bool = True,
|
| 735 |
+
num_classes: Optional[int] = None,
|
| 736 |
+
weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
|
| 737 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 738 |
+
**kwargs: Any,
|
| 739 |
+
) -> RetinaNet:
|
| 740 |
+
"""
|
| 741 |
+
Constructs a RetinaNet model with a ResNet-50-FPN backbone.
|
| 742 |
+
|
| 743 |
+
.. betastatus:: detection module
|
| 744 |
+
|
| 745 |
+
Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
|
| 746 |
+
|
| 747 |
+
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
|
| 748 |
+
image, and should be in ``0-1`` range. Different images can have different sizes.
|
| 749 |
+
|
| 750 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 751 |
+
|
| 752 |
+
During training, the model expects both the input tensors and targets (list of dictionary),
|
| 753 |
+
containing:
|
| 754 |
+
|
| 755 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 756 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 757 |
+
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
|
| 758 |
+
|
| 759 |
+
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
|
| 760 |
+
losses.
|
| 761 |
+
|
| 762 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 763 |
+
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
|
| 764 |
+
follows, where ``N`` is the number of detections:
|
| 765 |
+
|
| 766 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 767 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 768 |
+
- labels (``Int64Tensor[N]``): the predicted labels for each detection
|
| 769 |
+
- scores (``Tensor[N]``): the scores of each detection
|
| 770 |
+
|
| 771 |
+
For more details on the output, you may refer to :ref:`instance_seg_output`.
|
| 772 |
+
|
| 773 |
+
Example::
|
| 774 |
+
|
| 775 |
+
>>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
|
| 776 |
+
>>> model.eval()
|
| 777 |
+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
| 778 |
+
>>> predictions = model(x)
|
| 779 |
+
|
| 780 |
+
Args:
|
| 781 |
+
weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
|
| 782 |
+
pretrained weights to use. See
|
| 783 |
+
:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
|
| 784 |
+
below for more details, and possible values. By default, no
|
| 785 |
+
pre-trained weights are used.
|
| 786 |
+
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
|
| 787 |
+
num_classes (int, optional): number of output classes of the model (including the background)
|
| 788 |
+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
|
| 789 |
+
the backbone.
|
| 790 |
+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
|
| 791 |
+
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
|
| 792 |
+
passed (the default) this value is set to 3.
|
| 793 |
+
**kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
|
| 794 |
+
base class. Please refer to the `source code
|
| 795 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
|
| 796 |
+
for more details about this class.
|
| 797 |
+
|
| 798 |
+
.. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
|
| 799 |
+
:members:
|
| 800 |
+
"""
|
| 801 |
+
weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
|
| 802 |
+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
| 803 |
+
|
| 804 |
+
if weights is not None:
|
| 805 |
+
weights_backbone = None
|
| 806 |
+
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
| 807 |
+
elif num_classes is None:
|
| 808 |
+
num_classes = 91
|
| 809 |
+
|
| 810 |
+
is_trained = weights is not None or weights_backbone is not None
|
| 811 |
+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
| 812 |
+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
| 813 |
+
|
| 814 |
+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
| 815 |
+
# skip P2 because it generates too many anchors (according to their paper)
|
| 816 |
+
backbone = _resnet_fpn_extractor(
|
| 817 |
+
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
|
| 818 |
+
)
|
| 819 |
+
model = RetinaNet(backbone, num_classes, **kwargs)
|
| 820 |
+
|
| 821 |
+
if weights is not None:
|
| 822 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 823 |
+
if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
|
| 824 |
+
overwrite_eps(model, 0.0)
|
| 825 |
+
|
| 826 |
+
return model
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
@register_model()
|
| 830 |
+
@handle_legacy_interface(
|
| 831 |
+
weights=("pretrained", RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1),
|
| 832 |
+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
| 833 |
+
)
|
| 834 |
+
def retinanet_resnet50_fpn_v2(
|
| 835 |
+
*,
|
| 836 |
+
weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
|
| 837 |
+
progress: bool = True,
|
| 838 |
+
num_classes: Optional[int] = None,
|
| 839 |
+
weights_backbone: Optional[ResNet50_Weights] = None,
|
| 840 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 841 |
+
**kwargs: Any,
|
| 842 |
+
) -> RetinaNet:
|
| 843 |
+
"""
|
| 844 |
+
Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
|
| 845 |
+
|
| 846 |
+
.. betastatus:: detection module
|
| 847 |
+
|
| 848 |
+
Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
|
| 849 |
+
<https://arxiv.org/abs/1912.02424>`_.
|
| 850 |
+
|
| 851 |
+
:func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
|
| 852 |
+
|
| 853 |
+
Args:
|
| 854 |
+
weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
|
| 855 |
+
pretrained weights to use. See
|
| 856 |
+
:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
|
| 857 |
+
below for more details, and possible values. By default, no
|
| 858 |
+
pre-trained weights are used.
|
| 859 |
+
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
|
| 860 |
+
num_classes (int, optional): number of output classes of the model (including the background)
|
| 861 |
+
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
|
| 862 |
+
the backbone.
|
| 863 |
+
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
|
| 864 |
+
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
|
| 865 |
+
passed (the default) this value is set to 3.
|
| 866 |
+
**kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
|
| 867 |
+
base class. Please refer to the `source code
|
| 868 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
|
| 869 |
+
for more details about this class.
|
| 870 |
+
|
| 871 |
+
.. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
|
| 872 |
+
:members:
|
| 873 |
+
"""
|
| 874 |
+
weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
|
| 875 |
+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
| 876 |
+
|
| 877 |
+
if weights is not None:
|
| 878 |
+
weights_backbone = None
|
| 879 |
+
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
| 880 |
+
elif num_classes is None:
|
| 881 |
+
num_classes = 91
|
| 882 |
+
|
| 883 |
+
is_trained = weights is not None or weights_backbone is not None
|
| 884 |
+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
| 885 |
+
|
| 886 |
+
backbone = resnet50(weights=weights_backbone, progress=progress)
|
| 887 |
+
backbone = _resnet_fpn_extractor(
|
| 888 |
+
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
|
| 889 |
+
)
|
| 890 |
+
anchor_generator = _default_anchorgen()
|
| 891 |
+
head = RetinaNetHead(
|
| 892 |
+
backbone.out_channels,
|
| 893 |
+
anchor_generator.num_anchors_per_location()[0],
|
| 894 |
+
num_classes,
|
| 895 |
+
norm_layer=partial(nn.GroupNorm, 32),
|
| 896 |
+
)
|
| 897 |
+
head.regression_head._loss_type = "giou"
|
| 898 |
+
model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
|
| 899 |
+
|
| 900 |
+
if weights is not None:
|
| 901 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 902 |
+
|
| 903 |
+
return model
|