ayousanz commited on
Commit
06cad35
·
verified ·
1 Parent(s): 69087c5

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .venv/Lib/site-packages/torch/lib/cudnn_heuristic64_9.dll +3 -0
  3. .venv/Lib/site-packages/torch/lib/cudnn_ops64_9.dll +3 -0
  4. .venv/Lib/site-packages/torch/lib/sleef.lib +3 -0
  5. .venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py +0 -0
  6. .venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp +35 -0
  7. .venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp +68 -0
  8. .venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py +907 -0
  9. .venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h +0 -0
  10. .venv/Lib/site-packages/torch/utils/bottleneck/__init__.py +0 -0
  11. .venv/Lib/site-packages/torch/utils/bottleneck/__main__.py +230 -0
  12. .venv/Lib/site-packages/torch/utils/data/__init__.py +77 -0
  13. .venv/Lib/site-packages/torch/utils/data/_utils/__init__.py +54 -0
  14. .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  15. .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-39.pyc +0 -0
  16. .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-39.pyc +0 -0
  17. .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-39.pyc +0 -0
  18. .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-39.pyc +0 -0
  19. .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-39.pyc +0 -0
  20. .venv/Lib/site-packages/torch/utils/data/_utils/collate.py +398 -0
  21. .venv/Lib/site-packages/torch/utils/data/_utils/fetch.py +55 -0
  22. .venv/Lib/site-packages/torch/utils/data/_utils/pin_memory.py +108 -0
  23. .venv/Lib/site-packages/torch/utils/data/_utils/signal_handling.py +79 -0
  24. .venv/Lib/site-packages/torch/utils/data/_utils/worker.py +376 -0
  25. .venv/Lib/site-packages/torch/utils/data/backward_compatibility.py +11 -0
  26. .venv/Lib/site-packages/torch/utils/data/dataloader.py +1604 -0
  27. .venv/Lib/site-packages/torch/utils/data/datapipes/__init__.py +1 -0
  28. .venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-39.pyc +0 -0
  29. .venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-39.pyc +0 -0
  30. .venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-39.pyc +0 -0
  31. .venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-39.pyc +0 -0
  32. .venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-39.pyc +0 -0
  33. .venv/Lib/site-packages/torch/utils/data/datapipes/_decorator.py +213 -0
  34. .venv/Lib/site-packages/torch/utils/data/datapipes/_hook_iterator.py +279 -0
  35. .venv/Lib/site-packages/torch/utils/data/datapipes/_typing.py +486 -0
  36. .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__init__.py +11 -0
  37. .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-39.pyc +0 -0
  38. .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-39.pyc +0 -0
  39. .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-39.pyc +0 -0
  40. .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-39.pyc +0 -0
  41. .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-39.pyc +0 -0
  42. .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py +128 -0
  43. .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py +457 -0
  44. .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py +134 -0
  45. .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/structures.py +20 -0
  46. .venv/Lib/site-packages/torch/utils/data/datapipes/datapipe.py +415 -0
  47. .venv/Lib/site-packages/torch/utils/data/datapipes/datapipe.pyi +697 -0
  48. .venv/Lib/site-packages/torch/utils/data/datapipes/gen_pyi.py +305 -0
  49. .venv/Lib/site-packages/torch/utils/data/datapipes/iter/__init__.py +65 -0
  50. .venv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-39.pyc +0 -0
.gitattributes CHANGED
@@ -68,3 +68,6 @@ reference_sample_wavs/syuukovoice_200918_3_01.wav filter=lfs diff=lfs merge=lfs
68
  .venv/Lib/site-packages/torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs -text
69
  .venv/Lib/site-packages/torch/lib/curand64_10.dll filter=lfs diff=lfs merge=lfs -text
70
  .venv/Lib/site-packages/torch/lib/cusolverMg64_11.dll filter=lfs diff=lfs merge=lfs -text
 
 
 
 
68
  .venv/Lib/site-packages/torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs -text
69
  .venv/Lib/site-packages/torch/lib/curand64_10.dll filter=lfs diff=lfs merge=lfs -text
70
  .venv/Lib/site-packages/torch/lib/cusolverMg64_11.dll filter=lfs diff=lfs merge=lfs -text
71
+ .venv/Lib/site-packages/torch/lib/cudnn_heuristic64_9.dll filter=lfs diff=lfs merge=lfs -text
72
+ .venv/Lib/site-packages/torch/lib/sleef.lib filter=lfs diff=lfs merge=lfs -text
73
+ .venv/Lib/site-packages/torch/lib/cudnn_ops64_9.dll filter=lfs diff=lfs merge=lfs -text
.venv/Lib/site-packages/torch/lib/cudnn_heuristic64_9.dll ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee6d4831251387ab52a549df7ce7e5256272426eeef23a36d172ca8c725afba1
3
+ size 85741608
.venv/Lib/site-packages/torch/lib/cudnn_ops64_9.dll ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e26d81b20ceda0fff0fce1b60f5c4a7c0b32650afa2ab49f0ea4496816bead5b
3
+ size 107721256
.venv/Lib/site-packages/torch/lib/sleef.lib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55eb52de5e0e99ed7cbeeadb0b1e7523bd09278b1160145bd15e200a9df3139a
3
+ size 8862502
.venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py ADDED
File without changes
.venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Used to collect profiles of old versions of PyTorch. */
2
+ #include <callgrind.h>
3
+ #include <pybind11/pybind11.h>
4
+
5
+ bool _valgrind_supported_platform() {
6
+ #if defined(NVALGRIND)
7
+ return false;
8
+ #else
9
+ return true;
10
+ #endif
11
+ }
12
+
13
+ void _valgrind_toggle() {
14
+ #if defined(NVALGRIND)
15
+ TORCH_CHECK(false, "Valgrind is not supported.");
16
+ #else
17
+ CALLGRIND_TOGGLE_COLLECT;
18
+ #endif
19
+ }
20
+
21
+ void _valgrind_toggle_and_dump_stats() {
22
+ #if defined(NVALGRIND)
23
+ TORCH_CHECK(false, "Valgrind is not supported.");
24
+ #else
25
+ // NB: See note in Module.cpp
26
+ CALLGRIND_TOGGLE_COLLECT;
27
+ CALLGRIND_DUMP_STATS;
28
+ #endif
29
+ }
30
+
31
+ PYBIND11_MODULE(callgrind_bindings, m) {
32
+ m.def("_valgrind_supported_platform", &_valgrind_supported_platform);
33
+ m.def("_valgrind_toggle", &_valgrind_toggle);
34
+ m.def("_valgrind_toggle_and_dump_stats", &_valgrind_dump_stats);
35
+ }
.venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* C++ template for Timer.collect_callgrind
2
+
3
+ This template will be consumed by `cpp_jit.py`, and will replace:
4
+ `GLOBAL_SETUP_TEMPLATE_LOCATION`,
5
+ `SETUP_TEMPLATE_LOCATION`
6
+ and
7
+ `STMT_TEMPLATE_LOCATION`
8
+ sections with user provided statements.
9
+ */
10
+
11
+ #include <c10/util/irange.h>
12
+ #include <callgrind.h>
13
+ #include <torch/torch.h>
14
+
15
+ #include <string>
16
+
17
+ // Global setup. (e.g. #includes)
18
+ // GLOBAL_SETUP_TEMPLATE_LOCATION
19
+
20
+ #if defined(NVALGRIND)
21
+ static_assert(false);
22
+ #endif
23
+
24
+ int main(int argc, char* argv[]) {
25
+ // This file should only be called inside of `Timer`, so we can adopt a
26
+ // very simple and rigid argument parsing scheme.
27
+ TORCH_CHECK(argc == 9);
28
+ TORCH_CHECK(std::string(argv[1]) == "--number");
29
+ auto number = std::stoi(argv[2]);
30
+
31
+ TORCH_CHECK(
32
+ std::string(argv[3]) == "--number-warmup" ||
33
+ std::string(argv[3]) == "--number_warmup");
34
+ auto number_warmup = std::stoi(argv[4]);
35
+
36
+ TORCH_CHECK(std::string(argv[5]) == "--repeats");
37
+ auto repeats = std::stoi(argv[6]);
38
+
39
+ TORCH_CHECK(
40
+ std::string(argv[7]) == "--number-threads" ||
41
+ std::string(argv[7]) == "--number_threads");
42
+ auto number_threads = std::stoi(argv[8]);
43
+ torch::set_num_threads(number_threads);
44
+
45
+ // Setup
46
+ // SETUP_TEMPLATE_LOCATION
47
+
48
+ // Warmup
49
+ for (const auto i : c10::irange(number_warmup)) {
50
+ (void)i;
51
+ // STMT_TEMPLATE_LOCATION
52
+ }
53
+
54
+ // Main loop
55
+ for (const auto repeat : c10::irange(repeats)) {
56
+ (void)repeat;
57
+ CALLGRIND_TOGGLE_COLLECT;
58
+
59
+ for (const auto i : c10::irange(number)) {
60
+ (void)i;
61
+ // STMT_TEMPLATE_LOCATION
62
+ }
63
+
64
+ // NB: See note in Module.cpp
65
+ CALLGRIND_TOGGLE_COLLECT;
66
+ CALLGRIND_DUMP_STATS;
67
+ }
68
+ }
.venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Intermediate layer between `Timer` and `valgrind`."""
2
+ import collections
3
+ import enum
4
+ import dataclasses
5
+ import itertools as it
6
+ import os
7
+ import pickle
8
+ import re
9
+ import shutil
10
+ import subprocess
11
+ import sys
12
+ import textwrap
13
+ from typing import (
14
+ cast, Any, Callable, DefaultDict, Dict, Iterator, List, NamedTuple,
15
+ Optional, Tuple, Union, TYPE_CHECKING)
16
+
17
+ import torch
18
+ from torch.utils.benchmark.utils import common, cpp_jit
19
+ from torch.utils.benchmark.utils._stubs import CallgrindModuleType
20
+ import operator
21
+
22
+
23
+ __all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ CompletedProcessType = subprocess.CompletedProcess[str]
28
+ else:
29
+ CompletedProcessType = subprocess.CompletedProcess
30
+
31
+
32
+ class FunctionCount(NamedTuple):
33
+ # TODO(#105471): Rename the count field
34
+ count: int # type: ignore[assignment]
35
+ function: str
36
+
37
+
38
+ @dataclasses.dataclass(repr=False, eq=False, frozen=True)
39
+ class FunctionCounts:
40
+ """Container for manipulating Callgrind results.
41
+
42
+ It supports:
43
+ 1) Addition and subtraction to combine or diff results.
44
+ 2) Tuple-like indexing.
45
+ 3) A `denoise` function which strips CPython calls which are known to
46
+ be non-deterministic and quite noisy.
47
+ 4) Two higher order methods (`filter` and `transform`) for custom
48
+ manipulation.
49
+ """
50
+ _data: Tuple[FunctionCount, ...]
51
+ inclusive: bool
52
+ truncate_rows: bool = True
53
+
54
+ # For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines
55
+ # the print settings. This is simply to allow hermetic unit tests.
56
+ _linewidth: Optional[int] = None
57
+
58
+ def __iter__(self) -> Iterator[FunctionCount]:
59
+ yield from self._data
60
+
61
+ def __len__(self) -> int:
62
+ return len(self._data)
63
+
64
+ def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]:
65
+ data: Union[FunctionCount, Tuple[FunctionCount, ...]] = self._data[item]
66
+ return (
67
+ FunctionCounts(cast(Tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False)
68
+ if isinstance(data, tuple) else data
69
+ )
70
+
71
+ def __repr__(self) -> str:
72
+ count_len = 0
73
+ for c, _ in self:
74
+ # Account for sign in string length.
75
+ count_len = max(count_len, len(str(c)) + int(c < 0))
76
+
77
+ lines = []
78
+ linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth
79
+ fn_str_len = max(linewidth - count_len - 4, 40)
80
+ for c, fn in self:
81
+ if len(fn) > fn_str_len:
82
+ left_len = int((fn_str_len - 5) // 2)
83
+ fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):]
84
+ lines.append(f" {c:>{count_len}} {fn}")
85
+
86
+ if self.truncate_rows and len(lines) > 18:
87
+ lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:]
88
+
89
+ if not self.inclusive:
90
+ lines.extend(["", f"Total: {self.sum()}"])
91
+
92
+ return "\n".join([super().__repr__()] + lines)
93
+
94
+ def __add__(
95
+ self,
96
+ other: "FunctionCounts",
97
+ ) -> "FunctionCounts":
98
+ return self._merge(other, lambda c: c)
99
+
100
+ def __sub__(
101
+ self,
102
+ other: "FunctionCounts",
103
+ ) -> "FunctionCounts":
104
+ return self._merge(other, operator.neg)
105
+
106
+ def __mul__(self, other: Union[int, float]) -> "FunctionCounts":
107
+ return self._from_dict({
108
+ fn: int(c * other) for c, fn in self._data
109
+ }, self.inclusive)
110
+
111
+ def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts":
112
+ """Apply `map_fn` to all of the function names.
113
+
114
+ This can be used to regularize function names (e.g. stripping irrelevant
115
+ parts of the file path), coalesce entries by mapping multiple functions
116
+ to the same name (in which case the counts are added together), etc.
117
+ """
118
+ counts: DefaultDict[str, int] = collections.defaultdict(int)
119
+ for c, fn in self._data:
120
+ counts[map_fn(fn)] += c
121
+
122
+ return self._from_dict(counts, self.inclusive)
123
+
124
+ def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts":
125
+ """Keep only the elements where `filter_fn` applied to function name returns True."""
126
+ return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive)
127
+
128
+ def sum(self) -> int:
129
+ return sum(c for c, _ in self)
130
+
131
+ def denoise(self) -> "FunctionCounts":
132
+ """Remove known noisy instructions.
133
+
134
+ Several instructions in the CPython interpreter are rather noisy. These
135
+ instructions involve unicode to dictionary lookups which Python uses to
136
+ map variable names. FunctionCounts is generally a content agnostic
137
+ container, however this is sufficiently important for obtaining
138
+ reliable results to warrant an exception."""
139
+ return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn)
140
+
141
+ def _merge(
142
+ self,
143
+ second: "FunctionCounts",
144
+ merge_fn: Callable[[int], int]
145
+ ) -> "FunctionCounts":
146
+ assert self.inclusive == second.inclusive, "Cannot merge inclusive and exclusive counts."
147
+ counts: DefaultDict[str, int] = collections.defaultdict(int)
148
+ for c, fn in self:
149
+ counts[fn] += c
150
+
151
+ for c, fn in second:
152
+ counts[fn] += merge_fn(c)
153
+
154
+ return self._from_dict(counts, self.inclusive)
155
+
156
+ @staticmethod
157
+ def _from_dict(counts: Dict[str, int], inclusive: bool) -> "FunctionCounts":
158
+ flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c)
159
+ return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive)
160
+
161
+
162
+ @dataclasses.dataclass(repr=False, eq=False, frozen=True)
163
+ class CallgrindStats:
164
+ """Top level container for Callgrind results collected by Timer.
165
+
166
+ Manipulation is generally done using the FunctionCounts class, which is
167
+ obtained by calling `CallgrindStats.stats(...)`. Several convenience
168
+ methods are provided as well; the most significant is
169
+ `CallgrindStats.as_standardized()`.
170
+ """
171
+ task_spec: common.TaskSpec
172
+ number_per_run: int
173
+ built_with_debug_symbols: bool
174
+ baseline_inclusive_stats: FunctionCounts
175
+ baseline_exclusive_stats: FunctionCounts
176
+ stmt_inclusive_stats: FunctionCounts
177
+ stmt_exclusive_stats: FunctionCounts
178
+ stmt_callgrind_out: Optional[str]
179
+
180
+ def __repr__(self) -> str:
181
+ newline = "\n" # `\` cannot appear in fstring code section.
182
+ base_stats = self.baseline_exclusive_stats
183
+ output = f"""
184
+ {super().__repr__()}
185
+ {self.task_spec.summarize()}
186
+ {'':>25}All{'':>10}Noisy symbols removed
187
+ Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12}
188
+ Baseline: {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12}
189
+ {self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''}
190
+ """.strip()
191
+ if not self.built_with_debug_symbols:
192
+ output += textwrap.dedent("""
193
+ Warning: PyTorch was not built with debug symbols.
194
+ Source information may be limited. Rebuild with
195
+ REL_WITH_DEB_INFO=1 for more detailed results.""")
196
+ return output
197
+
198
+ def stats(self, inclusive: bool = False) -> FunctionCounts:
199
+ """Returns detailed function counts.
200
+
201
+ Conceptually, the FunctionCounts returned can be thought of as a tuple
202
+ of (count, path_and_function_name) tuples.
203
+
204
+ `inclusive` matches the semantics of callgrind. If True, the counts
205
+ include instructions executed by children. `inclusive=True` is useful
206
+ for identifying hot spots in code; `inclusive=False` is useful for
207
+ reducing noise when diffing counts from two different runs. (See
208
+ CallgrindStats.delta(...) for more details)
209
+ """
210
+ return self.stmt_inclusive_stats if inclusive else self.stmt_exclusive_stats
211
+
212
+ def counts(self, *, denoise: bool = False) -> int:
213
+ """Returns the total number of instructions executed.
214
+
215
+ See `FunctionCounts.denoise()` for an explanation of the `denoise` arg.
216
+ """
217
+ stats = self.stmt_exclusive_stats
218
+ return (stats.denoise() if denoise else stats).sum()
219
+
220
+ # FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563
221
+ def delta(
222
+ self,
223
+ other: "CallgrindStats",
224
+ inclusive: bool = False,
225
+ ) -> FunctionCounts:
226
+ """Diff two sets of counts.
227
+
228
+ One common reason to collect instruction counts is to determine the
229
+ the effect that a particular change will have on the number of instructions
230
+ needed to perform some unit of work. If a change increases that number, the
231
+ next logical question is "why". This generally involves looking at what part
232
+ if the code increased in instruction count. This function automates that
233
+ process so that one can easily diff counts on both an inclusive and
234
+ exclusive basis.
235
+ """
236
+ return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive)
237
+
238
+ def as_standardized(self) -> "CallgrindStats":
239
+ """Strip library names and some prefixes from function strings.
240
+
241
+ When comparing two different sets of instruction counts, on stumbling
242
+ block can be path prefixes. Callgrind includes the full filepath
243
+ when reporting a function (as it should). However, this can cause
244
+ issues when diffing profiles. If a key component such as Python
245
+ or PyTorch was built in separate locations in the two profiles, which
246
+ can result in something resembling::
247
+
248
+ 23234231 /tmp/first_build_dir/thing.c:foo(...)
249
+ 9823794 /tmp/first_build_dir/thing.c:bar(...)
250
+ ...
251
+ 53453 .../aten/src/Aten/...:function_that_actually_changed(...)
252
+ ...
253
+ -9823794 /tmp/second_build_dir/thing.c:bar(...)
254
+ -23234231 /tmp/second_build_dir/thing.c:foo(...)
255
+
256
+ Stripping prefixes can ameliorate this issue by regularizing the
257
+ strings and causing better cancellation of equivalent call sites
258
+ when diffing.
259
+ """
260
+ def strip(stats: FunctionCounts) -> FunctionCounts:
261
+ transforms = (
262
+ # PyTorch may have been built in different locations.
263
+ (r"^.+build/\.\./", "build/../"),
264
+ (r"^.+/" + re.escape("build/aten/"), "build/aten/"),
265
+
266
+ # "Python" and "Objects" come from CPython.
267
+ (r"^.+/" + re.escape("Python/"), "Python/"),
268
+ (r"^.+/" + re.escape("Objects/"), "Objects/"),
269
+
270
+ # Strip library name. e.g. `libtorch.so`
271
+ (r"\s\[.+\]$", ""),
272
+ )
273
+
274
+ for before, after in transforms:
275
+ stats = stats.transform(lambda fn: re.sub(before, after, fn))
276
+
277
+ return stats
278
+
279
+ return CallgrindStats(
280
+ task_spec=self.task_spec,
281
+ number_per_run=self.number_per_run,
282
+ built_with_debug_symbols=self.built_with_debug_symbols,
283
+ baseline_inclusive_stats=strip(self.baseline_inclusive_stats),
284
+ baseline_exclusive_stats=strip(self.baseline_exclusive_stats),
285
+ stmt_inclusive_stats=strip(self.stmt_inclusive_stats),
286
+ stmt_exclusive_stats=strip(self.stmt_exclusive_stats),
287
+
288
+ # `as_standardized` will change symbol names, so the contents will
289
+ # no longer map directly to `callgrind.out`
290
+ stmt_callgrind_out=None,
291
+ )
292
+
293
+
294
+ class Serialization(enum.Enum):
295
+ PICKLE = 0
296
+ TORCH = 1
297
+ TORCH_JIT = 2
298
+
299
+
300
+ _GLOBALS_ALLOWED_TYPES: Dict[Serialization, Tuple[Any, ...]] = {
301
+ Serialization.PICKLE: (str, bytes, bool, int, float, complex),
302
+ Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule),
303
+ Serialization.TORCH: (torch.nn.Module,),
304
+ }
305
+
306
+
307
+ class CopyIfCallgrind:
308
+ """Signal that a global may be replaced with a deserialized copy.
309
+
310
+ See `GlobalsBridge` for why this matters.
311
+ """
312
+ def __init__(self, value: Any, *, setup: Optional[str] = None):
313
+ for method, supported_types in _GLOBALS_ALLOWED_TYPES.items():
314
+ if any(isinstance(value, t) for t in supported_types):
315
+ self._value: Any = value
316
+ self._setup: Optional[str] = setup
317
+ self._serialization: Serialization = method
318
+ break
319
+ else:
320
+ supported_str = "\n".join([
321
+ getattr(t, "__name__", repr(t))
322
+ for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())])
323
+
324
+ raise ValueError(
325
+ f"Unsupported type: {type(value)}\n"
326
+ f"`collect_callgrind` restricts globals to the following types:\n"
327
+ f"{textwrap.indent(supported_str, ' ')}"
328
+ )
329
+
330
+ @property
331
+ def value(self) -> Any:
332
+ return self._value
333
+
334
+ @property
335
+ def setup(self) -> Optional[str]:
336
+ return self._setup
337
+
338
+ @property
339
+ def serialization(self) -> Serialization:
340
+ return self._serialization
341
+
342
+ @staticmethod
343
+ def unwrap_all(globals: Dict[str, Any]) -> Dict[str, Any]:
344
+ return {
345
+ k: (v.value if isinstance(v, CopyIfCallgrind) else v)
346
+ for k, v in globals.items()
347
+ }
348
+
349
+
350
+ class GlobalsBridge:
351
+ """Handle the transfer of (certain) globals when collecting Callgrind statistics.
352
+
353
+ Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to
354
+ work with `Timer.collect_callgrind`.
355
+
356
+ Consider the following code snippet:
357
+ ```
358
+ import pickle
359
+ import timeit
360
+
361
+ class Counter:
362
+ value = 0
363
+
364
+ def __call__(self):
365
+ self.value += 1
366
+
367
+ counter = Counter()
368
+ timeit.Timer("counter()", globals={"counter": counter}).timeit(10)
369
+ print(counter.value) # 10
370
+
371
+ timeit.Timer(
372
+ "counter()",
373
+ globals={"counter": pickle.loads(pickle.dumps(counter))}
374
+ ).timeit(20)
375
+ print(counter.value) # Still 10
376
+ ```
377
+
378
+ In the first case, `stmt` is executed using the objects in `globals`;
379
+ however, the addition of serialization and deserialization changes the
380
+ semantics and may meaningfully change behavior.
381
+
382
+ This is a practical consideration when collecting Callgrind statistics.
383
+ Unlike `exec` based execution (which `timeit` uses under the hood) which
384
+ can share in-memory data structures with the caller, Callgrind collection
385
+ requires an entirely new process in order to run under Valgrind. This means
386
+ that any data structures used for statement execution will have to be
387
+ serialized and deserialized in the subprocess.
388
+
389
+ In order to avoid surprising semantics from (user invisible) process
390
+ boundaries, what can be passed through `globals` is severely restricted
391
+ for `Timer.collect_callgrind`. It is expected that most setup should be
392
+ achievable (albeit perhaps less ergonomically) by passing a `setup`
393
+ string.
394
+
395
+ There are, however, exceptions. One such class are TorchScripted functions.
396
+ Because they require a concrete file with source code it is not possible
397
+ to define them using a `setup` string. Another group are torch.nn.Modules,
398
+ whose construction can be complex and prohibitively cumbersome to coerce
399
+ into a `setup` string. Finally, most builtin types are sufficiently well
400
+ behaved and sufficiently common to warrant allowing as well. (e.g.
401
+ `globals={"n": 1}` is very convenient.)
402
+
403
+ Fortunately, all have well defined serialization semantics. This class
404
+ is responsible for enabling the Valgrind subprocess to use elements in
405
+ `globals` so long as they are an allowed type.
406
+
407
+ Caveats:
408
+ The user is required to acknowledge this serialization by wrapping
409
+ elements in `globals` with `CopyIfCallgrind`.
410
+
411
+ While ScriptFunction and ScriptModule are expected to save and load
412
+ quite robustly, it is up to the user to ensure that an nn.Module can
413
+ un-pickle successfully.
414
+
415
+ `torch.Tensor` and `np.ndarray` are deliberately excluded. The
416
+ serialization/deserialization process perturbs the representation of a
417
+ tensor in ways that could result in incorrect measurements. For example,
418
+ if a tensor lives in pinned CPU memory, this fact would not be preserved
419
+ by a dump, and that will in turn change the performance of certain CUDA
420
+ operations.
421
+ """
422
+
423
+ def __init__(self, globals: Dict[str, Any], data_dir: str) -> None:
424
+ self._globals: Dict[str, CopyIfCallgrind] = {}
425
+ self._data_dir = data_dir
426
+ if not os.path.exists(data_dir):
427
+ os.mkdir(data_dir)
428
+
429
+ if globals.get("torch", torch) is not torch:
430
+ raise ValueError("`collect_callgrind` does not support mocking out `torch`.")
431
+
432
+ for name, value in globals.items():
433
+ if name in ("torch", "__builtins__"):
434
+ # Torch will be imported by the collection script, and
435
+ # __builtins__ is added by Timer.
436
+ continue
437
+
438
+ if not isinstance(value, CopyIfCallgrind):
439
+ raise ValueError(
440
+ "`collect_callgrind` requires that globals be wrapped in "
441
+ "`CopyIfCallgrind` so that serialization is explicit."
442
+ )
443
+
444
+ self._globals[name] = value
445
+
446
+ def construct(self) -> str:
447
+ load_lines = []
448
+ for name, wrapped_value in self._globals.items():
449
+ if wrapped_value.setup is not None:
450
+ load_lines.append(textwrap.dedent(wrapped_value.setup))
451
+
452
+ if wrapped_value.serialization == Serialization.PICKLE:
453
+ path = os.path.join(self._data_dir, f"{name}.pkl")
454
+ load_lines.append(
455
+ f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)")
456
+ with open(path, "wb") as f:
457
+ pickle.dump(wrapped_value.value, f)
458
+
459
+ elif wrapped_value.serialization == Serialization.TORCH:
460
+ path = os.path.join(self._data_dir, f"{name}.pt")
461
+ load_lines.append(f"{name} = torch.load({repr(path)})")
462
+ torch.save(wrapped_value.value, path)
463
+
464
+ elif wrapped_value.serialization == Serialization.TORCH_JIT:
465
+ path = os.path.join(self._data_dir, f"{name}.pt")
466
+ load_lines.append(f"{name} = torch.jit.load({repr(path)})")
467
+ with open(path, "wb") as f:
468
+ torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call]
469
+
470
+ else:
471
+ raise NotImplementedError(
472
+ f"Unknown serialization method: {wrapped_value.serialization}")
473
+
474
+ return "\n".join(load_lines)
475
+
476
+
477
+ class _ValgrindWrapper:
478
+ def __init__(self) -> None:
479
+ self._bindings_module: Optional[CallgrindModuleType] = None
480
+ valgrind_symbols = (
481
+ "_valgrind_supported_platform",
482
+ "_valgrind_toggle",
483
+ "_valgrind_toggle_and_dump_stats",
484
+ )
485
+ if all(hasattr(torch._C, symbol) for symbol in valgrind_symbols):
486
+ self._supported_platform: bool = torch._C._valgrind_supported_platform()
487
+
488
+ else:
489
+ print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.")
490
+ self._bindings_module = cpp_jit.get_compat_bindings()
491
+ assert all(hasattr(self._bindings_module, symbol) for symbol in valgrind_symbols)
492
+ self._supported_platform = self._bindings_module._valgrind_supported_platform()
493
+
494
+ self._commands_available: Dict[str, bool] = {}
495
+ if self._supported_platform:
496
+ # Only bother checking on supported platforms.
497
+ for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"):
498
+ self._commands_available[cmd] = not subprocess.run(
499
+ ["which", cmd],
500
+ capture_output=True,
501
+ check=False,
502
+ ).returncode
503
+
504
+ self._build_type: Optional[str] = None
505
+ build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show()) # type: ignore[no-untyped-call]
506
+ if build_search is not None:
507
+ self._build_type = build_search.groups()[0].split(",")[0]
508
+
509
+ def _validate(self) -> None:
510
+ if not self._supported_platform:
511
+ raise OSError("Valgrind is not supported on this platform.")
512
+
513
+ missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available]
514
+ if missing_cmds:
515
+ raise OSError("Missing: " + ", ".join(missing_cmds))
516
+
517
+ def collect_callgrind(
518
+ self,
519
+ task_spec: common.TaskSpec,
520
+ globals: Dict[str, Any],
521
+ *,
522
+ number: int,
523
+ repeats: int,
524
+ collect_baseline: bool,
525
+ is_python: bool,
526
+ retain_out_file: bool,
527
+ ) -> Tuple[CallgrindStats, ...]:
528
+ """Collect stats, and attach a reference run which can be used to filter interpreter overhead."""
529
+ self._validate()
530
+ assert is_python or not collect_baseline
531
+
532
+ *task_stats, baseline_stats = self._invoke(
533
+ task_spec=task_spec,
534
+ globals=globals,
535
+ number=number,
536
+ repeats=repeats,
537
+ collect_baseline=collect_baseline,
538
+ is_python=is_python,
539
+ retain_out_file=retain_out_file,
540
+ )
541
+ assert len(task_stats) == repeats
542
+
543
+ return tuple(
544
+ CallgrindStats(
545
+ task_spec=task_spec,
546
+ number_per_run=number,
547
+ built_with_debug_symbols=self._build_type == "RelWithDebInfo",
548
+ baseline_inclusive_stats=baseline_stats[0],
549
+ baseline_exclusive_stats=baseline_stats[1],
550
+ stmt_inclusive_stats=stmt_inclusive_stats,
551
+ stmt_exclusive_stats=stmt_exclusive_stats,
552
+ stmt_callgrind_out=out_contents,
553
+ )
554
+ for stmt_inclusive_stats, stmt_exclusive_stats, out_contents in task_stats
555
+ )
556
+
557
+ def _invoke(
558
+ self,
559
+ *,
560
+ task_spec: common.TaskSpec,
561
+ globals: Dict[str, Any],
562
+ number: int,
563
+ repeats: int,
564
+ collect_baseline: bool,
565
+ is_python: bool,
566
+ retain_out_file: bool,
567
+ ) -> Tuple[Tuple[FunctionCounts, FunctionCounts, Optional[str]], ...]:
568
+ """Core invocation method for Callgrind collection.
569
+
570
+ Valgrind operates by effectively replacing the CPU with an emulated
571
+ version which allows it to instrument any code at the cost of severe
572
+ performance degradation. This has the practical effect that in order
573
+ to collect Callgrind statistics, a new process has to be created
574
+ running under `valgrind`. The steps for this process are:
575
+
576
+ 1) Create a scratch directory.
577
+ 2) Codegen a run script. (_ValgrindWrapper._construct_script)
578
+ Inside the run script:
579
+ * Validate that Python and torch match the parent process
580
+ * Validate that it is indeed running under valgrind
581
+ * Execute `setup` and warm up `stmt`
582
+ * Begin collecting stats
583
+ * Run the `stmt` loop
584
+ * Stop collecting stats
585
+ 3) Parse the run results.
586
+ 4) Cleanup the scratch directory.
587
+ """
588
+ working_dir = common._make_temp_dir(prefix="callgrind")
589
+ data_dir = os.path.join(working_dir, "data")
590
+ script_file = os.path.join(working_dir, "timer_callgrind.py")
591
+ callgrind_out = os.path.join(working_dir, "callgrind.out")
592
+ error_log = os.path.join(working_dir, "error.txt")
593
+ stat_log = os.path.join(working_dir, "callgrind_stat.txt")
594
+ stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log")
595
+
596
+ def run(args: List[str], **kwargs: Any) -> Tuple[CompletedProcessType, str]:
597
+ # https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/
598
+ f_stdout_stderr = open(stdout_stderr_log, "wb")
599
+ try:
600
+ invocation = subprocess.run(
601
+ args,
602
+ stdout=f_stdout_stderr,
603
+ stderr=subprocess.STDOUT,
604
+ **kwargs,
605
+ )
606
+ with open(stdout_stderr_log) as f:
607
+ return invocation, f.read()
608
+ finally:
609
+ f_stdout_stderr.close()
610
+
611
+ try:
612
+ if is_python:
613
+ if self._bindings_module is not None:
614
+ shutil.copy(
615
+ self._bindings_module.__file__,
616
+ os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1])
617
+ )
618
+
619
+ script_file = os.path.join(working_dir, "timer_callgrind.py")
620
+ with open(script_file, "w") as f:
621
+ f.write(self._construct_script(
622
+ task_spec,
623
+ globals=GlobalsBridge(globals, data_dir),
624
+ number=number,
625
+ repeats=repeats,
626
+ collect_baseline=collect_baseline,
627
+ error_log=error_log,
628
+ stat_log=stat_log,
629
+ bindings=self._bindings_module))
630
+
631
+ run_loop_cmd = ["python", script_file]
632
+ else:
633
+ assert not collect_baseline
634
+ run_loop_exec = cpp_jit.compile_callgrind_template(
635
+ stmt=task_spec.stmt,
636
+ setup=task_spec.setup,
637
+ global_setup=task_spec.global_setup,
638
+ )
639
+ run_loop_cmd = [
640
+ run_loop_exec,
641
+ "--number", str(number),
642
+ "--number-warmup", str(min(number, 10)),
643
+ "--repeats", str(repeats),
644
+ "--number-threads", str(task_spec.num_threads),
645
+ ]
646
+
647
+ valgrind_invocation, valgrind_invocation_output = run([
648
+ "valgrind",
649
+ "--tool=callgrind",
650
+ f"--callgrind-out-file={callgrind_out}",
651
+ "--dump-line=yes",
652
+ "--dump-instr=yes",
653
+ "--instr-atstart=yes",
654
+ "--collect-atstart=no",
655
+ ] + run_loop_cmd)
656
+
657
+ if valgrind_invocation.returncode:
658
+ error_report = ""
659
+ if os.path.exists(error_log):
660
+ with open(error_log) as f:
661
+ error_report = f.read()
662
+ if not error_report:
663
+ error_report = "Unknown error.\n" + valgrind_invocation_output
664
+
665
+ raise OSError(f"Failed to collect callgrind profile:\n{error_report}")
666
+
667
+ def parse_output(fpath: str, inclusive: bool) -> FunctionCounts:
668
+ annotate_invocation, annotate_invocation_output = run([
669
+ "callgrind_annotate",
670
+ f"--inclusive={'yes' if inclusive else 'no'}",
671
+ "--threshold=100",
672
+ "--show-percs=no",
673
+ fpath
674
+ ], check=True)
675
+
676
+ total_pattern = re.compile(r"^([0-9,]+)\s+PROGRAM TOTALS")
677
+ begin_pattern = re.compile(r"Ir\s+file:function")
678
+ function_pattern = re.compile(r"^\s*([0-9,]+)\s+(.+:.+)$")
679
+
680
+ class ScanState(enum.Enum):
681
+ SCANNING_FOR_TOTAL = 0
682
+ SCANNING_FOR_START = 1
683
+ PARSING = 2
684
+
685
+ scan_state = ScanState.SCANNING_FOR_TOTAL
686
+ fn_counts = []
687
+ for l in annotate_invocation_output.splitlines(keepends=False):
688
+ if scan_state == ScanState.SCANNING_FOR_TOTAL:
689
+ total_match = total_pattern.match(l)
690
+ if total_match:
691
+ program_totals = int(total_match.groups()[0].replace(",", ""))
692
+ scan_state = ScanState.SCANNING_FOR_START
693
+
694
+ elif scan_state == ScanState.SCANNING_FOR_START:
695
+ if begin_pattern.match(l):
696
+ scan_state = ScanState.PARSING
697
+
698
+ else:
699
+ assert scan_state == ScanState.PARSING
700
+ fn_match = function_pattern.match(l)
701
+ if fn_match:
702
+ ir_str, file_function = fn_match.groups()
703
+ ir = int(ir_str.replace(",", ""))
704
+ if ir == program_totals: # type: ignore[possibly-undefined]
705
+ # Callgrind includes some top level red herring symbols when
706
+ # a program dumps multiple profiles.
707
+ continue
708
+ fn_counts.append(FunctionCount(ir, file_function))
709
+
710
+ elif re.match(r"-+", l):
711
+ # Ignore heading separator lines.
712
+ continue
713
+
714
+ else:
715
+ break
716
+
717
+ assert scan_state == ScanState.PARSING, f"Failed to parse {fpath}"
718
+ return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive)
719
+
720
+ def read_results(i: int) -> Tuple[FunctionCounts, FunctionCounts, Optional[str]]:
721
+ if i == repeats and not collect_baseline:
722
+ # Null baseline.
723
+ return (
724
+ FunctionCounts((), inclusive=True),
725
+ FunctionCounts((), inclusive=False),
726
+ None,
727
+ )
728
+
729
+ fpath = f"{callgrind_out}.{i + 1}" # Callgrind one-indexes files.
730
+ callgrind_out_contents: Optional[str] = None
731
+ if retain_out_file:
732
+ with open(fpath) as f:
733
+ callgrind_out_contents = f.read()
734
+
735
+ return (
736
+ parse_output(fpath, inclusive=True),
737
+ parse_output(fpath, inclusive=False),
738
+ callgrind_out_contents
739
+ )
740
+
741
+ return tuple(read_results(i) for i in range(repeats + 1))
742
+ finally:
743
+ shutil.rmtree(working_dir)
744
+
745
+ @staticmethod
746
+ def _construct_script(
747
+ task_spec: common.TaskSpec,
748
+ globals: GlobalsBridge,
749
+ *,
750
+ number: int,
751
+ repeats: int,
752
+ collect_baseline: bool,
753
+ error_log: str,
754
+ stat_log: str,
755
+ bindings: Optional[CallgrindModuleType],
756
+ ) -> str:
757
+ def block_stmt(stmt: str, indent: int = 0) -> str:
758
+ """Partially unroll benchmark loop.
759
+
760
+ The naive template looks something like:
761
+ "for _ in range({number}): {stmt}"
762
+
763
+ However a loop in Python is surprisingly expensive, and significantly
764
+ increases the number of background Python instructions. So instead we
765
+ partially unroll the loops, with a block size of 100 chosen to keep
766
+ the instruction overhead from `range` low while also not ballooning
767
+ the size of the generated file.
768
+ """
769
+ block_size = 100
770
+ loop_count = number // block_size
771
+ if loop_count == 1:
772
+ # There is no point in having `for _ in range(1): ...` rather
773
+ # than just `...`, and this lets us save shave a few background
774
+ # instructions.
775
+ loop_count = 0
776
+ remainder = number - block_size * loop_count
777
+ blocked_stmt = ""
778
+
779
+ if loop_count:
780
+ unrolled_stmts = textwrap.indent("\n".join([stmt] * block_size), " " * 4)
781
+ blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n"
782
+
783
+ if remainder:
784
+ blocked_stmt += "\n".join([stmt] * remainder)
785
+
786
+ return textwrap.indent(blocked_stmt, " " * indent)
787
+
788
+ pass_baseline = (
789
+ "callgrind_bindings._valgrind_toggle()\n"
790
+ f"{block_stmt('pass')}\n"
791
+ "callgrind_bindings._valgrind_toggle_and_dump_stats()"
792
+ )
793
+
794
+ return textwrap.dedent(r"""
795
+ import gc
796
+ import os
797
+ import pickle
798
+ import subprocess
799
+ import sys
800
+ import time
801
+
802
+ # Mitigate https://github.com/pytorch/pytorch/issues/37377
803
+ # which can sometimes cause the subprocess call to fail.
804
+ import numpy as np
805
+
806
+ import torch
807
+ torch.set_num_threads({num_threads})
808
+
809
+ {bindings_import}
810
+
811
+ PID = os.getpid()
812
+
813
+ def log_failure(msg):
814
+ with open({error_log_repr}, "wt") as f:
815
+ f.write(msg)
816
+ sys.exit(1)
817
+
818
+ def check_result(completed_process):
819
+ if completed_process.returncode:
820
+ log_failure(f"Command failed: {{' '.join(completed_process.args)}}")
821
+ return completed_process
822
+
823
+ # =============================================================================
824
+ # == Check that subprocess matches parent =====================================
825
+ # =============================================================================
826
+ if os.path.realpath(sys.executable) != "{parent_interpreter}":
827
+ log_failure(
828
+ "Interpreter mismatch:\n"
829
+ f" {{os.path.realpath(sys.executable)}}\n vs.\n {parent_interpreter}"
830
+ )
831
+
832
+ if torch.__file__ != "{torch_file}":
833
+ log_failure(
834
+ "PyTorch does not match expected file:\n"
835
+ f" {{torch.__file__}}\n vs.\n {torch_file}"
836
+ )
837
+
838
+ # =============================================================================
839
+ # == User specified setup =====================================================
840
+ # =============================================================================
841
+ # Load serialized globals
842
+ {load_globals}
843
+
844
+ # User setup str
845
+ {setup}
846
+
847
+ for _ in range({warmup_number}):
848
+ {indented_stmt}
849
+
850
+ # =============================================================================
851
+ # == Callgrind management =====================================================
852
+ # =============================================================================
853
+ with open("{stat_log}", "wb") as stat_file:
854
+ # If many instances of callgrind are running at once, the output of
855
+ # `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE`
856
+ # to deadlock. So instead we use a file.
857
+ callgrind_stat = check_result(subprocess.run(
858
+ ["callgrind_control", "--stat"],
859
+ stdout=stat_file,
860
+ stderr=subprocess.STDOUT,
861
+ ))
862
+
863
+ with open("{stat_log}", "rt") as stat_file:
864
+ stat_lines = stat_file.read().splitlines()
865
+
866
+ if f"PID {{PID}}: python {{__file__}}" not in stat_lines:
867
+ log_failure("Process does not appear to be running callgrind.")
868
+
869
+ gc.collect()
870
+ time.sleep(0.01)
871
+
872
+ # =============================================================================
873
+ # == User code block ==========================================================
874
+ # =============================================================================
875
+ for _ in range({repeats}):
876
+ callgrind_bindings._valgrind_toggle()
877
+ {blocked_stmt}
878
+ callgrind_bindings._valgrind_toggle_and_dump_stats()
879
+ gc.collect()
880
+
881
+ {baseline}
882
+ """).strip().format(
883
+ indented_stmt=textwrap.indent(task_spec.stmt, " " * 4),
884
+ blocked_stmt=block_stmt(task_spec.stmt, indent=4),
885
+ baseline=(pass_baseline if collect_baseline else ""),
886
+ number=number,
887
+ repeats=repeats,
888
+ load_globals=globals.construct(),
889
+ setup=task_spec.setup,
890
+ warmup_number=min(number, 10),
891
+ num_threads=task_spec.num_threads,
892
+ error_log_repr=repr(error_log),
893
+ stat_log=stat_log,
894
+ parent_interpreter=os.path.realpath(sys.executable),
895
+ torch_file=torch.__file__,
896
+ bindings_import=(
897
+ "import torch._C as callgrind_bindings" if bindings is None
898
+ else f"import {bindings.__name__} as callgrind_bindings"),
899
+ )
900
+
901
+
902
+ CALLGRIND_SINGLETON: Optional[_ValgrindWrapper] = None
903
+ def wrapper_singleton() -> _ValgrindWrapper:
904
+ global CALLGRIND_SINGLETON
905
+ if CALLGRIND_SINGLETON is None:
906
+ CALLGRIND_SINGLETON = _ValgrindWrapper()
907
+ return CALLGRIND_SINGLETON
.venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h ADDED
The diff for this file is too large to render. See raw diff
 
.venv/Lib/site-packages/torch/utils/bottleneck/__init__.py ADDED
File without changes
.venv/Lib/site-packages/torch/utils/bottleneck/__main__.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import argparse
3
+ import cProfile
4
+ import pstats
5
+ import sys
6
+ import os
7
+ from typing import Dict
8
+
9
+ import torch
10
+ from torch.autograd import profiler
11
+ from torch.utils.collect_env import get_env_info
12
+
13
+
14
+ def redirect_argv(new_argv):
15
+ sys.argv[:] = new_argv[:]
16
+
17
+
18
+ def compiled_with_cuda(sysinfo):
19
+ if sysinfo.cuda_compiled_version:
20
+ return f'compiled w/ CUDA {sysinfo.cuda_compiled_version}'
21
+ return 'not compiled w/ CUDA'
22
+
23
+
24
+ env_summary = """
25
+ --------------------------------------------------------------------------------
26
+ Environment Summary
27
+ --------------------------------------------------------------------------------
28
+ PyTorch {pytorch_version}{debug_str} {cuda_compiled}
29
+ Running with Python {py_version} and {cuda_runtime}
30
+
31
+ `{pip_version} list` truncated output:
32
+ {pip_list_output}
33
+ """.strip()
34
+
35
+
36
+ def run_env_analysis():
37
+ print('Running environment analysis...')
38
+ info = get_env_info()
39
+
40
+ result: Dict[str, str] = {}
41
+
42
+ debug_str = ''
43
+ if info.is_debug_build:
44
+ debug_str = ' DEBUG'
45
+
46
+ cuda_avail = ''
47
+ if info.is_cuda_available:
48
+ cuda = info.cuda_runtime_version
49
+ if cuda is not None:
50
+ cuda_avail = 'CUDA ' + cuda
51
+ else:
52
+ cuda = 'CUDA unavailable'
53
+
54
+ pip_version = info.pip_version
55
+ pip_list_output = info.pip_packages
56
+ if pip_list_output is None:
57
+ pip_list_output = 'Unable to fetch'
58
+
59
+ result = {
60
+ 'debug_str': debug_str,
61
+ 'pytorch_version': info.torch_version,
62
+ 'cuda_compiled': compiled_with_cuda(info),
63
+ 'py_version': f'{sys.version_info[0]}.{sys.version_info[1]}',
64
+ 'cuda_runtime': cuda_avail,
65
+ 'pip_version': pip_version,
66
+ 'pip_list_output': pip_list_output,
67
+ }
68
+
69
+ return env_summary.format(**result)
70
+
71
+
72
+ def run_cprofile(code, globs, launch_blocking=False):
73
+ print('Running your script with cProfile')
74
+ prof = cProfile.Profile()
75
+ prof.enable()
76
+ exec(code, globs, None)
77
+ prof.disable()
78
+ return prof
79
+
80
+
81
+ cprof_summary = """
82
+ --------------------------------------------------------------------------------
83
+ cProfile output
84
+ --------------------------------------------------------------------------------
85
+ """.strip()
86
+
87
+
88
+ def print_cprofile_summary(prof, sortby='tottime', topk=15):
89
+ print(cprof_summary)
90
+ cprofile_stats = pstats.Stats(prof).sort_stats(sortby)
91
+ cprofile_stats.print_stats(topk)
92
+
93
+
94
+ def run_autograd_prof(code, globs):
95
+ def run_prof(use_cuda=False):
96
+ with profiler.profile(use_cuda=use_cuda) as prof:
97
+ exec(code, globs, None)
98
+ return prof
99
+
100
+ print('Running your script with the autograd profiler...')
101
+ result = [run_prof(use_cuda=False)]
102
+ if torch.cuda.is_available():
103
+ result.append(run_prof(use_cuda=True))
104
+ else:
105
+ result.append(None)
106
+
107
+ return result
108
+
109
+
110
+ autograd_prof_summary = """
111
+ --------------------------------------------------------------------------------
112
+ autograd profiler output ({mode} mode)
113
+ --------------------------------------------------------------------------------
114
+ {description}
115
+ {cuda_warning}
116
+ {output}
117
+ """.strip()
118
+
119
+
120
+ def print_autograd_prof_summary(prof, mode, sortby='cpu_time', topk=15):
121
+ valid_sortby = ['cpu_time', 'cuda_time', 'cpu_time_total', 'cuda_time_total', 'count']
122
+ if sortby not in valid_sortby:
123
+ warn = ('WARNING: invalid sorting option for autograd profiler results: {}\n'
124
+ 'Expected `cpu_time`, `cpu_time_total`, or `count`. '
125
+ 'Defaulting to `cpu_time`.')
126
+ print(warn.format(sortby))
127
+ sortby = 'cpu_time'
128
+
129
+ if mode == 'CUDA':
130
+ cuda_warning = ('\n\tBecause the autograd profiler uses the CUDA event API,\n'
131
+ '\tthe CUDA time column reports approximately max(cuda_time, cpu_time).\n'
132
+ '\tPlease ignore this output if your code does not use CUDA.\n')
133
+ else:
134
+ cuda_warning = ''
135
+
136
+ sorted_events = sorted(prof.function_events,
137
+ key=lambda x: getattr(x, sortby), reverse=True)
138
+ topk_events = sorted_events[:topk]
139
+
140
+ result = {
141
+ 'mode': mode,
142
+ 'description': f'top {topk} events sorted by {sortby}',
143
+ 'output': torch.autograd.profiler_util._build_table(topk_events),
144
+ 'cuda_warning': cuda_warning
145
+ }
146
+
147
+ print(autograd_prof_summary.format(**result))
148
+
149
+
150
+ descript = """
151
+ `bottleneck` is a tool that can be used as an initial step for debugging
152
+ bottlenecks in your program.
153
+
154
+ It summarizes runs of your script with the Python profiler and PyTorch\'s
155
+ autograd profiler. Because your script will be profiled, please ensure that it
156
+ exits in a finite amount of time.
157
+
158
+ For more complicated uses of the profilers, please see
159
+ https://docs.python.org/3/library/profile.html and
160
+ https://pytorch.org/docs/main/autograd.html#profiler for more information.
161
+ """.strip()
162
+
163
+
164
+ def parse_args():
165
+ parser = argparse.ArgumentParser(description=descript)
166
+ parser.add_argument('scriptfile', type=str,
167
+ help='Path to the script to be run. '
168
+ 'Usually run with `python path/to/script`.')
169
+ parser.add_argument('args', type=str, nargs=argparse.REMAINDER,
170
+ help='Command-line arguments to be passed to the script.')
171
+ return parser.parse_args()
172
+
173
+
174
+ def cpu_time_total(autograd_prof):
175
+ return sum(event.cpu_time_total for event in autograd_prof.function_events)
176
+
177
+
178
+ def main():
179
+ args = parse_args()
180
+
181
+ # Customizable constants.
182
+ scriptfile = args.scriptfile
183
+ scriptargs = [] if args.args is None else args.args
184
+ scriptargs.insert(0, scriptfile)
185
+ cprofile_sortby = 'tottime'
186
+ cprofile_topk = 15
187
+ autograd_prof_sortby = 'cpu_time_total'
188
+ autograd_prof_topk = 15
189
+
190
+ redirect_argv(scriptargs)
191
+
192
+ sys.path.insert(0, os.path.dirname(scriptfile))
193
+ with open(scriptfile, 'rb') as stream:
194
+ code = compile(stream.read(), scriptfile, 'exec')
195
+ globs = {
196
+ '__file__': scriptfile,
197
+ '__name__': '__main__',
198
+ '__package__': None,
199
+ '__cached__': None,
200
+ }
201
+
202
+ print(descript)
203
+
204
+ env_summary = run_env_analysis()
205
+
206
+ if torch.cuda.is_available():
207
+ torch.cuda.init()
208
+ cprofile_prof = run_cprofile(code, globs)
209
+ autograd_prof_cpu, autograd_prof_cuda = run_autograd_prof(code, globs)
210
+
211
+ print(env_summary)
212
+ print_cprofile_summary(cprofile_prof, cprofile_sortby, cprofile_topk)
213
+
214
+ if not torch.cuda.is_available():
215
+ print_autograd_prof_summary(autograd_prof_cpu, 'CPU', autograd_prof_sortby, autograd_prof_topk)
216
+ return
217
+
218
+ # Print both the result of the CPU-mode and CUDA-mode autograd profilers
219
+ # if their execution times are very different.
220
+ cuda_prof_exec_time = cpu_time_total(autograd_prof_cuda)
221
+ if len(autograd_prof_cpu.function_events) > 0:
222
+ cpu_prof_exec_time = cpu_time_total(autograd_prof_cpu)
223
+ pct_diff = (cuda_prof_exec_time - cpu_prof_exec_time) / cuda_prof_exec_time
224
+ if abs(pct_diff) > 0.05:
225
+ print_autograd_prof_summary(autograd_prof_cpu, 'CPU', autograd_prof_sortby, autograd_prof_topk)
226
+
227
+ print_autograd_prof_summary(autograd_prof_cuda, 'CUDA', autograd_prof_sortby, autograd_prof_topk)
228
+
229
+ if __name__ == '__main__':
230
+ main()
.venv/Lib/site-packages/torch/utils/data/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data.dataloader import (
2
+ _DatasetKind,
3
+ DataLoader,
4
+ default_collate,
5
+ default_convert,
6
+ get_worker_info,
7
+ )
8
+ from torch.utils.data.datapipes._decorator import (
9
+ argument_validation,
10
+ functional_datapipe,
11
+ guaranteed_datapipes_determinism,
12
+ non_deterministic,
13
+ runtime_validation,
14
+ runtime_validation_disabled,
15
+ )
16
+ from torch.utils.data.datapipes.datapipe import (
17
+ DataChunk,
18
+ DFIterDataPipe,
19
+ IterDataPipe,
20
+ MapDataPipe,
21
+ )
22
+ from torch.utils.data.dataset import (
23
+ ChainDataset,
24
+ ConcatDataset,
25
+ Dataset,
26
+ IterableDataset,
27
+ random_split,
28
+ StackDataset,
29
+ Subset,
30
+ TensorDataset,
31
+ )
32
+ from torch.utils.data.distributed import DistributedSampler
33
+ from torch.utils.data.sampler import (
34
+ BatchSampler,
35
+ RandomSampler,
36
+ Sampler,
37
+ SequentialSampler,
38
+ SubsetRandomSampler,
39
+ WeightedRandomSampler,
40
+ )
41
+
42
+
43
+ __all__ = [
44
+ "BatchSampler",
45
+ "ChainDataset",
46
+ "ConcatDataset",
47
+ "DFIterDataPipe",
48
+ "DataChunk",
49
+ "DataLoader",
50
+ "Dataset",
51
+ "DistributedSampler",
52
+ "IterDataPipe",
53
+ "IterableDataset",
54
+ "MapDataPipe",
55
+ "RandomSampler",
56
+ "Sampler",
57
+ "SequentialSampler",
58
+ "StackDataset",
59
+ "Subset",
60
+ "SubsetRandomSampler",
61
+ "TensorDataset",
62
+ "WeightedRandomSampler",
63
+ "_DatasetKind",
64
+ "argument_validation",
65
+ "default_collate",
66
+ "default_convert",
67
+ "functional_datapipe",
68
+ "get_worker_info",
69
+ "guaranteed_datapipes_determinism",
70
+ "non_deterministic",
71
+ "random_split",
72
+ "runtime_validation",
73
+ "runtime_validation_disabled",
74
+ ]
75
+
76
+ # Please keep this list sorted
77
+ assert __all__ == sorted(__all__)
.venv/Lib/site-packages/torch/utils/data/_utils/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py.
3
+
4
+ A lot of multiprocessing is used in data loading, which only supports running
5
+ functions defined in global environment (py2 can't serialize static methods).
6
+ Therefore, for code tidiness we put these functions into different files in this
7
+ folder.
8
+ """
9
+
10
+ import atexit
11
+ import sys
12
+
13
+ # old private location of the ExceptionWrapper that some users rely on:
14
+ from torch._utils import ExceptionWrapper
15
+
16
+
17
+ IS_WINDOWS = sys.platform == "win32"
18
+
19
+
20
+ MP_STATUS_CHECK_INTERVAL = 5.0
21
+ r"""Interval (in seconds) to check status of processes to avoid hanging in
22
+ multiprocessing data loading. This is mainly used in getting data from
23
+ another process, in which case we need to periodically check whether the
24
+ sender is alive to prevent hanging."""
25
+
26
+
27
+ python_exit_status = False
28
+ r"""Whether Python is shutting down. This flag is guaranteed to be set before
29
+ the Python core library resources are freed, but Python may already be exiting
30
+ for some time when this is set.
31
+
32
+ Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
33
+ hook in Python 3.7 multiprocessing library:
34
+ https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
35
+ """
36
+
37
+
38
+ try:
39
+ import numpy
40
+
41
+ HAS_NUMPY = True
42
+ except ModuleNotFoundError:
43
+ HAS_NUMPY = False
44
+
45
+
46
+ def _set_python_exit_flag():
47
+ global python_exit_status
48
+ python_exit_status = True
49
+
50
+
51
+ atexit.register(_set_python_exit_flag)
52
+
53
+
54
+ from . import collate, fetch, pin_memory, signal_handling, worker
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.11 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-39.pyc ADDED
Binary file (13.5 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-39.pyc ADDED
Binary file (2.28 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-39.pyc ADDED
Binary file (3.25 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-39.pyc ADDED
Binary file (2.63 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-39.pyc ADDED
Binary file (7.87 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/_utils/collate.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
3
+
4
+ These methods are used to collate samples fetched from dataset into Tensor(s).
5
+ These **needs** to be in global scope since Py2 doesn't support serializing
6
+ static methods.
7
+
8
+ `default_collate` and `default_convert` are exposed to users via 'dataloader.py'.
9
+ """
10
+
11
+ import collections
12
+ import contextlib
13
+ import copy
14
+ import re
15
+ from typing import Callable, Dict, Optional, Tuple, Type, Union
16
+
17
+ import torch
18
+
19
+
20
+ np_str_obj_array_pattern = re.compile(r"[SaUO]")
21
+
22
+
23
+ def default_convert(data):
24
+ r"""
25
+ Convert each NumPy array element into a :class:`torch.Tensor`.
26
+
27
+ If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`.
28
+ If the input is not an NumPy array, it is left unchanged.
29
+ This is used as the default function for collation when both `batch_sampler` and `batch_size`
30
+ are NOT defined in :class:`~torch.utils.data.DataLoader`.
31
+
32
+ The general input type to output type mapping is similar to that
33
+ of :func:`~torch.utils.data.default_collate`. See the description there for more details.
34
+
35
+ Args:
36
+ data: a single data point to be converted
37
+
38
+ Examples:
39
+ >>> # xdoctest: +SKIP
40
+ >>> # Example with `int`
41
+ >>> default_convert(0)
42
+ 0
43
+ >>> # Example with NumPy array
44
+ >>> default_convert(np.array([0, 1]))
45
+ tensor([0, 1])
46
+ >>> # Example with NamedTuple
47
+ >>> Point = namedtuple('Point', ['x', 'y'])
48
+ >>> default_convert(Point(0, 0))
49
+ Point(x=0, y=0)
50
+ >>> default_convert(Point(np.array(0), np.array(0)))
51
+ Point(x=tensor(0), y=tensor(0))
52
+ >>> # Example with List
53
+ >>> default_convert([np.array([0, 1]), np.array([2, 3])])
54
+ [tensor([0, 1]), tensor([2, 3])]
55
+ """
56
+ elem_type = type(data)
57
+ if isinstance(data, torch.Tensor):
58
+ return data
59
+ elif (
60
+ elem_type.__module__ == "numpy"
61
+ and elem_type.__name__ != "str_"
62
+ and elem_type.__name__ != "string_"
63
+ ):
64
+ # array of string classes and object
65
+ if (
66
+ elem_type.__name__ == "ndarray"
67
+ and np_str_obj_array_pattern.search(data.dtype.str) is not None
68
+ ):
69
+ return data
70
+ return torch.as_tensor(data)
71
+ elif isinstance(data, collections.abc.Mapping):
72
+ try:
73
+ if isinstance(data, collections.abc.MutableMapping):
74
+ # The mapping type may have extra properties, so we can't just
75
+ # use `type(data)(...)` to create the new mapping.
76
+ # Create a clone and update it if the mapping type is mutable.
77
+ clone = copy.copy(data)
78
+ clone.update({key: default_convert(data[key]) for key in data})
79
+ return clone
80
+ else:
81
+ return elem_type({key: default_convert(data[key]) for key in data})
82
+ except TypeError:
83
+ # The mapping type may not support `copy()` / `update(mapping)`
84
+ # or `__init__(iterable)`.
85
+ return {key: default_convert(data[key]) for key in data}
86
+ elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
87
+ return elem_type(*(default_convert(d) for d in data))
88
+ elif isinstance(data, tuple):
89
+ return [default_convert(d) for d in data] # Backwards compatibility.
90
+ elif isinstance(data, collections.abc.Sequence) and not isinstance(
91
+ data, (str, bytes)
92
+ ):
93
+ try:
94
+ if isinstance(data, collections.abc.MutableSequence):
95
+ # The sequence type may have extra properties, so we can't just
96
+ # use `type(data)(...)` to create the new sequence.
97
+ # Create a clone and update it if the sequence type is mutable.
98
+ clone = copy.copy(data) # type: ignore[arg-type]
99
+ for i, d in enumerate(data):
100
+ clone[i] = default_convert(d)
101
+ return clone
102
+ else:
103
+ return elem_type([default_convert(d) for d in data])
104
+ except TypeError:
105
+ # The sequence type may not support `copy()` / `__setitem__(index, item)`
106
+ # or `__init__(iterable)` (e.g., `range`).
107
+ return [default_convert(d) for d in data]
108
+ else:
109
+ return data
110
+
111
+
112
+ default_collate_err_msg_format = (
113
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
114
+ "dicts or lists; found {}"
115
+ )
116
+
117
+
118
+ def collate(
119
+ batch,
120
+ *,
121
+ collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
122
+ ):
123
+ r"""
124
+ General collate function that handles collection type of element within each batch.
125
+
126
+ The function also opens function registry to deal with specific element types. `default_collate_fn_map`
127
+ provides default collate functions for tensors, numpy arrays, numbers and strings.
128
+
129
+ Args:
130
+ batch: a single batch to be collated
131
+ collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.
132
+ If the element type isn't present in this dictionary,
133
+ this function will go through each key of the dictionary in the insertion order to
134
+ invoke the corresponding collate function if the element type is a subclass of the key.
135
+
136
+ Examples:
137
+ >>> def collate_tensor_fn(batch, *, collate_fn_map):
138
+ ... # Extend this function to handle batch of tensors
139
+ ... return torch.stack(batch, 0)
140
+ >>> def custom_collate(batch):
141
+ ... collate_map = {torch.Tensor: collate_tensor_fn}
142
+ ... return collate(batch, collate_fn_map=collate_map)
143
+ >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
144
+ >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
145
+
146
+ Note:
147
+ Each collate function requires a positional argument for batch and a keyword argument
148
+ for the dictionary of collate functions as `collate_fn_map`.
149
+ """
150
+ elem = batch[0]
151
+ elem_type = type(elem)
152
+
153
+ if collate_fn_map is not None:
154
+ if elem_type in collate_fn_map:
155
+ return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
156
+
157
+ for collate_type in collate_fn_map:
158
+ if isinstance(elem, collate_type):
159
+ return collate_fn_map[collate_type](
160
+ batch, collate_fn_map=collate_fn_map
161
+ )
162
+
163
+ if isinstance(elem, collections.abc.Mapping):
164
+ try:
165
+ if isinstance(elem, collections.abc.MutableMapping):
166
+ # The mapping type may have extra properties, so we can't just
167
+ # use `type(data)(...)` to create the new mapping.
168
+ # Create a clone and update it if the mapping type is mutable.
169
+ clone = copy.copy(elem)
170
+ clone.update(
171
+ {
172
+ key: collate(
173
+ [d[key] for d in batch], collate_fn_map=collate_fn_map
174
+ )
175
+ for key in elem
176
+ }
177
+ )
178
+ return clone
179
+ else:
180
+ return elem_type(
181
+ {
182
+ key: collate(
183
+ [d[key] for d in batch], collate_fn_map=collate_fn_map
184
+ )
185
+ for key in elem
186
+ }
187
+ )
188
+ except TypeError:
189
+ # The mapping type may not support `copy()` / `update(mapping)`
190
+ # or `__init__(iterable)`.
191
+ return {
192
+ key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map)
193
+ for key in elem
194
+ }
195
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
196
+ return elem_type(
197
+ *(
198
+ collate(samples, collate_fn_map=collate_fn_map)
199
+ for samples in zip(*batch)
200
+ )
201
+ )
202
+ elif isinstance(elem, collections.abc.Sequence):
203
+ # check to make sure that the elements in batch have consistent size
204
+ it = iter(batch)
205
+ elem_size = len(next(it))
206
+ if not all(len(elem) == elem_size for elem in it):
207
+ raise RuntimeError("each element in list of batch should be of equal size")
208
+ transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
209
+
210
+ if isinstance(elem, tuple):
211
+ return [
212
+ collate(samples, collate_fn_map=collate_fn_map)
213
+ for samples in transposed
214
+ ] # Backwards compatibility.
215
+ else:
216
+ try:
217
+ if isinstance(elem, collections.abc.MutableSequence):
218
+ # The sequence type may have extra properties, so we can't just
219
+ # use `type(data)(...)` to create the new sequence.
220
+ # Create a clone and update it if the sequence type is mutable.
221
+ clone = copy.copy(elem) # type: ignore[arg-type]
222
+ for i, samples in enumerate(transposed):
223
+ clone[i] = collate(samples, collate_fn_map=collate_fn_map)
224
+ return clone
225
+ else:
226
+ return elem_type(
227
+ [
228
+ collate(samples, collate_fn_map=collate_fn_map)
229
+ for samples in transposed
230
+ ]
231
+ )
232
+ except TypeError:
233
+ # The sequence type may not support `copy()` / `__setitem__(index, item)`
234
+ # or `__init__(iterable)` (e.g., `range`).
235
+ return [
236
+ collate(samples, collate_fn_map=collate_fn_map)
237
+ for samples in transposed
238
+ ]
239
+
240
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
241
+
242
+
243
+ def collate_tensor_fn(
244
+ batch,
245
+ *,
246
+ collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
247
+ ):
248
+ elem = batch[0]
249
+ out = None
250
+ if elem.is_nested:
251
+ raise RuntimeError(
252
+ "Batches of nested tensors are not currently supported by the default collate_fn; "
253
+ "please provide a custom collate_fn to handle them appropriately."
254
+ )
255
+ if elem.layout in {
256
+ torch.sparse_coo,
257
+ torch.sparse_csr,
258
+ torch.sparse_bsr,
259
+ torch.sparse_csc,
260
+ torch.sparse_bsc,
261
+ }:
262
+ raise RuntimeError(
263
+ "Batches of sparse tensors are not currently supported by the default collate_fn; "
264
+ "please provide a custom collate_fn to handle them appropriately."
265
+ )
266
+ if torch.utils.data.get_worker_info() is not None:
267
+ # If we're in a background process, concatenate directly into a
268
+ # shared memory tensor to avoid an extra copy
269
+ numel = sum(x.numel() for x in batch)
270
+ storage = elem._typed_storage()._new_shared(numel, device=elem.device)
271
+ out = elem.new(storage).resize_(len(batch), *list(elem.size()))
272
+ return torch.stack(batch, 0, out=out)
273
+
274
+
275
+ def collate_numpy_array_fn(
276
+ batch,
277
+ *,
278
+ collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
279
+ ):
280
+ elem = batch[0]
281
+ # array of string classes and object
282
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
283
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
284
+
285
+ return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
286
+
287
+
288
+ def collate_numpy_scalar_fn(
289
+ batch,
290
+ *,
291
+ collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
292
+ ):
293
+ return torch.as_tensor(batch)
294
+
295
+
296
+ def collate_float_fn(
297
+ batch,
298
+ *,
299
+ collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
300
+ ):
301
+ return torch.tensor(batch, dtype=torch.float64)
302
+
303
+
304
+ def collate_int_fn(
305
+ batch,
306
+ *,
307
+ collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
308
+ ):
309
+ return torch.tensor(batch)
310
+
311
+
312
+ def collate_str_fn(
313
+ batch,
314
+ *,
315
+ collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
316
+ ):
317
+ return batch
318
+
319
+
320
+ default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {
321
+ torch.Tensor: collate_tensor_fn
322
+ }
323
+ with contextlib.suppress(ImportError):
324
+ import numpy as np
325
+
326
+ # For both ndarray and memmap (subclass of ndarray)
327
+ default_collate_fn_map[np.ndarray] = collate_numpy_array_fn
328
+ # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html
329
+ # Skip string scalars
330
+ default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
331
+ default_collate_fn_map[float] = collate_float_fn
332
+ default_collate_fn_map[int] = collate_int_fn
333
+ default_collate_fn_map[str] = collate_str_fn
334
+ default_collate_fn_map[bytes] = collate_str_fn
335
+
336
+
337
+ def default_collate(batch):
338
+ r"""
339
+ Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
340
+
341
+ The exact output type can be a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
342
+ Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
343
+ This is used as the default function for collation when
344
+ `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
345
+
346
+ Here is the general input type (based on the type of the element within the batch) to output type mapping:
347
+
348
+ * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
349
+ * NumPy Arrays -> :class:`torch.Tensor`
350
+ * `float` -> :class:`torch.Tensor`
351
+ * `int` -> :class:`torch.Tensor`
352
+ * `str` -> `str` (unchanged)
353
+ * `bytes` -> `bytes` (unchanged)
354
+ * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
355
+ * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]),
356
+ default_collate([V2_1, V2_2, ...]), ...]`
357
+ * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]),
358
+ default_collate([V2_1, V2_2, ...]), ...]`
359
+
360
+ Args:
361
+ batch: a single batch to be collated
362
+
363
+ Examples:
364
+ >>> # xdoctest: +SKIP
365
+ >>> # Example with a batch of `int`s:
366
+ >>> default_collate([0, 1, 2, 3])
367
+ tensor([0, 1, 2, 3])
368
+ >>> # Example with a batch of `str`s:
369
+ >>> default_collate(['a', 'b', 'c'])
370
+ ['a', 'b', 'c']
371
+ >>> # Example with `Map` inside the batch:
372
+ >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
373
+ {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
374
+ >>> # Example with `NamedTuple` inside the batch:
375
+ >>> Point = namedtuple('Point', ['x', 'y'])
376
+ >>> default_collate([Point(0, 0), Point(1, 1)])
377
+ Point(x=tensor([0, 1]), y=tensor([0, 1]))
378
+ >>> # Example with `Tuple` inside the batch:
379
+ >>> default_collate([(0, 1), (2, 3)])
380
+ [tensor([0, 2]), tensor([1, 3])]
381
+ >>> # Example with `List` inside the batch:
382
+ >>> default_collate([[0, 1], [2, 3]])
383
+ [tensor([0, 2]), tensor([1, 3])]
384
+ >>> # Two options to extend `default_collate` to handle specific type
385
+ >>> # Option 1: Write custom collate function and invoke `default_collate`
386
+ >>> def custom_collate(batch):
387
+ ... elem = batch[0]
388
+ ... if isinstance(elem, CustomType): # Some custom condition
389
+ ... return ...
390
+ ... else: # Fall back to `default_collate`
391
+ ... return default_collate(batch)
392
+ >>> # Option 2: In-place modify `default_collate_fn_map`
393
+ >>> def collate_customtype_fn(batch, *, collate_fn_map=None):
394
+ ... return ...
395
+ >>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
396
+ >>> default_collate(batch) # Handle `CustomType` automatically
397
+ """
398
+ return collate(batch, collate_fn_map=default_collate_fn_map)
.venv/Lib/site-packages/torch/utils/data/_utils/fetch.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset.
3
+
4
+ This logic is shared in both single- and multi-processing data loading.
5
+ """
6
+
7
+
8
+ class _BaseDatasetFetcher:
9
+ def __init__(self, dataset, auto_collation, collate_fn, drop_last):
10
+ self.dataset = dataset
11
+ self.auto_collation = auto_collation
12
+ self.collate_fn = collate_fn
13
+ self.drop_last = drop_last
14
+
15
+ def fetch(self, possibly_batched_index):
16
+ raise NotImplementedError
17
+
18
+
19
+ class _IterableDatasetFetcher(_BaseDatasetFetcher):
20
+ def __init__(self, dataset, auto_collation, collate_fn, drop_last):
21
+ super().__init__(dataset, auto_collation, collate_fn, drop_last)
22
+ self.dataset_iter = iter(dataset)
23
+ self.ended = False
24
+
25
+ def fetch(self, possibly_batched_index):
26
+ if self.ended:
27
+ raise StopIteration
28
+
29
+ if self.auto_collation:
30
+ data = []
31
+ for _ in possibly_batched_index:
32
+ try:
33
+ data.append(next(self.dataset_iter))
34
+ except StopIteration:
35
+ self.ended = True
36
+ break
37
+ if len(data) == 0 or (
38
+ self.drop_last and len(data) < len(possibly_batched_index)
39
+ ):
40
+ raise StopIteration
41
+ else:
42
+ data = next(self.dataset_iter)
43
+ return self.collate_fn(data)
44
+
45
+
46
+ class _MapDatasetFetcher(_BaseDatasetFetcher):
47
+ def fetch(self, possibly_batched_index):
48
+ if self.auto_collation:
49
+ if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
50
+ data = self.dataset.__getitems__(possibly_batched_index)
51
+ else:
52
+ data = [self.dataset[idx] for idx in possibly_batched_index]
53
+ else:
54
+ data = self.dataset[possibly_batched_index]
55
+ return self.collate_fn(data)
.venv/Lib/site-packages/torch/utils/data/_utils/pin_memory.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory.
3
+
4
+ These **needs** to be in global scope since Py2 doesn't support serializing
5
+ static methods.
6
+ """
7
+
8
+ import collections
9
+ import copy
10
+ import queue
11
+
12
+ import torch
13
+ from torch._utils import ExceptionWrapper
14
+
15
+ from . import MP_STATUS_CHECK_INTERVAL
16
+
17
+
18
+ def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
19
+ # This setting is thread local, and prevents the copy in pin_memory from
20
+ # consuming all CPU cores.
21
+ torch.set_num_threads(1)
22
+
23
+ torch.multiprocessing._set_thread_name("pt_data_pin")
24
+
25
+ if device == "cuda":
26
+ torch.cuda.set_device(device_id)
27
+ elif device == "xpu":
28
+ torch.xpu.set_device(device_id) # type: ignore[attr-defined]
29
+ elif device == torch._C._get_privateuse1_backend_name():
30
+ custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
31
+ custom_device_mod.set_device(device_id)
32
+
33
+ def do_one_step():
34
+ try:
35
+ r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
36
+ except queue.Empty:
37
+ return
38
+ idx, data = r
39
+ if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
40
+ try:
41
+ data = pin_memory(data, device)
42
+ except Exception:
43
+ data = ExceptionWrapper(
44
+ where=f"in pin memory thread for device {device_id}"
45
+ )
46
+ r = (idx, data)
47
+ while not done_event.is_set():
48
+ try:
49
+ out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
50
+ break
51
+ except queue.Full:
52
+ continue
53
+
54
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
55
+ # logic of this function.
56
+ while not done_event.is_set():
57
+ # Make sure that we don't preserve any object from one iteration
58
+ # to the next
59
+ do_one_step()
60
+
61
+
62
+ def pin_memory(data, device=None):
63
+ if isinstance(data, torch.Tensor):
64
+ return data.pin_memory(device)
65
+ elif isinstance(data, (str, bytes)):
66
+ return data
67
+ elif isinstance(data, collections.abc.Mapping):
68
+ try:
69
+ if isinstance(data, collections.abc.MutableMapping):
70
+ # The sequence type may have extra properties, so we can't just
71
+ # use `type(data)(...)` to create the new sequence.
72
+ # Create a clone and update it if the sequence type is mutable.
73
+ clone = copy.copy(data)
74
+ clone.update(
75
+ {k: pin_memory(sample, device) for k, sample in data.items()}
76
+ )
77
+ return clone
78
+ else:
79
+ return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg]
80
+ except TypeError:
81
+ # The mapping type may not support `copy()` / `update(mapping)`
82
+ # or `__init__(iterable)`.
83
+ return {k: pin_memory(sample, device) for k, sample in data.items()}
84
+ elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
85
+ return type(data)(*(pin_memory(sample, device) for sample in data))
86
+ elif isinstance(data, tuple):
87
+ return [
88
+ pin_memory(sample, device) for sample in data
89
+ ] # Backwards compatibility.
90
+ elif isinstance(data, collections.abc.Sequence):
91
+ try:
92
+ if isinstance(data, collections.abc.MutableSequence):
93
+ # The sequence type may have extra properties, so we can't just
94
+ # use `type(data)(...)` to create the new sequence.
95
+ # Create a clone and update it if the sequence type is mutable.
96
+ clone = copy.copy(data) # type: ignore[arg-type]
97
+ for i, item in enumerate(data):
98
+ clone[i] = pin_memory(item, device)
99
+ return clone
100
+ return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg]
101
+ except TypeError:
102
+ # The sequence type may not support `copy()` / `__setitem__(index, item)`
103
+ # or `__init__(iterable)` (e.g., `range`).
104
+ return [pin_memory(sample, device) for sample in data]
105
+ elif hasattr(data, "pin_memory"):
106
+ return data.pin_memory()
107
+ else:
108
+ return data
.venv/Lib/site-packages/torch/utils/data/_utils/signal_handling.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""Signal handling for multiprocessing data loading.
3
+
4
+ NOTE [ Signal handling in multiprocessing data loading ]
5
+
6
+ In cases like DataLoader, if a worker process dies due to bus error/segfault
7
+ or just hang, the main process will hang waiting for data. This is difficult
8
+ to avoid on PyTorch side as it can be caused by limited shm, or other
9
+ libraries users call in the workers. In this file and `DataLoader.cpp`, we make
10
+ our best effort to provide some error message to users when such unfortunate
11
+ events happen.
12
+
13
+ When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
14
+ defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
15
+ via `_set_worker_pids`.
16
+
17
+ When an error happens in a worker process, the main process received a SIGCHLD,
18
+ and Python will eventually call the handler registered below
19
+ (in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails`
20
+ call checks all registered worker pids and raise proper error message to
21
+ prevent main process from hanging waiting for data from worker.
22
+
23
+ Additionally, at the beginning of each worker's `_utils.worker._worker_loop`,
24
+ `_set_worker_signal_handlers` is called to register critical signal handlers
25
+ (e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error
26
+ message to stderr before triggering the default handler. So a message will also
27
+ be printed from the worker process when it is killed by such signals.
28
+
29
+ See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of
30
+ this signal handling design and other mechanism we implement to make our
31
+ multiprocessing data loading robust to errors.
32
+ """
33
+
34
+ import signal
35
+ import threading
36
+
37
+ # Some of the following imported functions are not used in this file, but are to
38
+ # be used `_utils.signal_handling.XXXXX`.
39
+ from torch._C import ( # noqa: F401
40
+ _error_if_any_worker_fails,
41
+ _remove_worker_pids,
42
+ _set_worker_pids,
43
+ _set_worker_signal_handlers,
44
+ )
45
+
46
+ from . import IS_WINDOWS
47
+
48
+
49
+ _SIGCHLD_handler_set = False
50
+ r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
51
+ handler needs to be set for all DataLoaders in a process."""
52
+
53
+
54
+ def _set_SIGCHLD_handler():
55
+ # Windows doesn't support SIGCHLD handler
56
+ if IS_WINDOWS:
57
+ return
58
+ # can't set signal in child threads
59
+ if not isinstance(threading.current_thread(), threading._MainThread): # type: ignore[attr-defined]
60
+ return
61
+ global _SIGCHLD_handler_set
62
+ if _SIGCHLD_handler_set:
63
+ return
64
+ previous_handler = signal.getsignal(signal.SIGCHLD)
65
+ if not callable(previous_handler):
66
+ # This doesn't catch default handler, but SIGCHLD default handler is a
67
+ # no-op.
68
+ previous_handler = None
69
+
70
+ def handler(signum, frame):
71
+ # This following call uses `waitid` with WNOHANG from C side. Therefore,
72
+ # Python can still get and update the process status successfully.
73
+ _error_if_any_worker_fails()
74
+ if previous_handler is not None:
75
+ assert callable(previous_handler)
76
+ previous_handler(signum, frame)
77
+
78
+ signal.signal(signal.SIGCHLD, handler)
79
+ _SIGCHLD_handler_set = True
.venv/Lib/site-packages/torch/utils/data/_utils/worker.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
3
+
4
+ These **needs** to be in global scope since Py2 doesn't support serializing
5
+ static methods.
6
+ """
7
+
8
+ import os
9
+ import queue
10
+ import random
11
+ from dataclasses import dataclass
12
+ from typing import Optional, TYPE_CHECKING, Union
13
+
14
+ import torch
15
+ from torch._utils import ExceptionWrapper
16
+
17
+ from . import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from torch.utils.data import Dataset
22
+
23
+ if IS_WINDOWS:
24
+ import ctypes
25
+ from ctypes.wintypes import BOOL, DWORD, HANDLE
26
+
27
+ # On Windows, the parent ID of the worker process remains unchanged when the manager process
28
+ # is gone, and the only way to check it through OS is to let the worker have a process handle
29
+ # of the manager and ask if the process status has changed.
30
+ class ManagerWatchdog:
31
+ def __init__(self) -> None:
32
+ self.manager_pid = os.getppid()
33
+
34
+ # mypy cannot detect this code is windows only
35
+ self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
36
+ self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
37
+ self.kernel32.OpenProcess.restype = HANDLE
38
+ self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
39
+ self.kernel32.WaitForSingleObject.restype = DWORD
40
+
41
+ # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
42
+ SYNCHRONIZE = 0x00100000
43
+ self.manager_handle = self.kernel32.OpenProcess(
44
+ SYNCHRONIZE, 0, self.manager_pid
45
+ )
46
+
47
+ if not self.manager_handle:
48
+ raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
49
+
50
+ self.manager_dead = False
51
+
52
+ def is_alive(self):
53
+ if not self.manager_dead:
54
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
55
+ self.manager_dead = (
56
+ self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
57
+ )
58
+ return not self.manager_dead
59
+
60
+ else:
61
+
62
+ class ManagerWatchdog: # type: ignore[no-redef]
63
+ def __init__(self) -> None:
64
+ self.manager_pid = os.getppid()
65
+ self.manager_dead = False
66
+
67
+ def is_alive(self):
68
+ if not self.manager_dead:
69
+ self.manager_dead = os.getppid() != self.manager_pid
70
+ return not self.manager_dead
71
+
72
+
73
+ _worker_info: Optional["WorkerInfo"] = None
74
+
75
+
76
+ class WorkerInfo:
77
+ id: int
78
+ num_workers: int
79
+ seed: int
80
+ dataset: "Dataset"
81
+ __initialized = False
82
+
83
+ def __init__(self, **kwargs):
84
+ for k, v in kwargs.items():
85
+ setattr(self, k, v)
86
+ self.__keys = tuple(kwargs.keys())
87
+ self.__initialized = True
88
+
89
+ def __setattr__(self, key, val):
90
+ if self.__initialized:
91
+ raise RuntimeError(
92
+ f"Cannot assign attributes to {self.__class__.__name__} objects"
93
+ )
94
+ return super().__setattr__(key, val)
95
+
96
+ def __repr__(self):
97
+ items = []
98
+ for k in self.__keys:
99
+ items.append(f"{k}={getattr(self, k)}")
100
+ return f"{self.__class__.__name__}({', '.join(items)})"
101
+
102
+
103
+ def get_worker_info() -> Optional[WorkerInfo]:
104
+ r"""Returns the information about the current
105
+ :class:`~torch.utils.data.DataLoader` iterator worker process.
106
+
107
+ When called in a worker, this returns an object guaranteed to have the
108
+ following attributes:
109
+
110
+ * :attr:`id`: the current worker id.
111
+ * :attr:`num_workers`: the total number of workers.
112
+ * :attr:`seed`: the random seed set for the current worker. This value is
113
+ determined by main process RNG and the worker id. See
114
+ :class:`~torch.utils.data.DataLoader`'s documentation for more details.
115
+ * :attr:`dataset`: the copy of the dataset object in **this** process. Note
116
+ that this will be a different object in a different process than the one
117
+ in the main process.
118
+
119
+ When called in the main process, this returns ``None``.
120
+
121
+ .. note::
122
+ When used in a :attr:`worker_init_fn` passed over to
123
+ :class:`~torch.utils.data.DataLoader`, this method can be useful to
124
+ set up each worker process differently, for instance, using ``worker_id``
125
+ to configure the ``dataset`` object to only read a specific fraction of a
126
+ sharded dataset, or use ``seed`` to seed other libraries used in dataset
127
+ code.
128
+ """
129
+ return _worker_info
130
+
131
+
132
+ r"""Dummy class used to signal the end of an IterableDataset"""
133
+
134
+
135
+ @dataclass(frozen=True)
136
+ class _IterableDatasetStopIteration:
137
+ worker_id: int
138
+
139
+
140
+ r"""Dummy class used to resume the fetching when worker reuse is enabled"""
141
+
142
+
143
+ @dataclass(frozen=True)
144
+ class _ResumeIteration:
145
+ seed: Optional[int] = None
146
+
147
+
148
+ # The function `_generate_state` is adapted from `numpy.random.SeedSequence`
149
+ # from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
150
+ # It's MIT licensed, here is the copyright:
151
+
152
+ # Copyright (c) 2015 Melissa E. O'Neill
153
+ # Copyright (c) 2019 NumPy Developers
154
+ #
155
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
156
+ # of this software and associated documentation files (the "Software"), to deal
157
+ # in the Software without restriction, including without limitation the rights
158
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
159
+ # copies of the Software, and to permit persons to whom the Software is
160
+ # furnished to do so, subject to the following conditions:
161
+ #
162
+ # The above copyright notice and this permission notice shall be included in
163
+ # all copies or substantial portions of the Software.
164
+ #
165
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
166
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
167
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
168
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
169
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
170
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
171
+ # SOFTWARE.
172
+
173
+
174
+ # This function generates an array of int32 as the seed for
175
+ # `numpy.random`, in order to prevent state collision due to same
176
+ # seed and algorithm for `numpy.random` and `random` modules.
177
+ # TODO: Implement `SeedSequence` like object for `torch.random`
178
+ def _generate_state(base_seed, worker_id):
179
+ INIT_A = 0x43B0D7E5
180
+ MULT_A = 0x931E8875
181
+ INIT_B = 0x8B51F9DD
182
+ MULT_B = 0x58F38DED
183
+ MIX_MULT_L = 0xCA01F9DD
184
+ MIX_MULT_R = 0x4973F715
185
+ XSHIFT = 4 * 8 // 2
186
+ MASK32 = 0xFFFFFFFF
187
+
188
+ entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
189
+ pool = [0] * 4
190
+
191
+ hash_const_A = INIT_A
192
+
193
+ def hash(value):
194
+ nonlocal hash_const_A
195
+ value = (value ^ hash_const_A) & MASK32
196
+ hash_const_A = (hash_const_A * MULT_A) & MASK32
197
+ value = (value * hash_const_A) & MASK32
198
+ value = (value ^ (value >> XSHIFT)) & MASK32
199
+ return value
200
+
201
+ def mix(x, y):
202
+ result_x = (MIX_MULT_L * x) & MASK32
203
+ result_y = (MIX_MULT_R * y) & MASK32
204
+ result = (result_x - result_y) & MASK32
205
+ result = (result ^ (result >> XSHIFT)) & MASK32
206
+ return result
207
+
208
+ # Add in the entropy to the pool.
209
+ for i in range(len(pool)):
210
+ pool[i] = hash(entropy[i])
211
+
212
+ # Mix all bits together so late bits can affect earlier bits.
213
+ for i_src in range(len(pool)):
214
+ for i_dst in range(len(pool)):
215
+ if i_src != i_dst:
216
+ pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
217
+
218
+ hash_const_B = INIT_B
219
+ state = []
220
+ for i_dst in range(4):
221
+ data_val = pool[i_dst]
222
+ data_val = (data_val ^ hash_const_B) & MASK32
223
+ hash_const_B = (hash_const_B * MULT_B) & MASK32
224
+ data_val = (data_val * hash_const_B) & MASK32
225
+ data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
226
+ state.append(data_val)
227
+ return state
228
+
229
+
230
+ def _worker_loop(
231
+ dataset_kind,
232
+ dataset,
233
+ index_queue,
234
+ data_queue,
235
+ done_event,
236
+ auto_collation,
237
+ collate_fn,
238
+ drop_last,
239
+ base_seed,
240
+ init_fn,
241
+ worker_id,
242
+ num_workers,
243
+ persistent_workers,
244
+ shared_seed,
245
+ ):
246
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
247
+ # logic of this function.
248
+
249
+ try:
250
+ # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
251
+ # module's handlers are executed after Python returns from C low-level
252
+ # handlers, likely when the same fatal signal had already happened
253
+ # again.
254
+ # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
255
+ signal_handling._set_worker_signal_handlers()
256
+
257
+ torch.multiprocessing._set_thread_name("pt_data_worker")
258
+
259
+ torch.set_num_threads(1)
260
+ seed = base_seed + worker_id
261
+ random.seed(seed)
262
+ torch.manual_seed(seed)
263
+ if HAS_NUMPY:
264
+ np_seed = _generate_state(base_seed, worker_id)
265
+ import numpy as np
266
+
267
+ np.random.seed(np_seed)
268
+
269
+ from torch.utils.data import IterDataPipe
270
+ from torch.utils.data.graph_settings import apply_random_seed
271
+
272
+ shared_rng = torch.Generator()
273
+ if isinstance(dataset, IterDataPipe):
274
+ assert shared_seed is not None
275
+ shared_rng.manual_seed(shared_seed)
276
+ dataset = apply_random_seed(dataset, shared_rng)
277
+
278
+ global _worker_info
279
+ _worker_info = WorkerInfo(
280
+ id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
281
+ )
282
+
283
+ from torch.utils.data import _DatasetKind
284
+
285
+ init_exception = None
286
+
287
+ try:
288
+ if init_fn is not None:
289
+ init_fn(worker_id)
290
+
291
+ fetcher = _DatasetKind.create_fetcher(
292
+ dataset_kind, dataset, auto_collation, collate_fn, drop_last
293
+ )
294
+ except Exception:
295
+ init_exception = ExceptionWrapper(
296
+ where=f"in DataLoader worker process {worker_id}"
297
+ )
298
+
299
+ # When using Iterable mode, some worker can exit earlier than others due
300
+ # to the IterableDataset behaving differently for different workers.
301
+ # When such things happen, an `_IterableDatasetStopIteration` object is
302
+ # sent over to the main process with the ID of this worker, so that the
303
+ # main process won't send more tasks to this worker, and will send
304
+ # `None` to this worker to properly exit it.
305
+ #
306
+ # Note that we cannot set `done_event` from a worker as it is shared
307
+ # among all processes. Instead, we set the `iteration_end` flag to
308
+ # signify that the iterator is exhausted. When either `done_event` or
309
+ # `iteration_end` is set, we skip all processing step and just wait for
310
+ # `None`.
311
+ iteration_end = False
312
+
313
+ watchdog = ManagerWatchdog()
314
+
315
+ while watchdog.is_alive():
316
+ try:
317
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
318
+ except queue.Empty:
319
+ continue
320
+ if isinstance(r, _ResumeIteration):
321
+ # Acknowledge the main process
322
+ data_queue.put((r, None))
323
+ iteration_end = False
324
+
325
+ if isinstance(dataset, IterDataPipe):
326
+ assert r.seed is not None
327
+ shared_rng.manual_seed(r.seed)
328
+ dataset = apply_random_seed(dataset, shared_rng)
329
+
330
+ # Recreate the fetcher for worker-reuse policy
331
+ fetcher = _DatasetKind.create_fetcher(
332
+ dataset_kind, dataset, auto_collation, collate_fn, drop_last
333
+ )
334
+ continue
335
+ elif r is None:
336
+ # Received the final signal
337
+ assert done_event.is_set() or iteration_end
338
+ break
339
+ elif done_event.is_set() or iteration_end:
340
+ # `done_event` is set. But I haven't received the final signal
341
+ # (None) yet. I will keep continuing until get it, and skip the
342
+ # processing steps.
343
+ continue
344
+ idx, index = r
345
+ data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
346
+ if init_exception is not None:
347
+ data = init_exception
348
+ init_exception = None
349
+ else:
350
+ try:
351
+ data = fetcher.fetch(index) # type: ignore[possibly-undefined]
352
+ except Exception as e:
353
+ if (
354
+ isinstance(e, StopIteration)
355
+ and dataset_kind == _DatasetKind.Iterable
356
+ ):
357
+ data = _IterableDatasetStopIteration(worker_id)
358
+ # Set `iteration_end`
359
+ # (1) to save future `next(...)` calls, and
360
+ # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
361
+ iteration_end = True
362
+ else:
363
+ # It is important that we don't store exc_info in a variable.
364
+ # `ExceptionWrapper` does the correct thing.
365
+ # See NOTE [ Python Traceback Reference Cycle Problem ]
366
+ data = ExceptionWrapper(
367
+ where=f"in DataLoader worker process {worker_id}"
368
+ )
369
+ data_queue.put((idx, data))
370
+ del data, idx, index, r # save memory
371
+ except KeyboardInterrupt:
372
+ # Main process will raise KeyboardInterrupt anyways.
373
+ pass
374
+ if done_event.is_set():
375
+ data_queue.cancel_join_thread()
376
+ data_queue.close()
.venv/Lib/site-packages/torch/utils/data/backward_compatibility.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing_extensions import deprecated as _deprecated
3
+
4
+
5
+ @_deprecated(
6
+ "Usage of `backward_compatibility.worker_init_fn` is deprecated "
7
+ "as `DataLoader` automatically applies sharding in every worker",
8
+ category=FutureWarning,
9
+ )
10
+ def worker_init_fn(worker_id):
11
+ pass
.venv/Lib/site-packages/torch/utils/data/dataloader.py ADDED
@@ -0,0 +1,1604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter.
3
+
4
+ To support these two classes, in `./_utils` we define many utility methods and
5
+ functions to be run in multiprocessing. E.g., the data loading worker loop is
6
+ in `./_utils/worker.py`.
7
+ """
8
+
9
+ import functools
10
+ import itertools
11
+ import logging
12
+ import multiprocessing as python_multiprocessing
13
+ import os
14
+ import queue
15
+ import threading
16
+ import warnings
17
+ from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar, Union
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ import torch.utils.data.graph_settings
22
+ from torch._utils import ExceptionWrapper
23
+ from torch.utils.data import _utils
24
+ from torch.utils.data.datapipes.datapipe import (
25
+ _IterDataPipeSerializationWrapper,
26
+ _MapDataPipeSerializationWrapper,
27
+ IterDataPipe,
28
+ MapDataPipe,
29
+ )
30
+ from torch.utils.data.dataset import Dataset, IterableDataset
31
+ from torch.utils.data.sampler import (
32
+ BatchSampler,
33
+ RandomSampler,
34
+ Sampler,
35
+ SequentialSampler,
36
+ )
37
+
38
+
39
+ __all__ = [
40
+ "DataLoader",
41
+ "get_worker_info",
42
+ "default_collate",
43
+ "default_convert",
44
+ ]
45
+
46
+
47
+ _T = TypeVar("_T")
48
+ _T_co = TypeVar("_T_co", covariant=True)
49
+ _worker_init_fn_t = Callable[[int], None]
50
+
51
+ # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
52
+ # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
53
+ # See https://github.com/python/mypy/issues/3737.
54
+ _collate_fn_t = Callable[[List[_T]], Any]
55
+
56
+
57
+ # These functions used to be defined in this file. However, it was moved to
58
+ # _utils/collate.py. Although it is rather hard to access this from user land
59
+ # (one has to explicitly directly `import torch.utils.data.dataloader`), there
60
+ # probably is user code out there using it. This aliasing maintains BC in this
61
+ # aspect.
62
+ default_collate: _collate_fn_t = _utils.collate.default_collate
63
+ default_convert = _utils.collate.default_convert
64
+
65
+ get_worker_info = _utils.worker.get_worker_info
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ class _DatasetKind:
71
+ Map = 0
72
+ Iterable = 1
73
+
74
+ @staticmethod
75
+ def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
76
+ if kind == _DatasetKind.Map:
77
+ return _utils.fetch._MapDatasetFetcher(
78
+ dataset, auto_collation, collate_fn, drop_last
79
+ )
80
+ else:
81
+ return _utils.fetch._IterableDatasetFetcher(
82
+ dataset, auto_collation, collate_fn, drop_last
83
+ )
84
+
85
+
86
+ class _InfiniteConstantSampler(Sampler):
87
+ r"""Analogous to ``itertools.repeat(None, None)``.
88
+
89
+ Used as sampler for :class:`~torch.utils.data.IterableDataset`.
90
+ """
91
+
92
+ def __iter__(self):
93
+ while True:
94
+ yield None
95
+
96
+
97
+ def _get_distributed_settings():
98
+ if dist.is_available() and dist.is_initialized():
99
+ return dist.get_world_size(), dist.get_rank()
100
+ else:
101
+ return 1, 0
102
+
103
+
104
+ def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
105
+ global_worker_id = worker_id
106
+ info = torch.utils.data.get_worker_info()
107
+ assert info is not None
108
+ total_workers = info.num_workers
109
+ datapipe = info.dataset
110
+ assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
111
+ # To distribute elements across distributed process evenly, we should shard data on distributed
112
+ # processes first then shard on worker processes
113
+ total_workers *= world_size
114
+ global_worker_id = global_worker_id * world_size + rank_id
115
+ # For BC, use default SHARDING_PRIORITIES
116
+ torch.utils.data.graph_settings.apply_sharding(
117
+ datapipe, total_workers, global_worker_id
118
+ )
119
+ if worker_init_fn is not None:
120
+ worker_init_fn(worker_id)
121
+
122
+
123
+ def _share_dist_seed(generator, pg):
124
+ _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
125
+ if isinstance(pg, dist.ProcessGroup):
126
+ dist.broadcast(_shared_seed, src=0, group=pg)
127
+ return _shared_seed.item()
128
+
129
+
130
+ class DataLoader(Generic[_T_co]):
131
+ r"""
132
+ Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.
133
+
134
+ The :class:`~torch.utils.data.DataLoader` supports both map-style and
135
+ iterable-style datasets with single- or multi-process loading, customizing
136
+ loading order and optional automatic batching (collation) and memory pinning.
137
+
138
+ See :py:mod:`torch.utils.data` documentation page for more details.
139
+
140
+ Args:
141
+ dataset (Dataset): dataset from which to load the data.
142
+ batch_size (int, optional): how many samples per batch to load
143
+ (default: ``1``).
144
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
145
+ at every epoch (default: ``False``).
146
+ sampler (Sampler or Iterable, optional): defines the strategy to draw
147
+ samples from the dataset. Can be any ``Iterable`` with ``__len__``
148
+ implemented. If specified, :attr:`shuffle` must not be specified.
149
+ batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
150
+ returns a batch of indices at a time. Mutually exclusive with
151
+ :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
152
+ and :attr:`drop_last`.
153
+ num_workers (int, optional): how many subprocesses to use for data
154
+ loading. ``0`` means that the data will be loaded in the main process.
155
+ (default: ``0``)
156
+ collate_fn (Callable, optional): merges a list of samples to form a
157
+ mini-batch of Tensor(s). Used when using batched loading from a
158
+ map-style dataset.
159
+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
160
+ into device/CUDA pinned memory before returning them. If your data elements
161
+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
162
+ see the example below.
163
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
164
+ if the dataset size is not divisible by the batch size. If ``False`` and
165
+ the size of dataset is not divisible by the batch size, then the last batch
166
+ will be smaller. (default: ``False``)
167
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
168
+ from workers. Should always be non-negative. (default: ``0``)
169
+ worker_init_fn (Callable, optional): If not ``None``, this will be called on each
170
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
171
+ input, after seeding and before data loading. (default: ``None``)
172
+ multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
173
+ ``None``, the default `multiprocessing context`_ of your operating system will
174
+ be used. (default: ``None``)
175
+ generator (torch.Generator, optional): If not ``None``, this RNG will be used
176
+ by RandomSampler to generate random indexes and multiprocessing to generate
177
+ ``base_seed`` for workers. (default: ``None``)
178
+ prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
179
+ in advance by each worker. ``2`` means there will be a total of
180
+ 2 * num_workers batches prefetched across all workers. (default value depends
181
+ on the set value for num_workers. If value of num_workers=0 default is ``None``.
182
+ Otherwise, if value of ``num_workers > 0`` default is ``2``).
183
+ persistent_workers (bool, optional): If ``True``, the data loader will not shut down
184
+ the worker processes after a dataset has been consumed once. This allows to
185
+ maintain the workers `Dataset` instances alive. (default: ``False``)
186
+ pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
187
+ ``True``.
188
+
189
+
190
+ .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
191
+ cannot be an unpicklable object, e.g., a lambda function. See
192
+ :ref:`multiprocessing-best-practices` on more details related
193
+ to multiprocessing in PyTorch.
194
+
195
+ .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
196
+ When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
197
+ it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
198
+ rounding depending on :attr:`drop_last`, regardless of multi-process loading
199
+ configurations. This represents the best guess PyTorch can make because PyTorch
200
+ trusts user :attr:`dataset` code in correctly handling multi-process
201
+ loading to avoid duplicate data.
202
+
203
+ However, if sharding results in multiple workers having incomplete last batches,
204
+ this estimate can still be inaccurate, because (1) an otherwise complete batch can
205
+ be broken into multiple ones and (2) more than one batch worth of samples can be
206
+ dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
207
+ cases in general.
208
+
209
+ See `Dataset Types`_ for more details on these two types of datasets and how
210
+ :class:`~torch.utils.data.IterableDataset` interacts with
211
+ `Multi-process data loading`_.
212
+
213
+ .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
214
+ :ref:`data-loading-randomness` notes for random seed related questions.
215
+
216
+ .. _multiprocessing context:
217
+ https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
218
+ """
219
+
220
+ dataset: Dataset[_T_co]
221
+ batch_size: Optional[int]
222
+ num_workers: int
223
+ pin_memory: bool
224
+ drop_last: bool
225
+ timeout: float
226
+ sampler: Union[Sampler, Iterable]
227
+ pin_memory_device: str
228
+ prefetch_factor: Optional[int]
229
+ _iterator: Optional["_BaseDataLoaderIter"]
230
+ __initialized = False
231
+
232
+ def __init__(
233
+ self,
234
+ dataset: Dataset[_T_co],
235
+ batch_size: Optional[int] = 1,
236
+ shuffle: Optional[bool] = None,
237
+ sampler: Union[Sampler, Iterable, None] = None,
238
+ batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
239
+ num_workers: int = 0,
240
+ collate_fn: Optional[_collate_fn_t] = None,
241
+ pin_memory: bool = False,
242
+ drop_last: bool = False,
243
+ timeout: float = 0,
244
+ worker_init_fn: Optional[_worker_init_fn_t] = None,
245
+ multiprocessing_context=None,
246
+ generator=None,
247
+ *,
248
+ prefetch_factor: Optional[int] = None,
249
+ persistent_workers: bool = False,
250
+ pin_memory_device: str = "",
251
+ ):
252
+ torch._C._log_api_usage_once("python.data_loader")
253
+
254
+ if num_workers < 0:
255
+ raise ValueError(
256
+ "num_workers option should be non-negative; "
257
+ "use num_workers=0 to disable multiprocessing."
258
+ )
259
+
260
+ if timeout < 0:
261
+ raise ValueError("timeout option should be non-negative")
262
+
263
+ if num_workers == 0 and prefetch_factor is not None:
264
+ raise ValueError(
265
+ "prefetch_factor option could only be specified in multiprocessing."
266
+ "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
267
+ )
268
+ elif num_workers > 0 and prefetch_factor is None:
269
+ prefetch_factor = 2
270
+ elif prefetch_factor is not None and prefetch_factor < 0:
271
+ raise ValueError("prefetch_factor option should be non-negative")
272
+
273
+ if persistent_workers and num_workers == 0:
274
+ raise ValueError("persistent_workers option needs num_workers > 0")
275
+
276
+ self.dataset = dataset
277
+ self.num_workers = num_workers
278
+ self.prefetch_factor = prefetch_factor
279
+ self.pin_memory = pin_memory
280
+ self.pin_memory_device = pin_memory_device
281
+ self.timeout = timeout
282
+ self.worker_init_fn = worker_init_fn
283
+ self.multiprocessing_context = multiprocessing_context
284
+
285
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
286
+ # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
287
+ if isinstance(self.dataset, IterDataPipe):
288
+ self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
289
+ elif isinstance(self.dataset, MapDataPipe):
290
+ self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
291
+
292
+ # Arg-check dataset related before checking samplers because we want to
293
+ # tell users that iterable-style datasets are incompatible with custom
294
+ # samplers first, so that they don't learn that this combo doesn't work
295
+ # after spending time fixing the custom sampler errors.
296
+ if isinstance(dataset, IterableDataset):
297
+ self._dataset_kind = _DatasetKind.Iterable
298
+ # NOTE [ Custom Samplers and IterableDataset ]
299
+ #
300
+ # `IterableDataset` does not support custom `batch_sampler` or
301
+ # `sampler` since the key is irrelevant (unless we support
302
+ # generator-style dataset one day...).
303
+ #
304
+ # For `sampler`, we always create a dummy sampler. This is an
305
+ # infinite sampler even when the dataset may have an implemented
306
+ # finite `__len__` because in multi-process data loading, naive
307
+ # settings will return duplicated data (which may be desired), and
308
+ # thus using a sampler with length matching that of dataset will
309
+ # cause data lost (you may have duplicates of the first couple
310
+ # batches, but never see anything afterwards). Therefore,
311
+ # `Iterabledataset` always uses an infinite sampler, an instance of
312
+ # `_InfiniteConstantSampler` defined above.
313
+ #
314
+ # A custom `batch_sampler` essentially only controls the batch size.
315
+ # However, it is unclear how useful it would be since an iterable-style
316
+ # dataset can handle that within itself. Moreover, it is pointless
317
+ # in multi-process data loading as the assignment order of batches
318
+ # to workers is an implementation detail so users can not control
319
+ # how to batchify each worker's iterable. Thus, we disable this
320
+ # option. If this turns out to be useful in future, we can re-enable
321
+ # this, and support custom samplers that specify the assignments to
322
+ # specific workers.
323
+ if isinstance(dataset, IterDataPipe):
324
+ if shuffle is not None:
325
+ dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
326
+ dataset, shuffle=shuffle
327
+ )
328
+ # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
329
+ elif shuffle not in {False, None}:
330
+ raise ValueError(
331
+ f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}"
332
+ )
333
+
334
+ if sampler is not None:
335
+ # See NOTE [ Custom Samplers and IterableDataset ]
336
+ raise ValueError(
337
+ f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}"
338
+ )
339
+ elif batch_sampler is not None:
340
+ # See NOTE [ Custom Samplers and IterableDataset ]
341
+ raise ValueError(
342
+ "DataLoader with IterableDataset: expected unspecified "
343
+ f"batch_sampler option, but got batch_sampler={batch_sampler}"
344
+ )
345
+ else:
346
+ shuffle = bool(shuffle)
347
+ self._dataset_kind = _DatasetKind.Map
348
+
349
+ if sampler is not None and shuffle:
350
+ raise ValueError("sampler option is mutually exclusive with " "shuffle")
351
+
352
+ if batch_sampler is not None:
353
+ # auto_collation with custom batch_sampler
354
+ if batch_size != 1 or shuffle or sampler is not None or drop_last:
355
+ raise ValueError(
356
+ "batch_sampler option is mutually exclusive "
357
+ "with batch_size, shuffle, sampler, and "
358
+ "drop_last"
359
+ )
360
+ batch_size = None
361
+ drop_last = False
362
+ elif batch_size is None:
363
+ # no auto_collation
364
+ if drop_last:
365
+ raise ValueError(
366
+ "batch_size=None option disables auto-batching "
367
+ "and is mutually exclusive with drop_last"
368
+ )
369
+
370
+ if sampler is None: # give default samplers
371
+ if self._dataset_kind == _DatasetKind.Iterable:
372
+ # See NOTE [ Custom Samplers and IterableDataset ]
373
+ sampler = _InfiniteConstantSampler()
374
+ else: # map-style
375
+ if shuffle:
376
+ sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
377
+ else:
378
+ sampler = SequentialSampler(dataset) # type: ignore[arg-type]
379
+
380
+ if batch_size is not None and batch_sampler is None:
381
+ # auto_collation without custom batch_sampler
382
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
383
+
384
+ self.batch_size = batch_size
385
+ self.drop_last = drop_last
386
+ self.sampler = sampler
387
+ self.batch_sampler = batch_sampler
388
+ self.generator = generator
389
+
390
+ if collate_fn is None:
391
+ if self._auto_collation:
392
+ collate_fn = _utils.collate.default_collate
393
+ else:
394
+ collate_fn = _utils.collate.default_convert
395
+
396
+ self.collate_fn = collate_fn
397
+ self.persistent_workers = persistent_workers
398
+
399
+ self.__initialized = True
400
+ self._IterableDataset_len_called = (
401
+ None # See NOTE [ IterableDataset and __len__ ]
402
+ )
403
+
404
+ self._iterator = None
405
+
406
+ self.check_worker_number_rationality()
407
+
408
+ torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
409
+
410
+ def _get_iterator(self) -> "_BaseDataLoaderIter":
411
+ if self.num_workers == 0:
412
+ return _SingleProcessDataLoaderIter(self)
413
+ else:
414
+ self.check_worker_number_rationality()
415
+ return _MultiProcessingDataLoaderIter(self)
416
+
417
+ @property
418
+ def multiprocessing_context(self):
419
+ return self.__multiprocessing_context
420
+
421
+ @multiprocessing_context.setter
422
+ def multiprocessing_context(self, multiprocessing_context):
423
+ if multiprocessing_context is not None:
424
+ if self.num_workers > 0:
425
+ if isinstance(multiprocessing_context, str):
426
+ valid_start_methods = torch.multiprocessing.get_all_start_methods()
427
+ if multiprocessing_context not in valid_start_methods:
428
+ raise ValueError(
429
+ "multiprocessing_context option "
430
+ f"should specify a valid start method in {valid_start_methods!r}, but got "
431
+ f"multiprocessing_context={multiprocessing_context!r}"
432
+ )
433
+ multiprocessing_context = torch.multiprocessing.get_context(
434
+ multiprocessing_context
435
+ )
436
+
437
+ if not isinstance(
438
+ multiprocessing_context, python_multiprocessing.context.BaseContext
439
+ ):
440
+ raise TypeError(
441
+ "multiprocessing_context option should be a valid context "
442
+ "object or a string specifying the start method, but got "
443
+ f"multiprocessing_context={multiprocessing_context}"
444
+ )
445
+ else:
446
+ raise ValueError(
447
+ "multiprocessing_context can only be used with "
448
+ "multi-process loading (num_workers > 0), but got "
449
+ f"num_workers={self.num_workers}"
450
+ )
451
+
452
+ self.__multiprocessing_context = multiprocessing_context
453
+
454
+ def __setattr__(self, attr, val):
455
+ if self.__initialized and attr in (
456
+ "batch_size",
457
+ "batch_sampler",
458
+ "sampler",
459
+ "drop_last",
460
+ "dataset",
461
+ "persistent_workers",
462
+ ):
463
+ raise ValueError(
464
+ f"{attr} attribute should not be set after {self.__class__.__name__} is initialized"
465
+ )
466
+
467
+ super().__setattr__(attr, val)
468
+
469
+ # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
470
+ # since '_BaseDataLoaderIter' references 'DataLoader'.
471
+ def __iter__(self) -> "_BaseDataLoaderIter":
472
+ # When using a single worker the returned iterator should be
473
+ # created everytime to avoid resetting its state
474
+ # However, in the case of a multiple workers iterator
475
+ # the iterator is only created once in the lifetime of the
476
+ # DataLoader object so that workers can be reused
477
+ if self.persistent_workers and self.num_workers > 0:
478
+ if self._iterator is None:
479
+ self._iterator = self._get_iterator()
480
+ else:
481
+ self._iterator._reset(self)
482
+ return self._iterator
483
+ else:
484
+ return self._get_iterator()
485
+
486
+ @property
487
+ def _auto_collation(self):
488
+ return self.batch_sampler is not None
489
+
490
+ @property
491
+ def _index_sampler(self):
492
+ # The actual sampler used for generating indices for `_DatasetFetcher`
493
+ # (see _utils/fetch.py) to read data at each time. This would be
494
+ # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
495
+ # We can't change `.sampler` and `.batch_sampler` attributes for BC
496
+ # reasons.
497
+ if self._auto_collation:
498
+ return self.batch_sampler
499
+ else:
500
+ return self.sampler
501
+
502
+ def __len__(self) -> int:
503
+ if self._dataset_kind == _DatasetKind.Iterable:
504
+ # NOTE [ IterableDataset and __len__ ]
505
+ #
506
+ # For `IterableDataset`, `__len__` could be inaccurate when one naively
507
+ # does multi-processing data loading, since the samples will be duplicated.
508
+ # However, no real use case should be actually using that behavior, so
509
+ # it should count as a user error. We should generally trust user
510
+ # code to do the proper thing (e.g., configure each replica differently
511
+ # in `__iter__`), and give us the correct `__len__` if they choose to
512
+ # implement it (this will still throw if the dataset does not implement
513
+ # a `__len__`).
514
+ #
515
+ # To provide a further warning, we track if `__len__` was called on the
516
+ # `DataLoader`, save the returned value in `self._len_called`, and warn
517
+ # if the iterator ends up yielding more than this number of samples.
518
+
519
+ # Cannot statically verify that dataset is Sized
520
+ length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
521
+ if (
522
+ self.batch_size is not None
523
+ ): # IterableDataset doesn't allow custom sampler or batch_sampler
524
+ from math import ceil
525
+
526
+ if self.drop_last:
527
+ length = length // self.batch_size
528
+ else:
529
+ length = ceil(length / self.batch_size)
530
+ return length
531
+ else:
532
+ return len(self._index_sampler)
533
+
534
+ def check_worker_number_rationality(self):
535
+ # This function check whether the dataloader's worker number is rational based on
536
+ # current system's resource. Current rule is that if the number of workers this
537
+ # Dataloader will create is bigger than the number of logical cpus that is allowed to
538
+ # use, than we will pop up a warning to let user pay attention.
539
+ #
540
+ # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
541
+ # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
542
+ # DataLoader process can use half of them which is 32, then the rational max number of
543
+ # worker that initiated from this process is 32.
544
+ # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
545
+ # So the warning message is triggered to notify the user to lower the worker number if
546
+ # necessary.
547
+ #
548
+ #
549
+ # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
550
+ # available (available in most of Linux system, but not OSX and Windows).
551
+ # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
552
+ # it doesn't repect cpuset.
553
+ # We don't take threading into account since each worker process is single threaded
554
+ # at this time.
555
+ #
556
+ # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
557
+ # other than `torch.set_num_threads` to 1 in the worker process, if the passing
558
+ # in functions use 3rd party modules that rely on those threading flags to determine
559
+ # how many thread to create (eg. numpy, etc), then it is caller's responsibility to
560
+ # set those flags correctly.
561
+ def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
562
+ suggested_max_worker_msg = (
563
+ (
564
+ (
565
+ "Our suggested max number of worker in current system is {}{}, which is smaller "
566
+ "than what this DataLoader is going to create."
567
+ ).format(
568
+ num_worker_suggest,
569
+ (
570
+ ""
571
+ if cpuset_checked
572
+ else " (`cpuset` is not taken into account)"
573
+ ),
574
+ )
575
+ )
576
+ if num_worker_suggest is not None
577
+ else (
578
+ "DataLoader is not able to compute a suggested max number of worker in current system."
579
+ )
580
+ )
581
+
582
+ warn_msg = (
583
+ f"This DataLoader will create {num_worker_created} worker processes in total. {suggested_max_worker_msg} "
584
+ "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
585
+ "lower the worker number to avoid potential slowness/freeze if necessary."
586
+ )
587
+ return warn_msg
588
+
589
+ if not self.num_workers or self.num_workers == 0:
590
+ return
591
+
592
+ # try to compute a suggested max number of worker based on system's resource
593
+ max_num_worker_suggest = None
594
+ cpuset_checked = False
595
+ if hasattr(os, "sched_getaffinity"):
596
+ try:
597
+ max_num_worker_suggest = len(os.sched_getaffinity(0))
598
+ cpuset_checked = True
599
+ except Exception:
600
+ pass
601
+ if max_num_worker_suggest is None:
602
+ # os.cpu_count() could return Optional[int]
603
+ # get cpu count first and check None in order to satisfy mypy check
604
+ cpu_count = os.cpu_count()
605
+ if cpu_count is not None:
606
+ max_num_worker_suggest = cpu_count
607
+
608
+ if max_num_worker_suggest is None:
609
+ warnings.warn(
610
+ _create_warning_msg(
611
+ max_num_worker_suggest, self.num_workers, cpuset_checked
612
+ )
613
+ )
614
+ return
615
+
616
+ if self.num_workers > max_num_worker_suggest:
617
+ warnings.warn(
618
+ _create_warning_msg(
619
+ max_num_worker_suggest, self.num_workers, cpuset_checked
620
+ )
621
+ )
622
+
623
+
624
+ class _BaseDataLoaderIter:
625
+ def __init__(self, loader: DataLoader) -> None:
626
+ self._dataset = loader.dataset
627
+ self._shared_seed = None
628
+ self._pg = None
629
+ if isinstance(self._dataset, IterDataPipe):
630
+ if dist.is_available() and dist.is_initialized():
631
+ self._pg = dist.new_group(backend="gloo")
632
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
633
+ shared_rng = torch.Generator()
634
+ shared_rng.manual_seed(self._shared_seed)
635
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(
636
+ self._dataset, shared_rng
637
+ )
638
+ self._dataset_kind = loader._dataset_kind
639
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
640
+ self._auto_collation = loader._auto_collation
641
+ self._drop_last = loader.drop_last
642
+ self._index_sampler = loader._index_sampler
643
+ self._num_workers = loader.num_workers
644
+ ws, rank = _get_distributed_settings()
645
+ self._world_size = ws
646
+ self._rank = rank
647
+ # for other backends, pin_memory_device need to set. if not set
648
+ # default behaviour is CUDA device. if pin_memory_device is selected
649
+ # and pin_memory is not set, the default behaviour false.
650
+ if len(loader.pin_memory_device) == 0:
651
+ self._pin_memory = loader.pin_memory and torch.cuda.is_available()
652
+ self._pin_memory_device = None
653
+ else:
654
+ if not loader.pin_memory:
655
+ warn_msg = (
656
+ "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
657
+ "please set pin_memory to true, if you need to use the device pin memory"
658
+ )
659
+ warnings.warn(warn_msg)
660
+
661
+ self._pin_memory = loader.pin_memory
662
+ self._pin_memory_device = loader.pin_memory_device
663
+ self._timeout = loader.timeout
664
+ self._collate_fn = loader.collate_fn
665
+ self._sampler_iter = iter(self._index_sampler)
666
+ self._base_seed = (
667
+ torch.empty((), dtype=torch.int64)
668
+ .random_(generator=loader.generator)
669
+ .item()
670
+ )
671
+ self._persistent_workers = loader.persistent_workers
672
+ self._num_yielded = 0
673
+ self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"
674
+
675
+ def __iter__(self) -> "_BaseDataLoaderIter":
676
+ return self
677
+
678
+ def _reset(self, loader, first_iter=False):
679
+ self._sampler_iter = iter(self._index_sampler)
680
+ self._num_yielded = 0
681
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
682
+ if isinstance(self._dataset, IterDataPipe):
683
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
684
+ shared_rng = torch.Generator()
685
+ shared_rng.manual_seed(self._shared_seed)
686
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(
687
+ self._dataset, shared_rng
688
+ )
689
+
690
+ def _next_index(self):
691
+ return next(self._sampler_iter) # may raise StopIteration
692
+
693
+ def _next_data(self):
694
+ raise NotImplementedError
695
+
696
+ def __next__(self) -> Any:
697
+ with torch.autograd.profiler.record_function(self._profile_name):
698
+ if self._sampler_iter is None:
699
+ # TODO(https://github.com/pytorch/pytorch/issues/76750)
700
+ self._reset() # type: ignore[call-arg]
701
+ data = self._next_data()
702
+ self._num_yielded += 1
703
+ if (
704
+ self._dataset_kind == _DatasetKind.Iterable
705
+ and self._IterableDataset_len_called is not None
706
+ and self._num_yielded > self._IterableDataset_len_called
707
+ ):
708
+ warn_msg = (
709
+ f"Length of IterableDataset {self._dataset} was reported to be {self._IterableDataset_len_called}"
710
+ f"(when accessing len(dataloader)), but {self._num_yielded} samples have been fetched. "
711
+ )
712
+ if self._num_workers > 0:
713
+ warn_msg += (
714
+ "For multiprocessing data-loading, this could be caused by not properly configuring the "
715
+ "IterableDataset replica at each worker. Please see "
716
+ "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
717
+ )
718
+ warnings.warn(warn_msg)
719
+ return data
720
+
721
+ def __len__(self) -> int:
722
+ return len(self._index_sampler)
723
+
724
+ def __getstate__(self):
725
+ # TODO: add limited pickling support for sharing an iterator
726
+ # across multiple threads for HOGWILD.
727
+ # Probably the best way to do this is by moving the sample pushing
728
+ # to a separate thread and then just sharing the data queue
729
+ # but signalling the end is tricky without a non-blocking API
730
+ raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
731
+
732
+
733
+ class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
734
+ def __init__(self, loader):
735
+ super().__init__(loader)
736
+ assert self._timeout == 0
737
+ assert self._num_workers == 0
738
+
739
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
740
+ # Taking care of distributed sharding
741
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
742
+ # For BC, use default SHARDING_PRIORITIES
743
+ torch.utils.data.graph_settings.apply_sharding(
744
+ self._dataset, self._world_size, self._rank
745
+ )
746
+
747
+ self._dataset_fetcher = _DatasetKind.create_fetcher(
748
+ self._dataset_kind,
749
+ self._dataset,
750
+ self._auto_collation,
751
+ self._collate_fn,
752
+ self._drop_last,
753
+ )
754
+
755
+ def _next_data(self):
756
+ index = self._next_index() # may raise StopIteration
757
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
758
+ if self._pin_memory:
759
+ data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
760
+ return data
761
+
762
+
763
+ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
764
+ r"""Iterates once over the DataLoader's dataset, as specified by the sampler."""
765
+
766
+ # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
767
+ #
768
+ # Preliminary:
769
+ #
770
+ # Our data model looks like this (queues are indicated with curly brackets):
771
+ #
772
+ # main process ||
773
+ # | ||
774
+ # {index_queue} ||
775
+ # | ||
776
+ # worker processes || DATA
777
+ # | ||
778
+ # {worker_result_queue} || FLOW
779
+ # | ||
780
+ # pin_memory_thread of main process || DIRECTION
781
+ # | ||
782
+ # {data_queue} ||
783
+ # | ||
784
+ # data output \/
785
+ #
786
+ # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
787
+ # `pin_memory=False`.
788
+ #
789
+ #
790
+ # Terminating multiprocessing logic requires very careful design. In
791
+ # particular, we need to make sure that
792
+ #
793
+ # 1. The iterator gracefully exits the workers when its last reference is
794
+ # gone or it is depleted.
795
+ #
796
+ # In this case, the workers should be gracefully exited because the
797
+ # main process may still need to continue to run, and we want cleaning
798
+ # up code in the workers to be executed (e.g., releasing GPU memory).
799
+ # Naturally, we implement the shutdown logic in `__del__` of
800
+ # DataLoaderIterator.
801
+ #
802
+ # We delay the discussion on the logic in this case until later.
803
+ #
804
+ # 2. The iterator exits the workers when the loader process and/or worker
805
+ # processes exits normally or with error.
806
+ #
807
+ # We set all workers and `pin_memory_thread` to have `daemon=True`.
808
+ #
809
+ # You may ask, why can't we make the workers non-daemonic, and
810
+ # gracefully exit using the same logic as we have in `__del__` when the
811
+ # iterator gets deleted (see 1 above)?
812
+ #
813
+ # First of all, `__del__` is **not** guaranteed to be called when
814
+ # interpreter exits. Even if it is called, by the time it executes,
815
+ # many Python core library resources may already be freed, and even
816
+ # simple things like acquiring an internal lock of a queue may hang.
817
+ # Therefore, in this case, we actually need to prevent `__del__` from
818
+ # being executed, and rely on the automatic termination of daemonic
819
+ # children.
820
+ #
821
+ # Thus, we register an `atexit` hook that sets a global flag
822
+ # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
823
+ # reverse order of registration, we are guaranteed that this flag is
824
+ # set before library resources we use are freed (which, at least in
825
+ # CPython, is done via an `atexit` handler defined in
826
+ # `multiprocessing/util.py`
827
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
828
+ # registered when an object requiring this mechanism is first
829
+ # created, e.g., `mp.Queue`
830
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
831
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
832
+ # )
833
+ #
834
+ # So in `__del__`, we check if `_utils.python_exit_status` is set or
835
+ # `None` (freed), and perform no-op if so.
836
+ #
837
+ # However, simply letting library clean-up codes run can also be bad,
838
+ # because such codes (i.e., `multiprocessing.util._exit_function()`)
839
+ # include join putting threads for `mp.Queue`, which can be blocking.
840
+ # Hence, the main process putting threads are called with
841
+ # `cancel_join_thread` at creation. See later section
842
+ # [ 3b. A process won't hang when putting into a queue; ]
843
+ # for more details.
844
+ #
845
+ # Here are two example cases where library clean-up codes can run
846
+ # before `__del__` is called:
847
+ #
848
+ # 1. If we hold onto a reference to the iterator, it more often
849
+ # than not tries to do `multiprocessing` library cleaning before
850
+ # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
851
+ # and thus prevents our cleaning-up code to run first.
852
+ #
853
+ # 2. A similar issue araises when a `DataLoader` is used in a subprocess.
854
+ # When a process ends, it shuts the all its daemonic children
855
+ # down with a SIGTERM (instead of joining them without a timeout).
856
+ # Simiarly for threads, but by a different mechanism. This fact,
857
+ # together with a few implementation details of multiprocessing, forces
858
+ # us to make workers daemonic. All of our problems arise when a
859
+ # DataLoader is used in a subprocess, and are caused by multiprocessing
860
+ # code which looks more or less like this:
861
+ #
862
+ # try:
863
+ # your_function_using_a_dataloader()
864
+ # finally:
865
+ # multiprocessing.util._exit_function()
866
+ #
867
+ # The joining/termination mentioned above happens inside
868
+ # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
869
+ # throws, the stack trace stored in the exception will prevent the
870
+ # frame which uses `DataLoaderIter` to be freed. If the frame has any
871
+ # reference to the `DataLoaderIter` (e.g., in a method of the iter),
872
+ # its `__del__`, which starts the shutdown procedure, will not be
873
+ # called. That, in turn, means that workers aren't notified. Attempting
874
+ # to join in `_exit_function` will then result in a hang.
875
+ #
876
+ # For context, `_exit_function` is also registered as an `atexit` call.
877
+ # So it is unclear to me (@ssnl) why this is needed in a finally block.
878
+ # The code dates back to 2008 and there is no comment on the original
879
+ # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
880
+ # the finally block and the `atexit` registration) that explains this.
881
+ #
882
+ #
883
+ # Finally, another choice is to just shutdown workers with logic in 1
884
+ # above whenever we see an error in `next`. This isn't ideal because
885
+ # a. It prevents users from using try-catch to resume data loading.
886
+ # b. It doesn't prevent hanging if users have references to the
887
+ # iterator.
888
+ #
889
+ # 3. All processes exit if any of them die unexpectedly by fatal signals.
890
+ #
891
+ # As shown above, the workers are set as daemonic children of the main
892
+ # process. However, automatic cleaning-up of such child processes only
893
+ # happens if the parent process exits gracefully (e.g., not via fatal
894
+ # signals like SIGKILL). So we must ensure that each process will exit
895
+ # even the process that should send/receive data to/from it were
896
+ # killed, i.e.,
897
+ #
898
+ # a. A process won't hang when getting from a queue.
899
+ #
900
+ # Even with carefully designed data dependencies (i.e., a `put()`
901
+ # always corresponding to a `get()`), hanging on `get()` can still
902
+ # happen when data in queue is corrupted (e.g., due to
903
+ # `cancel_join_thread` or unexpected exit).
904
+ #
905
+ # For child exit, we set a timeout whenever we try to get data
906
+ # from `data_queue`, and check the workers' status on each timeout
907
+ # and error.
908
+ # See `_DataLoaderiter._get_batch()` and
909
+ # `_DataLoaderiter._try_get_data()` for details.
910
+ #
911
+ # Additionally, for child exit on non-Windows platforms, we also
912
+ # register a SIGCHLD handler (which is supported on Windows) on
913
+ # the main process, which checks if any of the workers fail in the
914
+ # (Python) handler. This is more efficient and faster in detecting
915
+ # worker failures, compared to only using the above mechanism.
916
+ # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
917
+ #
918
+ # For `.get()` calls where the sender(s) is not the workers, we
919
+ # guard them with timeouts, and check the status of the sender
920
+ # when timeout happens:
921
+ # + in the workers, the `_utils.worker.ManagerWatchdog` class
922
+ # checks the status of the main process.
923
+ # + if `pin_memory=True`, when getting from `pin_memory_thread`,
924
+ # check `pin_memory_thread` status periodically until `.get()`
925
+ # returns or see that `pin_memory_thread` died.
926
+ #
927
+ # b. A process won't hang when putting into a queue;
928
+ #
929
+ # We use `mp.Queue` which has a separate background thread to put
930
+ # objects from an unbounded buffer array. The background thread is
931
+ # daemonic and usually automatically joined when the process
932
+ # *exits*.
933
+ #
934
+ # In case that the receiver has ended abruptly while
935
+ # reading from the pipe, the join will hang forever. The usual
936
+ # solution for this in Python is calling `q.cancel_join_thread`,
937
+ # which prevents automatically joining it when finalizing
938
+ # (exiting).
939
+ #
940
+ # Nonetheless, `cancel_join_thread` must only be called when the
941
+ # queue is **not** going to be read from or write into by another
942
+ # process, because it may hold onto a lock or leave corrupted data
943
+ # in the queue, leading other readers/writers to hang.
944
+ #
945
+ # Hence,
946
+ # + For worker processes, we only do so (for their output
947
+ # queues, i.e., `worker_result_queue`) before exiting.
948
+ # + For `pin_memory_thread`, its output queue `data_queue` is a
949
+ # `queue.Queue` that does blocking `put` if the queue is full.
950
+ # So there is no above problem, but as a result, in
951
+ # `_pin_memory_loop`, we do need to wrap the `put` in a loop
952
+ # that breaks not only upon success, but also when the main
953
+ # process stops reading, i.e., is shutting down.
954
+ # + For loader process, we `cancel_join_thread()` for all
955
+ # `_index_queues` because the whole purpose of workers and
956
+ # `pin_memory_thread` is to serve the loader process. If
957
+ # loader process is already exiting, we don't really care if
958
+ # the queues are corrupted.
959
+ #
960
+ #
961
+ # Now let's get back to 1:
962
+ # how we gracefully exit the workers when the last reference to the
963
+ # iterator is gone.
964
+ #
965
+ # To achieve this, we implement the following logic along with the design
966
+ # choices mentioned above:
967
+ #
968
+ # `workers_done_event`:
969
+ # A `multiprocessing.Event` shared among the main process and all worker
970
+ # processes. This is used to signal the workers that the iterator is
971
+ # shutting down. After it is set, they will not send processed data to
972
+ # queues anymore, and only wait for the final `None` before exiting.
973
+ # `done_event` isn't strictly needed. I.e., we can just check for `None`
974
+ # from the input queue, but it allows us to skip wasting resources
975
+ # processing data if we are already shutting down.
976
+ #
977
+ # `pin_memory_thread_done_event`:
978
+ # A `threading.Event` for a similar purpose to that of
979
+ # `workers_done_event`, but is for the `pin_memory_thread`. The reason
980
+ # that separate events are needed is that `pin_memory_thread` reads from
981
+ # the output queue of the workers. But the workers, upon seeing that
982
+ # `workers_done_event` is set, only wants to see the final `None`, and is
983
+ # not required to flush all data in the output queue (e.g., it may call
984
+ # `cancel_join_thread` on that queue if its `IterableDataset` iterator
985
+ # happens to exhaust coincidentally, which is out of the control of the
986
+ # main process). Thus, since we will exit `pin_memory_thread` before the
987
+ # workers (see below), two separete events are used.
988
+ #
989
+ # NOTE: In short, the protocol is that the main process will set these
990
+ # `done_event`s and then the corresponding processes/threads a `None`,
991
+ # and that they may exit at any time after receiving the `None`.
992
+ #
993
+ # NOTE: Using `None` as the final signal is valid, since normal data will
994
+ # always be a 2-tuple with the 1st element being the index of the data
995
+ # transferred (different from dataset index/key), and the 2nd being
996
+ # either the dataset key or the data sample (depending on which part
997
+ # of the data model the queue is at).
998
+ #
999
+ # [ worker processes ]
1000
+ # While loader process is alive:
1001
+ # Get from `index_queue`.
1002
+ # If get anything else,
1003
+ # Check `workers_done_event`.
1004
+ # If set, continue to next iteration
1005
+ # i.e., keep getting until see the `None`, then exit.
1006
+ # Otherwise, process data:
1007
+ # If is fetching from an `IterableDataset` and the iterator
1008
+ # is exhausted, send an `_IterableDatasetStopIteration`
1009
+ # object to signal iteration end. The main process, upon
1010
+ # receiving such an object, will send `None` to this
1011
+ # worker and not use the corresponding `index_queue`
1012
+ # anymore.
1013
+ # If timed out,
1014
+ # No matter `workers_done_event` is set (still need to see `None`)
1015
+ # or not, must continue to next iteration.
1016
+ # (outside loop)
1017
+ # If `workers_done_event` is set, (this can be False with `IterableDataset`)
1018
+ # `data_queue.cancel_join_thread()`. (Everything is ending here:
1019
+ # main process won't read from it;
1020
+ # other workers will also call
1021
+ # `cancel_join_thread`.)
1022
+ #
1023
+ # [ pin_memory_thread ]
1024
+ # # No need to check main thread. If this thread is alive, the main loader
1025
+ # # thread must be alive, because this thread is set as daemonic.
1026
+ # While `pin_memory_thread_done_event` is not set:
1027
+ # Get from `worker_result_queue`.
1028
+ # If timed out, continue to get in the next iteration.
1029
+ # Otherwise, process data.
1030
+ # While `pin_memory_thread_done_event` is not set:
1031
+ # Put processed data to `data_queue` (a `queue.Queue` with blocking put)
1032
+ # If timed out, continue to put in the next iteration.
1033
+ # Otherwise, break, i.e., continuing to the out loop.
1034
+ #
1035
+ # NOTE: we don't check the status of the main thread because
1036
+ # 1. if the process is killed by fatal signal, `pin_memory_thread`
1037
+ # ends.
1038
+ # 2. in other cases, either the cleaning-up in __del__ or the
1039
+ # automatic exit of daemonic thread will take care of it.
1040
+ # This won't busy-wait either because `.get(timeout)` does not
1041
+ # busy-wait.
1042
+ #
1043
+ # [ main process ]
1044
+ # In the DataLoader Iter's `__del__`
1045
+ # b. Exit `pin_memory_thread`
1046
+ # i. Set `pin_memory_thread_done_event`.
1047
+ # ii Put `None` in `worker_result_queue`.
1048
+ # iii. Join the `pin_memory_thread`.
1049
+ # iv. `worker_result_queue.cancel_join_thread()`.
1050
+ #
1051
+ # c. Exit the workers.
1052
+ # i. Set `workers_done_event`.
1053
+ # ii. Put `None` in each worker's `index_queue`.
1054
+ # iii. Join the workers.
1055
+ # iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
1056
+ #
1057
+ # NOTE: (c) is better placed after (b) because it may leave corrupted
1058
+ # data in `worker_result_queue`, which `pin_memory_thread`
1059
+ # reads from, in which case the `pin_memory_thread` can only
1060
+ # happen at timing out, which is slow. Nonetheless, same thing
1061
+ # happens if a worker is killed by signal at unfortunate times,
1062
+ # but in other cases, we are better off having a non-corrupted
1063
+ # `worker_result_queue` for `pin_memory_thread`.
1064
+ #
1065
+ # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
1066
+ # can be omitted
1067
+ #
1068
+ # NB: `done_event`s isn't strictly needed. E.g., we can just check for
1069
+ # `None` from `index_queue`, but it allows us to skip wasting resources
1070
+ # processing indices already in `index_queue` if we are already shutting
1071
+ # down.
1072
+
1073
+ def __init__(self, loader):
1074
+ super().__init__(loader)
1075
+
1076
+ self._prefetch_factor = loader.prefetch_factor
1077
+
1078
+ assert self._num_workers > 0
1079
+ assert self._prefetch_factor > 0
1080
+
1081
+ if loader.multiprocessing_context is None:
1082
+ multiprocessing_context = torch.multiprocessing
1083
+ else:
1084
+ multiprocessing_context = loader.multiprocessing_context
1085
+
1086
+ self._worker_init_fn = loader.worker_init_fn
1087
+
1088
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
1089
+ # Additional worker init function will take care of sharding in MP and Distributed
1090
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
1091
+ self._worker_init_fn = functools.partial(
1092
+ _sharding_worker_init_fn,
1093
+ self._worker_init_fn,
1094
+ self._world_size,
1095
+ self._rank,
1096
+ )
1097
+
1098
+ # No certainty which module multiprocessing_context is
1099
+ self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1100
+ self._worker_pids_set = False
1101
+ self._shutdown = False
1102
+ self._workers_done_event = multiprocessing_context.Event()
1103
+
1104
+ self._index_queues = []
1105
+ self._workers = []
1106
+ for i in range(self._num_workers):
1107
+ # No certainty which module multiprocessing_context is
1108
+ index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1109
+ # Need to `cancel_join_thread` here!
1110
+ # See sections (2) and (3b) above.
1111
+ index_queue.cancel_join_thread()
1112
+ w = multiprocessing_context.Process(
1113
+ target=_utils.worker._worker_loop,
1114
+ args=(
1115
+ self._dataset_kind,
1116
+ self._dataset,
1117
+ index_queue,
1118
+ self._worker_result_queue,
1119
+ self._workers_done_event,
1120
+ self._auto_collation,
1121
+ self._collate_fn,
1122
+ self._drop_last,
1123
+ self._base_seed,
1124
+ self._worker_init_fn,
1125
+ i,
1126
+ self._num_workers,
1127
+ self._persistent_workers,
1128
+ self._shared_seed,
1129
+ ),
1130
+ )
1131
+ w.daemon = True
1132
+ # NB: Process.start() actually take some time as it needs to
1133
+ # start a process and pass the arguments over via a pipe.
1134
+ # Therefore, we only add a worker to self._workers list after
1135
+ # it started, so that we do not call .join() if program dies
1136
+ # before it starts, and __del__ tries to join but will get:
1137
+ # AssertionError: can only join a started process.
1138
+ w.start()
1139
+ self._index_queues.append(index_queue)
1140
+ self._workers.append(w)
1141
+
1142
+ if self._pin_memory:
1143
+ self._pin_memory_thread_done_event = threading.Event()
1144
+
1145
+ # Queue is not type-annotated
1146
+ self._data_queue = queue.Queue() # type: ignore[var-annotated]
1147
+ if self._pin_memory_device == "xpu":
1148
+ current_device = torch.xpu.current_device() # type: ignore[attr-defined]
1149
+ elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
1150
+ custom_device_mod = getattr(
1151
+ torch, torch._C._get_privateuse1_backend_name()
1152
+ )
1153
+ current_device = custom_device_mod.current_device()
1154
+ else:
1155
+ current_device = torch.cuda.current_device() # choose cuda for default
1156
+ pin_memory_thread = threading.Thread(
1157
+ target=_utils.pin_memory._pin_memory_loop,
1158
+ args=(
1159
+ self._worker_result_queue,
1160
+ self._data_queue,
1161
+ current_device,
1162
+ self._pin_memory_thread_done_event,
1163
+ self._pin_memory_device,
1164
+ ),
1165
+ )
1166
+ pin_memory_thread.daemon = True
1167
+ pin_memory_thread.start()
1168
+ # Similar to workers (see comment above), we only register
1169
+ # pin_memory_thread once it is started.
1170
+ self._pin_memory_thread = pin_memory_thread
1171
+ else:
1172
+ self._data_queue = self._worker_result_queue # type: ignore[assignment]
1173
+
1174
+ # In some rare cases, persistent workers (daemonic processes)
1175
+ # would be terminated before `__del__` of iterator is invoked
1176
+ # when main process exits
1177
+ # It would cause failure when pin_memory_thread tries to read
1178
+ # corrupted data from worker_result_queue
1179
+ # atexit is used to shutdown thread and child processes in the
1180
+ # right sequence before main process exits
1181
+ if self._persistent_workers and self._pin_memory:
1182
+ import atexit
1183
+
1184
+ for w in self._workers:
1185
+ atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
1186
+
1187
+ # .pid can be None only before process is spawned (not the case, so ignore)
1188
+ _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
1189
+ _utils.signal_handling._set_SIGCHLD_handler()
1190
+ self._worker_pids_set = True
1191
+ self._reset(loader, first_iter=True)
1192
+
1193
+ def _reset(self, loader, first_iter=False):
1194
+ super()._reset(loader, first_iter)
1195
+ self._send_idx = 0 # idx of the next task to be sent to workers
1196
+ self._rcvd_idx = 0 # idx of the next task to be returned in __next__
1197
+ # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
1198
+ # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
1199
+ # \ (worker_id, data) if data is already fetched (out-of-order)
1200
+ self._task_info = {}
1201
+ self._tasks_outstanding = (
1202
+ 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
1203
+ )
1204
+ # A list of booleans representing whether each worker still has work to
1205
+ # do, i.e., not having exhausted its iterable dataset object. It always
1206
+ # contains all `True`s if not using an iterable-style dataset
1207
+ # (i.e., if kind != Iterable).
1208
+ # Not that this indicates that a worker still has work to do *for this epoch*.
1209
+ # It does not mean that a worker is dead. In case of `_persistent_workers`,
1210
+ # the worker will be reset to available in the next epoch.
1211
+ self._workers_status = [True for i in range(self._num_workers)]
1212
+ # Reset the worker queue cycle so it resumes next epoch at worker 0
1213
+ self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
1214
+ # We resume the prefetching in case it was enabled
1215
+ if not first_iter:
1216
+ for idx in range(self._num_workers):
1217
+ self._index_queues[idx].put(
1218
+ _utils.worker._ResumeIteration(self._shared_seed)
1219
+ )
1220
+ resume_iteration_cnt = self._num_workers
1221
+ while resume_iteration_cnt > 0:
1222
+ return_idx, return_data = self._get_data()
1223
+ if isinstance(return_idx, _utils.worker._ResumeIteration):
1224
+ assert return_data is None
1225
+ resume_iteration_cnt -= 1
1226
+ # prime the prefetch loop
1227
+ for _ in range(self._prefetch_factor * self._num_workers):
1228
+ self._try_put_index()
1229
+
1230
+ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
1231
+ # Tries to fetch data from `self._data_queue` once for a given timeout.
1232
+ # This can also be used as inner loop of fetching without timeout, with
1233
+ # the sender status as the loop condition.
1234
+ #
1235
+ # This raises a `RuntimeError` if any worker died expectedly. This error
1236
+ # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
1237
+ # (only for non-Windows platforms), or the manual check below on errors
1238
+ # and timeouts.
1239
+ #
1240
+ # Returns a 2-tuple:
1241
+ # (bool: whether successfully get data, any: data if successful else None)
1242
+ try:
1243
+ data = self._data_queue.get(timeout=timeout)
1244
+ return (True, data)
1245
+ except Exception as e:
1246
+ # At timeout and error, we manually check whether any worker has
1247
+ # failed. Note that this is the only mechanism for Windows to detect
1248
+ # worker failures.
1249
+ failed_workers = []
1250
+ for worker_id, w in enumerate(self._workers):
1251
+ if self._workers_status[worker_id] and not w.is_alive():
1252
+ failed_workers.append(w)
1253
+ self._mark_worker_as_unavailable(worker_id)
1254
+ if len(failed_workers) > 0:
1255
+ pids_str = ", ".join(str(w.pid) for w in failed_workers)
1256
+ raise RuntimeError(
1257
+ f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly"
1258
+ ) from e
1259
+ if isinstance(e, queue.Empty):
1260
+ return (False, None)
1261
+
1262
+ import errno
1263
+ import tempfile
1264
+
1265
+ try:
1266
+ # Raise an exception if we are this close to the FDs limit.
1267
+ # Apparently, trying to open only one file is not a sufficient
1268
+ # test.
1269
+ # See NOTE [ DataLoader on Linux and open files limit ]
1270
+ fds_limit_margin = 10
1271
+ fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
1272
+ except OSError as e:
1273
+ if e.errno == errno.EMFILE:
1274
+ raise RuntimeError(
1275
+ "Too many open files. Communication with the"
1276
+ " workers is no longer possible. Please increase the"
1277
+ " limit using `ulimit -n` in the shell or change the"
1278
+ " sharing strategy by calling"
1279
+ " `torch.multiprocessing.set_sharing_strategy('file_system')`"
1280
+ " at the beginning of your code"
1281
+ ) from None
1282
+ raise
1283
+
1284
+ # NOTE [ DataLoader on Linux and open files limit ]
1285
+ #
1286
+ # On Linux when DataLoader is used with multiprocessing we pass the data between
1287
+ # the root process and the workers through SHM files. We remove those files from
1288
+ # the filesystem as soon as they are created and keep them alive by
1289
+ # passing around their file descriptors through AF_UNIX sockets. (See
1290
+ # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
1291
+ # the wiki (https://github.com/pytorch/pytorch/wiki).)
1292
+ #
1293
+ # This sometimes leads us to exceeding the open files limit. When that happens,
1294
+ # and the offending file descriptor is coming over a socket, the `socket` Python
1295
+ # package silently strips the file descriptor from the message, setting only the
1296
+ # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
1297
+ # it _indicates that some control data were discarded due to lack of space in
1298
+ # the buffer for ancillary data_). This might reflect the C implementation of
1299
+ # AF_UNIX sockets.
1300
+ #
1301
+ # This behaviour can be reproduced with the script and instructions at the
1302
+ # bottom of this note.
1303
+ #
1304
+ # When that happens, the standard Python `multiprocessing` (and not
1305
+ # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
1306
+ #
1307
+ # Sometimes, instead of the FD being stripped, you may get an `OSError:
1308
+ # Too many open files`, both in the script below and in DataLoader. However,
1309
+ # this is rare and seems to be nondeterministic.
1310
+ #
1311
+ #
1312
+ # #!/usr/bin/env python3
1313
+ # import sys
1314
+ # import socket
1315
+ # import os
1316
+ # import array
1317
+ # import shutil
1318
+ # import socket
1319
+ #
1320
+ #
1321
+ # if len(sys.argv) != 4:
1322
+ # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
1323
+ # sys.exit(1)
1324
+ #
1325
+ # if __name__ == '__main__':
1326
+ # dirname = sys.argv[1]
1327
+ # sock_path = dirname + "/sock"
1328
+ # iterations = int(sys.argv[2])
1329
+ # def dummy_path(i):
1330
+ # return dirname + "/" + str(i) + ".dummy"
1331
+ #
1332
+ #
1333
+ # if sys.argv[3] == 'send':
1334
+ # while not os.path.exists(sock_path):
1335
+ # pass
1336
+ # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1337
+ # client.connect(sock_path)
1338
+ # for i in range(iterations):
1339
+ # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
1340
+ # ancdata = array.array('i', [fd])
1341
+ # msg = bytes([i % 256])
1342
+ # print("Sending fd ", fd, " (iteration #", i, ")")
1343
+ # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
1344
+ #
1345
+ #
1346
+ # else:
1347
+ # assert sys.argv[3] == 'recv'
1348
+ #
1349
+ # if os.path.exists(dirname):
1350
+ # raise Exception("Directory exists")
1351
+ #
1352
+ # os.mkdir(dirname)
1353
+ #
1354
+ # print("Opening socket...")
1355
+ # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1356
+ # server.bind(sock_path)
1357
+ #
1358
+ # print("Listening...")
1359
+ # for i in range(iterations):
1360
+ # a = array.array('i')
1361
+ # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
1362
+ # assert(len(ancdata) == 1)
1363
+ # cmsg_level, cmsg_type, cmsg_data = ancdata[0]
1364
+ # a.frombytes(cmsg_data)
1365
+ # print("Received fd ", a[0], " (iteration #", i, ")")
1366
+ #
1367
+ # shutil.rmtree(dirname)
1368
+ #
1369
+ # Steps to reproduce:
1370
+ #
1371
+ # 1. Run two shells and set lower file descriptor limit in the receiving one:
1372
+ # (shell1) ulimit -n 1020
1373
+ # (shell2) ulimit -n 1022
1374
+ #
1375
+ # 2. Run the script above with the `recv` option in the first shell
1376
+ # (shell1) ./test_socket.py sock_tmp 1017 recv
1377
+ #
1378
+ # 3. Run the script with the `send` option in the second shell:
1379
+ # (shell2) ./test_socket.py sock_tmp 1017 send
1380
+
1381
+ def _get_data(self):
1382
+ # Fetches data from `self._data_queue`.
1383
+ #
1384
+ # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
1385
+ # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
1386
+ # in a loop. This is the only mechanism to detect worker failures for
1387
+ # Windows. For other platforms, a SIGCHLD handler is also used for
1388
+ # worker failure detection.
1389
+ #
1390
+ # If `pin_memory=True`, we also need check if `pin_memory_thread` had
1391
+ # died at timeouts.
1392
+ if self._timeout > 0:
1393
+ success, data = self._try_get_data(self._timeout)
1394
+ if success:
1395
+ return data
1396
+ else:
1397
+ raise RuntimeError(
1398
+ f"DataLoader timed out after {self._timeout} seconds"
1399
+ )
1400
+ elif self._pin_memory:
1401
+ while self._pin_memory_thread.is_alive():
1402
+ success, data = self._try_get_data()
1403
+ if success:
1404
+ return data
1405
+ else:
1406
+ # while condition is false, i.e., pin_memory_thread died.
1407
+ raise RuntimeError("Pin memory thread exited unexpectedly")
1408
+ # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
1409
+ # need to call `.task_done()` because we don't use `.join()`.
1410
+ else:
1411
+ while True:
1412
+ success, data = self._try_get_data()
1413
+ if success:
1414
+ return data
1415
+
1416
+ def _next_data(self):
1417
+ while True:
1418
+ # If the worker responsible for `self._rcvd_idx` has already ended
1419
+ # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
1420
+ # we try to advance `self._rcvd_idx` to find the next valid index.
1421
+ #
1422
+ # This part needs to run in the loop because both the `self._get_data()`
1423
+ # call and `_IterableDatasetStopIteration` check below can mark
1424
+ # extra worker(s) as dead.
1425
+ while self._rcvd_idx < self._send_idx:
1426
+ info = self._task_info[self._rcvd_idx]
1427
+ worker_id = info[0]
1428
+ if (
1429
+ len(info) == 2 or self._workers_status[worker_id]
1430
+ ): # has data or is still active
1431
+ break
1432
+ del self._task_info[self._rcvd_idx]
1433
+ self._rcvd_idx += 1
1434
+ else:
1435
+ # no valid `self._rcvd_idx` is found (i.e., didn't break)
1436
+ if not self._persistent_workers:
1437
+ self._shutdown_workers()
1438
+ raise StopIteration
1439
+
1440
+ # Now `self._rcvd_idx` is the batch index we want to fetch
1441
+
1442
+ # Check if the next sample has already been generated
1443
+ if len(self._task_info[self._rcvd_idx]) == 2:
1444
+ data = self._task_info.pop(self._rcvd_idx)[1]
1445
+ return self._process_data(data)
1446
+
1447
+ assert not self._shutdown and self._tasks_outstanding > 0
1448
+ idx, data = self._get_data()
1449
+ self._tasks_outstanding -= 1
1450
+ if self._dataset_kind == _DatasetKind.Iterable:
1451
+ # Check for _IterableDatasetStopIteration
1452
+ if isinstance(data, _utils.worker._IterableDatasetStopIteration):
1453
+ if self._persistent_workers:
1454
+ self._workers_status[data.worker_id] = False
1455
+ else:
1456
+ self._mark_worker_as_unavailable(data.worker_id)
1457
+ self._try_put_index()
1458
+ continue
1459
+
1460
+ if idx != self._rcvd_idx:
1461
+ # store out-of-order samples
1462
+ self._task_info[idx] += (data,)
1463
+ else:
1464
+ del self._task_info[idx]
1465
+ return self._process_data(data)
1466
+
1467
+ def _try_put_index(self):
1468
+ assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
1469
+
1470
+ try:
1471
+ index = self._next_index()
1472
+ except StopIteration:
1473
+ return
1474
+ for _ in range(self._num_workers): # find the next active worker, if any
1475
+ worker_queue_idx = next(self._worker_queue_idx_cycle)
1476
+ if self._workers_status[worker_queue_idx]:
1477
+ break
1478
+ else:
1479
+ # not found (i.e., didn't break)
1480
+ return
1481
+
1482
+ self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined]
1483
+ self._task_info[self._send_idx] = (worker_queue_idx,)
1484
+ self._tasks_outstanding += 1
1485
+ self._send_idx += 1
1486
+
1487
+ def _process_data(self, data):
1488
+ self._rcvd_idx += 1
1489
+ self._try_put_index()
1490
+ if isinstance(data, ExceptionWrapper):
1491
+ data.reraise()
1492
+ return data
1493
+
1494
+ def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
1495
+ # Mark a worker as having finished its work e.g., due to
1496
+ # exhausting an `IterableDataset`. This should be used only when this
1497
+ # `_MultiProcessingDataLoaderIter` is going to continue running.
1498
+
1499
+ assert self._workers_status[worker_id] or (
1500
+ self._persistent_workers and shutdown
1501
+ )
1502
+
1503
+ # Signal termination to that specific worker.
1504
+ q = self._index_queues[worker_id]
1505
+ # Indicate that no more data will be put on this queue by the current
1506
+ # process.
1507
+ q.put(None)
1508
+
1509
+ # Note that we don't actually join the worker here, nor do we remove the
1510
+ # worker's pid from C side struct because (1) joining may be slow, and
1511
+ # (2) since we don't join, the worker may still raise error, and we
1512
+ # prefer capturing those, rather than ignoring them, even though they
1513
+ # are raised after the worker has finished its job.
1514
+ # Joinning is deferred to `_shutdown_workers`, which it is called when
1515
+ # all workers finish their jobs (e.g., `IterableDataset` replicas) or
1516
+ # when this iterator is garbage collected.
1517
+
1518
+ self._workers_status[worker_id] = False
1519
+
1520
+ assert self._workers_done_event.is_set() == shutdown
1521
+
1522
+ def _shutdown_workers(self):
1523
+ # Called when shutting down this `_MultiProcessingDataLoaderIter`.
1524
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
1525
+ # the logic of this function.
1526
+ if (
1527
+ _utils is None
1528
+ or _utils.python_exit_status is True
1529
+ or _utils.python_exit_status is None
1530
+ ):
1531
+ # See (2) of the note. If Python is shutting down, do no-op.
1532
+ return
1533
+ # Normal exit when last reference is gone / iterator is depleted.
1534
+ # See (1) and the second half of the note.
1535
+ if not self._shutdown:
1536
+ self._shutdown = True
1537
+ try:
1538
+ # Normal exit when last reference is gone / iterator is depleted.
1539
+ # See (1) and the second half of the note.
1540
+
1541
+ # Exit `pin_memory_thread` first because exiting workers may leave
1542
+ # corrupted data in `worker_result_queue` which `pin_memory_thread`
1543
+ # reads from.
1544
+ if hasattr(self, "_pin_memory_thread"):
1545
+ # Use hasattr in case error happens before we set the attribute.
1546
+ self._pin_memory_thread_done_event.set()
1547
+ # Send something to pin_memory_thread in case it is waiting
1548
+ # so that it can wake up and check `pin_memory_thread_done_event`
1549
+ self._worker_result_queue.put((None, None))
1550
+ self._pin_memory_thread.join()
1551
+ self._worker_result_queue.cancel_join_thread()
1552
+ self._worker_result_queue.close()
1553
+
1554
+ # Exit workers now.
1555
+ self._workers_done_event.set()
1556
+ for worker_id in range(len(self._workers)):
1557
+ # Get number of workers from `len(self._workers)` instead of
1558
+ # `self._num_workers` in case we error before starting all
1559
+ # workers.
1560
+ # If we are using workers_status with persistent_workers
1561
+ # we have to shut it down because the worker is paused
1562
+ if self._persistent_workers or self._workers_status[worker_id]:
1563
+ self._mark_worker_as_unavailable(worker_id, shutdown=True)
1564
+ for w in self._workers:
1565
+ # We should be able to join here, but in case anything went
1566
+ # wrong, we set a timeout and if the workers fail to join,
1567
+ # they are killed in the `finally` block.
1568
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1569
+ for q in self._index_queues:
1570
+ q.cancel_join_thread()
1571
+ q.close()
1572
+ finally:
1573
+ # Even though all this function does is putting into queues that
1574
+ # we have called `cancel_join_thread` on, weird things can
1575
+ # happen when a worker is killed by a signal, e.g., hanging in
1576
+ # `Event.set()`. So we need to guard this with SIGCHLD handler,
1577
+ # and remove pids from the C side data structure only at the
1578
+ # end.
1579
+ #
1580
+ # FIXME: Unfortunately, for Windows, we are missing a worker
1581
+ # error detection mechanism here in this function, as it
1582
+ # doesn't provide a SIGCHLD handler.
1583
+ if self._worker_pids_set:
1584
+ _utils.signal_handling._remove_worker_pids(id(self))
1585
+ self._worker_pids_set = False
1586
+ for w in self._workers:
1587
+ if w.is_alive():
1588
+ # Existing mechanisms try to make the workers exit
1589
+ # peacefully, but in case that we unfortunately reach
1590
+ # here, which we shouldn't, (e.g., pytorch/pytorch#39570),
1591
+ # we kill the worker.
1592
+ w.terminate()
1593
+
1594
+ # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
1595
+ @staticmethod
1596
+ def _clean_up_worker(w):
1597
+ try:
1598
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1599
+ finally:
1600
+ if w.is_alive():
1601
+ w.terminate()
1602
+
1603
+ def __del__(self):
1604
+ self._shutdown_workers()
.venv/Lib/site-packages/torch/utils/data/datapipes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from torch.utils.data.datapipes import dataframe as dataframe, iter as iter, map as map
.venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (287 Bytes). View file
 
.venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-39.pyc ADDED
Binary file (6.2 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-39.pyc ADDED
Binary file (8.45 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-39.pyc ADDED
Binary file (11.2 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-39.pyc ADDED
Binary file (16.7 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/datapipes/_decorator.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import inspect
3
+ from functools import wraps
4
+ from typing import Any, Callable, get_type_hints, Optional, Type, Union
5
+
6
+ from torch.utils.data.datapipes._typing import _DataPipeMeta
7
+ from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
8
+
9
+
10
+ ######################################################
11
+ # Functional API
12
+ ######################################################
13
+ class functional_datapipe:
14
+ name: str
15
+
16
+ def __init__(self, name: str, enable_df_api_tracing=False) -> None:
17
+ """
18
+ Define a functional datapipe.
19
+
20
+ Args:
21
+ enable_df_api_tracing - if set, any returned DataPipe would accept
22
+ DataFrames API in tracing mode.
23
+ """
24
+ self.name = name
25
+ self.enable_df_api_tracing = enable_df_api_tracing
26
+
27
+ def __call__(self, cls):
28
+ if issubclass(cls, IterDataPipe):
29
+ if isinstance(cls, Type): # type: ignore[arg-type]
30
+ if not isinstance(cls, _DataPipeMeta):
31
+ raise TypeError(
32
+ "`functional_datapipe` can only decorate IterDataPipe"
33
+ )
34
+ # with non_deterministic decorator
35
+ else:
36
+ if not isinstance(cls, non_deterministic) and not (
37
+ hasattr(cls, "__self__")
38
+ and isinstance(cls.__self__, non_deterministic)
39
+ ):
40
+ raise TypeError(
41
+ "`functional_datapipe` can only decorate IterDataPipe"
42
+ )
43
+ IterDataPipe.register_datapipe_as_function(
44
+ self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing
45
+ )
46
+ elif issubclass(cls, MapDataPipe):
47
+ MapDataPipe.register_datapipe_as_function(self.name, cls)
48
+
49
+ return cls
50
+
51
+
52
+ ######################################################
53
+ # Determinism
54
+ ######################################################
55
+ _determinism: bool = False
56
+
57
+
58
+ class guaranteed_datapipes_determinism:
59
+ prev: bool
60
+
61
+ def __init__(self) -> None:
62
+ global _determinism
63
+ self.prev = _determinism
64
+ _determinism = True
65
+
66
+ def __enter__(self) -> None:
67
+ pass
68
+
69
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
70
+ global _determinism
71
+ _determinism = self.prev
72
+
73
+
74
+ class non_deterministic:
75
+ cls: Optional[Type[IterDataPipe]] = None
76
+ # TODO: Lambda for picking
77
+ deterministic_fn: Callable[[], bool]
78
+
79
+ def __init__(self, arg: Union[Type[IterDataPipe], Callable[[], bool]]) -> None:
80
+ # 1. Decorator doesn't have any argument
81
+ if isinstance(arg, Type): # type: ignore[arg-type]
82
+ if not issubclass(arg, IterDataPipe): # type: ignore[arg-type]
83
+ raise TypeError(
84
+ "Only `IterDataPipe` can be decorated with `non_deterministic`"
85
+ f", but {arg.__name__} is found"
86
+ )
87
+ self.cls = arg # type: ignore[assignment]
88
+ # 2. Decorator has an argument of a function
89
+ # This class should behave differently given different inputs. Use this
90
+ # function to verify the determinism for each instance.
91
+ # When the function returns True, the instance is non-deterministic. Otherwise,
92
+ # the instance is a deterministic DataPipe.
93
+ elif isinstance(arg, Callable): # type:ignore[arg-type]
94
+ self.deterministic_fn = arg # type: ignore[assignment, misc]
95
+ else:
96
+ raise TypeError(f"{arg} can not be decorated by non_deterministic")
97
+
98
+ def __call__(self, *args, **kwargs):
99
+ global _determinism
100
+ # Decorate IterDataPipe
101
+ if self.cls is not None:
102
+ if _determinism:
103
+ raise TypeError(
104
+ f"{self.cls.__name__} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. "
105
+ "You can turn off determinism for this DataPipe if that is acceptable "
106
+ "for your application"
107
+ )
108
+ return self.cls(*args, **kwargs) # type: ignore[call-arg]
109
+
110
+ # Decorate with a functional argument
111
+ if not (
112
+ isinstance(args[0], type)
113
+ and issubclass(args[0], IterDataPipe) # type: ignore[arg-type]
114
+ ):
115
+ raise TypeError(
116
+ f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found"
117
+ )
118
+ self.cls = args[0]
119
+ return self.deterministic_wrapper_fn
120
+
121
+ def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe:
122
+ res = self.deterministic_fn(*args, **kwargs) # type: ignore[call-arg, misc]
123
+ if not isinstance(res, bool):
124
+ raise TypeError(
125
+ "deterministic_fn of `non_deterministic` decorator is required "
126
+ f"to return a boolean value, but {type(res)} is found"
127
+ )
128
+ global _determinism
129
+ if _determinism and res:
130
+ raise TypeError(
131
+ f"{self.cls.__name__} is non-deterministic with the inputs, but you set " # type: ignore[union-attr]
132
+ "'guaranteed_datapipes_determinism'. You can turn off determinism "
133
+ "for this DataPipe if that is acceptable for your application"
134
+ )
135
+ return self.cls(*args, **kwargs) # type: ignore[call-arg, misc]
136
+
137
+
138
+ ######################################################
139
+ # Type validation
140
+ ######################################################
141
+ # Validate each argument of DataPipe with hint as a subtype of the hint.
142
+ def argument_validation(f):
143
+ signature = inspect.signature(f)
144
+ hints = get_type_hints(f)
145
+
146
+ @wraps(f)
147
+ def wrapper(*args, **kwargs):
148
+ bound = signature.bind(*args, **kwargs)
149
+ for argument_name, value in bound.arguments.items():
150
+ if argument_name in hints and isinstance(
151
+ hints[argument_name], _DataPipeMeta
152
+ ):
153
+ hint = hints[argument_name]
154
+ if not isinstance(value, IterDataPipe):
155
+ raise TypeError(
156
+ f"Expected argument '{argument_name}' as a IterDataPipe, but found {type(value)}"
157
+ )
158
+ if not value.type.issubtype(hint.type):
159
+ raise TypeError(
160
+ f"Expected type of argument '{argument_name}' as a subtype of "
161
+ f"hint {hint.type}, but found {value.type}"
162
+ )
163
+
164
+ return f(*args, **kwargs)
165
+
166
+ return wrapper
167
+
168
+
169
+ # Default value is True
170
+ _runtime_validation_enabled: bool = True
171
+
172
+
173
+ class runtime_validation_disabled:
174
+ prev: bool
175
+
176
+ def __init__(self) -> None:
177
+ global _runtime_validation_enabled
178
+ self.prev = _runtime_validation_enabled
179
+ _runtime_validation_enabled = False
180
+
181
+ def __enter__(self) -> None:
182
+ pass
183
+
184
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
185
+ global _runtime_validation_enabled
186
+ _runtime_validation_enabled = self.prev
187
+
188
+
189
+ # Runtime checking
190
+ # Validate output data is subtype of return hint
191
+ def runtime_validation(f):
192
+ # TODO:
193
+ # Can be extended to validate '__getitem__' and nonblocking
194
+ if f.__name__ != "__iter__":
195
+ raise TypeError(
196
+ f"Can not decorate function {f.__name__} with 'runtime_validation'"
197
+ )
198
+
199
+ @wraps(f)
200
+ def wrapper(self):
201
+ global _runtime_validation_enabled
202
+ if not _runtime_validation_enabled:
203
+ yield from f(self)
204
+ else:
205
+ it = f(self)
206
+ for d in it:
207
+ if not self.type.issubtype_of_instance(d):
208
+ raise RuntimeError(
209
+ f"Expected an instance as subtype of {self.type}, but found {d}({type(d)})"
210
+ )
211
+ yield d
212
+
213
+ return wrapper
.venv/Lib/site-packages/torch/utils/data/datapipes/_hook_iterator.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import inspect
4
+ from enum import Enum
5
+
6
+ import torch
7
+
8
+
9
+ class _SnapshotState(Enum):
10
+ r"""
11
+ These are the snapshotting-related states that IterDataPipes can be in.
12
+
13
+ `NotStarted` - allows you to restore a snapshot and create an iterator with reset
14
+ `Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe
15
+ `Iterating` - can restore, will reset if you create a new iterator
16
+ """
17
+
18
+ NotStarted = 0
19
+ Restored = 1
20
+ Iterating = 2
21
+
22
+
23
+ def _simplify_obj_name(obj) -> str:
24
+ """Simplify the display strings of objects for the purpose of rendering within DataPipe error messages."""
25
+ if inspect.isfunction(obj):
26
+ return obj.__name__
27
+ else:
28
+ return repr(obj)
29
+
30
+
31
+ def _strip_datapipe_from_name(name: str) -> str:
32
+ return name.replace("IterDataPipe", "").replace("MapDataPipe", "")
33
+
34
+
35
+ def _generate_input_args_string(obj):
36
+ """Generate a string for the input arguments of an object."""
37
+ signature = inspect.signature(obj.__class__)
38
+ input_param_names = set(signature.parameters.keys())
39
+ result = []
40
+ for name, value in inspect.getmembers(obj):
41
+ if name in input_param_names:
42
+ result.append((name, _simplify_obj_name(value)))
43
+ return ", ".join([f"{name}={value}" for name, value in result])
44
+
45
+
46
+ def _generate_iterdatapipe_msg(datapipe, simplify_dp_name: bool = False):
47
+ output_string = (
48
+ f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
49
+ )
50
+ if simplify_dp_name:
51
+ output_string = _strip_datapipe_from_name(output_string)
52
+ return output_string
53
+
54
+
55
+ def _gen_invalid_iterdatapipe_msg(datapipe):
56
+ return (
57
+ "This iterator has been invalidated because another iterator has been created "
58
+ f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n"
59
+ "This may be caused multiple references to the same IterDataPipe. We recommend "
60
+ "using `.fork()` if that is necessary."
61
+ )
62
+
63
+
64
+ _feedback_msg = (
65
+ "\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free "
66
+ "to comment on this issue: https://github.com/pytorch/data/issues/45."
67
+ )
68
+
69
+
70
+ def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None:
71
+ r"""
72
+ Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception.
73
+
74
+ In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well.
75
+ """
76
+ if next_method_exists:
77
+ # This is the case where `IterDataPipe` has both `__iter__` and `__next__`.
78
+ # The `_valid_iterator_id` should either be never set (`None`), or set by at most one
79
+ # iterator (`0`). Otherwise, it means there are multiple iterators.
80
+ if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0:
81
+ extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method"
82
+ raise RuntimeError(
83
+ _gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg
84
+ )
85
+ elif (
86
+ hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True
87
+ ):
88
+ if hasattr(datapipe, "_check_valid_iterator_id"):
89
+ if not datapipe._check_valid_iterator_id(iterator_id):
90
+ raise RuntimeError(
91
+ "This iterator has been invalidated, because a new iterator has been created "
92
+ f"from one of the ChildDataPipes of "
93
+ f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}."
94
+ + _feedback_msg
95
+ )
96
+ else:
97
+ raise RuntimeError(
98
+ "ChildDataPipe must have method `_check_valid_iterator_id`."
99
+ )
100
+ elif datapipe._valid_iterator_id != iterator_id:
101
+ raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg)
102
+
103
+
104
+ def _set_datapipe_valid_iterator_id(datapipe):
105
+ """Given a DataPipe, updates its valid iterator ID and reset the DataPipe."""
106
+ if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True:
107
+ if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"):
108
+ datapipe._set_main_datapipe_valid_iterator_id() # reset() is called within this method when appropriate
109
+ else:
110
+ raise RuntimeError(
111
+ "ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`."
112
+ )
113
+ else:
114
+ if datapipe._valid_iterator_id is None:
115
+ datapipe._valid_iterator_id = 0
116
+ else:
117
+ datapipe._valid_iterator_id += 1
118
+ datapipe.reset()
119
+ return datapipe._valid_iterator_id
120
+
121
+
122
+ def hook_iterator(namespace):
123
+ r"""
124
+ Define a hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`.
125
+
126
+ This is done for the purpose of profiling and checking if an iterator is still valid.
127
+ """
128
+
129
+ def profiler_record_fn_context(datapipe):
130
+ if not hasattr(datapipe, "_profile_name"):
131
+ datapipe._profile_name = _generate_iterdatapipe_msg(
132
+ datapipe, simplify_dp_name=True
133
+ )
134
+ return torch.autograd.profiler.record_function(datapipe._profile_name)
135
+
136
+ class IteratorDecorator:
137
+ r"""
138
+ Wrap the iterator and modifying its `__next__` method.
139
+
140
+ This decorator is applied to DataPipes of which `__iter__` method is NOT a generator function.
141
+ Those `__iter__` method commonly returns `self` but not necessarily.
142
+ """
143
+
144
+ def __init__(self, iterator, datapipe, iterator_id, has_next_method):
145
+ self.iterator = iterator
146
+ self.datapipe = datapipe
147
+ self.iterator_id = iterator_id
148
+ self._profiler_enabled = torch.autograd._profiler_enabled()
149
+ # Check if `__iter__` returns `self` and `DataPipe` has `__next__`
150
+ self.self_and_has_next_method = (
151
+ self.iterator is self.datapipe and has_next_method
152
+ )
153
+
154
+ def __iter__(self):
155
+ return self
156
+
157
+ def _get_next(self):
158
+ """Return next with logic related to iterator validity, profiler, and incrementation of samples yielded."""
159
+ _check_iterator_valid(self.datapipe, self.iterator_id)
160
+ result = next(self.iterator)
161
+ if not self.self_and_has_next_method:
162
+ self.datapipe._number_of_samples_yielded += 1
163
+ return result
164
+
165
+ def __next__(self):
166
+ # TODO: Add try-except to in-place reduce traceback from the Exception
167
+ # See: https://github.com/pytorch/data/issues/284
168
+ if self._profiler_enabled:
169
+ with profiler_record_fn_context(self.datapipe):
170
+ return self._get_next()
171
+ else: # Decided against using `contextlib.nullcontext` for performance reasons
172
+ return self._get_next()
173
+
174
+ def __getattr__(self, name):
175
+ return getattr(self.iterator, name)
176
+
177
+ func = namespace["__iter__"]
178
+
179
+ # ``__iter__`` of IterDataPipe is a generator function
180
+ if inspect.isgeneratorfunction(func):
181
+
182
+ @functools.wraps(func)
183
+ def wrap_generator(*args, **kwargs):
184
+ gen = func(*args, **kwargs)
185
+ datapipe = args[0]
186
+ if datapipe._fast_forward_iterator:
187
+ it = datapipe._fast_forward_iterator
188
+ datapipe._fast_forward_iterator = None
189
+ datapipe._snapshot_state = _SnapshotState.Iterating
190
+ while True:
191
+ try:
192
+ yield next(it)
193
+ except StopIteration:
194
+ return
195
+ iterator_id = _set_datapipe_valid_iterator_id(
196
+ datapipe
197
+ ) # This ID is tied to each created iterator
198
+ _profiler_enabled = torch.autograd._profiler_enabled()
199
+ try:
200
+ if _profiler_enabled:
201
+ with profiler_record_fn_context(datapipe):
202
+ response = gen.send(None)
203
+ else:
204
+ response = gen.send(None)
205
+
206
+ while True:
207
+ datapipe._number_of_samples_yielded += 1
208
+ request = yield response
209
+ # Pass through here every time `__next__` is called
210
+ if _profiler_enabled:
211
+ with profiler_record_fn_context(datapipe):
212
+ _check_iterator_valid(datapipe, iterator_id)
213
+ response = gen.send(request)
214
+ else: # Decided against using `contextlib.nullcontext` for performance reasons
215
+ _check_iterator_valid(datapipe, iterator_id)
216
+ response = gen.send(request)
217
+ except StopIteration as e:
218
+ return
219
+ except Exception as e:
220
+ # TODO: Simplify the traceback message to skip over `response = gen.send(None)`
221
+ # Part of https://github.com/pytorch/data/issues/284
222
+ datapipe = args[0]
223
+ msg = "thrown by __iter__ of"
224
+ single_iterator_msg = "single iterator per IterDataPipe constraint"
225
+ if hasattr(e.args, "__len__"):
226
+ full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
227
+ if len(e.args) == 0 or not isinstance(
228
+ e.args[0], str
229
+ ): # If an exception message doesn't exist
230
+ e.args = (f"\nThis exception is {full_msg}",)
231
+ elif msg not in e.args[0] and single_iterator_msg not in e.args[0]:
232
+ e.args = (
233
+ e.args[0] + f"\nThis exception is {full_msg}",
234
+ ) + e.args[1:]
235
+ raise
236
+
237
+ namespace["__iter__"] = wrap_generator
238
+ else: # ``__iter__`` of IterDataPipe is NOT a generator function
239
+ # IterDataPipe is an iterator with both ``__iter__`` and ``__next__``
240
+ # And ``__iter__`` may or may not return `self`
241
+ if "__next__" in namespace: # If `__next__` exists, put a wrapper around it
242
+ next_func = namespace["__next__"]
243
+
244
+ @functools.wraps(next_func)
245
+ def wrap_next(*args, **kwargs):
246
+ datapipe = args[0]
247
+ if torch.autograd._profiler_enabled():
248
+ with profiler_record_fn_context(datapipe):
249
+ result = next_func(*args, **kwargs)
250
+ else:
251
+ result = next_func(*args, **kwargs)
252
+ datapipe._number_of_samples_yielded += 1
253
+ return result
254
+
255
+ namespace["__next__"] = wrap_next
256
+
257
+ # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but
258
+ # the user will be violating the iterator protocol. Potential issue:
259
+ # 1. Valid iterator ID may not update or checked properly
260
+ # 2. The number of samples yielded will be miscounted
261
+
262
+ # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators
263
+ @functools.wraps(func)
264
+ def wrap_iter(*args, **kwargs):
265
+ iter_ret = func(*args, **kwargs)
266
+ datapipe = args[0]
267
+ datapipe._snapshot_state = _SnapshotState.Iterating
268
+ if datapipe._fast_forward_iterator:
269
+ iter_ret = datapipe._fast_forward_iterator
270
+ datapipe._fast_forward_iterator = None
271
+ return iter_ret
272
+ iterator_id = _set_datapipe_valid_iterator_id(
273
+ datapipe
274
+ ) # This ID is tied to each created iterator
275
+ return IteratorDecorator(
276
+ iter_ret, datapipe, iterator_id, "__next__" in namespace
277
+ )
278
+
279
+ namespace["__iter__"] = wrap_iter
.venv/Lib/site-packages/torch/utils/data/datapipes/_typing.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # Taking reference from official Python typing
3
+ # https://github.com/python/cpython/blob/master/Lib/typing.py
4
+
5
+ import collections
6
+ import functools
7
+ import numbers
8
+ import sys
9
+
10
+ # Please check [Note: TypeMeta and TypeAlias]
11
+ # In case of metaclass conflict due to ABCMeta or _ProtocolMeta
12
+ # For Python 3.9, only Protocol in typing uses metaclass
13
+ from abc import ABCMeta
14
+
15
+ # TODO: Use TypeAlias when Python 3.6 is deprecated
16
+ from typing import ( # type: ignore[attr-defined]
17
+ _eval_type,
18
+ _GenericAlias,
19
+ _tp_cache,
20
+ _type_check,
21
+ _type_repr,
22
+ Any,
23
+ Dict,
24
+ ForwardRef,
25
+ Generic,
26
+ get_type_hints,
27
+ Iterator,
28
+ List,
29
+ Set,
30
+ Tuple,
31
+ TypeVar,
32
+ Union,
33
+ )
34
+
35
+ from torch.utils.data.datapipes._hook_iterator import _SnapshotState, hook_iterator
36
+
37
+
38
+ class GenericMeta(ABCMeta): # type: ignore[no-redef]
39
+ pass
40
+
41
+
42
+ class Integer(numbers.Integral):
43
+ pass
44
+
45
+
46
+ class Boolean(numbers.Integral):
47
+ pass
48
+
49
+
50
+ # Python 'type' object is not subscriptable
51
+ # Tuple[int, List, dict] -> valid
52
+ # tuple[int, list, dict] -> invalid
53
+ # Map Python 'type' to abstract base class
54
+ TYPE2ABC = {
55
+ bool: Boolean,
56
+ int: Integer,
57
+ float: numbers.Real,
58
+ complex: numbers.Complex,
59
+ dict: Dict,
60
+ list: List,
61
+ set: Set,
62
+ tuple: Tuple,
63
+ None: type(None),
64
+ }
65
+
66
+
67
+ def issubtype(left, right, recursive=True):
68
+ r"""
69
+ Check if the left-side type is a subtype of the right-side type.
70
+
71
+ If any of type is a composite type like `Union` and `TypeVar` with
72
+ bounds, it would be expanded into a list of types and check all
73
+ of left-side types are subtypes of either one from right-side types.
74
+ """
75
+ left = TYPE2ABC.get(left, left)
76
+ right = TYPE2ABC.get(right, right)
77
+
78
+ if right is Any or left == right:
79
+ return True
80
+
81
+ if isinstance(right, _GenericAlias):
82
+ if getattr(right, "__origin__", None) is Generic:
83
+ return True
84
+
85
+ if right == type(None):
86
+ return False
87
+
88
+ # Right-side type
89
+ constraints = _decompose_type(right)
90
+
91
+ if len(constraints) == 0 or Any in constraints:
92
+ return True
93
+
94
+ if left is Any:
95
+ return False
96
+
97
+ # Left-side type
98
+ variants = _decompose_type(left)
99
+
100
+ # all() will return True for empty variants
101
+ if len(variants) == 0:
102
+ return False
103
+
104
+ return all(
105
+ _issubtype_with_constraints(variant, constraints, recursive)
106
+ for variant in variants
107
+ )
108
+
109
+
110
+ def _decompose_type(t, to_list=True):
111
+ if isinstance(t, TypeVar):
112
+ if t.__bound__ is not None:
113
+ ts = [t.__bound__]
114
+ else:
115
+ # For T_co, __constraints__ is ()
116
+ ts = list(t.__constraints__)
117
+ elif hasattr(t, "__origin__") and t.__origin__ == Union:
118
+ ts = t.__args__
119
+ else:
120
+ if not to_list:
121
+ return None
122
+ ts = [t]
123
+ # Ignored: Generator has incompatible item type "object"; expected "Type[Any]"
124
+ ts = [TYPE2ABC.get(_t, _t) for _t in ts] # type: ignore[misc]
125
+ return ts
126
+
127
+
128
+ def _issubtype_with_constraints(variant, constraints, recursive=True):
129
+ r"""
130
+ Check if the variant is a subtype of either one from constraints.
131
+
132
+ For composite types like `Union` and `TypeVar` with bounds, they
133
+ would be expanded for testing.
134
+ """
135
+ if variant in constraints:
136
+ return True
137
+
138
+ # [Note: Subtype for Union and TypeVar]
139
+ # Python typing is able to flatten Union[Union[...]] or Union[TypeVar].
140
+ # But it couldn't flatten the following scenarios:
141
+ # - Union[int, TypeVar[Union[...]]]
142
+ # - TypeVar[TypeVar[...]]
143
+ # So, variant and each constraint may be a TypeVar or a Union.
144
+ # In these cases, all of inner types from the variant are required to be
145
+ # extraced and verified as a subtype of any constraint. And, all of
146
+ # inner types from any constraint being a TypeVar or a Union are
147
+ # also required to be extracted and verified if the variant belongs to
148
+ # any of them.
149
+
150
+ # Variant
151
+ vs = _decompose_type(variant, to_list=False)
152
+
153
+ # Variant is TypeVar or Union
154
+ if vs is not None:
155
+ return all(_issubtype_with_constraints(v, constraints, recursive) for v in vs)
156
+
157
+ # Variant is not TypeVar or Union
158
+ if hasattr(variant, "__origin__") and variant.__origin__ is not None:
159
+ v_origin = variant.__origin__
160
+ # In Python-3.9 typing library untyped generics do not have args
161
+ v_args = getattr(variant, "__args__", None)
162
+ else:
163
+ v_origin = variant
164
+ v_args = None
165
+
166
+ # Constraints
167
+ for constraint in constraints:
168
+ cs = _decompose_type(constraint, to_list=False)
169
+
170
+ # Constraint is TypeVar or Union
171
+ if cs is not None:
172
+ if _issubtype_with_constraints(variant, cs, recursive):
173
+ return True
174
+ # Constraint is not TypeVar or Union
175
+ else:
176
+ # __origin__ can be None for plain list, tuple, ... in Python 3.6
177
+ if hasattr(constraint, "__origin__") and constraint.__origin__ is not None:
178
+ c_origin = constraint.__origin__
179
+ if v_origin == c_origin:
180
+ if not recursive:
181
+ return True
182
+ # In Python-3.9 typing library untyped generics do not have args
183
+ c_args = getattr(constraint, "__args__", None)
184
+ if c_args is None or len(c_args) == 0:
185
+ return True
186
+ if (
187
+ v_args is not None
188
+ and len(v_args) == len(c_args)
189
+ and all(
190
+ issubtype(v_arg, c_arg)
191
+ for v_arg, c_arg in zip(v_args, c_args)
192
+ )
193
+ ):
194
+ return True
195
+ # Tuple[int] -> Tuple
196
+ else:
197
+ if v_origin == constraint:
198
+ return True
199
+
200
+ return False
201
+
202
+
203
+ def issubinstance(data, data_type):
204
+ if not issubtype(type(data), data_type, recursive=False):
205
+ return False
206
+
207
+ # In Python-3.9 typing library __args__ attribute is not defined for untyped generics
208
+ dt_args = getattr(data_type, "__args__", None)
209
+ if isinstance(data, tuple):
210
+ if dt_args is None or len(dt_args) == 0:
211
+ return True
212
+ if len(dt_args) != len(data):
213
+ return False
214
+ return all(issubinstance(d, t) for d, t in zip(data, dt_args))
215
+ elif isinstance(data, (list, set)):
216
+ if dt_args is None or len(dt_args) == 0:
217
+ return True
218
+ t = dt_args[0]
219
+ return all(issubinstance(d, t) for d in data)
220
+ elif isinstance(data, dict):
221
+ if dt_args is None or len(dt_args) == 0:
222
+ return True
223
+ kt, vt = dt_args
224
+ return all(
225
+ issubinstance(k, kt) and issubinstance(v, vt) for k, v in data.items()
226
+ )
227
+
228
+ return True
229
+
230
+
231
+ # [Note: TypeMeta and TypeAlias]
232
+ # In order to keep compatibility for Python 3.6, use Meta for the typing.
233
+ # TODO: When PyTorch drops the support for Python 3.6, it can be converted
234
+ # into the Alias system and using `__class_getitem__` for DataPipe. The
235
+ # typing system will gain benefit of performance and resolving metaclass
236
+ # conflicts as elaborated in https://www.python.org/dev/peps/pep-0560/
237
+
238
+
239
+ class _DataPipeType:
240
+ r"""Save type annotation in `param`."""
241
+
242
+ def __init__(self, param):
243
+ self.param = param
244
+
245
+ def __repr__(self):
246
+ return _type_repr(self.param)
247
+
248
+ def __eq__(self, other):
249
+ if isinstance(other, _DataPipeType):
250
+ return self.param == other.param
251
+ return NotImplemented
252
+
253
+ def __hash__(self):
254
+ return hash(self.param)
255
+
256
+ def issubtype(self, other):
257
+ if isinstance(other.param, _GenericAlias):
258
+ if getattr(other.param, "__origin__", None) is Generic:
259
+ return True
260
+ if isinstance(other, _DataPipeType):
261
+ return issubtype(self.param, other.param)
262
+ if isinstance(other, type):
263
+ return issubtype(self.param, other)
264
+ raise TypeError(f"Expected '_DataPipeType' or 'type', but found {type(other)}")
265
+
266
+ def issubtype_of_instance(self, other):
267
+ return issubinstance(other, self.param)
268
+
269
+
270
+ # Default type for DataPipe without annotation
271
+ _T_co = TypeVar("_T_co", covariant=True)
272
+ _DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
273
+
274
+
275
+ class _DataPipeMeta(GenericMeta):
276
+ r"""
277
+ Metaclass for `DataPipe`.
278
+
279
+ Add `type` attribute and `__init_subclass__` based on the type, and validate the return hint of `__iter__`.
280
+
281
+ Note that there is subclass `_IterDataPipeMeta` specifically for `IterDataPipe`.
282
+ """
283
+
284
+ type: _DataPipeType
285
+
286
+ def __new__(cls, name, bases, namespace, **kwargs):
287
+ return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
288
+
289
+ # TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
290
+ cls.__origin__ = None
291
+ if "type" in namespace:
292
+ return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
293
+
294
+ namespace["__type_class__"] = False
295
+ # For plain derived class without annotation
296
+ for base in bases:
297
+ if isinstance(base, _DataPipeMeta):
298
+ return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
299
+
300
+ namespace.update(
301
+ {"type": _DEFAULT_TYPE, "__init_subclass__": _dp_init_subclass}
302
+ )
303
+ return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
304
+
305
+ def __init__(self, name, bases, namespace, **kwargs):
306
+ super().__init__(name, bases, namespace, **kwargs) # type: ignore[call-overload]
307
+
308
+ # TODO: Fix isinstance bug
309
+ @_tp_cache
310
+ def _getitem_(self, params):
311
+ if params is None:
312
+ raise TypeError(f"{self.__name__}[t]: t can not be None")
313
+ if isinstance(params, str):
314
+ params = ForwardRef(params)
315
+ if not isinstance(params, tuple):
316
+ params = (params,)
317
+
318
+ msg = f"{self.__name__}[t]: t must be a type"
319
+ params = tuple(_type_check(p, msg) for p in params)
320
+
321
+ if isinstance(self.type.param, _GenericAlias):
322
+ orig = getattr(self.type.param, "__origin__", None)
323
+ if isinstance(orig, type) and orig is not Generic:
324
+ p = self.type.param[params] # type: ignore[index]
325
+ t = _DataPipeType(p)
326
+ l = len(str(self.type)) + 2
327
+ name = self.__name__[:-l]
328
+ name = name + "[" + str(t) + "]"
329
+ bases = (self,) + self.__bases__
330
+ return self.__class__(
331
+ name,
332
+ bases,
333
+ {
334
+ "__init_subclass__": _dp_init_subclass,
335
+ "type": t,
336
+ "__type_class__": True,
337
+ },
338
+ )
339
+
340
+ if len(params) > 1:
341
+ raise TypeError(
342
+ f"Too many parameters for {self} actual {len(params)}, expected 1"
343
+ )
344
+
345
+ t = _DataPipeType(params[0])
346
+
347
+ if not t.issubtype(self.type):
348
+ raise TypeError(
349
+ f"Can not subclass a DataPipe[{t}] from DataPipe[{self.type}]"
350
+ )
351
+
352
+ # Types are equal, fast path for inheritance
353
+ if self.type == t:
354
+ return self
355
+
356
+ name = self.__name__ + "[" + str(t) + "]"
357
+ bases = (self,) + self.__bases__
358
+
359
+ return self.__class__(
360
+ name,
361
+ bases,
362
+ {"__init_subclass__": _dp_init_subclass, "__type_class__": True, "type": t},
363
+ )
364
+
365
+ # TODO: Fix isinstance bug
366
+ def _eq_(self, other):
367
+ if not isinstance(other, _DataPipeMeta):
368
+ return NotImplemented
369
+ if self.__origin__ is None or other.__origin__ is None: # type: ignore[has-type]
370
+ return self is other
371
+ return (
372
+ self.__origin__ == other.__origin__ # type: ignore[has-type]
373
+ and self.type == other.type
374
+ )
375
+
376
+ # TODO: Fix isinstance bug
377
+ def _hash_(self):
378
+ return hash((self.__name__, self.type))
379
+
380
+
381
+ class _IterDataPipeMeta(_DataPipeMeta):
382
+ r"""
383
+ Metaclass for `IterDataPipe` and inherits from `_DataPipeMeta`.
384
+
385
+ Add various functions for behaviors specific to `IterDataPipe`.
386
+ """
387
+
388
+ def __new__(cls, name, bases, namespace, **kwargs):
389
+ if "reset" in namespace:
390
+ reset_func = namespace["reset"]
391
+
392
+ @functools.wraps(reset_func)
393
+ def conditional_reset(*args, **kwargs):
394
+ r"""
395
+ Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating` or `NotStarted`.
396
+
397
+ This allows recently restored DataPipe to preserve its restored state during the initial `__iter__` call.
398
+ """
399
+ datapipe = args[0]
400
+ if datapipe._snapshot_state in (
401
+ _SnapshotState.Iterating,
402
+ _SnapshotState.NotStarted,
403
+ ):
404
+ # Reset `NotStarted` is necessary because the `source_datapipe` of a DataPipe might have
405
+ # already begun iterating.
406
+ datapipe._number_of_samples_yielded = 0
407
+ datapipe._fast_forward_iterator = None
408
+ reset_func(*args, **kwargs)
409
+ datapipe._snapshot_state = _SnapshotState.Iterating
410
+
411
+ namespace["reset"] = conditional_reset
412
+
413
+ if "__iter__" in namespace:
414
+ hook_iterator(namespace)
415
+ return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
416
+
417
+
418
+ def _dp_init_subclass(sub_cls, *args, **kwargs):
419
+ # Add function for datapipe instance to reinforce the type
420
+ sub_cls.reinforce_type = reinforce_type
421
+
422
+ # TODO:
423
+ # - add global switch for type checking at compile-time
424
+
425
+ # Ignore internal type class
426
+ if getattr(sub_cls, "__type_class__", False):
427
+ return
428
+
429
+ # Check if the string type is valid
430
+ if isinstance(sub_cls.type.param, ForwardRef):
431
+ base_globals = sys.modules[sub_cls.__module__].__dict__
432
+ try:
433
+ param = _eval_type(sub_cls.type.param, base_globals, locals())
434
+ sub_cls.type.param = param
435
+ except TypeError as e:
436
+ raise TypeError(
437
+ f"{sub_cls.type.param.__forward_arg__} is not supported by Python typing"
438
+ ) from e
439
+
440
+ if "__iter__" in sub_cls.__dict__:
441
+ iter_fn = sub_cls.__dict__["__iter__"]
442
+ hints = get_type_hints(iter_fn)
443
+ if "return" in hints:
444
+ return_hint = hints["return"]
445
+ # Plain Return Hint for Python 3.6
446
+ if return_hint == Iterator:
447
+ return
448
+ if not (
449
+ hasattr(return_hint, "__origin__")
450
+ and (
451
+ return_hint.__origin__ == Iterator
452
+ or return_hint.__origin__ == collections.abc.Iterator
453
+ )
454
+ ):
455
+ raise TypeError(
456
+ "Expected 'Iterator' as the return annotation for `__iter__` of {}"
457
+ ", but found {}".format(
458
+ sub_cls.__name__, _type_repr(hints["return"])
459
+ )
460
+ )
461
+ data_type = return_hint.__args__[0]
462
+ if not issubtype(data_type, sub_cls.type.param):
463
+ raise TypeError(
464
+ f"Expected return type of '__iter__' as a subtype of {sub_cls.type},"
465
+ f" but found {_type_repr(data_type)} for {sub_cls.__name__}"
466
+ )
467
+
468
+
469
+ def reinforce_type(self, expected_type):
470
+ r"""
471
+ Reinforce the type for DataPipe instance.
472
+
473
+ And the 'expected_type' is required to be a subtype of the original type
474
+ hint to restrict the type requirement of DataPipe instance.
475
+ """
476
+ if isinstance(expected_type, tuple):
477
+ expected_type = Tuple[expected_type]
478
+ _type_check(expected_type, msg="'expected_type' must be a type")
479
+
480
+ if not issubtype(expected_type, self.type.param):
481
+ raise TypeError(
482
+ f"Expected 'expected_type' as subtype of {self.type}, but found {_type_repr(expected_type)}"
483
+ )
484
+
485
+ self.type = _DataPipeType(expected_type)
486
+ return self
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data.datapipes.dataframe.dataframes import (
2
+ CaptureDataFrame,
3
+ DFIterDataPipe,
4
+ )
5
+ from torch.utils.data.datapipes.dataframe.datapipes import DataFramesAsTuplesPipe
6
+
7
+
8
+ __all__ = ["CaptureDataFrame", "DFIterDataPipe", "DataFramesAsTuplesPipe"]
9
+
10
+ # Please keep this list sorted
11
+ assert __all__ == sorted(__all__)
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (476 Bytes). View file
 
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-39.pyc ADDED
Binary file (3.77 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-39.pyc ADDED
Binary file (15.9 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-39.pyc ADDED
Binary file (4.62 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-39.pyc ADDED
Binary file (1.05 kB). View file
 
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Any, Optional
3
+
4
+
5
+ _pandas: Any = None
6
+ _WITH_PANDAS: Optional[bool] = None
7
+
8
+
9
+ def _try_import_pandas() -> bool:
10
+ try:
11
+ import pandas # type: ignore[import]
12
+
13
+ global _pandas
14
+ _pandas = pandas
15
+ return True
16
+ except ImportError:
17
+ return False
18
+
19
+
20
+ # pandas used only for prototyping, will be shortly replaced with TorchArrow
21
+ def _with_pandas() -> bool:
22
+ global _WITH_PANDAS
23
+ if _WITH_PANDAS is None:
24
+ _WITH_PANDAS = _try_import_pandas()
25
+ return _WITH_PANDAS
26
+
27
+
28
+ class PandasWrapper:
29
+ @classmethod
30
+ def create_dataframe(cls, data, columns):
31
+ if not _with_pandas():
32
+ raise RuntimeError("DataFrames prototype requires pandas to function")
33
+ return _pandas.DataFrame(data, columns=columns) # type: ignore[union-attr]
34
+
35
+ @classmethod
36
+ def is_dataframe(cls, data):
37
+ if not _with_pandas():
38
+ return False
39
+ return isinstance(data, _pandas.core.frame.DataFrame) # type: ignore[union-attr]
40
+
41
+ @classmethod
42
+ def is_column(cls, data):
43
+ if not _with_pandas():
44
+ return False
45
+ return isinstance(data, _pandas.core.series.Series) # type: ignore[union-attr]
46
+
47
+ @classmethod
48
+ def iterate(cls, data):
49
+ if not _with_pandas():
50
+ raise RuntimeError("DataFrames prototype requires pandas to function")
51
+ yield from data.itertuples(index=False)
52
+
53
+ @classmethod
54
+ def concat(cls, buffer):
55
+ if not _with_pandas():
56
+ raise RuntimeError("DataFrames prototype requires pandas to function")
57
+ return _pandas.concat(buffer) # type: ignore[union-attr]
58
+
59
+ @classmethod
60
+ def get_item(cls, data, idx):
61
+ if not _with_pandas():
62
+ raise RuntimeError("DataFrames prototype requires pandas to function")
63
+ return data[idx : idx + 1]
64
+
65
+ @classmethod
66
+ def get_len(cls, df):
67
+ if not _with_pandas():
68
+ raise RuntimeError("DataFrames prototype requires pandas to function")
69
+ return len(df.index)
70
+
71
+ @classmethod
72
+ def get_columns(cls, df):
73
+ if not _with_pandas():
74
+ raise RuntimeError("DataFrames prototype requires pandas to function")
75
+ return list(df.columns.values.tolist())
76
+
77
+
78
+ # When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class)
79
+ default_wrapper = PandasWrapper
80
+
81
+
82
+ def get_df_wrapper():
83
+ return default_wrapper
84
+
85
+
86
+ def set_df_wrapper(wrapper):
87
+ global default_wrapper
88
+ default_wrapper = wrapper
89
+
90
+
91
+ def create_dataframe(data, columns=None):
92
+ wrapper = get_df_wrapper()
93
+ return wrapper.create_dataframe(data, columns)
94
+
95
+
96
+ def is_dataframe(data):
97
+ wrapper = get_df_wrapper()
98
+ return wrapper.is_dataframe(data)
99
+
100
+
101
+ def get_columns(data):
102
+ wrapper = get_df_wrapper()
103
+ return wrapper.get_columns(data)
104
+
105
+
106
+ def is_column(data):
107
+ wrapper = get_df_wrapper()
108
+ return wrapper.is_column(data)
109
+
110
+
111
+ def concat(buffer):
112
+ wrapper = get_df_wrapper()
113
+ return wrapper.concat(buffer)
114
+
115
+
116
+ def iterate(data):
117
+ wrapper = get_df_wrapper()
118
+ return wrapper.iterate(data)
119
+
120
+
121
+ def get_item(data, idx):
122
+ wrapper = get_df_wrapper()
123
+ return wrapper.get_item(data, idx)
124
+
125
+
126
+ def get_len(df):
127
+ wrapper = get_df_wrapper()
128
+ return wrapper.get_len(df)
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from torch.utils.data.datapipes._decorator import functional_datapipe
5
+ from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
6
+ from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
7
+
8
+
9
+ # TODO(VitalyFedyunin): Add error when two different traces get combined
10
+
11
+ __all__ = [
12
+ "Capture",
13
+ "CaptureA",
14
+ "CaptureAdd",
15
+ "CaptureCall",
16
+ "CaptureControl",
17
+ "CaptureDataFrame",
18
+ "CaptureDataFrameWithDataPipeOps",
19
+ "CaptureF",
20
+ "CaptureGetAttr",
21
+ "CaptureGetItem",
22
+ "CaptureInitial",
23
+ "CaptureLikeMock",
24
+ "CaptureMul",
25
+ "CaptureSetItem",
26
+ "CaptureSub",
27
+ "CaptureVariable",
28
+ "CaptureVariableAssign",
29
+ "DataFrameTracer",
30
+ "DataFrameTracedOps",
31
+ "disable_capture",
32
+ "get_val",
33
+ ]
34
+
35
+
36
+ def disable_capture():
37
+ CaptureControl.disabled = True
38
+
39
+
40
+ class CaptureControl:
41
+ disabled = False
42
+
43
+
44
+ class DataFrameTracedOps(DFIterDataPipe):
45
+ def __init__(self, source_datapipe, output_var):
46
+ self.source_datapipe = source_datapipe
47
+ self.output_var = output_var
48
+
49
+ def __iter__(self):
50
+ for item in self.source_datapipe:
51
+ yield self.output_var.apply_ops(item)
52
+
53
+
54
+ # TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions
55
+ DATAPIPES_OPS = [
56
+ "_dataframes_as_tuples",
57
+ "groupby",
58
+ "_dataframes_filter",
59
+ "map",
60
+ "to_datapipe",
61
+ "shuffle",
62
+ "concat",
63
+ "batch",
64
+ "_dataframes_per_row",
65
+ "_dataframes_concat",
66
+ "_dataframes_shuffle",
67
+ ]
68
+
69
+ UNIMPLEMENTED_ATTR = ["__deepcopy__", "__setstate__", "is_shardable", "apply_sharding"]
70
+
71
+
72
+ class Capture:
73
+ # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures
74
+
75
+ def __init__(self, schema_df=None):
76
+ self.ctx = {"operations": [], "variables": [], "schema_df": schema_df}
77
+
78
+ def __str__(self):
79
+ return self._ops_str()
80
+
81
+ def _ops_str(self):
82
+ res = ""
83
+ for op in self.ctx["operations"]:
84
+ if len(res) > 0:
85
+ res += "\n"
86
+ res += str(op)
87
+ return res
88
+
89
+ def __getstate__(self):
90
+ # TODO(VitalyFedyunin): Currently can't pickle (why?)
91
+ self.ctx["schema_df"] = None
92
+ for var in self.ctx["variables"]:
93
+ var.calculated_value = None
94
+ state = {}
95
+ for item in self.__dict__:
96
+ state[item] = getattr(self, item)
97
+ return state
98
+
99
+ def __setstate__(self, state):
100
+ for k, v in state.items():
101
+ setattr(self, k, v)
102
+
103
+ def __getattr__(self, attrname):
104
+ if attrname == "kwarg" or attrname == "kwargs":
105
+ raise RuntimeError("no kwargs!")
106
+ if attrname in ["__deepcopy__"]:
107
+ raise AttributeError
108
+ result = CaptureGetAttr(self, attrname, ctx=self.ctx)
109
+ return result
110
+
111
+ def __getitem__(self, key):
112
+ return CaptureGetItem(self, key, ctx=self.ctx)
113
+
114
+ def __setitem__(self, key, value):
115
+ self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
116
+
117
+ def __add__(self, add_val):
118
+ res = CaptureAdd(self, add_val, ctx=self.ctx)
119
+ var = CaptureVariable(res, ctx=self.ctx)
120
+ self.ctx["operations"].append(
121
+ CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
122
+ )
123
+ return var
124
+
125
+ def __sub__(self, add_val):
126
+ res = CaptureSub(self, add_val, ctx=self.ctx)
127
+ var = CaptureVariable(res, ctx=self.ctx)
128
+ self.ctx["operations"].append(
129
+ CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
130
+ )
131
+ return var
132
+
133
+ def __mul__(self, add_val):
134
+ res = CaptureMul(self, add_val, ctx=self.ctx)
135
+ var = CaptureVariable(res, ctx=self.ctx)
136
+ t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
137
+ self.ctx["operations"].append(t)
138
+ return var
139
+
140
+ def _is_context_empty(self):
141
+ return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
142
+
143
+ def apply_ops_2(self, dataframe):
144
+ # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
145
+ self.ctx["variables"][0].calculated_value = dataframe
146
+ for op in self.ctx["operations"]:
147
+ op.execute()
148
+
149
+ @property
150
+ def columns(self):
151
+ self.apply_ops_2(self.ctx["schema_df"])
152
+ value = self.execute()
153
+ return value.columns
154
+
155
+ # TODO(VitalyFedyunin): Add tests
156
+ # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture
157
+
158
+ def __call__(self, *args, **kwargs):
159
+ # TODO: Check if args or kwargs have more than one different context
160
+ if self._is_context_empty():
161
+ # TODO: Allow CaptureA to take context from mock
162
+ for arg in args:
163
+ if isinstance(arg, Capture) and not arg._is_context_empty():
164
+ self.ctx = arg.ctx
165
+ break
166
+ if self._is_context_empty():
167
+ for k, v in kwargs.items():
168
+ if isinstance(k, Capture) and not k._is_context_empty():
169
+ self.ctx = k.ctx
170
+ break
171
+ if isinstance(v, Capture) and not v._is_context_empty():
172
+ self.ctx = v.ctx
173
+ break
174
+
175
+ res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
176
+ var = CaptureVariable(None, ctx=self.ctx)
177
+ t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
178
+ self.ctx["operations"].append(t)
179
+ return var
180
+
181
+
182
+ class CaptureF(Capture):
183
+ def __init__(self, ctx=None, **kwargs):
184
+ if ctx is None:
185
+ self.ctx = {"operations": [], "variables": []}
186
+ else:
187
+ self.ctx = ctx
188
+ self.kwargs = kwargs
189
+
190
+
191
+ class CaptureA(CaptureF):
192
+ def __str__(self):
193
+ return f"{self.kwargs['name']}"
194
+
195
+ def execute(self):
196
+ value = self.kwargs["real_attribute"]
197
+ return value
198
+
199
+
200
+ class CaptureLikeMock:
201
+ def __init__(self, name):
202
+ import unittest.mock as mock
203
+
204
+ # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead.
205
+ get_target, attribute = mock._get_target(name) # type: ignore[attr-defined]
206
+ self.get_target = get_target
207
+ self.attribute = attribute
208
+ self.name = name
209
+
210
+ def __enter__(self):
211
+ self.save = getattr(self.get_target(), self.attribute)
212
+ capt = CaptureA(name=self.name, real_attribute=self.save)
213
+ setattr(self.get_target(), self.attribute, capt)
214
+
215
+ def __exit__(self, *exc_info):
216
+ setattr(self.get_target(), self.attribute, self.save)
217
+
218
+
219
+ class CaptureCall(Capture):
220
+ def __init__(self, callable, ctx=None, **kwargs):
221
+ if ctx is None:
222
+ self.ctx = {"operations": [], "variables": []}
223
+ else:
224
+ self.ctx = ctx
225
+ self.kwargs = kwargs
226
+ self.callable = callable
227
+
228
+ def __str__(self):
229
+ return "{callable}({args},{kwargs})".format(
230
+ callable=self.callable, **self.kwargs
231
+ )
232
+
233
+ def execute(self):
234
+ # TODO: VitalyFedyunin execute kwargs and maybe nested structures
235
+ executed_args = []
236
+ for arg in self.kwargs["args"]:
237
+ if isinstance(arg, Capture):
238
+ executed_args.append(arg.execute())
239
+ else:
240
+ executed_args.append(arg)
241
+ left = get_val(self.callable)
242
+ return left(*executed_args, **self.kwargs["kwargs"])
243
+
244
+
245
+ class CaptureVariableAssign(CaptureF):
246
+ def __str__(self):
247
+ variable = self.kwargs["variable"]
248
+ value = self.kwargs["value"]
249
+ return f"{variable} = {value}"
250
+
251
+ def execute(self):
252
+ self.kwargs["variable"].calculated_value = self.kwargs["value"].execute()
253
+
254
+
255
+ class CaptureVariable(Capture):
256
+ # TODO(VitalyFedyunin): This should be atomic and thread safe
257
+ names_idx = 0
258
+
259
+ def __init__(self, value, ctx):
260
+ if CaptureControl.disabled:
261
+ raise RuntimeError("Attempting to create capture variable with capture off")
262
+ self.ctx = ctx
263
+ self.value = value
264
+ self.name = f"var_{CaptureVariable.names_idx}"
265
+ CaptureVariable.names_idx += 1
266
+ self.ctx["variables"].append(self)
267
+
268
+ def __str__(self):
269
+ return self.name
270
+
271
+ def execute(self):
272
+ return self.calculated_value
273
+
274
+ def apply_ops(self, dataframe):
275
+ # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
276
+ self.ctx["variables"][0].calculated_value = dataframe
277
+ for op in self.ctx["operations"]:
278
+ op.execute()
279
+ return self.calculated_value
280
+
281
+
282
+ class CaptureGetItem(Capture):
283
+ def __init__(self, left, key, ctx):
284
+ self.ctx = ctx
285
+ self.left = left
286
+ self.key = key
287
+
288
+ def __str__(self):
289
+ return f"{self.left}[{get_val(self.key)}]"
290
+
291
+ def execute(self):
292
+ left = self.left.execute()
293
+ return left[self.key]
294
+
295
+
296
+ class CaptureSetItem(Capture):
297
+ def __init__(self, left, key, value, ctx):
298
+ self.ctx = ctx
299
+ self.left = left
300
+ self.key = key
301
+ self.value = value
302
+
303
+ def __str__(self):
304
+ return f"{self.left}[{get_val(self.key)}] = {self.value}"
305
+
306
+ def execute(self):
307
+ left = self.left.execute()
308
+ value = self.value.execute()
309
+ left[self.key] = value
310
+
311
+
312
+ class CaptureAdd(Capture):
313
+ def __init__(self, left, right, ctx):
314
+ self.ctx = ctx
315
+ self.left = left
316
+ self.right = right
317
+
318
+ def __str__(self):
319
+ return f"{self.left} + {self.right}"
320
+
321
+ def execute(self):
322
+ return get_val(self.left) + get_val(self.right)
323
+
324
+
325
+ class CaptureMul(Capture):
326
+ def __init__(self, left, right, ctx):
327
+ self.ctx = ctx
328
+ self.left = left
329
+ self.right = right
330
+
331
+ def __str__(self):
332
+ return f"{self.left} * {self.right}"
333
+
334
+ def execute(self):
335
+ return get_val(self.left) * get_val(self.right)
336
+
337
+
338
+ class CaptureSub(Capture):
339
+ def __init__(self, left, right, ctx):
340
+ self.ctx = ctx
341
+ self.left = left
342
+ self.right = right
343
+
344
+ def __str__(self):
345
+ return f"{self.left} - {self.right}"
346
+
347
+ def execute(self):
348
+ return get_val(self.left) - get_val(self.right)
349
+
350
+
351
+ class CaptureGetAttr(Capture):
352
+ def __init__(self, src, name, ctx):
353
+ self.ctx = ctx
354
+ self.src = src
355
+ self.name = name
356
+
357
+ def __str__(self):
358
+ return f"{self.src}.{self.name}"
359
+
360
+ def execute(self):
361
+ val = get_val(self.src)
362
+ return getattr(val, self.name)
363
+
364
+
365
+ def get_val(capture):
366
+ if isinstance(capture, Capture):
367
+ return capture.execute()
368
+ elif isinstance(capture, str):
369
+ return f'"{capture}"'
370
+ else:
371
+ return capture
372
+
373
+
374
+ class CaptureInitial(CaptureVariable):
375
+ def __init__(self, schema_df=None):
376
+ new_ctx: Dict[str, List[Any]] = {
377
+ "operations": [],
378
+ "variables": [],
379
+ "schema_df": schema_df,
380
+ }
381
+ super().__init__(None, new_ctx)
382
+ self.name = f"input_{self.name}"
383
+
384
+
385
+ class CaptureDataFrame(CaptureInitial):
386
+ pass
387
+
388
+
389
+ class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
390
+ def as_datapipe(self):
391
+ return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)
392
+
393
+ def raw_iterator(self):
394
+ return self.as_datapipe().__iter__()
395
+
396
+ def __iter__(self):
397
+ return iter(self._dataframes_as_tuples())
398
+
399
+ def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF):
400
+ dp = self._dataframes_per_row()._dataframes_concat(batch_size)
401
+ dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class)
402
+ dp._dp_contains_dataframe = True
403
+ return dp
404
+
405
+ def groupby(
406
+ self,
407
+ group_key_fn,
408
+ *,
409
+ buffer_size=10000,
410
+ group_size=None,
411
+ guaranteed_group_size=None,
412
+ drop_remaining=False,
413
+ ):
414
+ dp = self._dataframes_per_row()
415
+ dp = dp.as_datapipe().groupby(
416
+ group_key_fn,
417
+ buffer_size=buffer_size,
418
+ group_size=group_size,
419
+ guaranteed_group_size=guaranteed_group_size,
420
+ drop_remaining=drop_remaining,
421
+ )
422
+ return dp
423
+
424
+ def shuffle(self, *args, **kwargs):
425
+ return self._dataframes_shuffle(*args, **kwargs)
426
+
427
+ def filter(self, *args, **kwargs):
428
+ return self._dataframes_filter(*args, **kwargs)
429
+
430
+ def collate(self, *args, **kwargs):
431
+ raise RuntimeError("Can't collate unbatched DataFrames stream")
432
+
433
+ def __getattr__(self, attrname): # ?
434
+ if attrname in UNIMPLEMENTED_ATTR:
435
+ raise AttributeError("Attempting to get ", attrname)
436
+ if attrname in DATAPIPES_OPS:
437
+ return (self.as_datapipe()).__getattr__(attrname)
438
+ return super().__getattr__(attrname)
439
+
440
+
441
+ @functional_datapipe("trace_as_dataframe")
442
+ class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc]
443
+ source_datapipe: Optional[Any] = None
444
+
445
+ # TODO(VitalyFedyunin): Must implement all special functions of datapipes
446
+
447
+ def set_shuffle_settings(self, *args, **kwargs):
448
+ pass
449
+
450
+ def is_shardable(self):
451
+ return False
452
+
453
+ def __init__(self, source_datapipe, schema_df=None):
454
+ self.source_datapipe = source_datapipe
455
+ if schema_df is None:
456
+ schema_df = next(iter(self.source_datapipe))
457
+ super().__init__(schema_df=schema_df)
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import random
3
+
4
+ from torch.utils.data.datapipes._decorator import functional_datapipe
5
+ from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
6
+ from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
7
+
8
+
9
+ __all__ = [
10
+ "ConcatDataFramesPipe",
11
+ "DataFramesAsTuplesPipe",
12
+ "ExampleAggregateAsDataFrames",
13
+ "FilterDataFramesPipe",
14
+ "PerRowDataFramesPipe",
15
+ "ShuffleDataFramesPipe",
16
+ ]
17
+
18
+
19
+ @functional_datapipe("_dataframes_as_tuples")
20
+ class DataFramesAsTuplesPipe(IterDataPipe):
21
+ def __init__(self, source_datapipe):
22
+ self.source_datapipe = source_datapipe
23
+
24
+ def __iter__(self):
25
+ for df in self.source_datapipe:
26
+ # for record in df.to_records(index=False):
27
+ yield from df_wrapper.iterate(df)
28
+
29
+
30
+ @functional_datapipe("_dataframes_per_row", enable_df_api_tracing=True)
31
+ class PerRowDataFramesPipe(DFIterDataPipe):
32
+ def __init__(self, source_datapipe):
33
+ self.source_datapipe = source_datapipe
34
+
35
+ def __iter__(self):
36
+ for df in self.source_datapipe:
37
+ # TODO(VitalyFedyunin): Replacing with TorchArrow only API, as we are dropping pandas as followup
38
+ for i in range(len(df)):
39
+ yield df[i : i + 1]
40
+
41
+
42
+ @functional_datapipe("_dataframes_concat", enable_df_api_tracing=True)
43
+ class ConcatDataFramesPipe(DFIterDataPipe):
44
+ def __init__(self, source_datapipe, batch=3):
45
+ self.source_datapipe = source_datapipe
46
+ self.n_batch = batch
47
+
48
+ def __iter__(self):
49
+ buffer = []
50
+ for df in self.source_datapipe:
51
+ buffer.append(df)
52
+ if len(buffer) == self.n_batch:
53
+ yield df_wrapper.concat(buffer)
54
+ buffer = []
55
+ if len(buffer):
56
+ yield df_wrapper.concat(buffer)
57
+
58
+
59
+ @functional_datapipe("_dataframes_shuffle", enable_df_api_tracing=True)
60
+ class ShuffleDataFramesPipe(DFIterDataPipe):
61
+ def __init__(self, source_datapipe):
62
+ self.source_datapipe = source_datapipe
63
+
64
+ def __iter__(self):
65
+ size = None
66
+ all_buffer = []
67
+ for df in self.source_datapipe:
68
+ if size is None:
69
+ size = df_wrapper.get_len(df)
70
+ for i in range(df_wrapper.get_len(df)):
71
+ all_buffer.append(df_wrapper.get_item(df, i))
72
+ random.shuffle(all_buffer)
73
+ buffer = []
74
+ for df in all_buffer:
75
+ buffer.append(df)
76
+ if len(buffer) == size:
77
+ yield df_wrapper.concat(buffer)
78
+ buffer = []
79
+ if len(buffer):
80
+ yield df_wrapper.concat(buffer)
81
+
82
+
83
+ @functional_datapipe("_dataframes_filter", enable_df_api_tracing=True)
84
+ class FilterDataFramesPipe(DFIterDataPipe):
85
+ def __init__(self, source_datapipe, filter_fn):
86
+ self.source_datapipe = source_datapipe
87
+ self.filter_fn = filter_fn
88
+
89
+ def __iter__(self):
90
+ size = None
91
+ all_buffer = []
92
+ filter_res = []
93
+ for df in self.source_datapipe:
94
+ if size is None:
95
+ size = len(df.index)
96
+ for i in range(len(df.index)):
97
+ all_buffer.append(df[i : i + 1])
98
+ filter_res.append(self.filter_fn(df.iloc[i]))
99
+
100
+ buffer = []
101
+ for df, res in zip(all_buffer, filter_res):
102
+ if res:
103
+ buffer.append(df)
104
+ if len(buffer) == size:
105
+ yield df_wrapper.concat(buffer)
106
+ buffer = []
107
+ if len(buffer):
108
+ yield df_wrapper.concat(buffer)
109
+
110
+
111
+ @functional_datapipe("_to_dataframes_pipe", enable_df_api_tracing=True)
112
+ class ExampleAggregateAsDataFrames(DFIterDataPipe):
113
+ def __init__(self, source_datapipe, dataframe_size=10, columns=None):
114
+ self.source_datapipe = source_datapipe
115
+ self.columns = columns
116
+ self.dataframe_size = dataframe_size
117
+
118
+ def _as_list(self, item):
119
+ try:
120
+ return list(item)
121
+ except (
122
+ Exception
123
+ ): # TODO(VitalyFedyunin): Replace with better iterable exception
124
+ return [item]
125
+
126
+ def __iter__(self):
127
+ aggregate = []
128
+ for item in self.source_datapipe:
129
+ aggregate.append(self._as_list(item))
130
+ if len(aggregate) == self.dataframe_size:
131
+ yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
132
+ aggregate = []
133
+ if len(aggregate) > 0:
134
+ yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/structures.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
3
+ from torch.utils.data.datapipes.datapipe import DataChunk
4
+
5
+
6
+ __all__ = ["DataChunkDF"]
7
+
8
+
9
+ class DataChunkDF(DataChunk):
10
+ """DataChunkDF iterating over individual items inside of DataFrame containers, to access DataFrames user `raw_iterator`."""
11
+
12
+ def __iter__(self):
13
+ for df in self.items:
14
+ yield from df_wrapper.iterate(df)
15
+
16
+ def __len__(self):
17
+ total_len = 0
18
+ for df in self.items:
19
+ total_len += df_wrapper.get_len(df)
20
+ return total_len
.venv/Lib/site-packages/torch/utils/data/datapipes/datapipe.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import pickle
3
+ from typing import Callable, Dict, Iterable, Iterator, List, Optional, TypeVar
4
+
5
+ from torch.utils._import_utils import import_dill
6
+ from torch.utils.data.datapipes._hook_iterator import _SnapshotState
7
+ from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
8
+ from torch.utils.data.datapipes.utils.common import (
9
+ _deprecation_warning,
10
+ _iter_deprecated_functional_names,
11
+ _map_deprecated_functional_names,
12
+ )
13
+ from torch.utils.data.dataset import Dataset, IterableDataset
14
+
15
+
16
+ dill = import_dill()
17
+ HAS_DILL = dill is not None
18
+
19
+ __all__ = [
20
+ "DataChunk",
21
+ "DFIterDataPipe",
22
+ "IterDataPipe",
23
+ "MapDataPipe",
24
+ ]
25
+
26
+
27
+ _T = TypeVar("_T")
28
+ _T_co = TypeVar("_T_co", covariant=True)
29
+
30
+ UNTRACABLE_DATAFRAME_PIPES = [
31
+ "batch", # As it returns DataChunks
32
+ "groupby", # As it returns DataChunks
33
+ "_dataframes_as_tuples", # As it unpacks DF
34
+ "trace_as_dataframe", # As it used to mark DF for tracing
35
+ ]
36
+
37
+
38
+ class DataChunk(List[_T]):
39
+ def __init__(self, items: Iterable[_T]) -> None:
40
+ items = list(items)
41
+ super().__init__(items)
42
+ self.items = items
43
+
44
+ def as_str(self, indent: str = "") -> str:
45
+ return indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
46
+
47
+ def __iter__(self) -> Iterator[_T]:
48
+ yield from super().__iter__()
49
+
50
+ def raw_iterator(self) -> Iterator[_T]:
51
+ yield from self.items
52
+
53
+
54
+ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
55
+ r"""
56
+ Iterable-style DataPipe.
57
+
58
+ All DataPipes that represent an iterable of data samples should subclass this.
59
+ This style of DataPipes is particularly useful when data come from a stream, or
60
+ when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its
61
+ elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``.
62
+
63
+ All subclasses should overwrite :meth:`__iter__`, which would return an
64
+ iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its
65
+ method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should
66
+ override ``reset()`` if necessary. The common usages include resetting buffers, pointers,
67
+ and various state variables within the custom ``IterDataPipe``.
68
+
69
+ Note:
70
+ Only `one` iterator can be valid for each ``IterDataPipe`` at a time,
71
+ and the creation a second iterator will invalidate the first one. This constraint is necessary because
72
+ some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators.
73
+ The code example below presents details on how this constraint looks in practice.
74
+ If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_.
75
+
76
+ These DataPipes can be invoked in two ways, using the class constructor or applying their
77
+ functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes).
78
+ You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple
79
+ operations in succession.
80
+
81
+ .. _GitHub IterDataPipe Single Iterator Issue:
82
+ https://github.com/pytorch/data/issues/45
83
+
84
+ Note:
85
+ When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
86
+ item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader`
87
+ iterator. When :attr:`num_workers > 0`, each worker process will have a
88
+ different copy of the DataPipe object, so it is often desired to configure
89
+ each copy independently to avoid having duplicate data returned from the
90
+ workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
91
+ process, returns information about the worker. It can be used in either the
92
+ dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
93
+ :attr:`worker_init_fn` option to modify each copy's behavior.
94
+
95
+ Examples:
96
+ General Usage:
97
+ >>> # xdoctest: +SKIP
98
+ >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
99
+ >>> dp = IterableWrapper(range(10))
100
+ >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
101
+ >>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended)
102
+ >>> list(map_dp_1)
103
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
104
+ >>> list(map_dp_2)
105
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
106
+ >>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0)
107
+ >>> list(filter_dp)
108
+ [2, 4, 6, 8, 10]
109
+ Single Iterator Constraint Example:
110
+ >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
111
+ >>> source_dp = IterableWrapper(range(10))
112
+ >>> it1 = iter(source_dp)
113
+ >>> list(it1)
114
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
115
+ >>> it1 = iter(source_dp)
116
+ >>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1`
117
+ >>> next(it2)
118
+ 0
119
+ >>> next(it1) # Further usage of `it1` will raise a `RunTimeError`
120
+ """
121
+
122
+ functions: Dict[str, Callable] = {}
123
+ reduce_ex_hook: Optional[Callable] = None
124
+ getstate_hook: Optional[Callable] = None
125
+ str_hook: Optional[Callable] = None
126
+ repr_hook: Optional[Callable] = None
127
+ _valid_iterator_id: Optional[int] = None
128
+ _number_of_samples_yielded: int = 0
129
+ _snapshot_state: _SnapshotState = _SnapshotState.NotStarted
130
+ _fast_forward_iterator: Optional[Iterator] = None
131
+
132
+ def __iter__(self) -> Iterator[_T_co]:
133
+ return self
134
+
135
+ def __getattr__(self, attribute_name):
136
+ if attribute_name in IterDataPipe.functions:
137
+ if attribute_name in _iter_deprecated_functional_names:
138
+ kwargs = _iter_deprecated_functional_names[attribute_name]
139
+ _deprecation_warning(**kwargs)
140
+ f = IterDataPipe.functions[attribute_name]
141
+ function = functools.partial(f, self)
142
+ functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
143
+ return function
144
+ else:
145
+ raise AttributeError(
146
+ f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
147
+ )
148
+
149
+ @classmethod
150
+ def register_function(cls, function_name, function):
151
+ cls.functions[function_name] = function
152
+
153
+ @classmethod
154
+ def register_datapipe_as_function(
155
+ cls, function_name, cls_to_register, enable_df_api_tracing=False
156
+ ):
157
+ if function_name in cls.functions:
158
+ raise Exception( # noqa: TRY002
159
+ f"Unable to add DataPipe function name {function_name} as it is already taken"
160
+ )
161
+
162
+ def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
163
+ result_pipe = cls(source_dp, *args, **kwargs)
164
+ if isinstance(result_pipe, IterDataPipe):
165
+ if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
166
+ if function_name not in UNTRACABLE_DATAFRAME_PIPES:
167
+ result_pipe = result_pipe.trace_as_dataframe()
168
+
169
+ return result_pipe
170
+
171
+ function = functools.partial(
172
+ class_function, cls_to_register, enable_df_api_tracing
173
+ )
174
+ functools.update_wrapper(
175
+ wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
176
+ )
177
+ cls.functions[function_name] = function
178
+
179
+ def __getstate__(self):
180
+ """
181
+ Serialize `lambda` functions when `dill` is available.
182
+
183
+ If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
184
+ `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
185
+ """
186
+ state = self.__dict__
187
+ if IterDataPipe.getstate_hook is not None:
188
+ return IterDataPipe.getstate_hook(state)
189
+ return state
190
+
191
+ def __reduce_ex__(self, *args, **kwargs):
192
+ if IterDataPipe.reduce_ex_hook is not None:
193
+ try:
194
+ return IterDataPipe.reduce_ex_hook(self)
195
+ except NotImplementedError:
196
+ pass
197
+ return super().__reduce_ex__(*args, **kwargs)
198
+
199
+ @classmethod
200
+ def set_getstate_hook(cls, hook_fn):
201
+ if IterDataPipe.getstate_hook is not None and hook_fn is not None:
202
+ raise RuntimeError("Attempt to override existing getstate_hook")
203
+ IterDataPipe.getstate_hook = hook_fn
204
+
205
+ @classmethod
206
+ def set_reduce_ex_hook(cls, hook_fn):
207
+ if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None:
208
+ raise RuntimeError("Attempt to override existing reduce_ex_hook")
209
+ IterDataPipe.reduce_ex_hook = hook_fn
210
+
211
+ def __repr__(self):
212
+ if self.repr_hook is not None:
213
+ return self.repr_hook(self)
214
+ # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
215
+ return str(self.__class__.__qualname__)
216
+
217
+ def __str__(self):
218
+ if self.str_hook is not None:
219
+ return self.str_hook(self)
220
+ # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
221
+ return str(self.__class__.__qualname__)
222
+
223
+ def __dir__(self):
224
+ # for auto-completion in a REPL (e.g. Jupyter notebook)
225
+ return list(super().__dir__()) + list(self.functions.keys())
226
+
227
+ def reset(self) -> None:
228
+ r"""
229
+ Reset the `IterDataPipe` to the initial state.
230
+
231
+ By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities,
232
+ they may want to override this method with implementations that
233
+ may clear the buffers and reset pointers of the DataPipe.
234
+ The `reset` method is always called when `__iter__` is called as part of `hook_iterator`.
235
+ """
236
+
237
+
238
+ class DFIterDataPipe(IterDataPipe):
239
+ def _is_dfpipe(self):
240
+ return True
241
+
242
+
243
+ class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
244
+ r"""
245
+ Map-style DataPipe.
246
+
247
+ All datasets that represent a map from keys to data samples should subclass this.
248
+ Subclasses should overwrite :meth:`__getitem__`, supporting fetching a
249
+ data sample for a given, unique key. Subclasses can also optionally overwrite
250
+ :meth:`__len__`, which is expected to return the size of the dataset by many
251
+ :class:`~torch.utils.data.Sampler` implementations and the default options
252
+ of :class:`~torch.utils.data.DataLoader`.
253
+
254
+ These DataPipes can be invoked in two ways, using the class constructor or applying their
255
+ functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes).
256
+
257
+ Note:
258
+ :class:`~torch.utils.data.DataLoader` by default constructs an index
259
+ sampler that yields integral indices. To make it work with a map-style
260
+ DataPipe with non-integral indices/keys, a custom sampler must be provided.
261
+
262
+ Example:
263
+ >>> # xdoctest: +SKIP
264
+ >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
265
+ >>> dp = SequenceWrapper(range(10))
266
+ >>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended)
267
+ >>> list(map_dp_1)
268
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
269
+ >>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor
270
+ >>> list(map_dp_2)
271
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
272
+ >>> batch_dp = map_dp_1.batch(batch_size=2)
273
+ >>> list(batch_dp)
274
+ [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
275
+ """
276
+
277
+ functions: Dict[str, Callable] = {}
278
+ reduce_ex_hook: Optional[Callable] = None
279
+ getstate_hook: Optional[Callable] = None
280
+ str_hook: Optional[Callable] = None
281
+ repr_hook: Optional[Callable] = None
282
+
283
+ def __getattr__(self, attribute_name):
284
+ if attribute_name in MapDataPipe.functions:
285
+ if attribute_name in _map_deprecated_functional_names:
286
+ kwargs = _map_deprecated_functional_names[attribute_name]
287
+ _deprecation_warning(**kwargs)
288
+ f = MapDataPipe.functions[attribute_name]
289
+ function = functools.partial(f, self)
290
+ functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
291
+ return function
292
+ else:
293
+ raise AttributeError(
294
+ f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
295
+ )
296
+
297
+ @classmethod
298
+ def register_function(cls, function_name, function):
299
+ cls.functions[function_name] = function
300
+
301
+ @classmethod
302
+ def register_datapipe_as_function(cls, function_name, cls_to_register):
303
+ if function_name in cls.functions:
304
+ raise Exception( # noqa: TRY002
305
+ f"Unable to add DataPipe function name {function_name} as it is already taken"
306
+ )
307
+
308
+ def class_function(cls, source_dp, *args, **kwargs):
309
+ result_pipe = cls(source_dp, *args, **kwargs)
310
+ return result_pipe
311
+
312
+ function = functools.partial(class_function, cls_to_register)
313
+ functools.update_wrapper(
314
+ wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
315
+ )
316
+ cls.functions[function_name] = function
317
+
318
+ def __getstate__(self):
319
+ """
320
+ Serialize `lambda` functions when `dill` is available.
321
+
322
+ If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
323
+ `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
324
+ """
325
+ state = self.__dict__
326
+ if MapDataPipe.getstate_hook is not None:
327
+ return MapDataPipe.getstate_hook(state)
328
+ return state
329
+
330
+ def __reduce_ex__(self, *args, **kwargs):
331
+ if MapDataPipe.reduce_ex_hook is not None:
332
+ try:
333
+ return MapDataPipe.reduce_ex_hook(self)
334
+ except NotImplementedError:
335
+ pass
336
+ return super().__reduce_ex__(*args, **kwargs)
337
+
338
+ @classmethod
339
+ def set_getstate_hook(cls, hook_fn):
340
+ if MapDataPipe.getstate_hook is not None and hook_fn is not None:
341
+ raise RuntimeError("Attempt to override existing getstate_hook")
342
+ MapDataPipe.getstate_hook = hook_fn
343
+
344
+ @classmethod
345
+ def set_reduce_ex_hook(cls, hook_fn):
346
+ if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None:
347
+ raise RuntimeError("Attempt to override existing reduce_ex_hook")
348
+ MapDataPipe.reduce_ex_hook = hook_fn
349
+
350
+ def __repr__(self):
351
+ if self.repr_hook is not None:
352
+ return self.repr_hook(self)
353
+ # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
354
+ return str(self.__class__.__qualname__)
355
+
356
+ def __str__(self):
357
+ if self.str_hook is not None:
358
+ return self.str_hook(self)
359
+ # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
360
+ return str(self.__class__.__qualname__)
361
+
362
+ def __dir__(self):
363
+ # for auto-completion in a REPL (e.g. Jupyter notebook)
364
+ return list(super().__dir__()) + list(self.functions.keys())
365
+
366
+
367
+ class _DataPipeSerializationWrapper:
368
+ def __init__(self, datapipe):
369
+ self._datapipe = datapipe
370
+
371
+ def __getstate__(self):
372
+ use_dill = False
373
+ try:
374
+ value = pickle.dumps(self._datapipe)
375
+ except Exception:
376
+ if HAS_DILL:
377
+ value = dill.dumps(self._datapipe)
378
+ use_dill = True
379
+ else:
380
+ raise
381
+ return (value, use_dill)
382
+
383
+ def __setstate__(self, state):
384
+ value, use_dill = state
385
+ if use_dill:
386
+ self._datapipe = dill.loads(value)
387
+ else:
388
+ self._datapipe = pickle.loads(value)
389
+
390
+ def __len__(self):
391
+ try:
392
+ return len(self._datapipe)
393
+ except Exception as e:
394
+ raise TypeError(
395
+ f"{type(self).__name__} instance doesn't have valid length"
396
+ ) from e
397
+
398
+
399
+ class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
400
+ def __init__(self, datapipe: IterDataPipe[_T_co]):
401
+ super().__init__(datapipe)
402
+ self._datapipe_iter: Optional[Iterator[_T_co]] = None
403
+
404
+ def __iter__(self) -> "_IterDataPipeSerializationWrapper":
405
+ self._datapipe_iter = iter(self._datapipe)
406
+ return self
407
+
408
+ def __next__(self) -> _T_co: # type: ignore[type-var]
409
+ assert self._datapipe_iter is not None
410
+ return next(self._datapipe_iter)
411
+
412
+
413
+ class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
414
+ def __getitem__(self, idx):
415
+ return self._datapipe[idx]
.venv/Lib/site-packages/torch/utils/data/datapipes/datapipe.pyi ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection
3
+ # The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt
4
+ # Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other
5
+ # classes/objects here, even though we are not injecting extra code into them at the moment.
6
+
7
+ from typing import (
8
+ Any,
9
+ Callable,
10
+ Dict,
11
+ Iterable,
12
+ Iterator,
13
+ List,
14
+ Literal,
15
+ Optional,
16
+ Type,
17
+ TypeVar,
18
+ Union,
19
+ )
20
+
21
+ from torch.utils.data import Dataset, default_collate, IterableDataset
22
+ from torch.utils.data.datapipes._hook_iterator import _SnapshotState
23
+ from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
24
+
25
+ _T = TypeVar("_T")
26
+ _T_co = TypeVar("_T_co", covariant=True)
27
+ UNTRACABLE_DATAFRAME_PIPES: Any
28
+
29
+ class DataChunk(List[_T]):
30
+ items: List[_T]
31
+ def __init__(self, items: Iterable[_T]) -> None: ...
32
+ def as_str(self, indent: str = "") -> str: ...
33
+ def __iter__(self) -> Iterator[_T]: ...
34
+ def raw_iterator(self) -> Iterator[_T]: ...
35
+
36
+ class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
37
+ functions: Dict[str, Callable] = ...
38
+ reduce_ex_hook: Optional[Callable] = ...
39
+ getstate_hook: Optional[Callable] = ...
40
+ str_hook: Optional[Callable] = ...
41
+ repr_hook: Optional[Callable] = ...
42
+ def __getattr__(self, attribute_name: Any): ...
43
+ @classmethod
44
+ def register_function(cls, function_name: Any, function: Any) -> None: ...
45
+ @classmethod
46
+ def register_datapipe_as_function(
47
+ cls,
48
+ function_name: Any,
49
+ cls_to_register: Any,
50
+ ): ...
51
+ def __getstate__(self): ...
52
+ def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
53
+ @classmethod
54
+ def set_getstate_hook(cls, hook_fn: Any) -> None: ...
55
+ @classmethod
56
+ def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
57
+ # Functional form of 'BatcherMapDataPipe'
58
+ def batch(self, batch_size: int, drop_last: bool = False, wrapper_class: Type[DataChunk] = DataChunk) -> MapDataPipe:
59
+ r"""
60
+ Create mini-batches of data (functional name: ``batch``).
61
+
62
+ An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``,
63
+ or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``.
64
+
65
+ Args:
66
+ datapipe: Iterable DataPipe being batched
67
+ batch_size: The size of each batch
68
+ drop_last: Option to drop the last batch if it's not full
69
+
70
+ Example:
71
+ >>> # xdoctest: +SKIP
72
+ >>> from torchdata.datapipes.map import SequenceWrapper
73
+ >>> dp = SequenceWrapper(range(10))
74
+ >>> batch_dp = dp.batch(batch_size=2)
75
+ >>> list(batch_dp)
76
+ [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
77
+ """
78
+
79
+ # Functional form of 'ConcaterMapDataPipe'
80
+ def concat(self, *datapipes: MapDataPipe) -> MapDataPipe:
81
+ r"""
82
+ Concatenate multiple Map DataPipes (functional name: ``concat``).
83
+
84
+ The new index of is the cumulative sum of source DataPipes.
85
+ For example, if there are 2 source DataPipes both with length 5,
86
+ index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
87
+ elements of the first DataPipe, and 5 to 9 would refer to elements
88
+ of the second DataPipe.
89
+
90
+ Args:
91
+ datapipes: Map DataPipes being concatenated
92
+
93
+ Example:
94
+ >>> # xdoctest: +SKIP
95
+ >>> from torchdata.datapipes.map import SequenceWrapper
96
+ >>> dp1 = SequenceWrapper(range(3))
97
+ >>> dp2 = SequenceWrapper(range(3))
98
+ >>> concat_dp = dp1.concat(dp2)
99
+ >>> list(concat_dp)
100
+ [0, 1, 2, 0, 1, 2]
101
+ """
102
+
103
+ # Functional form of 'MapperMapDataPipe'
104
+ def map(self, fn: Callable= ...) -> MapDataPipe:
105
+ r"""
106
+ Apply the input function over each item from the source DataPipe (functional name: ``map``).
107
+
108
+ The function can be any regular Python function or partial object. Lambda
109
+ function is not recommended as it is not supported by pickle.
110
+
111
+ Args:
112
+ datapipe: Source MapDataPipe
113
+ fn: Function being applied to each item
114
+
115
+ Example:
116
+ >>> # xdoctest: +SKIP
117
+ >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
118
+ >>> def add_one(x):
119
+ ... return x + 1
120
+ >>> dp = SequenceWrapper(range(10))
121
+ >>> map_dp_1 = dp.map(add_one)
122
+ >>> list(map_dp_1)
123
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
124
+ >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
125
+ >>> list(map_dp_2)
126
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
127
+ """
128
+
129
+ # Functional form of 'ShufflerIterDataPipe'
130
+ def shuffle(self, *, indices: Optional[List] = None) -> IterDataPipe:
131
+ r"""
132
+ Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``).
133
+
134
+ When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
135
+ set up random seed are different based on :attr:`num_workers`.
136
+
137
+ For single-process mode (:attr:`num_workers == 0`), the random seed is set before
138
+ the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
139
+ mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed
140
+ for each worker process.
141
+
142
+ Args:
143
+ datapipe: MapDataPipe being shuffled
144
+ indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing
145
+
146
+ Example:
147
+ >>> # xdoctest: +SKIP
148
+ >>> from torchdata.datapipes.map import SequenceWrapper
149
+ >>> dp = SequenceWrapper(range(10))
150
+ >>> shuffle_dp = dp.shuffle().set_seed(0)
151
+ >>> list(shuffle_dp)
152
+ [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
153
+ >>> list(shuffle_dp)
154
+ [6, 1, 9, 5, 2, 4, 7, 3, 8, 0]
155
+ >>> # Reset seed for Shuffler
156
+ >>> shuffle_dp = shuffle_dp.set_seed(0)
157
+ >>> list(shuffle_dp)
158
+ [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
159
+
160
+ Note:
161
+ Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an
162
+ ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to
163
+ the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order
164
+ of data during data-processing.
165
+ """
166
+
167
+ # Functional form of 'ZipperMapDataPipe'
168
+ def zip(self, *datapipes: MapDataPipe[_T_co]) -> MapDataPipe:
169
+ r"""
170
+ Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
171
+
172
+ This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
173
+
174
+ Args:
175
+ *datapipes: Map DataPipes being aggregated
176
+
177
+ Example:
178
+ >>> # xdoctest: +SKIP
179
+ >>> from torchdata.datapipes.map import SequenceWrapper
180
+ >>> dp1 = SequenceWrapper(range(3))
181
+ >>> dp2 = SequenceWrapper(range(10, 13))
182
+ >>> zip_dp = dp1.zip(dp2)
183
+ >>> list(zip_dp)
184
+ [(0, 10), (1, 11), (2, 12)]
185
+ """
186
+
187
+
188
+ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
189
+ functions: Dict[str, Callable] = ...
190
+ reduce_ex_hook: Optional[Callable] = ...
191
+ getstate_hook: Optional[Callable] = ...
192
+ str_hook: Optional[Callable] = ...
193
+ repr_hook: Optional[Callable] = ...
194
+ _number_of_samples_yielded: int = ...
195
+ _snapshot_state: _SnapshotState = _SnapshotState.Iterating # noqa: PYI015
196
+ _fast_forward_iterator: Optional[Iterator] = ...
197
+ def __getattr__(self, attribute_name: Any): ...
198
+ @classmethod
199
+ def register_function(cls, function_name: Any, function: Any) -> None: ...
200
+ @classmethod
201
+ def register_datapipe_as_function(
202
+ cls,
203
+ function_name: Any,
204
+ cls_to_register: Any,
205
+ enable_df_api_tracing: bool = ...,
206
+ ): ...
207
+ def __getstate__(self): ...
208
+ def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
209
+ @classmethod
210
+ def set_getstate_hook(cls, hook_fn: Any) -> None: ...
211
+ @classmethod
212
+ def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
213
+ # Functional form of 'BatcherIterDataPipe'
214
+ def batch(self, batch_size: int, drop_last: bool = False, wrapper_class: Type[DataChunk] = DataChunk) -> IterDataPipe:
215
+ r"""
216
+ Creates mini-batches of data (functional name: ``batch``).
217
+
218
+ An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the
219
+ last batch if ``drop_last`` is set to ``False``.
220
+
221
+ Args:
222
+ datapipe: Iterable DataPipe being batched
223
+ batch_size: The size of each batch
224
+ drop_last: Option to drop the last batch if it's not full
225
+ wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding,
226
+ defaults to ``DataChunk``
227
+
228
+ Example:
229
+ >>> # xdoctest: +SKIP
230
+ >>> from torchdata.datapipes.iter import IterableWrapper
231
+ >>> dp = IterableWrapper(range(10))
232
+ >>> dp = dp.batch(batch_size=3, drop_last=True)
233
+ >>> list(dp)
234
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
235
+ """
236
+
237
+ # Functional form of 'CollatorIterDataPipe'
238
+ def collate(self, conversion: Union[Callable[..., Any], Dict[Union[str, Any], Union[Callable, Any]], None] = default_collate, collate_fn: Optional[Callable] = None) -> IterDataPipe:
239
+ r"""
240
+ Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``).
241
+
242
+ By default, it uses :func:`torch.utils.data.default_collate`.
243
+
244
+ .. note::
245
+ While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the
246
+ default behavior and `functools.partial` to specify any additional arguments.
247
+
248
+ Args:
249
+ datapipe: Iterable DataPipe being collated
250
+ collate_fn: Customized collate function to collect and combine data or a batch of data.
251
+ Default function collates to Tensor(s) based on data type.
252
+
253
+ Example:
254
+ >>> # xdoctest: +SKIP
255
+ >>> # Convert integer data to float Tensor
256
+ >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
257
+ ... def __init__(self, start, end):
258
+ ... super(MyIterDataPipe).__init__()
259
+ ... assert end > start, "this example code only works with end >= start"
260
+ ... self.start = start
261
+ ... self.end = end
262
+ ...
263
+ ... def __iter__(self):
264
+ ... return iter(range(self.start, self.end))
265
+ ...
266
+ ... def __len__(self):
267
+ ... return self.end - self.start
268
+ ...
269
+ >>> ds = MyIterDataPipe(start=3, end=7)
270
+ >>> print(list(ds))
271
+ [3, 4, 5, 6]
272
+ >>> def collate_fn(batch):
273
+ ... return torch.tensor(batch, dtype=torch.float)
274
+ ...
275
+ >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
276
+ >>> print(list(collated_ds))
277
+ [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
278
+ """
279
+
280
+ # Functional form of 'ConcaterIterDataPipe'
281
+ def concat(self, *datapipes: IterDataPipe) -> IterDataPipe:
282
+ r"""
283
+ Concatenates multiple Iterable DataPipes (functional name: ``concat``).
284
+
285
+ The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones.
286
+
287
+ Args:
288
+ datapipes: Iterable DataPipes being concatenated
289
+
290
+ Example:
291
+ >>> # xdoctest: +REQUIRES(module:torchdata)
292
+ >>> import random
293
+ >>> from torchdata.datapipes.iter import IterableWrapper
294
+ >>> dp1 = IterableWrapper(range(3))
295
+ >>> dp2 = IterableWrapper(range(5))
296
+ >>> list(dp1.concat(dp2))
297
+ [0, 1, 2, 0, 1, 2, 3, 4]
298
+ """
299
+
300
+ # Functional form of 'DemultiplexerIterDataPipe'
301
+ def demux(self, num_instances: int, classifier_fn: Callable[[_T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000) -> List[IterDataPipe]:
302
+ r"""
303
+ Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``).
304
+
305
+ A list of the child DataPipes is returned from this operation.
306
+
307
+ Args:
308
+ datapipe: Iterable DataPipe being filtered
309
+ num_instances: number of instances of the DataPipe to create
310
+ classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None``
311
+ drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None``
312
+ buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
313
+ DataPipes while waiting for their values to be yielded.
314
+ Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
315
+
316
+ Examples:
317
+ >>> # xdoctest: +REQUIRES(module:torchdata)
318
+ >>> from torchdata.datapipes.iter import IterableWrapper
319
+ >>> def odd_or_even(n):
320
+ ... return n % 2
321
+ >>> source_dp = IterableWrapper(range(5))
322
+ >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even)
323
+ >>> list(dp1)
324
+ [0, 2, 4]
325
+ >>> list(dp2)
326
+ [1, 3]
327
+ >>> # It can also filter out any element that gets `None` from the `classifier_fn`
328
+ >>> def odd_or_even_no_zero(n):
329
+ ... return n % 2 if n != 0 else None
330
+ >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
331
+ >>> list(dp1)
332
+ [2, 4]
333
+ >>> list(dp2)
334
+ [1, 3]
335
+ """
336
+
337
+ # Functional form of 'FilterIterDataPipe'
338
+ def filter(self, filter_fn: Callable, input_col=None) -> IterDataPipe:
339
+ r"""
340
+ Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``).
341
+
342
+ Args:
343
+ datapipe: Iterable DataPipe being filtered
344
+ filter_fn: Customized function mapping an element to a boolean.
345
+ input_col: Index or indices of data which ``filter_fn`` is applied, such as:
346
+
347
+ - ``None`` as default to apply ``filter_fn`` to the data directly.
348
+ - Integer(s) is used for list/tuple.
349
+ - Key(s) is used for dict.
350
+
351
+ Example:
352
+ >>> # xdoctest: +SKIP
353
+ >>> from torchdata.datapipes.iter import IterableWrapper
354
+ >>> def is_even(n):
355
+ ... return n % 2 == 0
356
+ >>> dp = IterableWrapper(range(5))
357
+ >>> filter_dp = dp.filter(filter_fn=is_even)
358
+ >>> list(filter_dp)
359
+ [0, 2, 4]
360
+ """
361
+
362
+ # Functional form of 'ForkerIterDataPipe'
363
+ def fork(self, num_instances: int, buffer_size: int = 1000, copy: Optional[Literal["shallow", "deep"]] = None) -> List[IterDataPipe]:
364
+ r"""
365
+ Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``).
366
+
367
+ Args:
368
+ datapipe: Iterable DataPipe being copied
369
+ num_instances: number of instances of the datapipe to create
370
+ buffer_size: this restricts how far ahead the leading child DataPipe
371
+ can read relative to the slowest child DataPipe.
372
+ Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
373
+ copy: copy strategy to use for items yielded by each branch. Supported
374
+ options are ``None`` for no copying, ``"shallow"`` for shallow object
375
+ copies, and ``"deep"`` for deep object copies. Defaults to ``None``.
376
+
377
+ Note:
378
+ All branches of the forked pipeline return the identical object unless
379
+ the copy parameter is supplied. If the object is mutable or contains
380
+ mutable objects, changing them in one branch will affect all others.
381
+
382
+ Example:
383
+ >>> # xdoctest: +REQUIRES(module:torchdata)
384
+ >>> from torchdata.datapipes.iter import IterableWrapper
385
+ >>> source_dp = IterableWrapper(range(5))
386
+ >>> dp1, dp2 = source_dp.fork(num_instances=2)
387
+ >>> list(dp1)
388
+ [0, 1, 2, 3, 4]
389
+ >>> list(dp2)
390
+ [0, 1, 2, 3, 4]
391
+ """
392
+
393
+ # Functional form of 'GrouperIterDataPipe'
394
+ def groupby(self, group_key_fn: Callable[[_T_co], Any], *, keep_key: bool = False, buffer_size: int = 10000, group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False) -> IterDataPipe:
395
+ r"""
396
+ Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``.
397
+
398
+ (functional name: ``groupby``).
399
+
400
+ The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group
401
+ will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full,
402
+ the DataPipe will yield the largest batch with the same key, provided that its size is larger
403
+ than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``.
404
+
405
+ After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity
406
+ will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``.
407
+
408
+ Args:
409
+ datapipe: Iterable datapipe to be grouped
410
+ group_key_fn: Function used to generate group key from the data of the source datapipe
411
+ keep_key: Option to yield the matching key along with the items in a tuple,
412
+ resulting in `(key, [items])` otherwise returning [items]
413
+ buffer_size: The size of buffer for ungrouped data
414
+ group_size: The max size of each group, a batch is yielded as soon as it reaches this size
415
+ guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full
416
+ drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer
417
+ when the buffer is full
418
+
419
+ Example:
420
+ >>> import os
421
+ >>> # xdoctest: +SKIP
422
+ >>> from torchdata.datapipes.iter import IterableWrapper
423
+ >>> def group_fn(file):
424
+ ... return os.path.basename(file).split(".")[0]
425
+ >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
426
+ >>> dp0 = source_dp.groupby(group_key_fn=group_fn)
427
+ >>> list(dp0)
428
+ [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
429
+ >>> # A group is yielded as soon as its size equals to `group_size`
430
+ >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
431
+ >>> list(dp1)
432
+ [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
433
+ >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
434
+ >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
435
+ >>> list(dp2)
436
+ [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
437
+ """
438
+
439
+ # Functional form of 'FileListerIterDataPipe'
440
+ def list_files(self, masks: Union[str, List[str]] = "", *, recursive: bool = False, abspath: bool = False, non_deterministic: bool = False, length: int = -1) -> IterDataPipe:
441
+ r"""
442
+ Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory.
443
+
444
+ Multiple root directories can be provided (functional name: ``list_files``).
445
+
446
+ Args:
447
+ root: Root directory or a sequence of root directories
448
+ masks: Unix style filter string or string list for filtering file name(s)
449
+ recursive: Whether to return pathname from nested directories or not
450
+ abspath: Whether to return relative pathname or absolute pathname
451
+ non_deterministic: Whether to return pathname in sorted order or not.
452
+ If ``False``, the results yielded from each root directory will be sorted
453
+ length: Nominal length of the datapipe
454
+
455
+ Example:
456
+ >>> # xdoctest: +SKIP
457
+ >>> from torchdata.datapipes.iter import FileLister
458
+ >>> dp = FileLister(root=".", recursive=True)
459
+ >>> list(dp)
460
+ ['example.py', './data/data.tar']
461
+ """
462
+
463
+ # Functional form of 'MapperIterDataPipe'
464
+ def map(self, fn: Callable, input_col=None, output_col=None) -> IterDataPipe:
465
+ r"""
466
+ Applies a function over each item from the source DataPipe (functional name: ``map``).
467
+
468
+ The function can be any regular Python function or partial object. Lambda
469
+ function is not recommended as it is not supported by pickle.
470
+
471
+ Args:
472
+ datapipe: Source Iterable DataPipe
473
+ fn: Function being applied over each item
474
+ input_col: Index or indices of data which ``fn`` is applied, such as:
475
+
476
+ - ``None`` as default to apply ``fn`` to the data directly.
477
+ - Integer(s) is used for list/tuple.
478
+ - Key(s) is used for dict.
479
+
480
+ output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
481
+ only when ``input_col`` is not ``None``
482
+
483
+ - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
484
+ multiple indices, the left-most one is used, and other indices will be removed.
485
+ - Integer is used for list/tuple. ``-1`` represents to append result at the end.
486
+ - Key is used for dict. New key is acceptable.
487
+
488
+ Example:
489
+ >>> # xdoctest: +SKIP
490
+ >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
491
+ >>> def add_one(x):
492
+ ... return x + 1
493
+ >>> dp = IterableWrapper(range(10))
494
+ >>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred
495
+ >>> list(map_dp_1)
496
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
497
+ >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
498
+ >>> # Use `functools.partial` or explicitly define the function instead
499
+ >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
500
+ >>> list(map_dp_2)
501
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
502
+ """
503
+
504
+ # Functional form of 'MultiplexerIterDataPipe'
505
+ def mux(self, *datapipes) -> IterDataPipe:
506
+ r"""
507
+ Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``).
508
+
509
+ As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
510
+ and so on. It ends when the shortest input DataPipe is exhausted.
511
+
512
+ Args:
513
+ datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted
514
+
515
+ Example:
516
+ >>> # xdoctest: +REQUIRES(module:torchdata)
517
+ >>> from torchdata.datapipes.iter import IterableWrapper
518
+ >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
519
+ >>> list(dp1.mux(dp2, dp3))
520
+ [0, 10, 20, 1, 11, 21, 2, 12, 22]
521
+ """
522
+
523
+ # Functional form of 'FileOpenerIterDataPipe'
524
+ def open_files(self, mode: str = "r", encoding: Optional[str] = None, length: int = -1) -> IterDataPipe:
525
+ r"""
526
+ Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``).
527
+
528
+ Args:
529
+ datapipe: Iterable datapipe that provides pathnames
530
+ mode: An optional string that specifies the mode in which
531
+ the file is opened by ``open()``. It defaults to ``r``, other options are
532
+ ``b`` for reading in binary mode and ``t`` for text mode.
533
+ encoding: An optional string that specifies the encoding of the
534
+ underlying file. It defaults to ``None`` to match the default encoding of ``open``.
535
+ length: Nominal length of the datapipe
536
+
537
+ Note:
538
+ The opened file handles will be closed by Python's GC periodically. Users can choose
539
+ to close them explicitly.
540
+
541
+ Example:
542
+ >>> # xdoctest: +SKIP
543
+ >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
544
+ >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
545
+ >>> dp = FileOpener(dp)
546
+ >>> dp = StreamReader(dp)
547
+ >>> list(dp)
548
+ [('./abc.txt', 'abc')]
549
+ """
550
+
551
+ # Functional form of 'StreamReaderIterDataPipe'
552
+ def read_from_stream(self, chunk=None) -> IterDataPipe:
553
+ r"""
554
+ Given IO streams and their label names, yield bytes with label name as tuple.
555
+
556
+ (functional name: ``read_from_stream``).
557
+
558
+ Args:
559
+ datapipe: Iterable DataPipe provides label/URL and byte stream
560
+ chunk: Number of bytes to be read from stream per iteration.
561
+ If ``None``, all bytes will be read until the EOF.
562
+
563
+ Example:
564
+ >>> # xdoctest: +SKIP
565
+ >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader
566
+ >>> from io import StringIO
567
+ >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))])
568
+ >>> list(StreamReader(dp, chunk=1))
569
+ [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')]
570
+ """
571
+
572
+ # Functional form of 'RoutedDecoderIterDataPipe'
573
+ def routed_decode(self, *handlers: Callable, key_fn: Callable= ...) -> IterDataPipe:
574
+ r"""
575
+ Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple.
576
+
577
+ (functional name: ``routed_decode``)
578
+
579
+ Args:
580
+ datapipe: Iterable datapipe that provides pathname and binary stream in tuples
581
+ handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder
582
+ handlers will be set as default. If multiple handles are provided, the priority
583
+ order follows the order of handlers (the first handler has the top priority)
584
+ key_fn: Function for decoder to extract key from pathname to dispatch handlers.
585
+ Default is set to extract file extension from pathname
586
+
587
+ Note:
588
+ When ``key_fn`` is specified returning anything other than extension, the default
589
+ handler will not work and users need to specify custom handler. Custom handler
590
+ could use regex to determine the eligibility to handle data.
591
+ """
592
+
593
+ # Functional form of 'ShardingFilterIterDataPipe'
594
+ def sharding_filter(self, sharding_group_filter=None) -> IterDataPipe:
595
+ r"""
596
+ Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``).
597
+
598
+ After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the
599
+ original DataPipe, where `n` equals to the number of instances.
600
+
601
+ Args:
602
+ source_datapipe: Iterable DataPipe that will be sharded
603
+ """
604
+
605
+ # Functional form of 'ShufflerIterDataPipe'
606
+ def shuffle(self, *, buffer_size: int = 10000, unbatch_level: int = 0) -> IterDataPipe:
607
+ r"""
608
+ Shuffle the input DataPipe with a buffer (functional name: ``shuffle``).
609
+
610
+ The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then,
611
+ each item will be yielded from the buffer by reservoir sampling via iterator.
612
+
613
+ ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the
614
+ datapipe is not shuffled. In order to fully shuffle all elements from datapipe,
615
+ ``buffer_size`` is required to be greater than or equal to the size of datapipe.
616
+
617
+ When it is used with :class:`torch.utils.data.DataLoader`, the methods to
618
+ set up random seed are different based on :attr:`num_workers`.
619
+
620
+ For single-process mode (:attr:`num_workers == 0`), the random seed is set before
621
+ the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
622
+ mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
623
+ for each worker process.
624
+
625
+ Args:
626
+ datapipe: The IterDataPipe being shuffled
627
+ buffer_size: The buffer size for shuffling (default to ``10000``)
628
+ unbatch_level: Specifies if it is necessary to unbatch source data before
629
+ applying the shuffle
630
+
631
+ Example:
632
+ >>> # xdoctest: +SKIP
633
+ >>> from torchdata.datapipes.iter import IterableWrapper
634
+ >>> dp = IterableWrapper(range(10))
635
+ >>> shuffle_dp = dp.shuffle()
636
+ >>> list(shuffle_dp)
637
+ [0, 4, 1, 6, 3, 2, 9, 5, 7, 8]
638
+ """
639
+
640
+ # Functional form of 'UnBatcherIterDataPipe'
641
+ def unbatch(self, unbatch_level: int = 1) -> IterDataPipe:
642
+ r"""
643
+ Undos batching of data (functional name: ``unbatch``).
644
+
645
+ In other words, it flattens the data up to the specified level within a batched DataPipe.
646
+
647
+ Args:
648
+ datapipe: Iterable DataPipe being un-batched
649
+ unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``,
650
+ it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe.
651
+
652
+ Example:
653
+ >>> # xdoctest: +SKIP
654
+ >>> from torchdata.datapipes.iter import IterableWrapper
655
+ >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]])
656
+ >>> dp1 = source_dp.unbatch()
657
+ >>> list(dp1)
658
+ [[0, 1], [2], [3, 4], [5], [6]]
659
+ >>> dp2 = source_dp.unbatch(unbatch_level=2)
660
+ >>> list(dp2)
661
+ [0, 1, 2, 3, 4, 5, 6]
662
+ """
663
+
664
+ # Functional form of 'ZipperIterDataPipe'
665
+ def zip(self, *datapipes: IterDataPipe) -> IterDataPipe:
666
+ r"""
667
+ Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
668
+
669
+ The output is stopped as soon as the shortest input DataPipe is exhausted.
670
+
671
+ Args:
672
+ *datapipes: Iterable DataPipes being aggregated
673
+
674
+ Example:
675
+ >>> # xdoctest: +REQUIRES(module:torchdata)
676
+ >>> from torchdata.datapipes.iter import IterableWrapper
677
+ >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
678
+ >>> list(dp1.zip(dp2, dp3))
679
+ [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
680
+ """
681
+
682
+
683
+ class DFIterDataPipe(IterDataPipe):
684
+ def _is_dfpipe(self): ...
685
+ def __iter__(self): ...
686
+
687
+ class _DataPipeSerializationWrapper:
688
+ def __init__(self, datapipe): ...
689
+ def __getstate__(self): ...
690
+ def __setstate__(self, state): ...
691
+ def __len__(self): ...
692
+
693
+ class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
694
+ def __iter__(self): ...
695
+
696
+ class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
697
+ def __getitem__(self, idx): ...
.venv/Lib/site-packages/torch/utils/data/datapipes/gen_pyi.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import os
3
+ import pathlib
4
+ from collections import defaultdict
5
+ from typing import Any, Dict, List, Set, Tuple, Union
6
+
7
+
8
+ def materialize_lines(lines: List[str], indentation: int) -> str:
9
+ output = ""
10
+ new_line_with_indent = "\n" + " " * indentation
11
+ for i, line in enumerate(lines):
12
+ if i != 0:
13
+ output += new_line_with_indent
14
+ output += line.replace("\n", new_line_with_indent)
15
+ return output
16
+
17
+
18
+ def gen_from_template(
19
+ dir: str,
20
+ template_name: str,
21
+ output_name: str,
22
+ replacements: List[Tuple[str, Any, int]],
23
+ ):
24
+ template_path = os.path.join(dir, template_name)
25
+ output_path = os.path.join(dir, output_name)
26
+
27
+ with open(template_path) as f:
28
+ content = f.read()
29
+ for placeholder, lines, indentation in replacements:
30
+ with open(output_path, "w") as f:
31
+ content = content.replace(
32
+ placeholder, materialize_lines(lines, indentation)
33
+ )
34
+ f.write(content)
35
+
36
+
37
+ def find_file_paths(dir_paths: List[str], files_to_exclude: Set[str]) -> Set[str]:
38
+ """
39
+ When given a path to a directory, returns the paths to the relevant files within it.
40
+
41
+ This function does NOT recursive traverse to subdirectories.
42
+ """
43
+ paths: Set[str] = set()
44
+ for dir_path in dir_paths:
45
+ all_files = os.listdir(dir_path)
46
+ python_files = {fname for fname in all_files if ".py" == fname[-3:]}
47
+ filter_files = {
48
+ fname for fname in python_files if fname not in files_to_exclude
49
+ }
50
+ paths.update({os.path.join(dir_path, fname) for fname in filter_files})
51
+ return paths
52
+
53
+
54
+ def extract_method_name(line: str) -> str:
55
+ """Extract method name from decorator in the form of "@functional_datapipe({method_name})"."""
56
+ if '("' in line:
57
+ start_token, end_token = '("', '")'
58
+ elif "('" in line:
59
+ start_token, end_token = "('", "')"
60
+ else:
61
+ raise RuntimeError(
62
+ f"Unable to find appropriate method name within line:\n{line}"
63
+ )
64
+ start, end = line.find(start_token) + len(start_token), line.find(end_token)
65
+ return line[start:end]
66
+
67
+
68
+ def extract_class_name(line: str) -> str:
69
+ """Extract class name from class definition in the form of "class {CLASS_NAME}({Type}):"."""
70
+ start_token = "class "
71
+ end_token = "("
72
+ start, end = line.find(start_token) + len(start_token), line.find(end_token)
73
+ return line[start:end]
74
+
75
+
76
+ def parse_datapipe_file(
77
+ file_path: str,
78
+ ) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
79
+ """Given a path to file, parses the file and returns a dictionary of method names to function signatures."""
80
+ method_to_signature, method_to_class_name, special_output_type = {}, {}, set()
81
+ doc_string_dict = defaultdict(list)
82
+ with open(file_path) as f:
83
+ open_paren_count = 0
84
+ method_name, class_name, signature = "", "", ""
85
+ skip = False
86
+ for line in f:
87
+ if line.count('"""') % 2 == 1:
88
+ skip = not skip
89
+ if skip or '"""' in line: # Saving docstrings
90
+ doc_string_dict[method_name].append(line)
91
+ continue
92
+ if "@functional_datapipe" in line:
93
+ method_name = extract_method_name(line)
94
+ doc_string_dict[method_name] = []
95
+ continue
96
+ if method_name and "class " in line:
97
+ class_name = extract_class_name(line)
98
+ continue
99
+ if method_name and ("def __init__(" in line or "def __new__(" in line):
100
+ if "def __new__(" in line:
101
+ special_output_type.add(method_name)
102
+ open_paren_count += 1
103
+ start = line.find("(") + len("(")
104
+ line = line[start:]
105
+ if open_paren_count > 0:
106
+ open_paren_count += line.count("(")
107
+ open_paren_count -= line.count(")")
108
+ if open_paren_count == 0:
109
+ end = line.rfind(")")
110
+ signature += line[:end]
111
+ method_to_signature[method_name] = process_signature(signature)
112
+ method_to_class_name[method_name] = class_name
113
+ method_name, class_name, signature = "", "", ""
114
+ elif open_paren_count < 0:
115
+ raise RuntimeError(
116
+ "open parenthesis count < 0. This shouldn't be possible."
117
+ )
118
+ else:
119
+ signature += line.strip("\n").strip(" ")
120
+ return (
121
+ method_to_signature,
122
+ method_to_class_name,
123
+ special_output_type,
124
+ doc_string_dict,
125
+ )
126
+
127
+
128
+ def parse_datapipe_files(
129
+ file_paths: Set[str],
130
+ ) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
131
+ (
132
+ methods_and_signatures,
133
+ methods_and_class_names,
134
+ methods_with_special_output_types,
135
+ ) = ({}, {}, set())
136
+ methods_and_doc_strings = {}
137
+ for path in file_paths:
138
+ (
139
+ method_to_signature,
140
+ method_to_class_name,
141
+ methods_needing_special_output_types,
142
+ doc_string_dict,
143
+ ) = parse_datapipe_file(path)
144
+ methods_and_signatures.update(method_to_signature)
145
+ methods_and_class_names.update(method_to_class_name)
146
+ methods_with_special_output_types.update(methods_needing_special_output_types)
147
+ methods_and_doc_strings.update(doc_string_dict)
148
+ return (
149
+ methods_and_signatures,
150
+ methods_and_class_names,
151
+ methods_with_special_output_types,
152
+ methods_and_doc_strings,
153
+ )
154
+
155
+
156
+ def split_outside_bracket(line: str, delimiter: str = ",") -> List[str]:
157
+ """Given a line of text, split it on comma unless the comma is within a bracket '[]'."""
158
+ bracket_count = 0
159
+ curr_token = ""
160
+ res = []
161
+ for char in line:
162
+ if char == "[":
163
+ bracket_count += 1
164
+ elif char == "]":
165
+ bracket_count -= 1
166
+ elif char == delimiter and bracket_count == 0:
167
+ res.append(curr_token)
168
+ curr_token = ""
169
+ continue
170
+ curr_token += char
171
+ res.append(curr_token)
172
+ return res
173
+
174
+
175
+ def process_signature(line: str) -> str:
176
+ """
177
+ Clean up a given raw function signature.
178
+
179
+ This includes removing the self-referential datapipe argument, default
180
+ arguments of input functions, newlines, and spaces.
181
+ """
182
+ tokens: List[str] = split_outside_bracket(line)
183
+ for i, token in enumerate(tokens):
184
+ tokens[i] = token.strip(" ")
185
+ if token == "cls":
186
+ tokens[i] = "self"
187
+ elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"):
188
+ # Remove the datapipe after 'self' or 'cls' unless it has '*'
189
+ tokens[i] = ""
190
+ elif "Callable =" in token: # Remove default argument if it is a function
191
+ head, default_arg = token.rsplit("=", 2)
192
+ tokens[i] = head.strip(" ") + "= ..."
193
+ tokens = [t for t in tokens if t != ""]
194
+ line = ", ".join(tokens)
195
+ return line
196
+
197
+
198
+ def get_method_definitions(
199
+ file_path: Union[str, List[str]],
200
+ files_to_exclude: Set[str],
201
+ deprecated_files: Set[str],
202
+ default_output_type: str,
203
+ method_to_special_output_type: Dict[str, str],
204
+ root: str = "",
205
+ ) -> List[str]:
206
+ """
207
+ #.pyi generation for functional DataPipes Process.
208
+
209
+ # 1. Find files that we want to process (exclude the ones who don't)
210
+ # 2. Parse method name and signature
211
+ # 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
212
+ """
213
+ if root == "":
214
+ root = str(pathlib.Path(__file__).parent.resolve())
215
+ file_path = [file_path] if isinstance(file_path, str) else file_path
216
+ file_path = [os.path.join(root, path) for path in file_path]
217
+ file_paths = find_file_paths(
218
+ file_path, files_to_exclude=files_to_exclude.union(deprecated_files)
219
+ )
220
+ (
221
+ methods_and_signatures,
222
+ methods_and_class_names,
223
+ methods_w_special_output_types,
224
+ methods_and_doc_strings,
225
+ ) = parse_datapipe_files(file_paths)
226
+
227
+ for fn_name in method_to_special_output_type:
228
+ if fn_name not in methods_w_special_output_types:
229
+ methods_w_special_output_types.add(fn_name)
230
+
231
+ method_definitions = []
232
+ for method_name, arguments in methods_and_signatures.items():
233
+ class_name = methods_and_class_names[method_name]
234
+ if method_name in methods_w_special_output_types:
235
+ output_type = method_to_special_output_type[method_name]
236
+ else:
237
+ output_type = default_output_type
238
+ doc_string = "".join(methods_and_doc_strings[method_name])
239
+ if doc_string == "":
240
+ doc_string = " ...\n"
241
+ method_definitions.append(
242
+ f"# Functional form of '{class_name}'\n"
243
+ f"def {method_name}({arguments}) -> {output_type}:\n"
244
+ f"{doc_string}"
245
+ )
246
+ method_definitions.sort(
247
+ key=lambda s: s.split("\n")[1]
248
+ ) # sorting based on method_name
249
+
250
+ return method_definitions
251
+
252
+
253
+ # Defined outside of main() so they can be imported by TorchData
254
+ iterDP_file_path: str = "iter"
255
+ iterDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
256
+ iterDP_deprecated_files: Set[str] = set()
257
+ iterDP_method_to_special_output_type: Dict[str, str] = {
258
+ "demux": "List[IterDataPipe]",
259
+ "fork": "List[IterDataPipe]",
260
+ }
261
+
262
+ mapDP_file_path: str = "map"
263
+ mapDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
264
+ mapDP_deprecated_files: Set[str] = set()
265
+ mapDP_method_to_special_output_type: Dict[str, str] = {"shuffle": "IterDataPipe"}
266
+
267
+
268
+ def main() -> None:
269
+ """
270
+ # Inject file into template datapipe.pyi.in.
271
+
272
+ TODO: The current implementation of this script only generates interfaces for built-in methods. To generate
273
+ interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`.
274
+ """
275
+ iter_method_definitions = get_method_definitions(
276
+ iterDP_file_path,
277
+ iterDP_files_to_exclude,
278
+ iterDP_deprecated_files,
279
+ "IterDataPipe",
280
+ iterDP_method_to_special_output_type,
281
+ )
282
+
283
+ map_method_definitions = get_method_definitions(
284
+ mapDP_file_path,
285
+ mapDP_files_to_exclude,
286
+ mapDP_deprecated_files,
287
+ "MapDataPipe",
288
+ mapDP_method_to_special_output_type,
289
+ )
290
+
291
+ path = pathlib.Path(__file__).parent.resolve()
292
+ replacements = [
293
+ ("${IterDataPipeMethods}", iter_method_definitions, 4),
294
+ ("${MapDataPipeMethods}", map_method_definitions, 4),
295
+ ]
296
+ gen_from_template(
297
+ dir=str(path),
298
+ template_name="datapipe.pyi.in",
299
+ output_name="datapipe.pyi",
300
+ replacements=replacements,
301
+ )
302
+
303
+
304
+ if __name__ == "__main__":
305
+ main()
.venv/Lib/site-packages/torch/utils/data/datapipes/iter/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data.datapipes.iter.callable import (
2
+ CollatorIterDataPipe as Collator,
3
+ MapperIterDataPipe as Mapper,
4
+ )
5
+ from torch.utils.data.datapipes.iter.combinatorics import (
6
+ SamplerIterDataPipe as Sampler,
7
+ ShufflerIterDataPipe as Shuffler,
8
+ )
9
+ from torch.utils.data.datapipes.iter.combining import (
10
+ ConcaterIterDataPipe as Concater,
11
+ DemultiplexerIterDataPipe as Demultiplexer,
12
+ ForkerIterDataPipe as Forker,
13
+ MultiplexerIterDataPipe as Multiplexer,
14
+ ZipperIterDataPipe as Zipper,
15
+ )
16
+ from torch.utils.data.datapipes.iter.filelister import (
17
+ FileListerIterDataPipe as FileLister,
18
+ )
19
+ from torch.utils.data.datapipes.iter.fileopener import (
20
+ FileOpenerIterDataPipe as FileOpener,
21
+ )
22
+ from torch.utils.data.datapipes.iter.grouping import (
23
+ BatcherIterDataPipe as Batcher,
24
+ GrouperIterDataPipe as Grouper,
25
+ UnBatcherIterDataPipe as UnBatcher,
26
+ )
27
+ from torch.utils.data.datapipes.iter.routeddecoder import (
28
+ RoutedDecoderIterDataPipe as RoutedDecoder,
29
+ )
30
+ from torch.utils.data.datapipes.iter.selecting import FilterIterDataPipe as Filter
31
+ from torch.utils.data.datapipes.iter.sharding import (
32
+ ShardingFilterIterDataPipe as ShardingFilter,
33
+ )
34
+ from torch.utils.data.datapipes.iter.streamreader import (
35
+ StreamReaderIterDataPipe as StreamReader,
36
+ )
37
+ from torch.utils.data.datapipes.iter.utils import (
38
+ IterableWrapperIterDataPipe as IterableWrapper,
39
+ )
40
+
41
+
42
+ __all__ = [
43
+ "Batcher",
44
+ "Collator",
45
+ "Concater",
46
+ "Demultiplexer",
47
+ "FileLister",
48
+ "FileOpener",
49
+ "Filter",
50
+ "Forker",
51
+ "Grouper",
52
+ "IterableWrapper",
53
+ "Mapper",
54
+ "Multiplexer",
55
+ "RoutedDecoder",
56
+ "Sampler",
57
+ "ShardingFilter",
58
+ "Shuffler",
59
+ "StreamReader",
60
+ "UnBatcher",
61
+ "Zipper",
62
+ ]
63
+
64
+ # Please keep this list sorted
65
+ assert __all__ == sorted(__all__)
.venv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-39.pyc ADDED
Binary file (7.84 kB). View file