NewLabs commited on
Commit
a0929d8
1 Parent(s): 1b01078

Upload 372 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. torchvision/_C.so +3 -0
  3. torchvision/__init__.py +114 -0
  4. torchvision/__pycache__/__init__.cpython-38.pyc +0 -0
  5. torchvision/__pycache__/_internally_replaced_utils.cpython-38.pyc +0 -0
  6. torchvision/__pycache__/_utils.cpython-38.pyc +0 -0
  7. torchvision/__pycache__/extension.cpython-38.pyc +0 -0
  8. torchvision/__pycache__/utils.cpython-38.pyc +0 -0
  9. torchvision/__pycache__/version.cpython-38.pyc +0 -0
  10. torchvision/_internally_replaced_utils.py +58 -0
  11. torchvision/_utils.py +32 -0
  12. torchvision/datapoints/__init__.py +12 -0
  13. torchvision/datapoints/__pycache__/__init__.cpython-38.pyc +0 -0
  14. torchvision/datapoints/__pycache__/_bounding_box.cpython-38.pyc +0 -0
  15. torchvision/datapoints/__pycache__/_datapoint.cpython-38.pyc +0 -0
  16. torchvision/datapoints/__pycache__/_dataset_wrapper.cpython-38.pyc +0 -0
  17. torchvision/datapoints/__pycache__/_image.cpython-38.pyc +0 -0
  18. torchvision/datapoints/__pycache__/_mask.cpython-38.pyc +0 -0
  19. torchvision/datapoints/__pycache__/_video.cpython-38.pyc +0 -0
  20. torchvision/datapoints/_bounding_box.py +237 -0
  21. torchvision/datapoints/_datapoint.py +259 -0
  22. torchvision/datapoints/_dataset_wrapper.py +499 -0
  23. torchvision/datapoints/_image.py +260 -0
  24. torchvision/datapoints/_mask.py +158 -0
  25. torchvision/datapoints/_video.py +250 -0
  26. torchvision/datasets/__init__.py +145 -0
  27. torchvision/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  28. torchvision/datasets/__pycache__/_optical_flow.cpython-38.pyc +0 -0
  29. torchvision/datasets/__pycache__/_stereo_matching.cpython-38.pyc +0 -0
  30. torchvision/datasets/__pycache__/caltech.cpython-38.pyc +0 -0
  31. torchvision/datasets/__pycache__/celeba.cpython-38.pyc +0 -0
  32. torchvision/datasets/__pycache__/cifar.cpython-38.pyc +0 -0
  33. torchvision/datasets/__pycache__/cityscapes.cpython-38.pyc +0 -0
  34. torchvision/datasets/__pycache__/clevr.cpython-38.pyc +0 -0
  35. torchvision/datasets/__pycache__/coco.cpython-38.pyc +0 -0
  36. torchvision/datasets/__pycache__/country211.cpython-38.pyc +0 -0
  37. torchvision/datasets/__pycache__/dtd.cpython-38.pyc +0 -0
  38. torchvision/datasets/__pycache__/eurosat.cpython-38.pyc +0 -0
  39. torchvision/datasets/__pycache__/fakedata.cpython-38.pyc +0 -0
  40. torchvision/datasets/__pycache__/fer2013.cpython-38.pyc +0 -0
  41. torchvision/datasets/__pycache__/fgvc_aircraft.cpython-38.pyc +0 -0
  42. torchvision/datasets/__pycache__/flickr.cpython-38.pyc +0 -0
  43. torchvision/datasets/__pycache__/flowers102.cpython-38.pyc +0 -0
  44. torchvision/datasets/__pycache__/folder.cpython-38.pyc +0 -0
  45. torchvision/datasets/__pycache__/food101.cpython-38.pyc +0 -0
  46. torchvision/datasets/__pycache__/gtsrb.cpython-38.pyc +0 -0
  47. torchvision/datasets/__pycache__/hmdb51.cpython-38.pyc +0 -0
  48. torchvision/datasets/__pycache__/imagenet.cpython-38.pyc +0 -0
  49. torchvision/datasets/__pycache__/inaturalist.cpython-38.pyc +0 -0
  50. torchvision/datasets/__pycache__/kinetics.cpython-38.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ torchvision/_C.so filter=lfs diff=lfs merge=lfs -text
37
+ torchvision/image.so filter=lfs diff=lfs merge=lfs -text
torchvision/_C.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f60f9e860992a1c25bbad7efc84b2036c82da5f23c976995dbdefeab6ff5d73
3
+ size 70104344
torchvision/__init__.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from modulefinder import Module
4
+
5
+ import torch
6
+ from torchvision import datasets, io, models, ops, transforms, utils
7
+
8
+ from .extension import _HAS_OPS
9
+
10
+ try:
11
+ from .version import __version__ # noqa: F401
12
+ except ImportError:
13
+ pass
14
+
15
+
16
+ # Check if torchvision is being imported within the root folder
17
+ if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
18
+ os.path.realpath(os.getcwd()), "torchvision"
19
+ ):
20
+ message = (
21
+ "You are importing torchvision within its own root folder ({}). "
22
+ "This is not expected to work and may give errors. Please exit the "
23
+ "torchvision project source and relaunch your python interpreter."
24
+ )
25
+ warnings.warn(message.format(os.getcwd()))
26
+
27
+ _image_backend = "PIL"
28
+
29
+ _video_backend = "pyav"
30
+
31
+
32
+ def set_image_backend(backend):
33
+ """
34
+ Specifies the package used to load images.
35
+
36
+ Args:
37
+ backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
38
+ The :mod:`accimage` package uses the Intel IPP library. It is
39
+ generally faster than PIL, but does not support as many operations.
40
+ """
41
+ global _image_backend
42
+ if backend not in ["PIL", "accimage"]:
43
+ raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
44
+ _image_backend = backend
45
+
46
+
47
+ def get_image_backend():
48
+ """
49
+ Gets the name of the package used to load images
50
+ """
51
+ return _image_backend
52
+
53
+
54
+ def set_video_backend(backend):
55
+ """
56
+ Specifies the package used to decode videos.
57
+
58
+ Args:
59
+ backend (string): Name of the video backend. one of {'pyav', 'video_reader'}.
60
+ The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
61
+ binding for the FFmpeg libraries.
62
+ The :mod:`video_reader` package includes a native C++ implementation on
63
+ top of FFMPEG libraries, and a python API of TorchScript custom operator.
64
+ It generally decodes faster than :mod:`pyav`, but is perhaps less robust.
65
+
66
+ .. note::
67
+ Building with FFMPEG is disabled by default in the latest `main`. If you want to use the 'video_reader'
68
+ backend, please compile torchvision from source.
69
+ """
70
+ global _video_backend
71
+ if backend not in ["pyav", "video_reader", "cuda"]:
72
+ raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
73
+ if backend == "video_reader" and not io._HAS_VIDEO_OPT:
74
+ # TODO: better messages
75
+ message = "video_reader video backend is not available. Please compile torchvision from source and try again"
76
+ raise RuntimeError(message)
77
+ elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
78
+ # TODO: better messages
79
+ message = "cuda video backend is not available."
80
+ raise RuntimeError(message)
81
+ else:
82
+ _video_backend = backend
83
+
84
+
85
+ def get_video_backend():
86
+ """
87
+ Returns the currently active video backend used to decode videos.
88
+
89
+ Returns:
90
+ str: Name of the video backend. one of {'pyav', 'video_reader'}.
91
+ """
92
+
93
+ return _video_backend
94
+
95
+
96
+ def _is_tracing():
97
+ return torch._C._get_tracing_state()
98
+
99
+
100
+ _WARN_ABOUT_BETA_TRANSFORMS = True
101
+ _BETA_TRANSFORMS_WARNING = (
102
+ "The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. "
103
+ "While we do not expect major breaking changes, some APIs may still change "
104
+ "according to user feedback. Please submit any feedback you may have in "
105
+ "this issue: https://github.com/pytorch/vision/issues/6753, and you can also "
106
+ "check out https://github.com/pytorch/vision/issues/7319 to learn more about "
107
+ "the APIs that we suspect might involve future changes. "
108
+ "You can silence this warning by calling torchvision.disable_beta_transforms_warning()."
109
+ )
110
+
111
+
112
+ def disable_beta_transforms_warning():
113
+ global _WARN_ABOUT_BETA_TRANSFORMS
114
+ _WARN_ABOUT_BETA_TRANSFORMS = False
torchvision/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (4.07 kB). View file
 
torchvision/__pycache__/_internally_replaced_utils.cpython-38.pyc ADDED
Binary file (1.76 kB). View file
 
torchvision/__pycache__/_utils.cpython-38.pyc ADDED
Binary file (1.46 kB). View file
 
torchvision/__pycache__/extension.cpython-38.pyc ADDED
Binary file (2.93 kB). View file
 
torchvision/__pycache__/utils.cpython-38.pyc ADDED
Binary file (18.8 kB). View file
 
torchvision/__pycache__/version.cpython-38.pyc ADDED
Binary file (367 Bytes). View file
 
torchvision/_internally_replaced_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.machinery
2
+ import os
3
+
4
+ from torch.hub import _get_torch_home
5
+
6
+
7
+ _HOME = os.path.join(_get_torch_home(), "datasets", "vision")
8
+ _USE_SHARDED_DATASETS = False
9
+
10
+
11
+ def _download_file_from_remote_location(fpath: str, url: str) -> None:
12
+ pass
13
+
14
+
15
+ def _is_remote_location_available() -> bool:
16
+ return False
17
+
18
+
19
+ try:
20
+ from torch.hub import load_state_dict_from_url # noqa: 401
21
+ except ImportError:
22
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
23
+
24
+
25
+ def _get_extension_path(lib_name):
26
+
27
+ lib_dir = os.path.dirname(__file__)
28
+ if os.name == "nt":
29
+ # Register the main torchvision library location on the default DLL path
30
+ import ctypes
31
+ import sys
32
+
33
+ kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
34
+ with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
35
+ prev_error_mode = kernel32.SetErrorMode(0x0001)
36
+
37
+ if with_load_library_flags:
38
+ kernel32.AddDllDirectory.restype = ctypes.c_void_p
39
+
40
+ if sys.version_info >= (3, 8):
41
+ os.add_dll_directory(lib_dir)
42
+ elif with_load_library_flags:
43
+ res = kernel32.AddDllDirectory(lib_dir)
44
+ if res is None:
45
+ err = ctypes.WinError(ctypes.get_last_error())
46
+ err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
47
+ raise err
48
+
49
+ kernel32.SetErrorMode(prev_error_mode)
50
+
51
+ loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
52
+
53
+ extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
54
+ ext_specs = extfinder.find_spec(lib_name)
55
+ if ext_specs is None:
56
+ raise ImportError
57
+
58
+ return ext_specs.origin
torchvision/_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from typing import Sequence, Type, TypeVar
3
+
4
+ T = TypeVar("T", bound=enum.Enum)
5
+
6
+
7
+ class StrEnumMeta(enum.EnumMeta):
8
+ auto = enum.auto
9
+
10
+ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
11
+ try:
12
+ return self[member]
13
+ except KeyError:
14
+ # TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
15
+ # soon as it is migrated.
16
+ raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
17
+
18
+
19
+ class StrEnum(enum.Enum, metaclass=StrEnumMeta):
20
+ pass
21
+
22
+
23
+ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
24
+ if not seq:
25
+ return ""
26
+ if len(seq) == 1:
27
+ return f"'{seq[0]}'"
28
+
29
+ head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
30
+ tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
31
+
32
+ return head + tail
torchvision/datapoints/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
2
+
3
+ from ._bounding_box import BoundingBox, BoundingBoxFormat
4
+ from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
5
+ from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
6
+ from ._mask import Mask
7
+ from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
8
+
9
+ if _WARN_ABOUT_BETA_TRANSFORMS:
10
+ import warnings
11
+
12
+ warnings.warn(_BETA_TRANSFORMS_WARNING)
torchvision/datapoints/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (840 Bytes). View file
 
torchvision/datapoints/__pycache__/_bounding_box.cpython-38.pyc ADDED
Binary file (8.02 kB). View file
 
torchvision/datapoints/__pycache__/_datapoint.cpython-38.pyc ADDED
Binary file (10 kB). View file
 
torchvision/datapoints/__pycache__/_dataset_wrapper.cpython-38.pyc ADDED
Binary file (16.3 kB). View file
 
torchvision/datapoints/__pycache__/_image.cpython-38.pyc ADDED
Binary file (10.4 kB). View file
 
torchvision/datapoints/__pycache__/_mask.cpython-38.pyc ADDED
Binary file (5.98 kB). View file
 
torchvision/datapoints/__pycache__/_video.cpython-38.pyc ADDED
Binary file (10.1 kB). View file
 
torchvision/datapoints/_bounding_box.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import Any, List, Optional, Sequence, Tuple, Union
5
+
6
+ import torch
7
+ from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
8
+
9
+ from ._datapoint import _FillTypeJIT, Datapoint
10
+
11
+
12
+ class BoundingBoxFormat(Enum):
13
+ """[BETA] Coordinate format of a bounding box.
14
+
15
+ Available formats are
16
+
17
+ * ``XYXY``
18
+ * ``XYWH``
19
+ * ``CXCYWH``
20
+ """
21
+
22
+ XYXY = "XYXY"
23
+ XYWH = "XYWH"
24
+ CXCYWH = "CXCYWH"
25
+
26
+
27
+ class BoundingBox(Datapoint):
28
+ """[BETA] :class:`torch.Tensor` subclass for bounding boxes.
29
+
30
+ Args:
31
+ data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
32
+ format (BoundingBoxFormat, str): Format of the bounding box.
33
+ spatial_size (two-tuple of ints): Height and width of the corresponding image or video.
34
+ dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
35
+ ``data``.
36
+ device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
37
+ :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
38
+ requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
39
+ ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
40
+ """
41
+
42
+ format: BoundingBoxFormat
43
+ spatial_size: Tuple[int, int]
44
+
45
+ @classmethod
46
+ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox:
47
+ bounding_box = tensor.as_subclass(cls)
48
+ bounding_box.format = format
49
+ bounding_box.spatial_size = spatial_size
50
+ return bounding_box
51
+
52
+ def __new__(
53
+ cls,
54
+ data: Any,
55
+ *,
56
+ format: Union[BoundingBoxFormat, str],
57
+ spatial_size: Tuple[int, int],
58
+ dtype: Optional[torch.dtype] = None,
59
+ device: Optional[Union[torch.device, str, int]] = None,
60
+ requires_grad: Optional[bool] = None,
61
+ ) -> BoundingBox:
62
+ tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
63
+
64
+ if isinstance(format, str):
65
+ format = BoundingBoxFormat[format.upper()]
66
+
67
+ return cls._wrap(tensor, format=format, spatial_size=spatial_size)
68
+
69
+ @classmethod
70
+ def wrap_like(
71
+ cls,
72
+ other: BoundingBox,
73
+ tensor: torch.Tensor,
74
+ *,
75
+ format: Optional[BoundingBoxFormat] = None,
76
+ spatial_size: Optional[Tuple[int, int]] = None,
77
+ ) -> BoundingBox:
78
+ """Wrap a :class:`torch.Tensor` as :class:`BoundingBox` from a reference.
79
+
80
+ Args:
81
+ other (BoundingBox): Reference bounding box.
82
+ tensor (Tensor): Tensor to be wrapped as :class:`BoundingBox`
83
+ format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the
84
+ reference.
85
+ spatial_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
86
+ omitted, it is taken from the reference.
87
+
88
+ """
89
+ if isinstance(format, str):
90
+ format = BoundingBoxFormat[format.upper()]
91
+
92
+ return cls._wrap(
93
+ tensor,
94
+ format=format if format is not None else other.format,
95
+ spatial_size=spatial_size if spatial_size is not None else other.spatial_size,
96
+ )
97
+
98
+ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
99
+ return self._make_repr(format=self.format, spatial_size=self.spatial_size)
100
+
101
+ def horizontal_flip(self) -> BoundingBox:
102
+ output = self._F.horizontal_flip_bounding_box(
103
+ self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
104
+ )
105
+ return BoundingBox.wrap_like(self, output)
106
+
107
+ def vertical_flip(self) -> BoundingBox:
108
+ output = self._F.vertical_flip_bounding_box(
109
+ self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
110
+ )
111
+ return BoundingBox.wrap_like(self, output)
112
+
113
+ def resize( # type: ignore[override]
114
+ self,
115
+ size: List[int],
116
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
117
+ max_size: Optional[int] = None,
118
+ antialias: Optional[Union[str, bool]] = "warn",
119
+ ) -> BoundingBox:
120
+ output, spatial_size = self._F.resize_bounding_box(
121
+ self.as_subclass(torch.Tensor),
122
+ spatial_size=self.spatial_size,
123
+ size=size,
124
+ max_size=max_size,
125
+ )
126
+ return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
127
+
128
+ def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
129
+ output, spatial_size = self._F.crop_bounding_box(
130
+ self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
131
+ )
132
+ return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
133
+
134
+ def center_crop(self, output_size: List[int]) -> BoundingBox:
135
+ output, spatial_size = self._F.center_crop_bounding_box(
136
+ self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size
137
+ )
138
+ return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
139
+
140
+ def resized_crop(
141
+ self,
142
+ top: int,
143
+ left: int,
144
+ height: int,
145
+ width: int,
146
+ size: List[int],
147
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
148
+ antialias: Optional[Union[str, bool]] = "warn",
149
+ ) -> BoundingBox:
150
+ output, spatial_size = self._F.resized_crop_bounding_box(
151
+ self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
152
+ )
153
+ return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
154
+
155
+ def pad(
156
+ self,
157
+ padding: Union[int, Sequence[int]],
158
+ fill: Optional[Union[int, float, List[float]]] = None,
159
+ padding_mode: str = "constant",
160
+ ) -> BoundingBox:
161
+ output, spatial_size = self._F.pad_bounding_box(
162
+ self.as_subclass(torch.Tensor),
163
+ format=self.format,
164
+ spatial_size=self.spatial_size,
165
+ padding=padding,
166
+ padding_mode=padding_mode,
167
+ )
168
+ return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
169
+
170
+ def rotate(
171
+ self,
172
+ angle: float,
173
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
174
+ expand: bool = False,
175
+ center: Optional[List[float]] = None,
176
+ fill: _FillTypeJIT = None,
177
+ ) -> BoundingBox:
178
+ output, spatial_size = self._F.rotate_bounding_box(
179
+ self.as_subclass(torch.Tensor),
180
+ format=self.format,
181
+ spatial_size=self.spatial_size,
182
+ angle=angle,
183
+ expand=expand,
184
+ center=center,
185
+ )
186
+ return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
187
+
188
+ def affine(
189
+ self,
190
+ angle: Union[int, float],
191
+ translate: List[float],
192
+ scale: float,
193
+ shear: List[float],
194
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
195
+ fill: _FillTypeJIT = None,
196
+ center: Optional[List[float]] = None,
197
+ ) -> BoundingBox:
198
+ output = self._F.affine_bounding_box(
199
+ self.as_subclass(torch.Tensor),
200
+ self.format,
201
+ self.spatial_size,
202
+ angle,
203
+ translate=translate,
204
+ scale=scale,
205
+ shear=shear,
206
+ center=center,
207
+ )
208
+ return BoundingBox.wrap_like(self, output)
209
+
210
+ def perspective(
211
+ self,
212
+ startpoints: Optional[List[List[int]]],
213
+ endpoints: Optional[List[List[int]]],
214
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
215
+ fill: _FillTypeJIT = None,
216
+ coefficients: Optional[List[float]] = None,
217
+ ) -> BoundingBox:
218
+ output = self._F.perspective_bounding_box(
219
+ self.as_subclass(torch.Tensor),
220
+ format=self.format,
221
+ spatial_size=self.spatial_size,
222
+ startpoints=startpoints,
223
+ endpoints=endpoints,
224
+ coefficients=coefficients,
225
+ )
226
+ return BoundingBox.wrap_like(self, output)
227
+
228
+ def elastic(
229
+ self,
230
+ displacement: torch.Tensor,
231
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
232
+ fill: _FillTypeJIT = None,
233
+ ) -> BoundingBox:
234
+ output = self._F.elastic_bounding_box(
235
+ self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement
236
+ )
237
+ return BoundingBox.wrap_like(self, output)
torchvision/datapoints/_datapoint.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from types import ModuleType
4
+ from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
5
+
6
+ import PIL.Image
7
+ import torch
8
+ from torch._C import DisableTorchFunctionSubclass
9
+ from torch.types import _device, _dtype, _size
10
+ from torchvision.transforms import InterpolationMode
11
+
12
+
13
+ D = TypeVar("D", bound="Datapoint")
14
+ _FillType = Union[int, float, Sequence[int], Sequence[float], None]
15
+ _FillTypeJIT = Optional[List[float]]
16
+
17
+
18
+ class Datapoint(torch.Tensor):
19
+ __F: Optional[ModuleType] = None
20
+
21
+ @staticmethod
22
+ def _to_tensor(
23
+ data: Any,
24
+ dtype: Optional[torch.dtype] = None,
25
+ device: Optional[Union[torch.device, str, int]] = None,
26
+ requires_grad: Optional[bool] = None,
27
+ ) -> torch.Tensor:
28
+ if requires_grad is None:
29
+ requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
30
+ return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
31
+
32
+ @classmethod
33
+ def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
34
+ raise NotImplementedError
35
+
36
+ _NO_WRAPPING_EXCEPTIONS = {
37
+ torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
38
+ torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
39
+ # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
40
+ # retains the type automatically
41
+ torch.Tensor.requires_grad_: lambda cls, input, output: output,
42
+ }
43
+
44
+ @classmethod
45
+ def __torch_function__(
46
+ cls,
47
+ func: Callable[..., torch.Tensor],
48
+ types: Tuple[Type[torch.Tensor], ...],
49
+ args: Sequence[Any] = (),
50
+ kwargs: Optional[Mapping[str, Any]] = None,
51
+ ) -> torch.Tensor:
52
+ """For general information about how the __torch_function__ protocol works,
53
+ see https://pytorch.org/docs/stable/notes/extending.html#extending-torch
54
+
55
+ TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
56
+ ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
57
+ ``args`` and ``kwargs`` of the original call.
58
+
59
+ The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
60
+ use case, this has two downsides:
61
+
62
+ 1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
63
+ ``return cls(func(*args, **kwargs))``, will fail for them.
64
+ 2. For most operations, there is no way of knowing if the input type is still valid for the output.
65
+
66
+ For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
67
+ listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
68
+ """
69
+ # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
70
+ # need to reimplement the functionality.
71
+
72
+ if not all(issubclass(cls, t) for t in types):
73
+ return NotImplemented
74
+
75
+ with DisableTorchFunctionSubclass():
76
+ output = func(*args, **kwargs or dict())
77
+
78
+ wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
79
+ # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
80
+ # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
81
+ # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
82
+ # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
83
+ # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
84
+ # be wrapped into a `datapoints.Image`.
85
+ if wrapper and isinstance(args[0], cls):
86
+ return wrapper(cls, args[0], output)
87
+
88
+ # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
89
+ # will retain the input type. Thus, we need to unwrap here.
90
+ if isinstance(output, cls):
91
+ return output.as_subclass(torch.Tensor)
92
+
93
+ return output
94
+
95
+ def _make_repr(self, **kwargs: Any) -> str:
96
+ # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
97
+ # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
98
+ extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
99
+ return f"{super().__repr__()[:-1]}, {extra_repr})"
100
+
101
+ @property
102
+ def _F(self) -> ModuleType:
103
+ # This implements a lazy import of the functional to get around the cyclic import. This import is deferred
104
+ # until the first time we need reference to the functional module and it's shared across all instances of
105
+ # the class. This approach avoids the DataLoader issue described at
106
+ # https://github.com/pytorch/vision/pull/6476#discussion_r953588621
107
+ if Datapoint.__F is None:
108
+ from ..transforms.v2 import functional
109
+
110
+ Datapoint.__F = functional
111
+ return Datapoint.__F
112
+
113
+ # Add properties for common attributes like shape, dtype, device, ndim etc
114
+ # this way we return the result without passing into __torch_function__
115
+ @property
116
+ def shape(self) -> _size: # type: ignore[override]
117
+ with DisableTorchFunctionSubclass():
118
+ return super().shape
119
+
120
+ @property
121
+ def ndim(self) -> int: # type: ignore[override]
122
+ with DisableTorchFunctionSubclass():
123
+ return super().ndim
124
+
125
+ @property
126
+ def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override]
127
+ with DisableTorchFunctionSubclass():
128
+ return super().device
129
+
130
+ @property
131
+ def dtype(self) -> _dtype: # type: ignore[override]
132
+ with DisableTorchFunctionSubclass():
133
+ return super().dtype
134
+
135
+ def horizontal_flip(self) -> Datapoint:
136
+ return self
137
+
138
+ def vertical_flip(self) -> Datapoint:
139
+ return self
140
+
141
+ # TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
142
+ # https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
143
+ def resize( # type: ignore[override]
144
+ self,
145
+ size: List[int],
146
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
147
+ max_size: Optional[int] = None,
148
+ antialias: Optional[Union[str, bool]] = "warn",
149
+ ) -> Datapoint:
150
+ return self
151
+
152
+ def crop(self, top: int, left: int, height: int, width: int) -> Datapoint:
153
+ return self
154
+
155
+ def center_crop(self, output_size: List[int]) -> Datapoint:
156
+ return self
157
+
158
+ def resized_crop(
159
+ self,
160
+ top: int,
161
+ left: int,
162
+ height: int,
163
+ width: int,
164
+ size: List[int],
165
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
166
+ antialias: Optional[Union[str, bool]] = "warn",
167
+ ) -> Datapoint:
168
+ return self
169
+
170
+ def pad(
171
+ self,
172
+ padding: List[int],
173
+ fill: Optional[Union[int, float, List[float]]] = None,
174
+ padding_mode: str = "constant",
175
+ ) -> Datapoint:
176
+ return self
177
+
178
+ def rotate(
179
+ self,
180
+ angle: float,
181
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
182
+ expand: bool = False,
183
+ center: Optional[List[float]] = None,
184
+ fill: _FillTypeJIT = None,
185
+ ) -> Datapoint:
186
+ return self
187
+
188
+ def affine(
189
+ self,
190
+ angle: Union[int, float],
191
+ translate: List[float],
192
+ scale: float,
193
+ shear: List[float],
194
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
195
+ fill: _FillTypeJIT = None,
196
+ center: Optional[List[float]] = None,
197
+ ) -> Datapoint:
198
+ return self
199
+
200
+ def perspective(
201
+ self,
202
+ startpoints: Optional[List[List[int]]],
203
+ endpoints: Optional[List[List[int]]],
204
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
205
+ fill: _FillTypeJIT = None,
206
+ coefficients: Optional[List[float]] = None,
207
+ ) -> Datapoint:
208
+ return self
209
+
210
+ def elastic(
211
+ self,
212
+ displacement: torch.Tensor,
213
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
214
+ fill: _FillTypeJIT = None,
215
+ ) -> Datapoint:
216
+ return self
217
+
218
+ def rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
219
+ return self
220
+
221
+ def adjust_brightness(self, brightness_factor: float) -> Datapoint:
222
+ return self
223
+
224
+ def adjust_saturation(self, saturation_factor: float) -> Datapoint:
225
+ return self
226
+
227
+ def adjust_contrast(self, contrast_factor: float) -> Datapoint:
228
+ return self
229
+
230
+ def adjust_sharpness(self, sharpness_factor: float) -> Datapoint:
231
+ return self
232
+
233
+ def adjust_hue(self, hue_factor: float) -> Datapoint:
234
+ return self
235
+
236
+ def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint:
237
+ return self
238
+
239
+ def posterize(self, bits: int) -> Datapoint:
240
+ return self
241
+
242
+ def solarize(self, threshold: float) -> Datapoint:
243
+ return self
244
+
245
+ def autocontrast(self) -> Datapoint:
246
+ return self
247
+
248
+ def equalize(self) -> Datapoint:
249
+ return self
250
+
251
+ def invert(self) -> Datapoint:
252
+ return self
253
+
254
+ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint:
255
+ return self
256
+
257
+
258
+ _InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
259
+ _InputTypeJIT = torch.Tensor
torchvision/datapoints/_dataset_wrapper.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+
3
+ from __future__ import annotations
4
+
5
+ import contextlib
6
+ from collections import defaultdict
7
+
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+
11
+ from torchvision import datapoints, datasets
12
+ from torchvision.transforms.v2 import functional as F
13
+
14
+ __all__ = ["wrap_dataset_for_transforms_v2"]
15
+
16
+
17
+ def wrap_dataset_for_transforms_v2(dataset):
18
+ """[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.
19
+
20
+ .. v2betastatus:: wrap_dataset_for_transforms_v2 function
21
+
22
+ Example:
23
+ >>> dataset = torchvision.datasets.CocoDetection(...)
24
+ >>> dataset = wrap_dataset_for_transforms_v2(dataset)
25
+
26
+ .. note::
27
+
28
+ For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset
29
+ configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you
30
+ to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so.
31
+
32
+ The dataset samples are wrapped according to the description below.
33
+
34
+ Special cases:
35
+
36
+ * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
37
+ returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
38
+ ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``.
39
+ The original keys are preserved.
40
+ * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
41
+ the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are
42
+ preserved.
43
+ * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
44
+ coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.
45
+ * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict
46
+ of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
47
+ in the corresponding ``torchvision.datapoints``. The original keys are preserved.
48
+ * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
49
+ :class:`~torchvision.datapoints.Mask` datapoint.
50
+ * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
51
+ :class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by
52
+ a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and
53
+ ``"labels"``.
54
+ * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
55
+ coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.
56
+
57
+ Image classification datasets
58
+
59
+ This wrapper is a no-op for image classification datasets, since they were already fully supported by
60
+ :mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`.
61
+
62
+ Segmentation datasets
63
+
64
+ Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of
65
+ :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
66
+ segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item).
67
+
68
+ Video classification datasets
69
+
70
+ Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple containing a
71
+ :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
72
+ :class:`~torchvision.datapoints.Video` while leaving the other items as is.
73
+
74
+ .. note::
75
+
76
+ Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative
77
+ ``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`.
78
+
79
+ Args:
80
+ dataset: the dataset instance to wrap for compatibility with transforms v2.
81
+ """
82
+ return VisionDatasetDatapointWrapper(dataset)
83
+
84
+
85
+ class WrapperFactories(dict):
86
+ def register(self, dataset_cls):
87
+ def decorator(wrapper_factory):
88
+ self[dataset_cls] = wrapper_factory
89
+ return wrapper_factory
90
+
91
+ return decorator
92
+
93
+
94
+ # We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the
95
+ # dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can
96
+ # provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when
97
+ # we have access to the dataset instance.
98
+ WRAPPER_FACTORIES = WrapperFactories()
99
+
100
+
101
+ class VisionDatasetDatapointWrapper(Dataset):
102
+ def __init__(self, dataset):
103
+ dataset_cls = type(dataset)
104
+
105
+ if not isinstance(dataset, datasets.VisionDataset):
106
+ raise TypeError(
107
+ f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
108
+ f"but got a '{dataset_cls.__name__}' instead."
109
+ )
110
+
111
+ for cls in dataset_cls.mro():
112
+ if cls in WRAPPER_FACTORIES:
113
+ wrapper_factory = WRAPPER_FACTORIES[cls]
114
+ break
115
+ elif cls is datasets.VisionDataset:
116
+ # TODO: If we have documentation on how to do that, put a link in the error message.
117
+ msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
118
+ if dataset_cls in datasets.__dict__.values():
119
+ msg = (
120
+ f"{msg} If an automated wrapper for this dataset would be useful for you, "
121
+ f"please open an issue at https://github.com/pytorch/vision/issues."
122
+ )
123
+ raise TypeError(msg)
124
+
125
+ self._dataset = dataset
126
+ self._wrapper = wrapper_factory(dataset)
127
+
128
+ # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
129
+ # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
130
+ # `transforms`
131
+ # https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
132
+ # some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
133
+ # disable all three here to be able to extract the untransformed sample to wrap.
134
+ self.transform, dataset.transform = dataset.transform, None
135
+ self.target_transform, dataset.target_transform = dataset.target_transform, None
136
+ self.transforms, dataset.transforms = dataset.transforms, None
137
+
138
+ def __getattr__(self, item):
139
+ with contextlib.suppress(AttributeError):
140
+ return object.__getattribute__(self, item)
141
+
142
+ return getattr(self._dataset, item)
143
+
144
+ def __getitem__(self, idx):
145
+ # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
146
+ # of this class
147
+ sample = self._dataset[idx]
148
+
149
+ sample = self._wrapper(idx, sample)
150
+
151
+ # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
152
+ # or joint (`transforms`), we can access the full functionality through `transforms`
153
+ if self.transforms is not None:
154
+ sample = self.transforms(*sample)
155
+
156
+ return sample
157
+
158
+ def __len__(self):
159
+ return len(self._dataset)
160
+
161
+
162
+ def raise_not_supported(description):
163
+ raise RuntimeError(
164
+ f"{description} is currently not supported by this wrapper. "
165
+ f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues."
166
+ )
167
+
168
+
169
+ def identity(item):
170
+ return item
171
+
172
+
173
+ def identity_wrapper_factory(dataset):
174
+ def wrapper(idx, sample):
175
+ return sample
176
+
177
+ return wrapper
178
+
179
+
180
+ def pil_image_to_mask(pil_image):
181
+ return datapoints.Mask(pil_image)
182
+
183
+
184
+ def list_of_dicts_to_dict_of_lists(list_of_dicts):
185
+ dict_of_lists = defaultdict(list)
186
+ for dct in list_of_dicts:
187
+ for key, value in dct.items():
188
+ dict_of_lists[key].append(value)
189
+ return dict(dict_of_lists)
190
+
191
+
192
+ def wrap_target_by_type(target, *, target_types, type_wrappers):
193
+ if not isinstance(target, (tuple, list)):
194
+ target = [target]
195
+
196
+ wrapped_target = tuple(
197
+ type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target)
198
+ )
199
+
200
+ if len(wrapped_target) == 1:
201
+ wrapped_target = wrapped_target[0]
202
+
203
+ return wrapped_target
204
+
205
+
206
+ def classification_wrapper_factory(dataset):
207
+ return identity_wrapper_factory(dataset)
208
+
209
+
210
+ for dataset_cls in [
211
+ datasets.Caltech256,
212
+ datasets.CIFAR10,
213
+ datasets.CIFAR100,
214
+ datasets.ImageNet,
215
+ datasets.MNIST,
216
+ datasets.FashionMNIST,
217
+ datasets.GTSRB,
218
+ datasets.DatasetFolder,
219
+ datasets.ImageFolder,
220
+ ]:
221
+ WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)
222
+
223
+
224
+ def segmentation_wrapper_factory(dataset):
225
+ def wrapper(idx, sample):
226
+ image, mask = sample
227
+ return image, pil_image_to_mask(mask)
228
+
229
+ return wrapper
230
+
231
+
232
+ for dataset_cls in [
233
+ datasets.VOCSegmentation,
234
+ ]:
235
+ WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)
236
+
237
+
238
+ def video_classification_wrapper_factory(dataset):
239
+ if dataset.video_clips.output_format == "THWC":
240
+ raise RuntimeError(
241
+ f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "
242
+ f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
243
+ )
244
+
245
+ def wrapper(idx, sample):
246
+ video, audio, label = sample
247
+
248
+ video = datapoints.Video(video)
249
+
250
+ return video, audio, label
251
+
252
+ return wrapper
253
+
254
+
255
+ for dataset_cls in [
256
+ datasets.HMDB51,
257
+ datasets.Kinetics,
258
+ datasets.UCF101,
259
+ ]:
260
+ WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory)
261
+
262
+
263
+ @WRAPPER_FACTORIES.register(datasets.Caltech101)
264
+ def caltech101_wrapper_factory(dataset):
265
+ if "annotation" in dataset.target_type:
266
+ raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")
267
+
268
+ return classification_wrapper_factory(dataset)
269
+
270
+
271
+ @WRAPPER_FACTORIES.register(datasets.CocoDetection)
272
+ def coco_dectection_wrapper_factory(dataset):
273
+ def segmentation_to_mask(segmentation, *, spatial_size):
274
+ from pycocotools import mask
275
+
276
+ segmentation = (
277
+ mask.frPyObjects(segmentation, *spatial_size)
278
+ if isinstance(segmentation, dict)
279
+ else mask.merge(mask.frPyObjects(segmentation, *spatial_size))
280
+ )
281
+ return torch.from_numpy(mask.decode(segmentation))
282
+
283
+ def wrapper(idx, sample):
284
+ image_id = dataset.ids[idx]
285
+
286
+ image, target = sample
287
+
288
+ if not target:
289
+ return image, dict(image_id=image_id)
290
+
291
+ batched_target = list_of_dicts_to_dict_of_lists(target)
292
+
293
+ batched_target["image_id"] = image_id
294
+
295
+ spatial_size = tuple(F.get_spatial_size(image))
296
+ batched_target["boxes"] = F.convert_format_bounding_box(
297
+ datapoints.BoundingBox(
298
+ batched_target["bbox"],
299
+ format=datapoints.BoundingBoxFormat.XYWH,
300
+ spatial_size=spatial_size,
301
+ ),
302
+ new_format=datapoints.BoundingBoxFormat.XYXY,
303
+ )
304
+ batched_target["masks"] = datapoints.Mask(
305
+ torch.stack(
306
+ [
307
+ segmentation_to_mask(segmentation, spatial_size=spatial_size)
308
+ for segmentation in batched_target["segmentation"]
309
+ ]
310
+ ),
311
+ )
312
+ batched_target["labels"] = torch.tensor(batched_target["category_id"])
313
+
314
+ return image, batched_target
315
+
316
+ return wrapper
317
+
318
+
319
+ WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)
320
+
321
+
322
+ VOC_DETECTION_CATEGORIES = [
323
+ "__background__",
324
+ "aeroplane",
325
+ "bicycle",
326
+ "bird",
327
+ "boat",
328
+ "bottle",
329
+ "bus",
330
+ "car",
331
+ "cat",
332
+ "chair",
333
+ "cow",
334
+ "diningtable",
335
+ "dog",
336
+ "horse",
337
+ "motorbike",
338
+ "person",
339
+ "pottedplant",
340
+ "sheep",
341
+ "sofa",
342
+ "train",
343
+ "tvmonitor",
344
+ ]
345
+ VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES))))
346
+
347
+
348
+ @WRAPPER_FACTORIES.register(datasets.VOCDetection)
349
+ def voc_detection_wrapper_factory(dataset):
350
+ def wrapper(idx, sample):
351
+ image, target = sample
352
+
353
+ batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
354
+
355
+ target["boxes"] = datapoints.BoundingBox(
356
+ [
357
+ [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
358
+ for bndbox in batched_instances["bndbox"]
359
+ ],
360
+ format=datapoints.BoundingBoxFormat.XYXY,
361
+ spatial_size=(image.height, image.width),
362
+ )
363
+ target["labels"] = torch.tensor(
364
+ [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]]
365
+ )
366
+
367
+ return image, target
368
+
369
+ return wrapper
370
+
371
+
372
+ @WRAPPER_FACTORIES.register(datasets.SBDataset)
373
+ def sbd_wrapper(dataset):
374
+ if dataset.mode == "boundaries":
375
+ raise_not_supported("SBDataset with mode='boundaries'")
376
+
377
+ return segmentation_wrapper_factory(dataset)
378
+
379
+
380
+ @WRAPPER_FACTORIES.register(datasets.CelebA)
381
+ def celeba_wrapper_factory(dataset):
382
+ if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
383
+ raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
384
+
385
+ def wrapper(idx, sample):
386
+ image, target = sample
387
+
388
+ target = wrap_target_by_type(
389
+ target,
390
+ target_types=dataset.target_type,
391
+ type_wrappers={
392
+ "bbox": lambda item: F.convert_format_bounding_box(
393
+ datapoints.BoundingBox(
394
+ item,
395
+ format=datapoints.BoundingBoxFormat.XYWH,
396
+ spatial_size=(image.height, image.width),
397
+ ),
398
+ new_format=datapoints.BoundingBoxFormat.XYXY,
399
+ ),
400
+ },
401
+ )
402
+
403
+ return image, target
404
+
405
+ return wrapper
406
+
407
+
408
+ KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"]
409
+ KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))))
410
+
411
+
412
+ @WRAPPER_FACTORIES.register(datasets.Kitti)
413
+ def kitti_wrapper_factory(dataset):
414
+ def wrapper(idx, sample):
415
+ image, target = sample
416
+
417
+ if target is not None:
418
+ target = list_of_dicts_to_dict_of_lists(target)
419
+
420
+ target["boxes"] = datapoints.BoundingBox(
421
+ target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width)
422
+ )
423
+ target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in target["type"]])
424
+
425
+ return image, target
426
+
427
+ return wrapper
428
+
429
+
430
+ @WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
431
+ def oxford_iiit_pet_wrapper_factor(dataset):
432
+ def wrapper(idx, sample):
433
+ image, target = sample
434
+
435
+ if target is not None:
436
+ target = wrap_target_by_type(
437
+ target,
438
+ target_types=dataset._target_types,
439
+ type_wrappers={
440
+ "segmentation": pil_image_to_mask,
441
+ },
442
+ )
443
+
444
+ return image, target
445
+
446
+ return wrapper
447
+
448
+
449
+ @WRAPPER_FACTORIES.register(datasets.Cityscapes)
450
+ def cityscapes_wrapper_factory(dataset):
451
+ if any(target_type in dataset.target_type for target_type in ["polygon", "color"]):
452
+ raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")
453
+
454
+ def instance_segmentation_wrapper(mask):
455
+ # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21
456
+ data = pil_image_to_mask(mask)
457
+ masks = []
458
+ labels = []
459
+ for id in data.unique():
460
+ masks.append(data == id)
461
+ label = id
462
+ if label >= 1_000:
463
+ label //= 1_000
464
+ labels.append(label)
465
+ return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))
466
+
467
+ def wrapper(idx, sample):
468
+ image, target = sample
469
+
470
+ target = wrap_target_by_type(
471
+ target,
472
+ target_types=dataset.target_type,
473
+ type_wrappers={
474
+ "instance": instance_segmentation_wrapper,
475
+ "semantic": pil_image_to_mask,
476
+ },
477
+ )
478
+
479
+ return image, target
480
+
481
+ return wrapper
482
+
483
+
484
+ @WRAPPER_FACTORIES.register(datasets.WIDERFace)
485
+ def widerface_wrapper(dataset):
486
+ def wrapper(idx, sample):
487
+ image, target = sample
488
+
489
+ if target is not None:
490
+ target["bbox"] = F.convert_format_bounding_box(
491
+ datapoints.BoundingBox(
492
+ target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)
493
+ ),
494
+ new_format=datapoints.BoundingBoxFormat.XYXY,
495
+ )
496
+
497
+ return image, target
498
+
499
+ return wrapper
torchvision/datapoints/_image.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, List, Optional, Tuple, Union
4
+
5
+ import PIL.Image
6
+ import torch
7
+ from torchvision.transforms.functional import InterpolationMode
8
+
9
+ from ._datapoint import _FillTypeJIT, Datapoint
10
+
11
+
12
+ class Image(Datapoint):
13
+ """[BETA] :class:`torch.Tensor` subclass for images.
14
+
15
+ Args:
16
+ data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
17
+ well as PIL images.
18
+ dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
19
+ ``data``.
20
+ device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
21
+ :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
22
+ requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
23
+ ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
24
+ """
25
+
26
+ @classmethod
27
+ def _wrap(cls, tensor: torch.Tensor) -> Image:
28
+ image = tensor.as_subclass(cls)
29
+ return image
30
+
31
+ def __new__(
32
+ cls,
33
+ data: Any,
34
+ *,
35
+ dtype: Optional[torch.dtype] = None,
36
+ device: Optional[Union[torch.device, str, int]] = None,
37
+ requires_grad: Optional[bool] = None,
38
+ ) -> Image:
39
+ if isinstance(data, PIL.Image.Image):
40
+ from torchvision.transforms.v2 import functional as F
41
+
42
+ data = F.pil_to_tensor(data)
43
+
44
+ tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
45
+ if tensor.ndim < 2:
46
+ raise ValueError
47
+ elif tensor.ndim == 2:
48
+ tensor = tensor.unsqueeze(0)
49
+
50
+ return cls._wrap(tensor)
51
+
52
+ @classmethod
53
+ def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image:
54
+ return cls._wrap(tensor)
55
+
56
+ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
57
+ return self._make_repr()
58
+
59
+ @property
60
+ def spatial_size(self) -> Tuple[int, int]:
61
+ return tuple(self.shape[-2:]) # type: ignore[return-value]
62
+
63
+ @property
64
+ def num_channels(self) -> int:
65
+ return self.shape[-3]
66
+
67
+ def horizontal_flip(self) -> Image:
68
+ output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
69
+ return Image.wrap_like(self, output)
70
+
71
+ def vertical_flip(self) -> Image:
72
+ output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
73
+ return Image.wrap_like(self, output)
74
+
75
+ def resize( # type: ignore[override]
76
+ self,
77
+ size: List[int],
78
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
79
+ max_size: Optional[int] = None,
80
+ antialias: Optional[Union[str, bool]] = "warn",
81
+ ) -> Image:
82
+ output = self._F.resize_image_tensor(
83
+ self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias
84
+ )
85
+ return Image.wrap_like(self, output)
86
+
87
+ def crop(self, top: int, left: int, height: int, width: int) -> Image:
88
+ output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width)
89
+ return Image.wrap_like(self, output)
90
+
91
+ def center_crop(self, output_size: List[int]) -> Image:
92
+ output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size)
93
+ return Image.wrap_like(self, output)
94
+
95
+ def resized_crop(
96
+ self,
97
+ top: int,
98
+ left: int,
99
+ height: int,
100
+ width: int,
101
+ size: List[int],
102
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
103
+ antialias: Optional[Union[str, bool]] = "warn",
104
+ ) -> Image:
105
+ output = self._F.resized_crop_image_tensor(
106
+ self.as_subclass(torch.Tensor),
107
+ top,
108
+ left,
109
+ height,
110
+ width,
111
+ size=list(size),
112
+ interpolation=interpolation,
113
+ antialias=antialias,
114
+ )
115
+ return Image.wrap_like(self, output)
116
+
117
+ def pad(
118
+ self,
119
+ padding: List[int],
120
+ fill: Optional[Union[int, float, List[float]]] = None,
121
+ padding_mode: str = "constant",
122
+ ) -> Image:
123
+ output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
124
+ return Image.wrap_like(self, output)
125
+
126
+ def rotate(
127
+ self,
128
+ angle: float,
129
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
130
+ expand: bool = False,
131
+ center: Optional[List[float]] = None,
132
+ fill: _FillTypeJIT = None,
133
+ ) -> Image:
134
+ output = self._F.rotate_image_tensor(
135
+ self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
136
+ )
137
+ return Image.wrap_like(self, output)
138
+
139
+ def affine(
140
+ self,
141
+ angle: Union[int, float],
142
+ translate: List[float],
143
+ scale: float,
144
+ shear: List[float],
145
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
146
+ fill: _FillTypeJIT = None,
147
+ center: Optional[List[float]] = None,
148
+ ) -> Image:
149
+ output = self._F.affine_image_tensor(
150
+ self.as_subclass(torch.Tensor),
151
+ angle,
152
+ translate=translate,
153
+ scale=scale,
154
+ shear=shear,
155
+ interpolation=interpolation,
156
+ fill=fill,
157
+ center=center,
158
+ )
159
+ return Image.wrap_like(self, output)
160
+
161
+ def perspective(
162
+ self,
163
+ startpoints: Optional[List[List[int]]],
164
+ endpoints: Optional[List[List[int]]],
165
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
166
+ fill: _FillTypeJIT = None,
167
+ coefficients: Optional[List[float]] = None,
168
+ ) -> Image:
169
+ output = self._F.perspective_image_tensor(
170
+ self.as_subclass(torch.Tensor),
171
+ startpoints,
172
+ endpoints,
173
+ interpolation=interpolation,
174
+ fill=fill,
175
+ coefficients=coefficients,
176
+ )
177
+ return Image.wrap_like(self, output)
178
+
179
+ def elastic(
180
+ self,
181
+ displacement: torch.Tensor,
182
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
183
+ fill: _FillTypeJIT = None,
184
+ ) -> Image:
185
+ output = self._F.elastic_image_tensor(
186
+ self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
187
+ )
188
+ return Image.wrap_like(self, output)
189
+
190
+ def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image:
191
+ output = self._F.rgb_to_grayscale_image_tensor(
192
+ self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
193
+ )
194
+ return Image.wrap_like(self, output)
195
+
196
+ def adjust_brightness(self, brightness_factor: float) -> Image:
197
+ output = self._F.adjust_brightness_image_tensor(
198
+ self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
199
+ )
200
+ return Image.wrap_like(self, output)
201
+
202
+ def adjust_saturation(self, saturation_factor: float) -> Image:
203
+ output = self._F.adjust_saturation_image_tensor(
204
+ self.as_subclass(torch.Tensor), saturation_factor=saturation_factor
205
+ )
206
+ return Image.wrap_like(self, output)
207
+
208
+ def adjust_contrast(self, contrast_factor: float) -> Image:
209
+ output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
210
+ return Image.wrap_like(self, output)
211
+
212
+ def adjust_sharpness(self, sharpness_factor: float) -> Image:
213
+ output = self._F.adjust_sharpness_image_tensor(
214
+ self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor
215
+ )
216
+ return Image.wrap_like(self, output)
217
+
218
+ def adjust_hue(self, hue_factor: float) -> Image:
219
+ output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
220
+ return Image.wrap_like(self, output)
221
+
222
+ def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
223
+ output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
224
+ return Image.wrap_like(self, output)
225
+
226
+ def posterize(self, bits: int) -> Image:
227
+ output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits)
228
+ return Image.wrap_like(self, output)
229
+
230
+ def solarize(self, threshold: float) -> Image:
231
+ output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold)
232
+ return Image.wrap_like(self, output)
233
+
234
+ def autocontrast(self) -> Image:
235
+ output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor))
236
+ return Image.wrap_like(self, output)
237
+
238
+ def equalize(self) -> Image:
239
+ output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor))
240
+ return Image.wrap_like(self, output)
241
+
242
+ def invert(self) -> Image:
243
+ output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor))
244
+ return Image.wrap_like(self, output)
245
+
246
+ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
247
+ output = self._F.gaussian_blur_image_tensor(
248
+ self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma
249
+ )
250
+ return Image.wrap_like(self, output)
251
+
252
+ def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
253
+ output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
254
+ return Image.wrap_like(self, output)
255
+
256
+
257
+ _ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
258
+ _ImageTypeJIT = torch.Tensor
259
+ _TensorImageType = Union[torch.Tensor, Image]
260
+ _TensorImageTypeJIT = torch.Tensor
torchvision/datapoints/_mask.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, List, Optional, Tuple, Union
4
+
5
+ import PIL.Image
6
+ import torch
7
+ from torchvision.transforms import InterpolationMode
8
+
9
+ from ._datapoint import _FillTypeJIT, Datapoint
10
+
11
+
12
+ class Mask(Datapoint):
13
+ """[BETA] :class:`torch.Tensor` subclass for segmentation and detection masks.
14
+
15
+ Args:
16
+ data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
17
+ well as PIL images.
18
+ dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
19
+ ``data``.
20
+ device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
21
+ :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
22
+ requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
23
+ ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
24
+ """
25
+
26
+ @classmethod
27
+ def _wrap(cls, tensor: torch.Tensor) -> Mask:
28
+ return tensor.as_subclass(cls)
29
+
30
+ def __new__(
31
+ cls,
32
+ data: Any,
33
+ *,
34
+ dtype: Optional[torch.dtype] = None,
35
+ device: Optional[Union[torch.device, str, int]] = None,
36
+ requires_grad: Optional[bool] = None,
37
+ ) -> Mask:
38
+ if isinstance(data, PIL.Image.Image):
39
+ from torchvision.transforms.v2 import functional as F
40
+
41
+ data = F.pil_to_tensor(data)
42
+
43
+ tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
44
+ return cls._wrap(tensor)
45
+
46
+ @classmethod
47
+ def wrap_like(
48
+ cls,
49
+ other: Mask,
50
+ tensor: torch.Tensor,
51
+ ) -> Mask:
52
+ return cls._wrap(tensor)
53
+
54
+ @property
55
+ def spatial_size(self) -> Tuple[int, int]:
56
+ return tuple(self.shape[-2:]) # type: ignore[return-value]
57
+
58
+ def horizontal_flip(self) -> Mask:
59
+ output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
60
+ return Mask.wrap_like(self, output)
61
+
62
+ def vertical_flip(self) -> Mask:
63
+ output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor))
64
+ return Mask.wrap_like(self, output)
65
+
66
+ def resize( # type: ignore[override]
67
+ self,
68
+ size: List[int],
69
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
70
+ max_size: Optional[int] = None,
71
+ antialias: Optional[Union[str, bool]] = "warn",
72
+ ) -> Mask:
73
+ output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size)
74
+ return Mask.wrap_like(self, output)
75
+
76
+ def crop(self, top: int, left: int, height: int, width: int) -> Mask:
77
+ output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width)
78
+ return Mask.wrap_like(self, output)
79
+
80
+ def center_crop(self, output_size: List[int]) -> Mask:
81
+ output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size)
82
+ return Mask.wrap_like(self, output)
83
+
84
+ def resized_crop(
85
+ self,
86
+ top: int,
87
+ left: int,
88
+ height: int,
89
+ width: int,
90
+ size: List[int],
91
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
92
+ antialias: Optional[Union[str, bool]] = "warn",
93
+ ) -> Mask:
94
+ output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
95
+ return Mask.wrap_like(self, output)
96
+
97
+ def pad(
98
+ self,
99
+ padding: List[int],
100
+ fill: Optional[Union[int, float, List[float]]] = None,
101
+ padding_mode: str = "constant",
102
+ ) -> Mask:
103
+ output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
104
+ return Mask.wrap_like(self, output)
105
+
106
+ def rotate(
107
+ self,
108
+ angle: float,
109
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
110
+ expand: bool = False,
111
+ center: Optional[List[float]] = None,
112
+ fill: _FillTypeJIT = None,
113
+ ) -> Mask:
114
+ output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
115
+ return Mask.wrap_like(self, output)
116
+
117
+ def affine(
118
+ self,
119
+ angle: Union[int, float],
120
+ translate: List[float],
121
+ scale: float,
122
+ shear: List[float],
123
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
124
+ fill: _FillTypeJIT = None,
125
+ center: Optional[List[float]] = None,
126
+ ) -> Mask:
127
+ output = self._F.affine_mask(
128
+ self.as_subclass(torch.Tensor),
129
+ angle,
130
+ translate=translate,
131
+ scale=scale,
132
+ shear=shear,
133
+ fill=fill,
134
+ center=center,
135
+ )
136
+ return Mask.wrap_like(self, output)
137
+
138
+ def perspective(
139
+ self,
140
+ startpoints: Optional[List[List[int]]],
141
+ endpoints: Optional[List[List[int]]],
142
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
143
+ fill: _FillTypeJIT = None,
144
+ coefficients: Optional[List[float]] = None,
145
+ ) -> Mask:
146
+ output = self._F.perspective_mask(
147
+ self.as_subclass(torch.Tensor), startpoints, endpoints, fill=fill, coefficients=coefficients
148
+ )
149
+ return Mask.wrap_like(self, output)
150
+
151
+ def elastic(
152
+ self,
153
+ displacement: torch.Tensor,
154
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
155
+ fill: _FillTypeJIT = None,
156
+ ) -> Mask:
157
+ output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
158
+ return Mask.wrap_like(self, output)
torchvision/datapoints/_video.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torchvision.transforms.functional import InterpolationMode
7
+
8
+ from ._datapoint import _FillTypeJIT, Datapoint
9
+
10
+
11
+ class Video(Datapoint):
12
+ """[BETA] :class:`torch.Tensor` subclass for videos.
13
+
14
+ Args:
15
+ data (tensor-like): Any data that can be turned into a tensor with :func:`torch.as_tensor`.
16
+ dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
17
+ ``data``.
18
+ device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
19
+ :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
20
+ requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
21
+ ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
22
+ """
23
+
24
+ @classmethod
25
+ def _wrap(cls, tensor: torch.Tensor) -> Video:
26
+ video = tensor.as_subclass(cls)
27
+ return video
28
+
29
+ def __new__(
30
+ cls,
31
+ data: Any,
32
+ *,
33
+ dtype: Optional[torch.dtype] = None,
34
+ device: Optional[Union[torch.device, str, int]] = None,
35
+ requires_grad: Optional[bool] = None,
36
+ ) -> Video:
37
+ tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
38
+ if data.ndim < 4:
39
+ raise ValueError
40
+ return cls._wrap(tensor)
41
+
42
+ @classmethod
43
+ def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video:
44
+ return cls._wrap(tensor)
45
+
46
+ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
47
+ return self._make_repr()
48
+
49
+ @property
50
+ def spatial_size(self) -> Tuple[int, int]:
51
+ return tuple(self.shape[-2:]) # type: ignore[return-value]
52
+
53
+ @property
54
+ def num_channels(self) -> int:
55
+ return self.shape[-3]
56
+
57
+ @property
58
+ def num_frames(self) -> int:
59
+ return self.shape[-4]
60
+
61
+ def horizontal_flip(self) -> Video:
62
+ output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor))
63
+ return Video.wrap_like(self, output)
64
+
65
+ def vertical_flip(self) -> Video:
66
+ output = self._F.vertical_flip_video(self.as_subclass(torch.Tensor))
67
+ return Video.wrap_like(self, output)
68
+
69
+ def resize( # type: ignore[override]
70
+ self,
71
+ size: List[int],
72
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
73
+ max_size: Optional[int] = None,
74
+ antialias: Optional[Union[str, bool]] = "warn",
75
+ ) -> Video:
76
+ output = self._F.resize_video(
77
+ self.as_subclass(torch.Tensor),
78
+ size,
79
+ interpolation=interpolation,
80
+ max_size=max_size,
81
+ antialias=antialias,
82
+ )
83
+ return Video.wrap_like(self, output)
84
+
85
+ def crop(self, top: int, left: int, height: int, width: int) -> Video:
86
+ output = self._F.crop_video(self.as_subclass(torch.Tensor), top, left, height, width)
87
+ return Video.wrap_like(self, output)
88
+
89
+ def center_crop(self, output_size: List[int]) -> Video:
90
+ output = self._F.center_crop_video(self.as_subclass(torch.Tensor), output_size=output_size)
91
+ return Video.wrap_like(self, output)
92
+
93
+ def resized_crop(
94
+ self,
95
+ top: int,
96
+ left: int,
97
+ height: int,
98
+ width: int,
99
+ size: List[int],
100
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
101
+ antialias: Optional[Union[str, bool]] = "warn",
102
+ ) -> Video:
103
+ output = self._F.resized_crop_video(
104
+ self.as_subclass(torch.Tensor),
105
+ top,
106
+ left,
107
+ height,
108
+ width,
109
+ size=list(size),
110
+ interpolation=interpolation,
111
+ antialias=antialias,
112
+ )
113
+ return Video.wrap_like(self, output)
114
+
115
+ def pad(
116
+ self,
117
+ padding: List[int],
118
+ fill: Optional[Union[int, float, List[float]]] = None,
119
+ padding_mode: str = "constant",
120
+ ) -> Video:
121
+ output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
122
+ return Video.wrap_like(self, output)
123
+
124
+ def rotate(
125
+ self,
126
+ angle: float,
127
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
128
+ expand: bool = False,
129
+ center: Optional[List[float]] = None,
130
+ fill: _FillTypeJIT = None,
131
+ ) -> Video:
132
+ output = self._F.rotate_video(
133
+ self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
134
+ )
135
+ return Video.wrap_like(self, output)
136
+
137
+ def affine(
138
+ self,
139
+ angle: Union[int, float],
140
+ translate: List[float],
141
+ scale: float,
142
+ shear: List[float],
143
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
144
+ fill: _FillTypeJIT = None,
145
+ center: Optional[List[float]] = None,
146
+ ) -> Video:
147
+ output = self._F.affine_video(
148
+ self.as_subclass(torch.Tensor),
149
+ angle,
150
+ translate=translate,
151
+ scale=scale,
152
+ shear=shear,
153
+ interpolation=interpolation,
154
+ fill=fill,
155
+ center=center,
156
+ )
157
+ return Video.wrap_like(self, output)
158
+
159
+ def perspective(
160
+ self,
161
+ startpoints: Optional[List[List[int]]],
162
+ endpoints: Optional[List[List[int]]],
163
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
164
+ fill: _FillTypeJIT = None,
165
+ coefficients: Optional[List[float]] = None,
166
+ ) -> Video:
167
+ output = self._F.perspective_video(
168
+ self.as_subclass(torch.Tensor),
169
+ startpoints,
170
+ endpoints,
171
+ interpolation=interpolation,
172
+ fill=fill,
173
+ coefficients=coefficients,
174
+ )
175
+ return Video.wrap_like(self, output)
176
+
177
+ def elastic(
178
+ self,
179
+ displacement: torch.Tensor,
180
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
181
+ fill: _FillTypeJIT = None,
182
+ ) -> Video:
183
+ output = self._F.elastic_video(
184
+ self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
185
+ )
186
+ return Video.wrap_like(self, output)
187
+
188
+ def rgb_to_grayscale(self, num_output_channels: int = 1) -> Video:
189
+ output = self._F.rgb_to_grayscale_image_tensor(
190
+ self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
191
+ )
192
+ return Video.wrap_like(self, output)
193
+
194
+ def adjust_brightness(self, brightness_factor: float) -> Video:
195
+ output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor)
196
+ return Video.wrap_like(self, output)
197
+
198
+ def adjust_saturation(self, saturation_factor: float) -> Video:
199
+ output = self._F.adjust_saturation_video(self.as_subclass(torch.Tensor), saturation_factor=saturation_factor)
200
+ return Video.wrap_like(self, output)
201
+
202
+ def adjust_contrast(self, contrast_factor: float) -> Video:
203
+ output = self._F.adjust_contrast_video(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
204
+ return Video.wrap_like(self, output)
205
+
206
+ def adjust_sharpness(self, sharpness_factor: float) -> Video:
207
+ output = self._F.adjust_sharpness_video(self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor)
208
+ return Video.wrap_like(self, output)
209
+
210
+ def adjust_hue(self, hue_factor: float) -> Video:
211
+ output = self._F.adjust_hue_video(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
212
+ return Video.wrap_like(self, output)
213
+
214
+ def adjust_gamma(self, gamma: float, gain: float = 1) -> Video:
215
+ output = self._F.adjust_gamma_video(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
216
+ return Video.wrap_like(self, output)
217
+
218
+ def posterize(self, bits: int) -> Video:
219
+ output = self._F.posterize_video(self.as_subclass(torch.Tensor), bits=bits)
220
+ return Video.wrap_like(self, output)
221
+
222
+ def solarize(self, threshold: float) -> Video:
223
+ output = self._F.solarize_video(self.as_subclass(torch.Tensor), threshold=threshold)
224
+ return Video.wrap_like(self, output)
225
+
226
+ def autocontrast(self) -> Video:
227
+ output = self._F.autocontrast_video(self.as_subclass(torch.Tensor))
228
+ return Video.wrap_like(self, output)
229
+
230
+ def equalize(self) -> Video:
231
+ output = self._F.equalize_video(self.as_subclass(torch.Tensor))
232
+ return Video.wrap_like(self, output)
233
+
234
+ def invert(self) -> Video:
235
+ output = self._F.invert_video(self.as_subclass(torch.Tensor))
236
+ return Video.wrap_like(self, output)
237
+
238
+ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video:
239
+ output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma)
240
+ return Video.wrap_like(self, output)
241
+
242
+ def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video:
243
+ output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
244
+ return Video.wrap_like(self, output)
245
+
246
+
247
+ _VideoType = Union[torch.Tensor, Video]
248
+ _VideoTypeJIT = torch.Tensor
249
+ _TensorVideoType = Union[torch.Tensor, Video]
250
+ _TensorVideoTypeJIT = torch.Tensor
torchvision/datasets/__init__.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
2
+ from ._stereo_matching import (
3
+ CarlaStereo,
4
+ CREStereo,
5
+ ETH3DStereo,
6
+ FallingThingsStereo,
7
+ InStereo2k,
8
+ Kitti2012Stereo,
9
+ Kitti2015Stereo,
10
+ Middlebury2014Stereo,
11
+ SceneFlowStereo,
12
+ SintelStereo,
13
+ )
14
+ from .caltech import Caltech101, Caltech256
15
+ from .celeba import CelebA
16
+ from .cifar import CIFAR10, CIFAR100
17
+ from .cityscapes import Cityscapes
18
+ from .clevr import CLEVRClassification
19
+ from .coco import CocoCaptions, CocoDetection
20
+ from .country211 import Country211
21
+ from .dtd import DTD
22
+ from .eurosat import EuroSAT
23
+ from .fakedata import FakeData
24
+ from .fer2013 import FER2013
25
+ from .fgvc_aircraft import FGVCAircraft
26
+ from .flickr import Flickr30k, Flickr8k
27
+ from .flowers102 import Flowers102
28
+ from .folder import DatasetFolder, ImageFolder
29
+ from .food101 import Food101
30
+ from .gtsrb import GTSRB
31
+ from .hmdb51 import HMDB51
32
+ from .imagenet import ImageNet
33
+ from .inaturalist import INaturalist
34
+ from .kinetics import Kinetics
35
+ from .kitti import Kitti
36
+ from .lfw import LFWPairs, LFWPeople
37
+ from .lsun import LSUN, LSUNClass
38
+ from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
39
+ from .moving_mnist import MovingMNIST
40
+ from .omniglot import Omniglot
41
+ from .oxford_iiit_pet import OxfordIIITPet
42
+ from .pcam import PCAM
43
+ from .phototour import PhotoTour
44
+ from .places365 import Places365
45
+ from .rendered_sst2 import RenderedSST2
46
+ from .sbd import SBDataset
47
+ from .sbu import SBU
48
+ from .semeion import SEMEION
49
+ from .stanford_cars import StanfordCars
50
+ from .stl10 import STL10
51
+ from .sun397 import SUN397
52
+ from .svhn import SVHN
53
+ from .ucf101 import UCF101
54
+ from .usps import USPS
55
+ from .vision import VisionDataset
56
+ from .voc import VOCDetection, VOCSegmentation
57
+ from .widerface import WIDERFace
58
+
59
+ __all__ = (
60
+ "LSUN",
61
+ "LSUNClass",
62
+ "ImageFolder",
63
+ "DatasetFolder",
64
+ "FakeData",
65
+ "CocoCaptions",
66
+ "CocoDetection",
67
+ "CIFAR10",
68
+ "CIFAR100",
69
+ "EMNIST",
70
+ "FashionMNIST",
71
+ "QMNIST",
72
+ "MNIST",
73
+ "KMNIST",
74
+ "StanfordCars",
75
+ "STL10",
76
+ "SUN397",
77
+ "SVHN",
78
+ "PhotoTour",
79
+ "SEMEION",
80
+ "Omniglot",
81
+ "SBU",
82
+ "Flickr8k",
83
+ "Flickr30k",
84
+ "Flowers102",
85
+ "VOCSegmentation",
86
+ "VOCDetection",
87
+ "Cityscapes",
88
+ "ImageNet",
89
+ "Caltech101",
90
+ "Caltech256",
91
+ "CelebA",
92
+ "WIDERFace",
93
+ "SBDataset",
94
+ "VisionDataset",
95
+ "USPS",
96
+ "Kinetics",
97
+ "HMDB51",
98
+ "UCF101",
99
+ "Places365",
100
+ "Kitti",
101
+ "INaturalist",
102
+ "LFWPeople",
103
+ "LFWPairs",
104
+ "KittiFlow",
105
+ "Sintel",
106
+ "FlyingChairs",
107
+ "FlyingThings3D",
108
+ "HD1K",
109
+ "Food101",
110
+ "DTD",
111
+ "FER2013",
112
+ "GTSRB",
113
+ "CLEVRClassification",
114
+ "OxfordIIITPet",
115
+ "PCAM",
116
+ "Country211",
117
+ "FGVCAircraft",
118
+ "EuroSAT",
119
+ "RenderedSST2",
120
+ "Kitti2012Stereo",
121
+ "Kitti2015Stereo",
122
+ "CarlaStereo",
123
+ "Middlebury2014Stereo",
124
+ "CREStereo",
125
+ "FallingThingsStereo",
126
+ "SceneFlowStereo",
127
+ "SintelStereo",
128
+ "InStereo2k",
129
+ "ETH3DStereo",
130
+ )
131
+
132
+
133
+ # We override current module's attributes to handle the import:
134
+ # from torchvision.datasets import wrap_dataset_for_transforms_v2
135
+ # with beta state v2 warning from torchvision.datapoints
136
+ # We also want to avoid raising the warning when importing other attributes
137
+ # from torchvision.datasets
138
+ # Ref: https://peps.python.org/pep-0562/
139
+ def __getattr__(name):
140
+ if name in ("wrap_dataset_for_transforms_v2",):
141
+ from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2
142
+
143
+ return wrap_dataset_for_transforms_v2
144
+
145
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
torchvision/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (3.28 kB). View file
 
torchvision/datasets/__pycache__/_optical_flow.cpython-38.pyc ADDED
Binary file (17.6 kB). View file
 
torchvision/datasets/__pycache__/_stereo_matching.cpython-38.pyc ADDED
Binary file (40.2 kB). View file
 
torchvision/datasets/__pycache__/caltech.cpython-38.pyc ADDED
Binary file (8.1 kB). View file
 
torchvision/datasets/__pycache__/celeba.cpython-38.pyc ADDED
Binary file (7.14 kB). View file
 
torchvision/datasets/__pycache__/cifar.cpython-38.pyc ADDED
Binary file (5.81 kB). View file
 
torchvision/datasets/__pycache__/cityscapes.cpython-38.pyc ADDED
Binary file (8.41 kB). View file
 
torchvision/datasets/__pycache__/clevr.cpython-38.pyc ADDED
Binary file (4.11 kB). View file
 
torchvision/datasets/__pycache__/coco.cpython-38.pyc ADDED
Binary file (5.02 kB). View file
 
torchvision/datasets/__pycache__/country211.cpython-38.pyc ADDED
Binary file (2.79 kB). View file
 
torchvision/datasets/__pycache__/dtd.cpython-38.pyc ADDED
Binary file (4.33 kB). View file
 
torchvision/datasets/__pycache__/eurosat.cpython-38.pyc ADDED
Binary file (2.46 kB). View file
 
torchvision/datasets/__pycache__/fakedata.cpython-38.pyc ADDED
Binary file (2.67 kB). View file
 
torchvision/datasets/__pycache__/fer2013.cpython-38.pyc ADDED
Binary file (3.3 kB). View file
 
torchvision/datasets/__pycache__/fgvc_aircraft.cpython-38.pyc ADDED
Binary file (4.72 kB). View file
 
torchvision/datasets/__pycache__/flickr.cpython-38.pyc ADDED
Binary file (5.27 kB). View file
 
torchvision/datasets/__pycache__/flowers102.cpython-38.pyc ADDED
Binary file (4.6 kB). View file
 
torchvision/datasets/__pycache__/folder.cpython-38.pyc ADDED
Binary file (11.5 kB). View file
 
torchvision/datasets/__pycache__/food101.cpython-38.pyc ADDED
Binary file (4.35 kB). View file
 
torchvision/datasets/__pycache__/gtsrb.cpython-38.pyc ADDED
Binary file (3.74 kB). View file
 
torchvision/datasets/__pycache__/hmdb51.cpython-38.pyc ADDED
Binary file (5.64 kB). View file
 
torchvision/datasets/__pycache__/imagenet.cpython-38.pyc ADDED
Binary file (9.69 kB). View file
 
torchvision/datasets/__pycache__/inaturalist.cpython-38.pyc ADDED
Binary file (8.59 kB). View file
 
torchvision/datasets/__pycache__/kinetics.cpython-38.pyc ADDED
Binary file (9.66 kB). View file