Spaces:
Sleeping
Sleeping
Upload 372 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- torchvision/_C.so +3 -0
- torchvision/__init__.py +114 -0
- torchvision/__pycache__/__init__.cpython-38.pyc +0 -0
- torchvision/__pycache__/_internally_replaced_utils.cpython-38.pyc +0 -0
- torchvision/__pycache__/_utils.cpython-38.pyc +0 -0
- torchvision/__pycache__/extension.cpython-38.pyc +0 -0
- torchvision/__pycache__/utils.cpython-38.pyc +0 -0
- torchvision/__pycache__/version.cpython-38.pyc +0 -0
- torchvision/_internally_replaced_utils.py +58 -0
- torchvision/_utils.py +32 -0
- torchvision/datapoints/__init__.py +12 -0
- torchvision/datapoints/__pycache__/__init__.cpython-38.pyc +0 -0
- torchvision/datapoints/__pycache__/_bounding_box.cpython-38.pyc +0 -0
- torchvision/datapoints/__pycache__/_datapoint.cpython-38.pyc +0 -0
- torchvision/datapoints/__pycache__/_dataset_wrapper.cpython-38.pyc +0 -0
- torchvision/datapoints/__pycache__/_image.cpython-38.pyc +0 -0
- torchvision/datapoints/__pycache__/_mask.cpython-38.pyc +0 -0
- torchvision/datapoints/__pycache__/_video.cpython-38.pyc +0 -0
- torchvision/datapoints/_bounding_box.py +237 -0
- torchvision/datapoints/_datapoint.py +259 -0
- torchvision/datapoints/_dataset_wrapper.py +499 -0
- torchvision/datapoints/_image.py +260 -0
- torchvision/datapoints/_mask.py +158 -0
- torchvision/datapoints/_video.py +250 -0
- torchvision/datasets/__init__.py +145 -0
- torchvision/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/_optical_flow.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/_stereo_matching.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/caltech.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/celeba.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/cifar.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/cityscapes.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/clevr.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/coco.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/country211.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/dtd.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/eurosat.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/fakedata.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/fer2013.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/fgvc_aircraft.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/flickr.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/flowers102.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/folder.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/food101.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/gtsrb.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/hmdb51.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/imagenet.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/inaturalist.cpython-38.pyc +0 -0
- torchvision/datasets/__pycache__/kinetics.cpython-38.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
torchvision/_C.so filter=lfs diff=lfs merge=lfs -text
|
37 |
+
torchvision/image.so filter=lfs diff=lfs merge=lfs -text
|
torchvision/_C.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f60f9e860992a1c25bbad7efc84b2036c82da5f23c976995dbdefeab6ff5d73
|
3 |
+
size 70104344
|
torchvision/__init__.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
from modulefinder import Module
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torchvision import datasets, io, models, ops, transforms, utils
|
7 |
+
|
8 |
+
from .extension import _HAS_OPS
|
9 |
+
|
10 |
+
try:
|
11 |
+
from .version import __version__ # noqa: F401
|
12 |
+
except ImportError:
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
# Check if torchvision is being imported within the root folder
|
17 |
+
if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
|
18 |
+
os.path.realpath(os.getcwd()), "torchvision"
|
19 |
+
):
|
20 |
+
message = (
|
21 |
+
"You are importing torchvision within its own root folder ({}). "
|
22 |
+
"This is not expected to work and may give errors. Please exit the "
|
23 |
+
"torchvision project source and relaunch your python interpreter."
|
24 |
+
)
|
25 |
+
warnings.warn(message.format(os.getcwd()))
|
26 |
+
|
27 |
+
_image_backend = "PIL"
|
28 |
+
|
29 |
+
_video_backend = "pyav"
|
30 |
+
|
31 |
+
|
32 |
+
def set_image_backend(backend):
|
33 |
+
"""
|
34 |
+
Specifies the package used to load images.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
|
38 |
+
The :mod:`accimage` package uses the Intel IPP library. It is
|
39 |
+
generally faster than PIL, but does not support as many operations.
|
40 |
+
"""
|
41 |
+
global _image_backend
|
42 |
+
if backend not in ["PIL", "accimage"]:
|
43 |
+
raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
|
44 |
+
_image_backend = backend
|
45 |
+
|
46 |
+
|
47 |
+
def get_image_backend():
|
48 |
+
"""
|
49 |
+
Gets the name of the package used to load images
|
50 |
+
"""
|
51 |
+
return _image_backend
|
52 |
+
|
53 |
+
|
54 |
+
def set_video_backend(backend):
|
55 |
+
"""
|
56 |
+
Specifies the package used to decode videos.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
backend (string): Name of the video backend. one of {'pyav', 'video_reader'}.
|
60 |
+
The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
|
61 |
+
binding for the FFmpeg libraries.
|
62 |
+
The :mod:`video_reader` package includes a native C++ implementation on
|
63 |
+
top of FFMPEG libraries, and a python API of TorchScript custom operator.
|
64 |
+
It generally decodes faster than :mod:`pyav`, but is perhaps less robust.
|
65 |
+
|
66 |
+
.. note::
|
67 |
+
Building with FFMPEG is disabled by default in the latest `main`. If you want to use the 'video_reader'
|
68 |
+
backend, please compile torchvision from source.
|
69 |
+
"""
|
70 |
+
global _video_backend
|
71 |
+
if backend not in ["pyav", "video_reader", "cuda"]:
|
72 |
+
raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
|
73 |
+
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
|
74 |
+
# TODO: better messages
|
75 |
+
message = "video_reader video backend is not available. Please compile torchvision from source and try again"
|
76 |
+
raise RuntimeError(message)
|
77 |
+
elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
|
78 |
+
# TODO: better messages
|
79 |
+
message = "cuda video backend is not available."
|
80 |
+
raise RuntimeError(message)
|
81 |
+
else:
|
82 |
+
_video_backend = backend
|
83 |
+
|
84 |
+
|
85 |
+
def get_video_backend():
|
86 |
+
"""
|
87 |
+
Returns the currently active video backend used to decode videos.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
str: Name of the video backend. one of {'pyav', 'video_reader'}.
|
91 |
+
"""
|
92 |
+
|
93 |
+
return _video_backend
|
94 |
+
|
95 |
+
|
96 |
+
def _is_tracing():
|
97 |
+
return torch._C._get_tracing_state()
|
98 |
+
|
99 |
+
|
100 |
+
_WARN_ABOUT_BETA_TRANSFORMS = True
|
101 |
+
_BETA_TRANSFORMS_WARNING = (
|
102 |
+
"The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. "
|
103 |
+
"While we do not expect major breaking changes, some APIs may still change "
|
104 |
+
"according to user feedback. Please submit any feedback you may have in "
|
105 |
+
"this issue: https://github.com/pytorch/vision/issues/6753, and you can also "
|
106 |
+
"check out https://github.com/pytorch/vision/issues/7319 to learn more about "
|
107 |
+
"the APIs that we suspect might involve future changes. "
|
108 |
+
"You can silence this warning by calling torchvision.disable_beta_transforms_warning()."
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
def disable_beta_transforms_warning():
|
113 |
+
global _WARN_ABOUT_BETA_TRANSFORMS
|
114 |
+
_WARN_ABOUT_BETA_TRANSFORMS = False
|
torchvision/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (4.07 kB). View file
|
|
torchvision/__pycache__/_internally_replaced_utils.cpython-38.pyc
ADDED
Binary file (1.76 kB). View file
|
|
torchvision/__pycache__/_utils.cpython-38.pyc
ADDED
Binary file (1.46 kB). View file
|
|
torchvision/__pycache__/extension.cpython-38.pyc
ADDED
Binary file (2.93 kB). View file
|
|
torchvision/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (18.8 kB). View file
|
|
torchvision/__pycache__/version.cpython-38.pyc
ADDED
Binary file (367 Bytes). View file
|
|
torchvision/_internally_replaced_utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.machinery
|
2 |
+
import os
|
3 |
+
|
4 |
+
from torch.hub import _get_torch_home
|
5 |
+
|
6 |
+
|
7 |
+
_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
|
8 |
+
_USE_SHARDED_DATASETS = False
|
9 |
+
|
10 |
+
|
11 |
+
def _download_file_from_remote_location(fpath: str, url: str) -> None:
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
def _is_remote_location_available() -> bool:
|
16 |
+
return False
|
17 |
+
|
18 |
+
|
19 |
+
try:
|
20 |
+
from torch.hub import load_state_dict_from_url # noqa: 401
|
21 |
+
except ImportError:
|
22 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
|
23 |
+
|
24 |
+
|
25 |
+
def _get_extension_path(lib_name):
|
26 |
+
|
27 |
+
lib_dir = os.path.dirname(__file__)
|
28 |
+
if os.name == "nt":
|
29 |
+
# Register the main torchvision library location on the default DLL path
|
30 |
+
import ctypes
|
31 |
+
import sys
|
32 |
+
|
33 |
+
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
34 |
+
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
35 |
+
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
36 |
+
|
37 |
+
if with_load_library_flags:
|
38 |
+
kernel32.AddDllDirectory.restype = ctypes.c_void_p
|
39 |
+
|
40 |
+
if sys.version_info >= (3, 8):
|
41 |
+
os.add_dll_directory(lib_dir)
|
42 |
+
elif with_load_library_flags:
|
43 |
+
res = kernel32.AddDllDirectory(lib_dir)
|
44 |
+
if res is None:
|
45 |
+
err = ctypes.WinError(ctypes.get_last_error())
|
46 |
+
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
|
47 |
+
raise err
|
48 |
+
|
49 |
+
kernel32.SetErrorMode(prev_error_mode)
|
50 |
+
|
51 |
+
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
|
52 |
+
|
53 |
+
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
|
54 |
+
ext_specs = extfinder.find_spec(lib_name)
|
55 |
+
if ext_specs is None:
|
56 |
+
raise ImportError
|
57 |
+
|
58 |
+
return ext_specs.origin
|
torchvision/_utils.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
from typing import Sequence, Type, TypeVar
|
3 |
+
|
4 |
+
T = TypeVar("T", bound=enum.Enum)
|
5 |
+
|
6 |
+
|
7 |
+
class StrEnumMeta(enum.EnumMeta):
|
8 |
+
auto = enum.auto
|
9 |
+
|
10 |
+
def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
|
11 |
+
try:
|
12 |
+
return self[member]
|
13 |
+
except KeyError:
|
14 |
+
# TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
|
15 |
+
# soon as it is migrated.
|
16 |
+
raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
|
17 |
+
|
18 |
+
|
19 |
+
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
|
24 |
+
if not seq:
|
25 |
+
return ""
|
26 |
+
if len(seq) == 1:
|
27 |
+
return f"'{seq[0]}'"
|
28 |
+
|
29 |
+
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
|
30 |
+
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
|
31 |
+
|
32 |
+
return head + tail
|
torchvision/datapoints/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
|
2 |
+
|
3 |
+
from ._bounding_box import BoundingBox, BoundingBoxFormat
|
4 |
+
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
|
5 |
+
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
|
6 |
+
from ._mask import Mask
|
7 |
+
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
|
8 |
+
|
9 |
+
if _WARN_ABOUT_BETA_TRANSFORMS:
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
warnings.warn(_BETA_TRANSFORMS_WARNING)
|
torchvision/datapoints/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (840 Bytes). View file
|
|
torchvision/datapoints/__pycache__/_bounding_box.cpython-38.pyc
ADDED
Binary file (8.02 kB). View file
|
|
torchvision/datapoints/__pycache__/_datapoint.cpython-38.pyc
ADDED
Binary file (10 kB). View file
|
|
torchvision/datapoints/__pycache__/_dataset_wrapper.cpython-38.pyc
ADDED
Binary file (16.3 kB). View file
|
|
torchvision/datapoints/__pycache__/_image.cpython-38.pyc
ADDED
Binary file (10.4 kB). View file
|
|
torchvision/datapoints/__pycache__/_mask.cpython-38.pyc
ADDED
Binary file (5.98 kB). View file
|
|
torchvision/datapoints/__pycache__/_video.cpython-38.pyc
ADDED
Binary file (10.1 kB). View file
|
|
torchvision/datapoints/_bounding_box.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Any, List, Optional, Sequence, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
|
8 |
+
|
9 |
+
from ._datapoint import _FillTypeJIT, Datapoint
|
10 |
+
|
11 |
+
|
12 |
+
class BoundingBoxFormat(Enum):
|
13 |
+
"""[BETA] Coordinate format of a bounding box.
|
14 |
+
|
15 |
+
Available formats are
|
16 |
+
|
17 |
+
* ``XYXY``
|
18 |
+
* ``XYWH``
|
19 |
+
* ``CXCYWH``
|
20 |
+
"""
|
21 |
+
|
22 |
+
XYXY = "XYXY"
|
23 |
+
XYWH = "XYWH"
|
24 |
+
CXCYWH = "CXCYWH"
|
25 |
+
|
26 |
+
|
27 |
+
class BoundingBox(Datapoint):
|
28 |
+
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
|
32 |
+
format (BoundingBoxFormat, str): Format of the bounding box.
|
33 |
+
spatial_size (two-tuple of ints): Height and width of the corresponding image or video.
|
34 |
+
dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
|
35 |
+
``data``.
|
36 |
+
device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
|
37 |
+
:class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
|
38 |
+
requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
|
39 |
+
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
|
40 |
+
"""
|
41 |
+
|
42 |
+
format: BoundingBoxFormat
|
43 |
+
spatial_size: Tuple[int, int]
|
44 |
+
|
45 |
+
@classmethod
|
46 |
+
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox:
|
47 |
+
bounding_box = tensor.as_subclass(cls)
|
48 |
+
bounding_box.format = format
|
49 |
+
bounding_box.spatial_size = spatial_size
|
50 |
+
return bounding_box
|
51 |
+
|
52 |
+
def __new__(
|
53 |
+
cls,
|
54 |
+
data: Any,
|
55 |
+
*,
|
56 |
+
format: Union[BoundingBoxFormat, str],
|
57 |
+
spatial_size: Tuple[int, int],
|
58 |
+
dtype: Optional[torch.dtype] = None,
|
59 |
+
device: Optional[Union[torch.device, str, int]] = None,
|
60 |
+
requires_grad: Optional[bool] = None,
|
61 |
+
) -> BoundingBox:
|
62 |
+
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
|
63 |
+
|
64 |
+
if isinstance(format, str):
|
65 |
+
format = BoundingBoxFormat[format.upper()]
|
66 |
+
|
67 |
+
return cls._wrap(tensor, format=format, spatial_size=spatial_size)
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def wrap_like(
|
71 |
+
cls,
|
72 |
+
other: BoundingBox,
|
73 |
+
tensor: torch.Tensor,
|
74 |
+
*,
|
75 |
+
format: Optional[BoundingBoxFormat] = None,
|
76 |
+
spatial_size: Optional[Tuple[int, int]] = None,
|
77 |
+
) -> BoundingBox:
|
78 |
+
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBox` from a reference.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
other (BoundingBox): Reference bounding box.
|
82 |
+
tensor (Tensor): Tensor to be wrapped as :class:`BoundingBox`
|
83 |
+
format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the
|
84 |
+
reference.
|
85 |
+
spatial_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
|
86 |
+
omitted, it is taken from the reference.
|
87 |
+
|
88 |
+
"""
|
89 |
+
if isinstance(format, str):
|
90 |
+
format = BoundingBoxFormat[format.upper()]
|
91 |
+
|
92 |
+
return cls._wrap(
|
93 |
+
tensor,
|
94 |
+
format=format if format is not None else other.format,
|
95 |
+
spatial_size=spatial_size if spatial_size is not None else other.spatial_size,
|
96 |
+
)
|
97 |
+
|
98 |
+
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
|
99 |
+
return self._make_repr(format=self.format, spatial_size=self.spatial_size)
|
100 |
+
|
101 |
+
def horizontal_flip(self) -> BoundingBox:
|
102 |
+
output = self._F.horizontal_flip_bounding_box(
|
103 |
+
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
|
104 |
+
)
|
105 |
+
return BoundingBox.wrap_like(self, output)
|
106 |
+
|
107 |
+
def vertical_flip(self) -> BoundingBox:
|
108 |
+
output = self._F.vertical_flip_bounding_box(
|
109 |
+
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
|
110 |
+
)
|
111 |
+
return BoundingBox.wrap_like(self, output)
|
112 |
+
|
113 |
+
def resize( # type: ignore[override]
|
114 |
+
self,
|
115 |
+
size: List[int],
|
116 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
117 |
+
max_size: Optional[int] = None,
|
118 |
+
antialias: Optional[Union[str, bool]] = "warn",
|
119 |
+
) -> BoundingBox:
|
120 |
+
output, spatial_size = self._F.resize_bounding_box(
|
121 |
+
self.as_subclass(torch.Tensor),
|
122 |
+
spatial_size=self.spatial_size,
|
123 |
+
size=size,
|
124 |
+
max_size=max_size,
|
125 |
+
)
|
126 |
+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
|
127 |
+
|
128 |
+
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
|
129 |
+
output, spatial_size = self._F.crop_bounding_box(
|
130 |
+
self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
|
131 |
+
)
|
132 |
+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
|
133 |
+
|
134 |
+
def center_crop(self, output_size: List[int]) -> BoundingBox:
|
135 |
+
output, spatial_size = self._F.center_crop_bounding_box(
|
136 |
+
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size
|
137 |
+
)
|
138 |
+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
|
139 |
+
|
140 |
+
def resized_crop(
|
141 |
+
self,
|
142 |
+
top: int,
|
143 |
+
left: int,
|
144 |
+
height: int,
|
145 |
+
width: int,
|
146 |
+
size: List[int],
|
147 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
148 |
+
antialias: Optional[Union[str, bool]] = "warn",
|
149 |
+
) -> BoundingBox:
|
150 |
+
output, spatial_size = self._F.resized_crop_bounding_box(
|
151 |
+
self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
|
152 |
+
)
|
153 |
+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
|
154 |
+
|
155 |
+
def pad(
|
156 |
+
self,
|
157 |
+
padding: Union[int, Sequence[int]],
|
158 |
+
fill: Optional[Union[int, float, List[float]]] = None,
|
159 |
+
padding_mode: str = "constant",
|
160 |
+
) -> BoundingBox:
|
161 |
+
output, spatial_size = self._F.pad_bounding_box(
|
162 |
+
self.as_subclass(torch.Tensor),
|
163 |
+
format=self.format,
|
164 |
+
spatial_size=self.spatial_size,
|
165 |
+
padding=padding,
|
166 |
+
padding_mode=padding_mode,
|
167 |
+
)
|
168 |
+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
|
169 |
+
|
170 |
+
def rotate(
|
171 |
+
self,
|
172 |
+
angle: float,
|
173 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
174 |
+
expand: bool = False,
|
175 |
+
center: Optional[List[float]] = None,
|
176 |
+
fill: _FillTypeJIT = None,
|
177 |
+
) -> BoundingBox:
|
178 |
+
output, spatial_size = self._F.rotate_bounding_box(
|
179 |
+
self.as_subclass(torch.Tensor),
|
180 |
+
format=self.format,
|
181 |
+
spatial_size=self.spatial_size,
|
182 |
+
angle=angle,
|
183 |
+
expand=expand,
|
184 |
+
center=center,
|
185 |
+
)
|
186 |
+
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
|
187 |
+
|
188 |
+
def affine(
|
189 |
+
self,
|
190 |
+
angle: Union[int, float],
|
191 |
+
translate: List[float],
|
192 |
+
scale: float,
|
193 |
+
shear: List[float],
|
194 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
195 |
+
fill: _FillTypeJIT = None,
|
196 |
+
center: Optional[List[float]] = None,
|
197 |
+
) -> BoundingBox:
|
198 |
+
output = self._F.affine_bounding_box(
|
199 |
+
self.as_subclass(torch.Tensor),
|
200 |
+
self.format,
|
201 |
+
self.spatial_size,
|
202 |
+
angle,
|
203 |
+
translate=translate,
|
204 |
+
scale=scale,
|
205 |
+
shear=shear,
|
206 |
+
center=center,
|
207 |
+
)
|
208 |
+
return BoundingBox.wrap_like(self, output)
|
209 |
+
|
210 |
+
def perspective(
|
211 |
+
self,
|
212 |
+
startpoints: Optional[List[List[int]]],
|
213 |
+
endpoints: Optional[List[List[int]]],
|
214 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
215 |
+
fill: _FillTypeJIT = None,
|
216 |
+
coefficients: Optional[List[float]] = None,
|
217 |
+
) -> BoundingBox:
|
218 |
+
output = self._F.perspective_bounding_box(
|
219 |
+
self.as_subclass(torch.Tensor),
|
220 |
+
format=self.format,
|
221 |
+
spatial_size=self.spatial_size,
|
222 |
+
startpoints=startpoints,
|
223 |
+
endpoints=endpoints,
|
224 |
+
coefficients=coefficients,
|
225 |
+
)
|
226 |
+
return BoundingBox.wrap_like(self, output)
|
227 |
+
|
228 |
+
def elastic(
|
229 |
+
self,
|
230 |
+
displacement: torch.Tensor,
|
231 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
232 |
+
fill: _FillTypeJIT = None,
|
233 |
+
) -> BoundingBox:
|
234 |
+
output = self._F.elastic_bounding_box(
|
235 |
+
self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement
|
236 |
+
)
|
237 |
+
return BoundingBox.wrap_like(self, output)
|
torchvision/datapoints/_datapoint.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from types import ModuleType
|
4 |
+
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
|
5 |
+
|
6 |
+
import PIL.Image
|
7 |
+
import torch
|
8 |
+
from torch._C import DisableTorchFunctionSubclass
|
9 |
+
from torch.types import _device, _dtype, _size
|
10 |
+
from torchvision.transforms import InterpolationMode
|
11 |
+
|
12 |
+
|
13 |
+
D = TypeVar("D", bound="Datapoint")
|
14 |
+
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
|
15 |
+
_FillTypeJIT = Optional[List[float]]
|
16 |
+
|
17 |
+
|
18 |
+
class Datapoint(torch.Tensor):
|
19 |
+
__F: Optional[ModuleType] = None
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def _to_tensor(
|
23 |
+
data: Any,
|
24 |
+
dtype: Optional[torch.dtype] = None,
|
25 |
+
device: Optional[Union[torch.device, str, int]] = None,
|
26 |
+
requires_grad: Optional[bool] = None,
|
27 |
+
) -> torch.Tensor:
|
28 |
+
if requires_grad is None:
|
29 |
+
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
|
30 |
+
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
|
34 |
+
raise NotImplementedError
|
35 |
+
|
36 |
+
_NO_WRAPPING_EXCEPTIONS = {
|
37 |
+
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
|
38 |
+
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
|
39 |
+
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
|
40 |
+
# retains the type automatically
|
41 |
+
torch.Tensor.requires_grad_: lambda cls, input, output: output,
|
42 |
+
}
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def __torch_function__(
|
46 |
+
cls,
|
47 |
+
func: Callable[..., torch.Tensor],
|
48 |
+
types: Tuple[Type[torch.Tensor], ...],
|
49 |
+
args: Sequence[Any] = (),
|
50 |
+
kwargs: Optional[Mapping[str, Any]] = None,
|
51 |
+
) -> torch.Tensor:
|
52 |
+
"""For general information about how the __torch_function__ protocol works,
|
53 |
+
see https://pytorch.org/docs/stable/notes/extending.html#extending-torch
|
54 |
+
|
55 |
+
TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
|
56 |
+
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
|
57 |
+
``args`` and ``kwargs`` of the original call.
|
58 |
+
|
59 |
+
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
|
60 |
+
use case, this has two downsides:
|
61 |
+
|
62 |
+
1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
|
63 |
+
``return cls(func(*args, **kwargs))``, will fail for them.
|
64 |
+
2. For most operations, there is no way of knowing if the input type is still valid for the output.
|
65 |
+
|
66 |
+
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
|
67 |
+
listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
|
68 |
+
"""
|
69 |
+
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
|
70 |
+
# need to reimplement the functionality.
|
71 |
+
|
72 |
+
if not all(issubclass(cls, t) for t in types):
|
73 |
+
return NotImplemented
|
74 |
+
|
75 |
+
with DisableTorchFunctionSubclass():
|
76 |
+
output = func(*args, **kwargs or dict())
|
77 |
+
|
78 |
+
wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
|
79 |
+
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
|
80 |
+
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
|
81 |
+
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
|
82 |
+
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
|
83 |
+
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
|
84 |
+
# be wrapped into a `datapoints.Image`.
|
85 |
+
if wrapper and isinstance(args[0], cls):
|
86 |
+
return wrapper(cls, args[0], output)
|
87 |
+
|
88 |
+
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
|
89 |
+
# will retain the input type. Thus, we need to unwrap here.
|
90 |
+
if isinstance(output, cls):
|
91 |
+
return output.as_subclass(torch.Tensor)
|
92 |
+
|
93 |
+
return output
|
94 |
+
|
95 |
+
def _make_repr(self, **kwargs: Any) -> str:
|
96 |
+
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
|
97 |
+
# If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
|
98 |
+
extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
|
99 |
+
return f"{super().__repr__()[:-1]}, {extra_repr})"
|
100 |
+
|
101 |
+
@property
|
102 |
+
def _F(self) -> ModuleType:
|
103 |
+
# This implements a lazy import of the functional to get around the cyclic import. This import is deferred
|
104 |
+
# until the first time we need reference to the functional module and it's shared across all instances of
|
105 |
+
# the class. This approach avoids the DataLoader issue described at
|
106 |
+
# https://github.com/pytorch/vision/pull/6476#discussion_r953588621
|
107 |
+
if Datapoint.__F is None:
|
108 |
+
from ..transforms.v2 import functional
|
109 |
+
|
110 |
+
Datapoint.__F = functional
|
111 |
+
return Datapoint.__F
|
112 |
+
|
113 |
+
# Add properties for common attributes like shape, dtype, device, ndim etc
|
114 |
+
# this way we return the result without passing into __torch_function__
|
115 |
+
@property
|
116 |
+
def shape(self) -> _size: # type: ignore[override]
|
117 |
+
with DisableTorchFunctionSubclass():
|
118 |
+
return super().shape
|
119 |
+
|
120 |
+
@property
|
121 |
+
def ndim(self) -> int: # type: ignore[override]
|
122 |
+
with DisableTorchFunctionSubclass():
|
123 |
+
return super().ndim
|
124 |
+
|
125 |
+
@property
|
126 |
+
def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override]
|
127 |
+
with DisableTorchFunctionSubclass():
|
128 |
+
return super().device
|
129 |
+
|
130 |
+
@property
|
131 |
+
def dtype(self) -> _dtype: # type: ignore[override]
|
132 |
+
with DisableTorchFunctionSubclass():
|
133 |
+
return super().dtype
|
134 |
+
|
135 |
+
def horizontal_flip(self) -> Datapoint:
|
136 |
+
return self
|
137 |
+
|
138 |
+
def vertical_flip(self) -> Datapoint:
|
139 |
+
return self
|
140 |
+
|
141 |
+
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
|
142 |
+
# https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
|
143 |
+
def resize( # type: ignore[override]
|
144 |
+
self,
|
145 |
+
size: List[int],
|
146 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
147 |
+
max_size: Optional[int] = None,
|
148 |
+
antialias: Optional[Union[str, bool]] = "warn",
|
149 |
+
) -> Datapoint:
|
150 |
+
return self
|
151 |
+
|
152 |
+
def crop(self, top: int, left: int, height: int, width: int) -> Datapoint:
|
153 |
+
return self
|
154 |
+
|
155 |
+
def center_crop(self, output_size: List[int]) -> Datapoint:
|
156 |
+
return self
|
157 |
+
|
158 |
+
def resized_crop(
|
159 |
+
self,
|
160 |
+
top: int,
|
161 |
+
left: int,
|
162 |
+
height: int,
|
163 |
+
width: int,
|
164 |
+
size: List[int],
|
165 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
166 |
+
antialias: Optional[Union[str, bool]] = "warn",
|
167 |
+
) -> Datapoint:
|
168 |
+
return self
|
169 |
+
|
170 |
+
def pad(
|
171 |
+
self,
|
172 |
+
padding: List[int],
|
173 |
+
fill: Optional[Union[int, float, List[float]]] = None,
|
174 |
+
padding_mode: str = "constant",
|
175 |
+
) -> Datapoint:
|
176 |
+
return self
|
177 |
+
|
178 |
+
def rotate(
|
179 |
+
self,
|
180 |
+
angle: float,
|
181 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
182 |
+
expand: bool = False,
|
183 |
+
center: Optional[List[float]] = None,
|
184 |
+
fill: _FillTypeJIT = None,
|
185 |
+
) -> Datapoint:
|
186 |
+
return self
|
187 |
+
|
188 |
+
def affine(
|
189 |
+
self,
|
190 |
+
angle: Union[int, float],
|
191 |
+
translate: List[float],
|
192 |
+
scale: float,
|
193 |
+
shear: List[float],
|
194 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
195 |
+
fill: _FillTypeJIT = None,
|
196 |
+
center: Optional[List[float]] = None,
|
197 |
+
) -> Datapoint:
|
198 |
+
return self
|
199 |
+
|
200 |
+
def perspective(
|
201 |
+
self,
|
202 |
+
startpoints: Optional[List[List[int]]],
|
203 |
+
endpoints: Optional[List[List[int]]],
|
204 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
205 |
+
fill: _FillTypeJIT = None,
|
206 |
+
coefficients: Optional[List[float]] = None,
|
207 |
+
) -> Datapoint:
|
208 |
+
return self
|
209 |
+
|
210 |
+
def elastic(
|
211 |
+
self,
|
212 |
+
displacement: torch.Tensor,
|
213 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
214 |
+
fill: _FillTypeJIT = None,
|
215 |
+
) -> Datapoint:
|
216 |
+
return self
|
217 |
+
|
218 |
+
def rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
|
219 |
+
return self
|
220 |
+
|
221 |
+
def adjust_brightness(self, brightness_factor: float) -> Datapoint:
|
222 |
+
return self
|
223 |
+
|
224 |
+
def adjust_saturation(self, saturation_factor: float) -> Datapoint:
|
225 |
+
return self
|
226 |
+
|
227 |
+
def adjust_contrast(self, contrast_factor: float) -> Datapoint:
|
228 |
+
return self
|
229 |
+
|
230 |
+
def adjust_sharpness(self, sharpness_factor: float) -> Datapoint:
|
231 |
+
return self
|
232 |
+
|
233 |
+
def adjust_hue(self, hue_factor: float) -> Datapoint:
|
234 |
+
return self
|
235 |
+
|
236 |
+
def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint:
|
237 |
+
return self
|
238 |
+
|
239 |
+
def posterize(self, bits: int) -> Datapoint:
|
240 |
+
return self
|
241 |
+
|
242 |
+
def solarize(self, threshold: float) -> Datapoint:
|
243 |
+
return self
|
244 |
+
|
245 |
+
def autocontrast(self) -> Datapoint:
|
246 |
+
return self
|
247 |
+
|
248 |
+
def equalize(self) -> Datapoint:
|
249 |
+
return self
|
250 |
+
|
251 |
+
def invert(self) -> Datapoint:
|
252 |
+
return self
|
253 |
+
|
254 |
+
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint:
|
255 |
+
return self
|
256 |
+
|
257 |
+
|
258 |
+
_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
|
259 |
+
_InputTypeJIT = torch.Tensor
|
torchvision/datapoints/_dataset_wrapper.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# type: ignore
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import contextlib
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
|
11 |
+
from torchvision import datapoints, datasets
|
12 |
+
from torchvision.transforms.v2 import functional as F
|
13 |
+
|
14 |
+
__all__ = ["wrap_dataset_for_transforms_v2"]
|
15 |
+
|
16 |
+
|
17 |
+
def wrap_dataset_for_transforms_v2(dataset):
|
18 |
+
"""[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.
|
19 |
+
|
20 |
+
.. v2betastatus:: wrap_dataset_for_transforms_v2 function
|
21 |
+
|
22 |
+
Example:
|
23 |
+
>>> dataset = torchvision.datasets.CocoDetection(...)
|
24 |
+
>>> dataset = wrap_dataset_for_transforms_v2(dataset)
|
25 |
+
|
26 |
+
.. note::
|
27 |
+
|
28 |
+
For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset
|
29 |
+
configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you
|
30 |
+
to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so.
|
31 |
+
|
32 |
+
The dataset samples are wrapped according to the description below.
|
33 |
+
|
34 |
+
Special cases:
|
35 |
+
|
36 |
+
* :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
|
37 |
+
returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
|
38 |
+
``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``.
|
39 |
+
The original keys are preserved.
|
40 |
+
* :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
|
41 |
+
the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are
|
42 |
+
preserved.
|
43 |
+
* :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
|
44 |
+
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.
|
45 |
+
* :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict
|
46 |
+
of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
|
47 |
+
in the corresponding ``torchvision.datapoints``. The original keys are preserved.
|
48 |
+
* :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
|
49 |
+
:class:`~torchvision.datapoints.Mask` datapoint.
|
50 |
+
* :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
|
51 |
+
:class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by
|
52 |
+
a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and
|
53 |
+
``"labels"``.
|
54 |
+
* :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
|
55 |
+
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.
|
56 |
+
|
57 |
+
Image classification datasets
|
58 |
+
|
59 |
+
This wrapper is a no-op for image classification datasets, since they were already fully supported by
|
60 |
+
:mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`.
|
61 |
+
|
62 |
+
Segmentation datasets
|
63 |
+
|
64 |
+
Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of
|
65 |
+
:class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
|
66 |
+
segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item).
|
67 |
+
|
68 |
+
Video classification datasets
|
69 |
+
|
70 |
+
Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple containing a
|
71 |
+
:class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
|
72 |
+
:class:`~torchvision.datapoints.Video` while leaving the other items as is.
|
73 |
+
|
74 |
+
.. note::
|
75 |
+
|
76 |
+
Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative
|
77 |
+
``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
dataset: the dataset instance to wrap for compatibility with transforms v2.
|
81 |
+
"""
|
82 |
+
return VisionDatasetDatapointWrapper(dataset)
|
83 |
+
|
84 |
+
|
85 |
+
class WrapperFactories(dict):
|
86 |
+
def register(self, dataset_cls):
|
87 |
+
def decorator(wrapper_factory):
|
88 |
+
self[dataset_cls] = wrapper_factory
|
89 |
+
return wrapper_factory
|
90 |
+
|
91 |
+
return decorator
|
92 |
+
|
93 |
+
|
94 |
+
# We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the
|
95 |
+
# dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can
|
96 |
+
# provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when
|
97 |
+
# we have access to the dataset instance.
|
98 |
+
WRAPPER_FACTORIES = WrapperFactories()
|
99 |
+
|
100 |
+
|
101 |
+
class VisionDatasetDatapointWrapper(Dataset):
|
102 |
+
def __init__(self, dataset):
|
103 |
+
dataset_cls = type(dataset)
|
104 |
+
|
105 |
+
if not isinstance(dataset, datasets.VisionDataset):
|
106 |
+
raise TypeError(
|
107 |
+
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
|
108 |
+
f"but got a '{dataset_cls.__name__}' instead."
|
109 |
+
)
|
110 |
+
|
111 |
+
for cls in dataset_cls.mro():
|
112 |
+
if cls in WRAPPER_FACTORIES:
|
113 |
+
wrapper_factory = WRAPPER_FACTORIES[cls]
|
114 |
+
break
|
115 |
+
elif cls is datasets.VisionDataset:
|
116 |
+
# TODO: If we have documentation on how to do that, put a link in the error message.
|
117 |
+
msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
|
118 |
+
if dataset_cls in datasets.__dict__.values():
|
119 |
+
msg = (
|
120 |
+
f"{msg} If an automated wrapper for this dataset would be useful for you, "
|
121 |
+
f"please open an issue at https://github.com/pytorch/vision/issues."
|
122 |
+
)
|
123 |
+
raise TypeError(msg)
|
124 |
+
|
125 |
+
self._dataset = dataset
|
126 |
+
self._wrapper = wrapper_factory(dataset)
|
127 |
+
|
128 |
+
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
|
129 |
+
# Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
|
130 |
+
# `transforms`
|
131 |
+
# https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
|
132 |
+
# some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
|
133 |
+
# disable all three here to be able to extract the untransformed sample to wrap.
|
134 |
+
self.transform, dataset.transform = dataset.transform, None
|
135 |
+
self.target_transform, dataset.target_transform = dataset.target_transform, None
|
136 |
+
self.transforms, dataset.transforms = dataset.transforms, None
|
137 |
+
|
138 |
+
def __getattr__(self, item):
|
139 |
+
with contextlib.suppress(AttributeError):
|
140 |
+
return object.__getattribute__(self, item)
|
141 |
+
|
142 |
+
return getattr(self._dataset, item)
|
143 |
+
|
144 |
+
def __getitem__(self, idx):
|
145 |
+
# This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
|
146 |
+
# of this class
|
147 |
+
sample = self._dataset[idx]
|
148 |
+
|
149 |
+
sample = self._wrapper(idx, sample)
|
150 |
+
|
151 |
+
# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
|
152 |
+
# or joint (`transforms`), we can access the full functionality through `transforms`
|
153 |
+
if self.transforms is not None:
|
154 |
+
sample = self.transforms(*sample)
|
155 |
+
|
156 |
+
return sample
|
157 |
+
|
158 |
+
def __len__(self):
|
159 |
+
return len(self._dataset)
|
160 |
+
|
161 |
+
|
162 |
+
def raise_not_supported(description):
|
163 |
+
raise RuntimeError(
|
164 |
+
f"{description} is currently not supported by this wrapper. "
|
165 |
+
f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues."
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
def identity(item):
|
170 |
+
return item
|
171 |
+
|
172 |
+
|
173 |
+
def identity_wrapper_factory(dataset):
|
174 |
+
def wrapper(idx, sample):
|
175 |
+
return sample
|
176 |
+
|
177 |
+
return wrapper
|
178 |
+
|
179 |
+
|
180 |
+
def pil_image_to_mask(pil_image):
|
181 |
+
return datapoints.Mask(pil_image)
|
182 |
+
|
183 |
+
|
184 |
+
def list_of_dicts_to_dict_of_lists(list_of_dicts):
|
185 |
+
dict_of_lists = defaultdict(list)
|
186 |
+
for dct in list_of_dicts:
|
187 |
+
for key, value in dct.items():
|
188 |
+
dict_of_lists[key].append(value)
|
189 |
+
return dict(dict_of_lists)
|
190 |
+
|
191 |
+
|
192 |
+
def wrap_target_by_type(target, *, target_types, type_wrappers):
|
193 |
+
if not isinstance(target, (tuple, list)):
|
194 |
+
target = [target]
|
195 |
+
|
196 |
+
wrapped_target = tuple(
|
197 |
+
type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target)
|
198 |
+
)
|
199 |
+
|
200 |
+
if len(wrapped_target) == 1:
|
201 |
+
wrapped_target = wrapped_target[0]
|
202 |
+
|
203 |
+
return wrapped_target
|
204 |
+
|
205 |
+
|
206 |
+
def classification_wrapper_factory(dataset):
|
207 |
+
return identity_wrapper_factory(dataset)
|
208 |
+
|
209 |
+
|
210 |
+
for dataset_cls in [
|
211 |
+
datasets.Caltech256,
|
212 |
+
datasets.CIFAR10,
|
213 |
+
datasets.CIFAR100,
|
214 |
+
datasets.ImageNet,
|
215 |
+
datasets.MNIST,
|
216 |
+
datasets.FashionMNIST,
|
217 |
+
datasets.GTSRB,
|
218 |
+
datasets.DatasetFolder,
|
219 |
+
datasets.ImageFolder,
|
220 |
+
]:
|
221 |
+
WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)
|
222 |
+
|
223 |
+
|
224 |
+
def segmentation_wrapper_factory(dataset):
|
225 |
+
def wrapper(idx, sample):
|
226 |
+
image, mask = sample
|
227 |
+
return image, pil_image_to_mask(mask)
|
228 |
+
|
229 |
+
return wrapper
|
230 |
+
|
231 |
+
|
232 |
+
for dataset_cls in [
|
233 |
+
datasets.VOCSegmentation,
|
234 |
+
]:
|
235 |
+
WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)
|
236 |
+
|
237 |
+
|
238 |
+
def video_classification_wrapper_factory(dataset):
|
239 |
+
if dataset.video_clips.output_format == "THWC":
|
240 |
+
raise RuntimeError(
|
241 |
+
f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "
|
242 |
+
f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
|
243 |
+
)
|
244 |
+
|
245 |
+
def wrapper(idx, sample):
|
246 |
+
video, audio, label = sample
|
247 |
+
|
248 |
+
video = datapoints.Video(video)
|
249 |
+
|
250 |
+
return video, audio, label
|
251 |
+
|
252 |
+
return wrapper
|
253 |
+
|
254 |
+
|
255 |
+
for dataset_cls in [
|
256 |
+
datasets.HMDB51,
|
257 |
+
datasets.Kinetics,
|
258 |
+
datasets.UCF101,
|
259 |
+
]:
|
260 |
+
WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory)
|
261 |
+
|
262 |
+
|
263 |
+
@WRAPPER_FACTORIES.register(datasets.Caltech101)
|
264 |
+
def caltech101_wrapper_factory(dataset):
|
265 |
+
if "annotation" in dataset.target_type:
|
266 |
+
raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")
|
267 |
+
|
268 |
+
return classification_wrapper_factory(dataset)
|
269 |
+
|
270 |
+
|
271 |
+
@WRAPPER_FACTORIES.register(datasets.CocoDetection)
|
272 |
+
def coco_dectection_wrapper_factory(dataset):
|
273 |
+
def segmentation_to_mask(segmentation, *, spatial_size):
|
274 |
+
from pycocotools import mask
|
275 |
+
|
276 |
+
segmentation = (
|
277 |
+
mask.frPyObjects(segmentation, *spatial_size)
|
278 |
+
if isinstance(segmentation, dict)
|
279 |
+
else mask.merge(mask.frPyObjects(segmentation, *spatial_size))
|
280 |
+
)
|
281 |
+
return torch.from_numpy(mask.decode(segmentation))
|
282 |
+
|
283 |
+
def wrapper(idx, sample):
|
284 |
+
image_id = dataset.ids[idx]
|
285 |
+
|
286 |
+
image, target = sample
|
287 |
+
|
288 |
+
if not target:
|
289 |
+
return image, dict(image_id=image_id)
|
290 |
+
|
291 |
+
batched_target = list_of_dicts_to_dict_of_lists(target)
|
292 |
+
|
293 |
+
batched_target["image_id"] = image_id
|
294 |
+
|
295 |
+
spatial_size = tuple(F.get_spatial_size(image))
|
296 |
+
batched_target["boxes"] = F.convert_format_bounding_box(
|
297 |
+
datapoints.BoundingBox(
|
298 |
+
batched_target["bbox"],
|
299 |
+
format=datapoints.BoundingBoxFormat.XYWH,
|
300 |
+
spatial_size=spatial_size,
|
301 |
+
),
|
302 |
+
new_format=datapoints.BoundingBoxFormat.XYXY,
|
303 |
+
)
|
304 |
+
batched_target["masks"] = datapoints.Mask(
|
305 |
+
torch.stack(
|
306 |
+
[
|
307 |
+
segmentation_to_mask(segmentation, spatial_size=spatial_size)
|
308 |
+
for segmentation in batched_target["segmentation"]
|
309 |
+
]
|
310 |
+
),
|
311 |
+
)
|
312 |
+
batched_target["labels"] = torch.tensor(batched_target["category_id"])
|
313 |
+
|
314 |
+
return image, batched_target
|
315 |
+
|
316 |
+
return wrapper
|
317 |
+
|
318 |
+
|
319 |
+
WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)
|
320 |
+
|
321 |
+
|
322 |
+
VOC_DETECTION_CATEGORIES = [
|
323 |
+
"__background__",
|
324 |
+
"aeroplane",
|
325 |
+
"bicycle",
|
326 |
+
"bird",
|
327 |
+
"boat",
|
328 |
+
"bottle",
|
329 |
+
"bus",
|
330 |
+
"car",
|
331 |
+
"cat",
|
332 |
+
"chair",
|
333 |
+
"cow",
|
334 |
+
"diningtable",
|
335 |
+
"dog",
|
336 |
+
"horse",
|
337 |
+
"motorbike",
|
338 |
+
"person",
|
339 |
+
"pottedplant",
|
340 |
+
"sheep",
|
341 |
+
"sofa",
|
342 |
+
"train",
|
343 |
+
"tvmonitor",
|
344 |
+
]
|
345 |
+
VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES))))
|
346 |
+
|
347 |
+
|
348 |
+
@WRAPPER_FACTORIES.register(datasets.VOCDetection)
|
349 |
+
def voc_detection_wrapper_factory(dataset):
|
350 |
+
def wrapper(idx, sample):
|
351 |
+
image, target = sample
|
352 |
+
|
353 |
+
batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
|
354 |
+
|
355 |
+
target["boxes"] = datapoints.BoundingBox(
|
356 |
+
[
|
357 |
+
[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
|
358 |
+
for bndbox in batched_instances["bndbox"]
|
359 |
+
],
|
360 |
+
format=datapoints.BoundingBoxFormat.XYXY,
|
361 |
+
spatial_size=(image.height, image.width),
|
362 |
+
)
|
363 |
+
target["labels"] = torch.tensor(
|
364 |
+
[VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]]
|
365 |
+
)
|
366 |
+
|
367 |
+
return image, target
|
368 |
+
|
369 |
+
return wrapper
|
370 |
+
|
371 |
+
|
372 |
+
@WRAPPER_FACTORIES.register(datasets.SBDataset)
|
373 |
+
def sbd_wrapper(dataset):
|
374 |
+
if dataset.mode == "boundaries":
|
375 |
+
raise_not_supported("SBDataset with mode='boundaries'")
|
376 |
+
|
377 |
+
return segmentation_wrapper_factory(dataset)
|
378 |
+
|
379 |
+
|
380 |
+
@WRAPPER_FACTORIES.register(datasets.CelebA)
|
381 |
+
def celeba_wrapper_factory(dataset):
|
382 |
+
if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
|
383 |
+
raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
|
384 |
+
|
385 |
+
def wrapper(idx, sample):
|
386 |
+
image, target = sample
|
387 |
+
|
388 |
+
target = wrap_target_by_type(
|
389 |
+
target,
|
390 |
+
target_types=dataset.target_type,
|
391 |
+
type_wrappers={
|
392 |
+
"bbox": lambda item: F.convert_format_bounding_box(
|
393 |
+
datapoints.BoundingBox(
|
394 |
+
item,
|
395 |
+
format=datapoints.BoundingBoxFormat.XYWH,
|
396 |
+
spatial_size=(image.height, image.width),
|
397 |
+
),
|
398 |
+
new_format=datapoints.BoundingBoxFormat.XYXY,
|
399 |
+
),
|
400 |
+
},
|
401 |
+
)
|
402 |
+
|
403 |
+
return image, target
|
404 |
+
|
405 |
+
return wrapper
|
406 |
+
|
407 |
+
|
408 |
+
KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"]
|
409 |
+
KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))))
|
410 |
+
|
411 |
+
|
412 |
+
@WRAPPER_FACTORIES.register(datasets.Kitti)
|
413 |
+
def kitti_wrapper_factory(dataset):
|
414 |
+
def wrapper(idx, sample):
|
415 |
+
image, target = sample
|
416 |
+
|
417 |
+
if target is not None:
|
418 |
+
target = list_of_dicts_to_dict_of_lists(target)
|
419 |
+
|
420 |
+
target["boxes"] = datapoints.BoundingBox(
|
421 |
+
target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width)
|
422 |
+
)
|
423 |
+
target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in target["type"]])
|
424 |
+
|
425 |
+
return image, target
|
426 |
+
|
427 |
+
return wrapper
|
428 |
+
|
429 |
+
|
430 |
+
@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
|
431 |
+
def oxford_iiit_pet_wrapper_factor(dataset):
|
432 |
+
def wrapper(idx, sample):
|
433 |
+
image, target = sample
|
434 |
+
|
435 |
+
if target is not None:
|
436 |
+
target = wrap_target_by_type(
|
437 |
+
target,
|
438 |
+
target_types=dataset._target_types,
|
439 |
+
type_wrappers={
|
440 |
+
"segmentation": pil_image_to_mask,
|
441 |
+
},
|
442 |
+
)
|
443 |
+
|
444 |
+
return image, target
|
445 |
+
|
446 |
+
return wrapper
|
447 |
+
|
448 |
+
|
449 |
+
@WRAPPER_FACTORIES.register(datasets.Cityscapes)
|
450 |
+
def cityscapes_wrapper_factory(dataset):
|
451 |
+
if any(target_type in dataset.target_type for target_type in ["polygon", "color"]):
|
452 |
+
raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")
|
453 |
+
|
454 |
+
def instance_segmentation_wrapper(mask):
|
455 |
+
# See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21
|
456 |
+
data = pil_image_to_mask(mask)
|
457 |
+
masks = []
|
458 |
+
labels = []
|
459 |
+
for id in data.unique():
|
460 |
+
masks.append(data == id)
|
461 |
+
label = id
|
462 |
+
if label >= 1_000:
|
463 |
+
label //= 1_000
|
464 |
+
labels.append(label)
|
465 |
+
return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))
|
466 |
+
|
467 |
+
def wrapper(idx, sample):
|
468 |
+
image, target = sample
|
469 |
+
|
470 |
+
target = wrap_target_by_type(
|
471 |
+
target,
|
472 |
+
target_types=dataset.target_type,
|
473 |
+
type_wrappers={
|
474 |
+
"instance": instance_segmentation_wrapper,
|
475 |
+
"semantic": pil_image_to_mask,
|
476 |
+
},
|
477 |
+
)
|
478 |
+
|
479 |
+
return image, target
|
480 |
+
|
481 |
+
return wrapper
|
482 |
+
|
483 |
+
|
484 |
+
@WRAPPER_FACTORIES.register(datasets.WIDERFace)
|
485 |
+
def widerface_wrapper(dataset):
|
486 |
+
def wrapper(idx, sample):
|
487 |
+
image, target = sample
|
488 |
+
|
489 |
+
if target is not None:
|
490 |
+
target["bbox"] = F.convert_format_bounding_box(
|
491 |
+
datapoints.BoundingBox(
|
492 |
+
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)
|
493 |
+
),
|
494 |
+
new_format=datapoints.BoundingBoxFormat.XYXY,
|
495 |
+
)
|
496 |
+
|
497 |
+
return image, target
|
498 |
+
|
499 |
+
return wrapper
|
torchvision/datapoints/_image.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import PIL.Image
|
6 |
+
import torch
|
7 |
+
from torchvision.transforms.functional import InterpolationMode
|
8 |
+
|
9 |
+
from ._datapoint import _FillTypeJIT, Datapoint
|
10 |
+
|
11 |
+
|
12 |
+
class Image(Datapoint):
|
13 |
+
"""[BETA] :class:`torch.Tensor` subclass for images.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
|
17 |
+
well as PIL images.
|
18 |
+
dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
|
19 |
+
``data``.
|
20 |
+
device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
|
21 |
+
:class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
|
22 |
+
requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
|
23 |
+
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
|
24 |
+
"""
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def _wrap(cls, tensor: torch.Tensor) -> Image:
|
28 |
+
image = tensor.as_subclass(cls)
|
29 |
+
return image
|
30 |
+
|
31 |
+
def __new__(
|
32 |
+
cls,
|
33 |
+
data: Any,
|
34 |
+
*,
|
35 |
+
dtype: Optional[torch.dtype] = None,
|
36 |
+
device: Optional[Union[torch.device, str, int]] = None,
|
37 |
+
requires_grad: Optional[bool] = None,
|
38 |
+
) -> Image:
|
39 |
+
if isinstance(data, PIL.Image.Image):
|
40 |
+
from torchvision.transforms.v2 import functional as F
|
41 |
+
|
42 |
+
data = F.pil_to_tensor(data)
|
43 |
+
|
44 |
+
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
|
45 |
+
if tensor.ndim < 2:
|
46 |
+
raise ValueError
|
47 |
+
elif tensor.ndim == 2:
|
48 |
+
tensor = tensor.unsqueeze(0)
|
49 |
+
|
50 |
+
return cls._wrap(tensor)
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image:
|
54 |
+
return cls._wrap(tensor)
|
55 |
+
|
56 |
+
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
|
57 |
+
return self._make_repr()
|
58 |
+
|
59 |
+
@property
|
60 |
+
def spatial_size(self) -> Tuple[int, int]:
|
61 |
+
return tuple(self.shape[-2:]) # type: ignore[return-value]
|
62 |
+
|
63 |
+
@property
|
64 |
+
def num_channels(self) -> int:
|
65 |
+
return self.shape[-3]
|
66 |
+
|
67 |
+
def horizontal_flip(self) -> Image:
|
68 |
+
output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
|
69 |
+
return Image.wrap_like(self, output)
|
70 |
+
|
71 |
+
def vertical_flip(self) -> Image:
|
72 |
+
output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
|
73 |
+
return Image.wrap_like(self, output)
|
74 |
+
|
75 |
+
def resize( # type: ignore[override]
|
76 |
+
self,
|
77 |
+
size: List[int],
|
78 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
79 |
+
max_size: Optional[int] = None,
|
80 |
+
antialias: Optional[Union[str, bool]] = "warn",
|
81 |
+
) -> Image:
|
82 |
+
output = self._F.resize_image_tensor(
|
83 |
+
self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias
|
84 |
+
)
|
85 |
+
return Image.wrap_like(self, output)
|
86 |
+
|
87 |
+
def crop(self, top: int, left: int, height: int, width: int) -> Image:
|
88 |
+
output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width)
|
89 |
+
return Image.wrap_like(self, output)
|
90 |
+
|
91 |
+
def center_crop(self, output_size: List[int]) -> Image:
|
92 |
+
output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size)
|
93 |
+
return Image.wrap_like(self, output)
|
94 |
+
|
95 |
+
def resized_crop(
|
96 |
+
self,
|
97 |
+
top: int,
|
98 |
+
left: int,
|
99 |
+
height: int,
|
100 |
+
width: int,
|
101 |
+
size: List[int],
|
102 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
103 |
+
antialias: Optional[Union[str, bool]] = "warn",
|
104 |
+
) -> Image:
|
105 |
+
output = self._F.resized_crop_image_tensor(
|
106 |
+
self.as_subclass(torch.Tensor),
|
107 |
+
top,
|
108 |
+
left,
|
109 |
+
height,
|
110 |
+
width,
|
111 |
+
size=list(size),
|
112 |
+
interpolation=interpolation,
|
113 |
+
antialias=antialias,
|
114 |
+
)
|
115 |
+
return Image.wrap_like(self, output)
|
116 |
+
|
117 |
+
def pad(
|
118 |
+
self,
|
119 |
+
padding: List[int],
|
120 |
+
fill: Optional[Union[int, float, List[float]]] = None,
|
121 |
+
padding_mode: str = "constant",
|
122 |
+
) -> Image:
|
123 |
+
output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
|
124 |
+
return Image.wrap_like(self, output)
|
125 |
+
|
126 |
+
def rotate(
|
127 |
+
self,
|
128 |
+
angle: float,
|
129 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
130 |
+
expand: bool = False,
|
131 |
+
center: Optional[List[float]] = None,
|
132 |
+
fill: _FillTypeJIT = None,
|
133 |
+
) -> Image:
|
134 |
+
output = self._F.rotate_image_tensor(
|
135 |
+
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
|
136 |
+
)
|
137 |
+
return Image.wrap_like(self, output)
|
138 |
+
|
139 |
+
def affine(
|
140 |
+
self,
|
141 |
+
angle: Union[int, float],
|
142 |
+
translate: List[float],
|
143 |
+
scale: float,
|
144 |
+
shear: List[float],
|
145 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
146 |
+
fill: _FillTypeJIT = None,
|
147 |
+
center: Optional[List[float]] = None,
|
148 |
+
) -> Image:
|
149 |
+
output = self._F.affine_image_tensor(
|
150 |
+
self.as_subclass(torch.Tensor),
|
151 |
+
angle,
|
152 |
+
translate=translate,
|
153 |
+
scale=scale,
|
154 |
+
shear=shear,
|
155 |
+
interpolation=interpolation,
|
156 |
+
fill=fill,
|
157 |
+
center=center,
|
158 |
+
)
|
159 |
+
return Image.wrap_like(self, output)
|
160 |
+
|
161 |
+
def perspective(
|
162 |
+
self,
|
163 |
+
startpoints: Optional[List[List[int]]],
|
164 |
+
endpoints: Optional[List[List[int]]],
|
165 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
166 |
+
fill: _FillTypeJIT = None,
|
167 |
+
coefficients: Optional[List[float]] = None,
|
168 |
+
) -> Image:
|
169 |
+
output = self._F.perspective_image_tensor(
|
170 |
+
self.as_subclass(torch.Tensor),
|
171 |
+
startpoints,
|
172 |
+
endpoints,
|
173 |
+
interpolation=interpolation,
|
174 |
+
fill=fill,
|
175 |
+
coefficients=coefficients,
|
176 |
+
)
|
177 |
+
return Image.wrap_like(self, output)
|
178 |
+
|
179 |
+
def elastic(
|
180 |
+
self,
|
181 |
+
displacement: torch.Tensor,
|
182 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
183 |
+
fill: _FillTypeJIT = None,
|
184 |
+
) -> Image:
|
185 |
+
output = self._F.elastic_image_tensor(
|
186 |
+
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
|
187 |
+
)
|
188 |
+
return Image.wrap_like(self, output)
|
189 |
+
|
190 |
+
def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image:
|
191 |
+
output = self._F.rgb_to_grayscale_image_tensor(
|
192 |
+
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
|
193 |
+
)
|
194 |
+
return Image.wrap_like(self, output)
|
195 |
+
|
196 |
+
def adjust_brightness(self, brightness_factor: float) -> Image:
|
197 |
+
output = self._F.adjust_brightness_image_tensor(
|
198 |
+
self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
|
199 |
+
)
|
200 |
+
return Image.wrap_like(self, output)
|
201 |
+
|
202 |
+
def adjust_saturation(self, saturation_factor: float) -> Image:
|
203 |
+
output = self._F.adjust_saturation_image_tensor(
|
204 |
+
self.as_subclass(torch.Tensor), saturation_factor=saturation_factor
|
205 |
+
)
|
206 |
+
return Image.wrap_like(self, output)
|
207 |
+
|
208 |
+
def adjust_contrast(self, contrast_factor: float) -> Image:
|
209 |
+
output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
|
210 |
+
return Image.wrap_like(self, output)
|
211 |
+
|
212 |
+
def adjust_sharpness(self, sharpness_factor: float) -> Image:
|
213 |
+
output = self._F.adjust_sharpness_image_tensor(
|
214 |
+
self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor
|
215 |
+
)
|
216 |
+
return Image.wrap_like(self, output)
|
217 |
+
|
218 |
+
def adjust_hue(self, hue_factor: float) -> Image:
|
219 |
+
output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
|
220 |
+
return Image.wrap_like(self, output)
|
221 |
+
|
222 |
+
def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
|
223 |
+
output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
|
224 |
+
return Image.wrap_like(self, output)
|
225 |
+
|
226 |
+
def posterize(self, bits: int) -> Image:
|
227 |
+
output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits)
|
228 |
+
return Image.wrap_like(self, output)
|
229 |
+
|
230 |
+
def solarize(self, threshold: float) -> Image:
|
231 |
+
output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold)
|
232 |
+
return Image.wrap_like(self, output)
|
233 |
+
|
234 |
+
def autocontrast(self) -> Image:
|
235 |
+
output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor))
|
236 |
+
return Image.wrap_like(self, output)
|
237 |
+
|
238 |
+
def equalize(self) -> Image:
|
239 |
+
output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor))
|
240 |
+
return Image.wrap_like(self, output)
|
241 |
+
|
242 |
+
def invert(self) -> Image:
|
243 |
+
output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor))
|
244 |
+
return Image.wrap_like(self, output)
|
245 |
+
|
246 |
+
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
|
247 |
+
output = self._F.gaussian_blur_image_tensor(
|
248 |
+
self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma
|
249 |
+
)
|
250 |
+
return Image.wrap_like(self, output)
|
251 |
+
|
252 |
+
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
|
253 |
+
output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
|
254 |
+
return Image.wrap_like(self, output)
|
255 |
+
|
256 |
+
|
257 |
+
_ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
|
258 |
+
_ImageTypeJIT = torch.Tensor
|
259 |
+
_TensorImageType = Union[torch.Tensor, Image]
|
260 |
+
_TensorImageTypeJIT = torch.Tensor
|
torchvision/datapoints/_mask.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import PIL.Image
|
6 |
+
import torch
|
7 |
+
from torchvision.transforms import InterpolationMode
|
8 |
+
|
9 |
+
from ._datapoint import _FillTypeJIT, Datapoint
|
10 |
+
|
11 |
+
|
12 |
+
class Mask(Datapoint):
|
13 |
+
"""[BETA] :class:`torch.Tensor` subclass for segmentation and detection masks.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
|
17 |
+
well as PIL images.
|
18 |
+
dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
|
19 |
+
``data``.
|
20 |
+
device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
|
21 |
+
:class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
|
22 |
+
requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
|
23 |
+
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
|
24 |
+
"""
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def _wrap(cls, tensor: torch.Tensor) -> Mask:
|
28 |
+
return tensor.as_subclass(cls)
|
29 |
+
|
30 |
+
def __new__(
|
31 |
+
cls,
|
32 |
+
data: Any,
|
33 |
+
*,
|
34 |
+
dtype: Optional[torch.dtype] = None,
|
35 |
+
device: Optional[Union[torch.device, str, int]] = None,
|
36 |
+
requires_grad: Optional[bool] = None,
|
37 |
+
) -> Mask:
|
38 |
+
if isinstance(data, PIL.Image.Image):
|
39 |
+
from torchvision.transforms.v2 import functional as F
|
40 |
+
|
41 |
+
data = F.pil_to_tensor(data)
|
42 |
+
|
43 |
+
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
|
44 |
+
return cls._wrap(tensor)
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def wrap_like(
|
48 |
+
cls,
|
49 |
+
other: Mask,
|
50 |
+
tensor: torch.Tensor,
|
51 |
+
) -> Mask:
|
52 |
+
return cls._wrap(tensor)
|
53 |
+
|
54 |
+
@property
|
55 |
+
def spatial_size(self) -> Tuple[int, int]:
|
56 |
+
return tuple(self.shape[-2:]) # type: ignore[return-value]
|
57 |
+
|
58 |
+
def horizontal_flip(self) -> Mask:
|
59 |
+
output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
|
60 |
+
return Mask.wrap_like(self, output)
|
61 |
+
|
62 |
+
def vertical_flip(self) -> Mask:
|
63 |
+
output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor))
|
64 |
+
return Mask.wrap_like(self, output)
|
65 |
+
|
66 |
+
def resize( # type: ignore[override]
|
67 |
+
self,
|
68 |
+
size: List[int],
|
69 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
70 |
+
max_size: Optional[int] = None,
|
71 |
+
antialias: Optional[Union[str, bool]] = "warn",
|
72 |
+
) -> Mask:
|
73 |
+
output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size)
|
74 |
+
return Mask.wrap_like(self, output)
|
75 |
+
|
76 |
+
def crop(self, top: int, left: int, height: int, width: int) -> Mask:
|
77 |
+
output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width)
|
78 |
+
return Mask.wrap_like(self, output)
|
79 |
+
|
80 |
+
def center_crop(self, output_size: List[int]) -> Mask:
|
81 |
+
output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size)
|
82 |
+
return Mask.wrap_like(self, output)
|
83 |
+
|
84 |
+
def resized_crop(
|
85 |
+
self,
|
86 |
+
top: int,
|
87 |
+
left: int,
|
88 |
+
height: int,
|
89 |
+
width: int,
|
90 |
+
size: List[int],
|
91 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
92 |
+
antialias: Optional[Union[str, bool]] = "warn",
|
93 |
+
) -> Mask:
|
94 |
+
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
|
95 |
+
return Mask.wrap_like(self, output)
|
96 |
+
|
97 |
+
def pad(
|
98 |
+
self,
|
99 |
+
padding: List[int],
|
100 |
+
fill: Optional[Union[int, float, List[float]]] = None,
|
101 |
+
padding_mode: str = "constant",
|
102 |
+
) -> Mask:
|
103 |
+
output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
|
104 |
+
return Mask.wrap_like(self, output)
|
105 |
+
|
106 |
+
def rotate(
|
107 |
+
self,
|
108 |
+
angle: float,
|
109 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
110 |
+
expand: bool = False,
|
111 |
+
center: Optional[List[float]] = None,
|
112 |
+
fill: _FillTypeJIT = None,
|
113 |
+
) -> Mask:
|
114 |
+
output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
|
115 |
+
return Mask.wrap_like(self, output)
|
116 |
+
|
117 |
+
def affine(
|
118 |
+
self,
|
119 |
+
angle: Union[int, float],
|
120 |
+
translate: List[float],
|
121 |
+
scale: float,
|
122 |
+
shear: List[float],
|
123 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
124 |
+
fill: _FillTypeJIT = None,
|
125 |
+
center: Optional[List[float]] = None,
|
126 |
+
) -> Mask:
|
127 |
+
output = self._F.affine_mask(
|
128 |
+
self.as_subclass(torch.Tensor),
|
129 |
+
angle,
|
130 |
+
translate=translate,
|
131 |
+
scale=scale,
|
132 |
+
shear=shear,
|
133 |
+
fill=fill,
|
134 |
+
center=center,
|
135 |
+
)
|
136 |
+
return Mask.wrap_like(self, output)
|
137 |
+
|
138 |
+
def perspective(
|
139 |
+
self,
|
140 |
+
startpoints: Optional[List[List[int]]],
|
141 |
+
endpoints: Optional[List[List[int]]],
|
142 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
143 |
+
fill: _FillTypeJIT = None,
|
144 |
+
coefficients: Optional[List[float]] = None,
|
145 |
+
) -> Mask:
|
146 |
+
output = self._F.perspective_mask(
|
147 |
+
self.as_subclass(torch.Tensor), startpoints, endpoints, fill=fill, coefficients=coefficients
|
148 |
+
)
|
149 |
+
return Mask.wrap_like(self, output)
|
150 |
+
|
151 |
+
def elastic(
|
152 |
+
self,
|
153 |
+
displacement: torch.Tensor,
|
154 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
155 |
+
fill: _FillTypeJIT = None,
|
156 |
+
) -> Mask:
|
157 |
+
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
|
158 |
+
return Mask.wrap_like(self, output)
|
torchvision/datapoints/_video.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torchvision.transforms.functional import InterpolationMode
|
7 |
+
|
8 |
+
from ._datapoint import _FillTypeJIT, Datapoint
|
9 |
+
|
10 |
+
|
11 |
+
class Video(Datapoint):
|
12 |
+
"""[BETA] :class:`torch.Tensor` subclass for videos.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
data (tensor-like): Any data that can be turned into a tensor with :func:`torch.as_tensor`.
|
16 |
+
dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
|
17 |
+
``data``.
|
18 |
+
device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
|
19 |
+
:class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
|
20 |
+
requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
|
21 |
+
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
|
22 |
+
"""
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def _wrap(cls, tensor: torch.Tensor) -> Video:
|
26 |
+
video = tensor.as_subclass(cls)
|
27 |
+
return video
|
28 |
+
|
29 |
+
def __new__(
|
30 |
+
cls,
|
31 |
+
data: Any,
|
32 |
+
*,
|
33 |
+
dtype: Optional[torch.dtype] = None,
|
34 |
+
device: Optional[Union[torch.device, str, int]] = None,
|
35 |
+
requires_grad: Optional[bool] = None,
|
36 |
+
) -> Video:
|
37 |
+
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
|
38 |
+
if data.ndim < 4:
|
39 |
+
raise ValueError
|
40 |
+
return cls._wrap(tensor)
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video:
|
44 |
+
return cls._wrap(tensor)
|
45 |
+
|
46 |
+
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
|
47 |
+
return self._make_repr()
|
48 |
+
|
49 |
+
@property
|
50 |
+
def spatial_size(self) -> Tuple[int, int]:
|
51 |
+
return tuple(self.shape[-2:]) # type: ignore[return-value]
|
52 |
+
|
53 |
+
@property
|
54 |
+
def num_channels(self) -> int:
|
55 |
+
return self.shape[-3]
|
56 |
+
|
57 |
+
@property
|
58 |
+
def num_frames(self) -> int:
|
59 |
+
return self.shape[-4]
|
60 |
+
|
61 |
+
def horizontal_flip(self) -> Video:
|
62 |
+
output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor))
|
63 |
+
return Video.wrap_like(self, output)
|
64 |
+
|
65 |
+
def vertical_flip(self) -> Video:
|
66 |
+
output = self._F.vertical_flip_video(self.as_subclass(torch.Tensor))
|
67 |
+
return Video.wrap_like(self, output)
|
68 |
+
|
69 |
+
def resize( # type: ignore[override]
|
70 |
+
self,
|
71 |
+
size: List[int],
|
72 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
73 |
+
max_size: Optional[int] = None,
|
74 |
+
antialias: Optional[Union[str, bool]] = "warn",
|
75 |
+
) -> Video:
|
76 |
+
output = self._F.resize_video(
|
77 |
+
self.as_subclass(torch.Tensor),
|
78 |
+
size,
|
79 |
+
interpolation=interpolation,
|
80 |
+
max_size=max_size,
|
81 |
+
antialias=antialias,
|
82 |
+
)
|
83 |
+
return Video.wrap_like(self, output)
|
84 |
+
|
85 |
+
def crop(self, top: int, left: int, height: int, width: int) -> Video:
|
86 |
+
output = self._F.crop_video(self.as_subclass(torch.Tensor), top, left, height, width)
|
87 |
+
return Video.wrap_like(self, output)
|
88 |
+
|
89 |
+
def center_crop(self, output_size: List[int]) -> Video:
|
90 |
+
output = self._F.center_crop_video(self.as_subclass(torch.Tensor), output_size=output_size)
|
91 |
+
return Video.wrap_like(self, output)
|
92 |
+
|
93 |
+
def resized_crop(
|
94 |
+
self,
|
95 |
+
top: int,
|
96 |
+
left: int,
|
97 |
+
height: int,
|
98 |
+
width: int,
|
99 |
+
size: List[int],
|
100 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
101 |
+
antialias: Optional[Union[str, bool]] = "warn",
|
102 |
+
) -> Video:
|
103 |
+
output = self._F.resized_crop_video(
|
104 |
+
self.as_subclass(torch.Tensor),
|
105 |
+
top,
|
106 |
+
left,
|
107 |
+
height,
|
108 |
+
width,
|
109 |
+
size=list(size),
|
110 |
+
interpolation=interpolation,
|
111 |
+
antialias=antialias,
|
112 |
+
)
|
113 |
+
return Video.wrap_like(self, output)
|
114 |
+
|
115 |
+
def pad(
|
116 |
+
self,
|
117 |
+
padding: List[int],
|
118 |
+
fill: Optional[Union[int, float, List[float]]] = None,
|
119 |
+
padding_mode: str = "constant",
|
120 |
+
) -> Video:
|
121 |
+
output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
|
122 |
+
return Video.wrap_like(self, output)
|
123 |
+
|
124 |
+
def rotate(
|
125 |
+
self,
|
126 |
+
angle: float,
|
127 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
128 |
+
expand: bool = False,
|
129 |
+
center: Optional[List[float]] = None,
|
130 |
+
fill: _FillTypeJIT = None,
|
131 |
+
) -> Video:
|
132 |
+
output = self._F.rotate_video(
|
133 |
+
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
|
134 |
+
)
|
135 |
+
return Video.wrap_like(self, output)
|
136 |
+
|
137 |
+
def affine(
|
138 |
+
self,
|
139 |
+
angle: Union[int, float],
|
140 |
+
translate: List[float],
|
141 |
+
scale: float,
|
142 |
+
shear: List[float],
|
143 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
144 |
+
fill: _FillTypeJIT = None,
|
145 |
+
center: Optional[List[float]] = None,
|
146 |
+
) -> Video:
|
147 |
+
output = self._F.affine_video(
|
148 |
+
self.as_subclass(torch.Tensor),
|
149 |
+
angle,
|
150 |
+
translate=translate,
|
151 |
+
scale=scale,
|
152 |
+
shear=shear,
|
153 |
+
interpolation=interpolation,
|
154 |
+
fill=fill,
|
155 |
+
center=center,
|
156 |
+
)
|
157 |
+
return Video.wrap_like(self, output)
|
158 |
+
|
159 |
+
def perspective(
|
160 |
+
self,
|
161 |
+
startpoints: Optional[List[List[int]]],
|
162 |
+
endpoints: Optional[List[List[int]]],
|
163 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
164 |
+
fill: _FillTypeJIT = None,
|
165 |
+
coefficients: Optional[List[float]] = None,
|
166 |
+
) -> Video:
|
167 |
+
output = self._F.perspective_video(
|
168 |
+
self.as_subclass(torch.Tensor),
|
169 |
+
startpoints,
|
170 |
+
endpoints,
|
171 |
+
interpolation=interpolation,
|
172 |
+
fill=fill,
|
173 |
+
coefficients=coefficients,
|
174 |
+
)
|
175 |
+
return Video.wrap_like(self, output)
|
176 |
+
|
177 |
+
def elastic(
|
178 |
+
self,
|
179 |
+
displacement: torch.Tensor,
|
180 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
181 |
+
fill: _FillTypeJIT = None,
|
182 |
+
) -> Video:
|
183 |
+
output = self._F.elastic_video(
|
184 |
+
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
|
185 |
+
)
|
186 |
+
return Video.wrap_like(self, output)
|
187 |
+
|
188 |
+
def rgb_to_grayscale(self, num_output_channels: int = 1) -> Video:
|
189 |
+
output = self._F.rgb_to_grayscale_image_tensor(
|
190 |
+
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
|
191 |
+
)
|
192 |
+
return Video.wrap_like(self, output)
|
193 |
+
|
194 |
+
def adjust_brightness(self, brightness_factor: float) -> Video:
|
195 |
+
output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor)
|
196 |
+
return Video.wrap_like(self, output)
|
197 |
+
|
198 |
+
def adjust_saturation(self, saturation_factor: float) -> Video:
|
199 |
+
output = self._F.adjust_saturation_video(self.as_subclass(torch.Tensor), saturation_factor=saturation_factor)
|
200 |
+
return Video.wrap_like(self, output)
|
201 |
+
|
202 |
+
def adjust_contrast(self, contrast_factor: float) -> Video:
|
203 |
+
output = self._F.adjust_contrast_video(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
|
204 |
+
return Video.wrap_like(self, output)
|
205 |
+
|
206 |
+
def adjust_sharpness(self, sharpness_factor: float) -> Video:
|
207 |
+
output = self._F.adjust_sharpness_video(self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor)
|
208 |
+
return Video.wrap_like(self, output)
|
209 |
+
|
210 |
+
def adjust_hue(self, hue_factor: float) -> Video:
|
211 |
+
output = self._F.adjust_hue_video(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
|
212 |
+
return Video.wrap_like(self, output)
|
213 |
+
|
214 |
+
def adjust_gamma(self, gamma: float, gain: float = 1) -> Video:
|
215 |
+
output = self._F.adjust_gamma_video(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
|
216 |
+
return Video.wrap_like(self, output)
|
217 |
+
|
218 |
+
def posterize(self, bits: int) -> Video:
|
219 |
+
output = self._F.posterize_video(self.as_subclass(torch.Tensor), bits=bits)
|
220 |
+
return Video.wrap_like(self, output)
|
221 |
+
|
222 |
+
def solarize(self, threshold: float) -> Video:
|
223 |
+
output = self._F.solarize_video(self.as_subclass(torch.Tensor), threshold=threshold)
|
224 |
+
return Video.wrap_like(self, output)
|
225 |
+
|
226 |
+
def autocontrast(self) -> Video:
|
227 |
+
output = self._F.autocontrast_video(self.as_subclass(torch.Tensor))
|
228 |
+
return Video.wrap_like(self, output)
|
229 |
+
|
230 |
+
def equalize(self) -> Video:
|
231 |
+
output = self._F.equalize_video(self.as_subclass(torch.Tensor))
|
232 |
+
return Video.wrap_like(self, output)
|
233 |
+
|
234 |
+
def invert(self) -> Video:
|
235 |
+
output = self._F.invert_video(self.as_subclass(torch.Tensor))
|
236 |
+
return Video.wrap_like(self, output)
|
237 |
+
|
238 |
+
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video:
|
239 |
+
output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma)
|
240 |
+
return Video.wrap_like(self, output)
|
241 |
+
|
242 |
+
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video:
|
243 |
+
output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
|
244 |
+
return Video.wrap_like(self, output)
|
245 |
+
|
246 |
+
|
247 |
+
_VideoType = Union[torch.Tensor, Video]
|
248 |
+
_VideoTypeJIT = torch.Tensor
|
249 |
+
_TensorVideoType = Union[torch.Tensor, Video]
|
250 |
+
_TensorVideoTypeJIT = torch.Tensor
|
torchvision/datasets/__init__.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
|
2 |
+
from ._stereo_matching import (
|
3 |
+
CarlaStereo,
|
4 |
+
CREStereo,
|
5 |
+
ETH3DStereo,
|
6 |
+
FallingThingsStereo,
|
7 |
+
InStereo2k,
|
8 |
+
Kitti2012Stereo,
|
9 |
+
Kitti2015Stereo,
|
10 |
+
Middlebury2014Stereo,
|
11 |
+
SceneFlowStereo,
|
12 |
+
SintelStereo,
|
13 |
+
)
|
14 |
+
from .caltech import Caltech101, Caltech256
|
15 |
+
from .celeba import CelebA
|
16 |
+
from .cifar import CIFAR10, CIFAR100
|
17 |
+
from .cityscapes import Cityscapes
|
18 |
+
from .clevr import CLEVRClassification
|
19 |
+
from .coco import CocoCaptions, CocoDetection
|
20 |
+
from .country211 import Country211
|
21 |
+
from .dtd import DTD
|
22 |
+
from .eurosat import EuroSAT
|
23 |
+
from .fakedata import FakeData
|
24 |
+
from .fer2013 import FER2013
|
25 |
+
from .fgvc_aircraft import FGVCAircraft
|
26 |
+
from .flickr import Flickr30k, Flickr8k
|
27 |
+
from .flowers102 import Flowers102
|
28 |
+
from .folder import DatasetFolder, ImageFolder
|
29 |
+
from .food101 import Food101
|
30 |
+
from .gtsrb import GTSRB
|
31 |
+
from .hmdb51 import HMDB51
|
32 |
+
from .imagenet import ImageNet
|
33 |
+
from .inaturalist import INaturalist
|
34 |
+
from .kinetics import Kinetics
|
35 |
+
from .kitti import Kitti
|
36 |
+
from .lfw import LFWPairs, LFWPeople
|
37 |
+
from .lsun import LSUN, LSUNClass
|
38 |
+
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
|
39 |
+
from .moving_mnist import MovingMNIST
|
40 |
+
from .omniglot import Omniglot
|
41 |
+
from .oxford_iiit_pet import OxfordIIITPet
|
42 |
+
from .pcam import PCAM
|
43 |
+
from .phototour import PhotoTour
|
44 |
+
from .places365 import Places365
|
45 |
+
from .rendered_sst2 import RenderedSST2
|
46 |
+
from .sbd import SBDataset
|
47 |
+
from .sbu import SBU
|
48 |
+
from .semeion import SEMEION
|
49 |
+
from .stanford_cars import StanfordCars
|
50 |
+
from .stl10 import STL10
|
51 |
+
from .sun397 import SUN397
|
52 |
+
from .svhn import SVHN
|
53 |
+
from .ucf101 import UCF101
|
54 |
+
from .usps import USPS
|
55 |
+
from .vision import VisionDataset
|
56 |
+
from .voc import VOCDetection, VOCSegmentation
|
57 |
+
from .widerface import WIDERFace
|
58 |
+
|
59 |
+
__all__ = (
|
60 |
+
"LSUN",
|
61 |
+
"LSUNClass",
|
62 |
+
"ImageFolder",
|
63 |
+
"DatasetFolder",
|
64 |
+
"FakeData",
|
65 |
+
"CocoCaptions",
|
66 |
+
"CocoDetection",
|
67 |
+
"CIFAR10",
|
68 |
+
"CIFAR100",
|
69 |
+
"EMNIST",
|
70 |
+
"FashionMNIST",
|
71 |
+
"QMNIST",
|
72 |
+
"MNIST",
|
73 |
+
"KMNIST",
|
74 |
+
"StanfordCars",
|
75 |
+
"STL10",
|
76 |
+
"SUN397",
|
77 |
+
"SVHN",
|
78 |
+
"PhotoTour",
|
79 |
+
"SEMEION",
|
80 |
+
"Omniglot",
|
81 |
+
"SBU",
|
82 |
+
"Flickr8k",
|
83 |
+
"Flickr30k",
|
84 |
+
"Flowers102",
|
85 |
+
"VOCSegmentation",
|
86 |
+
"VOCDetection",
|
87 |
+
"Cityscapes",
|
88 |
+
"ImageNet",
|
89 |
+
"Caltech101",
|
90 |
+
"Caltech256",
|
91 |
+
"CelebA",
|
92 |
+
"WIDERFace",
|
93 |
+
"SBDataset",
|
94 |
+
"VisionDataset",
|
95 |
+
"USPS",
|
96 |
+
"Kinetics",
|
97 |
+
"HMDB51",
|
98 |
+
"UCF101",
|
99 |
+
"Places365",
|
100 |
+
"Kitti",
|
101 |
+
"INaturalist",
|
102 |
+
"LFWPeople",
|
103 |
+
"LFWPairs",
|
104 |
+
"KittiFlow",
|
105 |
+
"Sintel",
|
106 |
+
"FlyingChairs",
|
107 |
+
"FlyingThings3D",
|
108 |
+
"HD1K",
|
109 |
+
"Food101",
|
110 |
+
"DTD",
|
111 |
+
"FER2013",
|
112 |
+
"GTSRB",
|
113 |
+
"CLEVRClassification",
|
114 |
+
"OxfordIIITPet",
|
115 |
+
"PCAM",
|
116 |
+
"Country211",
|
117 |
+
"FGVCAircraft",
|
118 |
+
"EuroSAT",
|
119 |
+
"RenderedSST2",
|
120 |
+
"Kitti2012Stereo",
|
121 |
+
"Kitti2015Stereo",
|
122 |
+
"CarlaStereo",
|
123 |
+
"Middlebury2014Stereo",
|
124 |
+
"CREStereo",
|
125 |
+
"FallingThingsStereo",
|
126 |
+
"SceneFlowStereo",
|
127 |
+
"SintelStereo",
|
128 |
+
"InStereo2k",
|
129 |
+
"ETH3DStereo",
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
# We override current module's attributes to handle the import:
|
134 |
+
# from torchvision.datasets import wrap_dataset_for_transforms_v2
|
135 |
+
# with beta state v2 warning from torchvision.datapoints
|
136 |
+
# We also want to avoid raising the warning when importing other attributes
|
137 |
+
# from torchvision.datasets
|
138 |
+
# Ref: https://peps.python.org/pep-0562/
|
139 |
+
def __getattr__(name):
|
140 |
+
if name in ("wrap_dataset_for_transforms_v2",):
|
141 |
+
from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2
|
142 |
+
|
143 |
+
return wrap_dataset_for_transforms_v2
|
144 |
+
|
145 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
torchvision/datasets/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (3.28 kB). View file
|
|
torchvision/datasets/__pycache__/_optical_flow.cpython-38.pyc
ADDED
Binary file (17.6 kB). View file
|
|
torchvision/datasets/__pycache__/_stereo_matching.cpython-38.pyc
ADDED
Binary file (40.2 kB). View file
|
|
torchvision/datasets/__pycache__/caltech.cpython-38.pyc
ADDED
Binary file (8.1 kB). View file
|
|
torchvision/datasets/__pycache__/celeba.cpython-38.pyc
ADDED
Binary file (7.14 kB). View file
|
|
torchvision/datasets/__pycache__/cifar.cpython-38.pyc
ADDED
Binary file (5.81 kB). View file
|
|
torchvision/datasets/__pycache__/cityscapes.cpython-38.pyc
ADDED
Binary file (8.41 kB). View file
|
|
torchvision/datasets/__pycache__/clevr.cpython-38.pyc
ADDED
Binary file (4.11 kB). View file
|
|
torchvision/datasets/__pycache__/coco.cpython-38.pyc
ADDED
Binary file (5.02 kB). View file
|
|
torchvision/datasets/__pycache__/country211.cpython-38.pyc
ADDED
Binary file (2.79 kB). View file
|
|
torchvision/datasets/__pycache__/dtd.cpython-38.pyc
ADDED
Binary file (4.33 kB). View file
|
|
torchvision/datasets/__pycache__/eurosat.cpython-38.pyc
ADDED
Binary file (2.46 kB). View file
|
|
torchvision/datasets/__pycache__/fakedata.cpython-38.pyc
ADDED
Binary file (2.67 kB). View file
|
|
torchvision/datasets/__pycache__/fer2013.cpython-38.pyc
ADDED
Binary file (3.3 kB). View file
|
|
torchvision/datasets/__pycache__/fgvc_aircraft.cpython-38.pyc
ADDED
Binary file (4.72 kB). View file
|
|
torchvision/datasets/__pycache__/flickr.cpython-38.pyc
ADDED
Binary file (5.27 kB). View file
|
|
torchvision/datasets/__pycache__/flowers102.cpython-38.pyc
ADDED
Binary file (4.6 kB). View file
|
|
torchvision/datasets/__pycache__/folder.cpython-38.pyc
ADDED
Binary file (11.5 kB). View file
|
|
torchvision/datasets/__pycache__/food101.cpython-38.pyc
ADDED
Binary file (4.35 kB). View file
|
|
torchvision/datasets/__pycache__/gtsrb.cpython-38.pyc
ADDED
Binary file (3.74 kB). View file
|
|
torchvision/datasets/__pycache__/hmdb51.cpython-38.pyc
ADDED
Binary file (5.64 kB). View file
|
|
torchvision/datasets/__pycache__/imagenet.cpython-38.pyc
ADDED
Binary file (9.69 kB). View file
|
|
torchvision/datasets/__pycache__/inaturalist.cpython-38.pyc
ADDED
Binary file (8.59 kB). View file
|
|
torchvision/datasets/__pycache__/kinetics.cpython-38.pyc
ADDED
Binary file (9.66 kB). View file
|
|