Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .venv/Lib/site-packages/torch/lib/cudnn_heuristic64_9.dll +3 -0
- .venv/Lib/site-packages/torch/lib/cudnn_ops64_9.dll +3 -0
- .venv/Lib/site-packages/torch/lib/sleef.lib +3 -0
- .venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py +0 -0
- .venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp +35 -0
- .venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp +68 -0
- .venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py +907 -0
- .venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h +0 -0
- .venv/Lib/site-packages/torch/utils/bottleneck/__init__.py +0 -0
- .venv/Lib/site-packages/torch/utils/bottleneck/__main__.py +230 -0
- .venv/Lib/site-packages/torch/utils/data/__init__.py +77 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/__init__.py +54 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/collate.py +398 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/fetch.py +55 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/pin_memory.py +108 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/signal_handling.py +79 -0
- .venv/Lib/site-packages/torch/utils/data/_utils/worker.py +376 -0
- .venv/Lib/site-packages/torch/utils/data/backward_compatibility.py +11 -0
- .venv/Lib/site-packages/torch/utils/data/dataloader.py +1604 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/__init__.py +1 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/_decorator.py +213 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/_hook_iterator.py +279 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/_typing.py +486 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__init__.py +11 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py +128 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py +457 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py +134 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/structures.py +20 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/datapipe.py +415 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/datapipe.pyi +697 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/gen_pyi.py +305 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/iter/__init__.py +65 -0
- .venv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-39.pyc +0 -0
.gitattributes
CHANGED
|
@@ -68,3 +68,6 @@ reference_sample_wavs/syuukovoice_200918_3_01.wav filter=lfs diff=lfs merge=lfs
|
|
| 68 |
.venv/Lib/site-packages/torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs -text
|
| 69 |
.venv/Lib/site-packages/torch/lib/curand64_10.dll filter=lfs diff=lfs merge=lfs -text
|
| 70 |
.venv/Lib/site-packages/torch/lib/cusolverMg64_11.dll filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
.venv/Lib/site-packages/torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs -text
|
| 69 |
.venv/Lib/site-packages/torch/lib/curand64_10.dll filter=lfs diff=lfs merge=lfs -text
|
| 70 |
.venv/Lib/site-packages/torch/lib/cusolverMg64_11.dll filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
.venv/Lib/site-packages/torch/lib/cudnn_heuristic64_9.dll filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
.venv/Lib/site-packages/torch/lib/sleef.lib filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
.venv/Lib/site-packages/torch/lib/cudnn_ops64_9.dll filter=lfs diff=lfs merge=lfs -text
|
.venv/Lib/site-packages/torch/lib/cudnn_heuristic64_9.dll
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee6d4831251387ab52a549df7ce7e5256272426eeef23a36d172ca8c725afba1
|
| 3 |
+
size 85741608
|
.venv/Lib/site-packages/torch/lib/cudnn_ops64_9.dll
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e26d81b20ceda0fff0fce1b60f5c4a7c0b32650afa2ab49f0ea4496816bead5b
|
| 3 |
+
size 107721256
|
.venv/Lib/site-packages/torch/lib/sleef.lib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:55eb52de5e0e99ed7cbeeadb0b1e7523bd09278b1160145bd15e200a9df3139a
|
| 3 |
+
size 8862502
|
.venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py
ADDED
|
File without changes
|
.venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Used to collect profiles of old versions of PyTorch. */
|
| 2 |
+
#include <callgrind.h>
|
| 3 |
+
#include <pybind11/pybind11.h>
|
| 4 |
+
|
| 5 |
+
bool _valgrind_supported_platform() {
|
| 6 |
+
#if defined(NVALGRIND)
|
| 7 |
+
return false;
|
| 8 |
+
#else
|
| 9 |
+
return true;
|
| 10 |
+
#endif
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
void _valgrind_toggle() {
|
| 14 |
+
#if defined(NVALGRIND)
|
| 15 |
+
TORCH_CHECK(false, "Valgrind is not supported.");
|
| 16 |
+
#else
|
| 17 |
+
CALLGRIND_TOGGLE_COLLECT;
|
| 18 |
+
#endif
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
void _valgrind_toggle_and_dump_stats() {
|
| 22 |
+
#if defined(NVALGRIND)
|
| 23 |
+
TORCH_CHECK(false, "Valgrind is not supported.");
|
| 24 |
+
#else
|
| 25 |
+
// NB: See note in Module.cpp
|
| 26 |
+
CALLGRIND_TOGGLE_COLLECT;
|
| 27 |
+
CALLGRIND_DUMP_STATS;
|
| 28 |
+
#endif
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
PYBIND11_MODULE(callgrind_bindings, m) {
|
| 32 |
+
m.def("_valgrind_supported_platform", &_valgrind_supported_platform);
|
| 33 |
+
m.def("_valgrind_toggle", &_valgrind_toggle);
|
| 34 |
+
m.def("_valgrind_toggle_and_dump_stats", &_valgrind_dump_stats);
|
| 35 |
+
}
|
.venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* C++ template for Timer.collect_callgrind
|
| 2 |
+
|
| 3 |
+
This template will be consumed by `cpp_jit.py`, and will replace:
|
| 4 |
+
`GLOBAL_SETUP_TEMPLATE_LOCATION`,
|
| 5 |
+
`SETUP_TEMPLATE_LOCATION`
|
| 6 |
+
and
|
| 7 |
+
`STMT_TEMPLATE_LOCATION`
|
| 8 |
+
sections with user provided statements.
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <c10/util/irange.h>
|
| 12 |
+
#include <callgrind.h>
|
| 13 |
+
#include <torch/torch.h>
|
| 14 |
+
|
| 15 |
+
#include <string>
|
| 16 |
+
|
| 17 |
+
// Global setup. (e.g. #includes)
|
| 18 |
+
// GLOBAL_SETUP_TEMPLATE_LOCATION
|
| 19 |
+
|
| 20 |
+
#if defined(NVALGRIND)
|
| 21 |
+
static_assert(false);
|
| 22 |
+
#endif
|
| 23 |
+
|
| 24 |
+
int main(int argc, char* argv[]) {
|
| 25 |
+
// This file should only be called inside of `Timer`, so we can adopt a
|
| 26 |
+
// very simple and rigid argument parsing scheme.
|
| 27 |
+
TORCH_CHECK(argc == 9);
|
| 28 |
+
TORCH_CHECK(std::string(argv[1]) == "--number");
|
| 29 |
+
auto number = std::stoi(argv[2]);
|
| 30 |
+
|
| 31 |
+
TORCH_CHECK(
|
| 32 |
+
std::string(argv[3]) == "--number-warmup" ||
|
| 33 |
+
std::string(argv[3]) == "--number_warmup");
|
| 34 |
+
auto number_warmup = std::stoi(argv[4]);
|
| 35 |
+
|
| 36 |
+
TORCH_CHECK(std::string(argv[5]) == "--repeats");
|
| 37 |
+
auto repeats = std::stoi(argv[6]);
|
| 38 |
+
|
| 39 |
+
TORCH_CHECK(
|
| 40 |
+
std::string(argv[7]) == "--number-threads" ||
|
| 41 |
+
std::string(argv[7]) == "--number_threads");
|
| 42 |
+
auto number_threads = std::stoi(argv[8]);
|
| 43 |
+
torch::set_num_threads(number_threads);
|
| 44 |
+
|
| 45 |
+
// Setup
|
| 46 |
+
// SETUP_TEMPLATE_LOCATION
|
| 47 |
+
|
| 48 |
+
// Warmup
|
| 49 |
+
for (const auto i : c10::irange(number_warmup)) {
|
| 50 |
+
(void)i;
|
| 51 |
+
// STMT_TEMPLATE_LOCATION
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// Main loop
|
| 55 |
+
for (const auto repeat : c10::irange(repeats)) {
|
| 56 |
+
(void)repeat;
|
| 57 |
+
CALLGRIND_TOGGLE_COLLECT;
|
| 58 |
+
|
| 59 |
+
for (const auto i : c10::irange(number)) {
|
| 60 |
+
(void)i;
|
| 61 |
+
// STMT_TEMPLATE_LOCATION
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// NB: See note in Module.cpp
|
| 65 |
+
CALLGRIND_TOGGLE_COLLECT;
|
| 66 |
+
CALLGRIND_DUMP_STATS;
|
| 67 |
+
}
|
| 68 |
+
}
|
.venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
ADDED
|
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Intermediate layer between `Timer` and `valgrind`."""
|
| 2 |
+
import collections
|
| 3 |
+
import enum
|
| 4 |
+
import dataclasses
|
| 5 |
+
import itertools as it
|
| 6 |
+
import os
|
| 7 |
+
import pickle
|
| 8 |
+
import re
|
| 9 |
+
import shutil
|
| 10 |
+
import subprocess
|
| 11 |
+
import sys
|
| 12 |
+
import textwrap
|
| 13 |
+
from typing import (
|
| 14 |
+
cast, Any, Callable, DefaultDict, Dict, Iterator, List, NamedTuple,
|
| 15 |
+
Optional, Tuple, Union, TYPE_CHECKING)
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch.utils.benchmark.utils import common, cpp_jit
|
| 19 |
+
from torch.utils.benchmark.utils._stubs import CallgrindModuleType
|
| 20 |
+
import operator
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
CompletedProcessType = subprocess.CompletedProcess[str]
|
| 28 |
+
else:
|
| 29 |
+
CompletedProcessType = subprocess.CompletedProcess
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class FunctionCount(NamedTuple):
|
| 33 |
+
# TODO(#105471): Rename the count field
|
| 34 |
+
count: int # type: ignore[assignment]
|
| 35 |
+
function: str
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclasses.dataclass(repr=False, eq=False, frozen=True)
|
| 39 |
+
class FunctionCounts:
|
| 40 |
+
"""Container for manipulating Callgrind results.
|
| 41 |
+
|
| 42 |
+
It supports:
|
| 43 |
+
1) Addition and subtraction to combine or diff results.
|
| 44 |
+
2) Tuple-like indexing.
|
| 45 |
+
3) A `denoise` function which strips CPython calls which are known to
|
| 46 |
+
be non-deterministic and quite noisy.
|
| 47 |
+
4) Two higher order methods (`filter` and `transform`) for custom
|
| 48 |
+
manipulation.
|
| 49 |
+
"""
|
| 50 |
+
_data: Tuple[FunctionCount, ...]
|
| 51 |
+
inclusive: bool
|
| 52 |
+
truncate_rows: bool = True
|
| 53 |
+
|
| 54 |
+
# For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines
|
| 55 |
+
# the print settings. This is simply to allow hermetic unit tests.
|
| 56 |
+
_linewidth: Optional[int] = None
|
| 57 |
+
|
| 58 |
+
def __iter__(self) -> Iterator[FunctionCount]:
|
| 59 |
+
yield from self._data
|
| 60 |
+
|
| 61 |
+
def __len__(self) -> int:
|
| 62 |
+
return len(self._data)
|
| 63 |
+
|
| 64 |
+
def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]:
|
| 65 |
+
data: Union[FunctionCount, Tuple[FunctionCount, ...]] = self._data[item]
|
| 66 |
+
return (
|
| 67 |
+
FunctionCounts(cast(Tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False)
|
| 68 |
+
if isinstance(data, tuple) else data
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def __repr__(self) -> str:
|
| 72 |
+
count_len = 0
|
| 73 |
+
for c, _ in self:
|
| 74 |
+
# Account for sign in string length.
|
| 75 |
+
count_len = max(count_len, len(str(c)) + int(c < 0))
|
| 76 |
+
|
| 77 |
+
lines = []
|
| 78 |
+
linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth
|
| 79 |
+
fn_str_len = max(linewidth - count_len - 4, 40)
|
| 80 |
+
for c, fn in self:
|
| 81 |
+
if len(fn) > fn_str_len:
|
| 82 |
+
left_len = int((fn_str_len - 5) // 2)
|
| 83 |
+
fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):]
|
| 84 |
+
lines.append(f" {c:>{count_len}} {fn}")
|
| 85 |
+
|
| 86 |
+
if self.truncate_rows and len(lines) > 18:
|
| 87 |
+
lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:]
|
| 88 |
+
|
| 89 |
+
if not self.inclusive:
|
| 90 |
+
lines.extend(["", f"Total: {self.sum()}"])
|
| 91 |
+
|
| 92 |
+
return "\n".join([super().__repr__()] + lines)
|
| 93 |
+
|
| 94 |
+
def __add__(
|
| 95 |
+
self,
|
| 96 |
+
other: "FunctionCounts",
|
| 97 |
+
) -> "FunctionCounts":
|
| 98 |
+
return self._merge(other, lambda c: c)
|
| 99 |
+
|
| 100 |
+
def __sub__(
|
| 101 |
+
self,
|
| 102 |
+
other: "FunctionCounts",
|
| 103 |
+
) -> "FunctionCounts":
|
| 104 |
+
return self._merge(other, operator.neg)
|
| 105 |
+
|
| 106 |
+
def __mul__(self, other: Union[int, float]) -> "FunctionCounts":
|
| 107 |
+
return self._from_dict({
|
| 108 |
+
fn: int(c * other) for c, fn in self._data
|
| 109 |
+
}, self.inclusive)
|
| 110 |
+
|
| 111 |
+
def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts":
|
| 112 |
+
"""Apply `map_fn` to all of the function names.
|
| 113 |
+
|
| 114 |
+
This can be used to regularize function names (e.g. stripping irrelevant
|
| 115 |
+
parts of the file path), coalesce entries by mapping multiple functions
|
| 116 |
+
to the same name (in which case the counts are added together), etc.
|
| 117 |
+
"""
|
| 118 |
+
counts: DefaultDict[str, int] = collections.defaultdict(int)
|
| 119 |
+
for c, fn in self._data:
|
| 120 |
+
counts[map_fn(fn)] += c
|
| 121 |
+
|
| 122 |
+
return self._from_dict(counts, self.inclusive)
|
| 123 |
+
|
| 124 |
+
def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts":
|
| 125 |
+
"""Keep only the elements where `filter_fn` applied to function name returns True."""
|
| 126 |
+
return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive)
|
| 127 |
+
|
| 128 |
+
def sum(self) -> int:
|
| 129 |
+
return sum(c for c, _ in self)
|
| 130 |
+
|
| 131 |
+
def denoise(self) -> "FunctionCounts":
|
| 132 |
+
"""Remove known noisy instructions.
|
| 133 |
+
|
| 134 |
+
Several instructions in the CPython interpreter are rather noisy. These
|
| 135 |
+
instructions involve unicode to dictionary lookups which Python uses to
|
| 136 |
+
map variable names. FunctionCounts is generally a content agnostic
|
| 137 |
+
container, however this is sufficiently important for obtaining
|
| 138 |
+
reliable results to warrant an exception."""
|
| 139 |
+
return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn)
|
| 140 |
+
|
| 141 |
+
def _merge(
|
| 142 |
+
self,
|
| 143 |
+
second: "FunctionCounts",
|
| 144 |
+
merge_fn: Callable[[int], int]
|
| 145 |
+
) -> "FunctionCounts":
|
| 146 |
+
assert self.inclusive == second.inclusive, "Cannot merge inclusive and exclusive counts."
|
| 147 |
+
counts: DefaultDict[str, int] = collections.defaultdict(int)
|
| 148 |
+
for c, fn in self:
|
| 149 |
+
counts[fn] += c
|
| 150 |
+
|
| 151 |
+
for c, fn in second:
|
| 152 |
+
counts[fn] += merge_fn(c)
|
| 153 |
+
|
| 154 |
+
return self._from_dict(counts, self.inclusive)
|
| 155 |
+
|
| 156 |
+
@staticmethod
|
| 157 |
+
def _from_dict(counts: Dict[str, int], inclusive: bool) -> "FunctionCounts":
|
| 158 |
+
flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c)
|
| 159 |
+
return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@dataclasses.dataclass(repr=False, eq=False, frozen=True)
|
| 163 |
+
class CallgrindStats:
|
| 164 |
+
"""Top level container for Callgrind results collected by Timer.
|
| 165 |
+
|
| 166 |
+
Manipulation is generally done using the FunctionCounts class, which is
|
| 167 |
+
obtained by calling `CallgrindStats.stats(...)`. Several convenience
|
| 168 |
+
methods are provided as well; the most significant is
|
| 169 |
+
`CallgrindStats.as_standardized()`.
|
| 170 |
+
"""
|
| 171 |
+
task_spec: common.TaskSpec
|
| 172 |
+
number_per_run: int
|
| 173 |
+
built_with_debug_symbols: bool
|
| 174 |
+
baseline_inclusive_stats: FunctionCounts
|
| 175 |
+
baseline_exclusive_stats: FunctionCounts
|
| 176 |
+
stmt_inclusive_stats: FunctionCounts
|
| 177 |
+
stmt_exclusive_stats: FunctionCounts
|
| 178 |
+
stmt_callgrind_out: Optional[str]
|
| 179 |
+
|
| 180 |
+
def __repr__(self) -> str:
|
| 181 |
+
newline = "\n" # `\` cannot appear in fstring code section.
|
| 182 |
+
base_stats = self.baseline_exclusive_stats
|
| 183 |
+
output = f"""
|
| 184 |
+
{super().__repr__()}
|
| 185 |
+
{self.task_spec.summarize()}
|
| 186 |
+
{'':>25}All{'':>10}Noisy symbols removed
|
| 187 |
+
Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12}
|
| 188 |
+
Baseline: {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12}
|
| 189 |
+
{self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''}
|
| 190 |
+
""".strip()
|
| 191 |
+
if not self.built_with_debug_symbols:
|
| 192 |
+
output += textwrap.dedent("""
|
| 193 |
+
Warning: PyTorch was not built with debug symbols.
|
| 194 |
+
Source information may be limited. Rebuild with
|
| 195 |
+
REL_WITH_DEB_INFO=1 for more detailed results.""")
|
| 196 |
+
return output
|
| 197 |
+
|
| 198 |
+
def stats(self, inclusive: bool = False) -> FunctionCounts:
|
| 199 |
+
"""Returns detailed function counts.
|
| 200 |
+
|
| 201 |
+
Conceptually, the FunctionCounts returned can be thought of as a tuple
|
| 202 |
+
of (count, path_and_function_name) tuples.
|
| 203 |
+
|
| 204 |
+
`inclusive` matches the semantics of callgrind. If True, the counts
|
| 205 |
+
include instructions executed by children. `inclusive=True` is useful
|
| 206 |
+
for identifying hot spots in code; `inclusive=False` is useful for
|
| 207 |
+
reducing noise when diffing counts from two different runs. (See
|
| 208 |
+
CallgrindStats.delta(...) for more details)
|
| 209 |
+
"""
|
| 210 |
+
return self.stmt_inclusive_stats if inclusive else self.stmt_exclusive_stats
|
| 211 |
+
|
| 212 |
+
def counts(self, *, denoise: bool = False) -> int:
|
| 213 |
+
"""Returns the total number of instructions executed.
|
| 214 |
+
|
| 215 |
+
See `FunctionCounts.denoise()` for an explanation of the `denoise` arg.
|
| 216 |
+
"""
|
| 217 |
+
stats = self.stmt_exclusive_stats
|
| 218 |
+
return (stats.denoise() if denoise else stats).sum()
|
| 219 |
+
|
| 220 |
+
# FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563
|
| 221 |
+
def delta(
|
| 222 |
+
self,
|
| 223 |
+
other: "CallgrindStats",
|
| 224 |
+
inclusive: bool = False,
|
| 225 |
+
) -> FunctionCounts:
|
| 226 |
+
"""Diff two sets of counts.
|
| 227 |
+
|
| 228 |
+
One common reason to collect instruction counts is to determine the
|
| 229 |
+
the effect that a particular change will have on the number of instructions
|
| 230 |
+
needed to perform some unit of work. If a change increases that number, the
|
| 231 |
+
next logical question is "why". This generally involves looking at what part
|
| 232 |
+
if the code increased in instruction count. This function automates that
|
| 233 |
+
process so that one can easily diff counts on both an inclusive and
|
| 234 |
+
exclusive basis.
|
| 235 |
+
"""
|
| 236 |
+
return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive)
|
| 237 |
+
|
| 238 |
+
def as_standardized(self) -> "CallgrindStats":
|
| 239 |
+
"""Strip library names and some prefixes from function strings.
|
| 240 |
+
|
| 241 |
+
When comparing two different sets of instruction counts, on stumbling
|
| 242 |
+
block can be path prefixes. Callgrind includes the full filepath
|
| 243 |
+
when reporting a function (as it should). However, this can cause
|
| 244 |
+
issues when diffing profiles. If a key component such as Python
|
| 245 |
+
or PyTorch was built in separate locations in the two profiles, which
|
| 246 |
+
can result in something resembling::
|
| 247 |
+
|
| 248 |
+
23234231 /tmp/first_build_dir/thing.c:foo(...)
|
| 249 |
+
9823794 /tmp/first_build_dir/thing.c:bar(...)
|
| 250 |
+
...
|
| 251 |
+
53453 .../aten/src/Aten/...:function_that_actually_changed(...)
|
| 252 |
+
...
|
| 253 |
+
-9823794 /tmp/second_build_dir/thing.c:bar(...)
|
| 254 |
+
-23234231 /tmp/second_build_dir/thing.c:foo(...)
|
| 255 |
+
|
| 256 |
+
Stripping prefixes can ameliorate this issue by regularizing the
|
| 257 |
+
strings and causing better cancellation of equivalent call sites
|
| 258 |
+
when diffing.
|
| 259 |
+
"""
|
| 260 |
+
def strip(stats: FunctionCounts) -> FunctionCounts:
|
| 261 |
+
transforms = (
|
| 262 |
+
# PyTorch may have been built in different locations.
|
| 263 |
+
(r"^.+build/\.\./", "build/../"),
|
| 264 |
+
(r"^.+/" + re.escape("build/aten/"), "build/aten/"),
|
| 265 |
+
|
| 266 |
+
# "Python" and "Objects" come from CPython.
|
| 267 |
+
(r"^.+/" + re.escape("Python/"), "Python/"),
|
| 268 |
+
(r"^.+/" + re.escape("Objects/"), "Objects/"),
|
| 269 |
+
|
| 270 |
+
# Strip library name. e.g. `libtorch.so`
|
| 271 |
+
(r"\s\[.+\]$", ""),
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
for before, after in transforms:
|
| 275 |
+
stats = stats.transform(lambda fn: re.sub(before, after, fn))
|
| 276 |
+
|
| 277 |
+
return stats
|
| 278 |
+
|
| 279 |
+
return CallgrindStats(
|
| 280 |
+
task_spec=self.task_spec,
|
| 281 |
+
number_per_run=self.number_per_run,
|
| 282 |
+
built_with_debug_symbols=self.built_with_debug_symbols,
|
| 283 |
+
baseline_inclusive_stats=strip(self.baseline_inclusive_stats),
|
| 284 |
+
baseline_exclusive_stats=strip(self.baseline_exclusive_stats),
|
| 285 |
+
stmt_inclusive_stats=strip(self.stmt_inclusive_stats),
|
| 286 |
+
stmt_exclusive_stats=strip(self.stmt_exclusive_stats),
|
| 287 |
+
|
| 288 |
+
# `as_standardized` will change symbol names, so the contents will
|
| 289 |
+
# no longer map directly to `callgrind.out`
|
| 290 |
+
stmt_callgrind_out=None,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class Serialization(enum.Enum):
|
| 295 |
+
PICKLE = 0
|
| 296 |
+
TORCH = 1
|
| 297 |
+
TORCH_JIT = 2
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
_GLOBALS_ALLOWED_TYPES: Dict[Serialization, Tuple[Any, ...]] = {
|
| 301 |
+
Serialization.PICKLE: (str, bytes, bool, int, float, complex),
|
| 302 |
+
Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule),
|
| 303 |
+
Serialization.TORCH: (torch.nn.Module,),
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class CopyIfCallgrind:
|
| 308 |
+
"""Signal that a global may be replaced with a deserialized copy.
|
| 309 |
+
|
| 310 |
+
See `GlobalsBridge` for why this matters.
|
| 311 |
+
"""
|
| 312 |
+
def __init__(self, value: Any, *, setup: Optional[str] = None):
|
| 313 |
+
for method, supported_types in _GLOBALS_ALLOWED_TYPES.items():
|
| 314 |
+
if any(isinstance(value, t) for t in supported_types):
|
| 315 |
+
self._value: Any = value
|
| 316 |
+
self._setup: Optional[str] = setup
|
| 317 |
+
self._serialization: Serialization = method
|
| 318 |
+
break
|
| 319 |
+
else:
|
| 320 |
+
supported_str = "\n".join([
|
| 321 |
+
getattr(t, "__name__", repr(t))
|
| 322 |
+
for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())])
|
| 323 |
+
|
| 324 |
+
raise ValueError(
|
| 325 |
+
f"Unsupported type: {type(value)}\n"
|
| 326 |
+
f"`collect_callgrind` restricts globals to the following types:\n"
|
| 327 |
+
f"{textwrap.indent(supported_str, ' ')}"
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
@property
|
| 331 |
+
def value(self) -> Any:
|
| 332 |
+
return self._value
|
| 333 |
+
|
| 334 |
+
@property
|
| 335 |
+
def setup(self) -> Optional[str]:
|
| 336 |
+
return self._setup
|
| 337 |
+
|
| 338 |
+
@property
|
| 339 |
+
def serialization(self) -> Serialization:
|
| 340 |
+
return self._serialization
|
| 341 |
+
|
| 342 |
+
@staticmethod
|
| 343 |
+
def unwrap_all(globals: Dict[str, Any]) -> Dict[str, Any]:
|
| 344 |
+
return {
|
| 345 |
+
k: (v.value if isinstance(v, CopyIfCallgrind) else v)
|
| 346 |
+
for k, v in globals.items()
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class GlobalsBridge:
|
| 351 |
+
"""Handle the transfer of (certain) globals when collecting Callgrind statistics.
|
| 352 |
+
|
| 353 |
+
Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to
|
| 354 |
+
work with `Timer.collect_callgrind`.
|
| 355 |
+
|
| 356 |
+
Consider the following code snippet:
|
| 357 |
+
```
|
| 358 |
+
import pickle
|
| 359 |
+
import timeit
|
| 360 |
+
|
| 361 |
+
class Counter:
|
| 362 |
+
value = 0
|
| 363 |
+
|
| 364 |
+
def __call__(self):
|
| 365 |
+
self.value += 1
|
| 366 |
+
|
| 367 |
+
counter = Counter()
|
| 368 |
+
timeit.Timer("counter()", globals={"counter": counter}).timeit(10)
|
| 369 |
+
print(counter.value) # 10
|
| 370 |
+
|
| 371 |
+
timeit.Timer(
|
| 372 |
+
"counter()",
|
| 373 |
+
globals={"counter": pickle.loads(pickle.dumps(counter))}
|
| 374 |
+
).timeit(20)
|
| 375 |
+
print(counter.value) # Still 10
|
| 376 |
+
```
|
| 377 |
+
|
| 378 |
+
In the first case, `stmt` is executed using the objects in `globals`;
|
| 379 |
+
however, the addition of serialization and deserialization changes the
|
| 380 |
+
semantics and may meaningfully change behavior.
|
| 381 |
+
|
| 382 |
+
This is a practical consideration when collecting Callgrind statistics.
|
| 383 |
+
Unlike `exec` based execution (which `timeit` uses under the hood) which
|
| 384 |
+
can share in-memory data structures with the caller, Callgrind collection
|
| 385 |
+
requires an entirely new process in order to run under Valgrind. This means
|
| 386 |
+
that any data structures used for statement execution will have to be
|
| 387 |
+
serialized and deserialized in the subprocess.
|
| 388 |
+
|
| 389 |
+
In order to avoid surprising semantics from (user invisible) process
|
| 390 |
+
boundaries, what can be passed through `globals` is severely restricted
|
| 391 |
+
for `Timer.collect_callgrind`. It is expected that most setup should be
|
| 392 |
+
achievable (albeit perhaps less ergonomically) by passing a `setup`
|
| 393 |
+
string.
|
| 394 |
+
|
| 395 |
+
There are, however, exceptions. One such class are TorchScripted functions.
|
| 396 |
+
Because they require a concrete file with source code it is not possible
|
| 397 |
+
to define them using a `setup` string. Another group are torch.nn.Modules,
|
| 398 |
+
whose construction can be complex and prohibitively cumbersome to coerce
|
| 399 |
+
into a `setup` string. Finally, most builtin types are sufficiently well
|
| 400 |
+
behaved and sufficiently common to warrant allowing as well. (e.g.
|
| 401 |
+
`globals={"n": 1}` is very convenient.)
|
| 402 |
+
|
| 403 |
+
Fortunately, all have well defined serialization semantics. This class
|
| 404 |
+
is responsible for enabling the Valgrind subprocess to use elements in
|
| 405 |
+
`globals` so long as they are an allowed type.
|
| 406 |
+
|
| 407 |
+
Caveats:
|
| 408 |
+
The user is required to acknowledge this serialization by wrapping
|
| 409 |
+
elements in `globals` with `CopyIfCallgrind`.
|
| 410 |
+
|
| 411 |
+
While ScriptFunction and ScriptModule are expected to save and load
|
| 412 |
+
quite robustly, it is up to the user to ensure that an nn.Module can
|
| 413 |
+
un-pickle successfully.
|
| 414 |
+
|
| 415 |
+
`torch.Tensor` and `np.ndarray` are deliberately excluded. The
|
| 416 |
+
serialization/deserialization process perturbs the representation of a
|
| 417 |
+
tensor in ways that could result in incorrect measurements. For example,
|
| 418 |
+
if a tensor lives in pinned CPU memory, this fact would not be preserved
|
| 419 |
+
by a dump, and that will in turn change the performance of certain CUDA
|
| 420 |
+
operations.
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
def __init__(self, globals: Dict[str, Any], data_dir: str) -> None:
|
| 424 |
+
self._globals: Dict[str, CopyIfCallgrind] = {}
|
| 425 |
+
self._data_dir = data_dir
|
| 426 |
+
if not os.path.exists(data_dir):
|
| 427 |
+
os.mkdir(data_dir)
|
| 428 |
+
|
| 429 |
+
if globals.get("torch", torch) is not torch:
|
| 430 |
+
raise ValueError("`collect_callgrind` does not support mocking out `torch`.")
|
| 431 |
+
|
| 432 |
+
for name, value in globals.items():
|
| 433 |
+
if name in ("torch", "__builtins__"):
|
| 434 |
+
# Torch will be imported by the collection script, and
|
| 435 |
+
# __builtins__ is added by Timer.
|
| 436 |
+
continue
|
| 437 |
+
|
| 438 |
+
if not isinstance(value, CopyIfCallgrind):
|
| 439 |
+
raise ValueError(
|
| 440 |
+
"`collect_callgrind` requires that globals be wrapped in "
|
| 441 |
+
"`CopyIfCallgrind` so that serialization is explicit."
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
self._globals[name] = value
|
| 445 |
+
|
| 446 |
+
def construct(self) -> str:
|
| 447 |
+
load_lines = []
|
| 448 |
+
for name, wrapped_value in self._globals.items():
|
| 449 |
+
if wrapped_value.setup is not None:
|
| 450 |
+
load_lines.append(textwrap.dedent(wrapped_value.setup))
|
| 451 |
+
|
| 452 |
+
if wrapped_value.serialization == Serialization.PICKLE:
|
| 453 |
+
path = os.path.join(self._data_dir, f"{name}.pkl")
|
| 454 |
+
load_lines.append(
|
| 455 |
+
f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)")
|
| 456 |
+
with open(path, "wb") as f:
|
| 457 |
+
pickle.dump(wrapped_value.value, f)
|
| 458 |
+
|
| 459 |
+
elif wrapped_value.serialization == Serialization.TORCH:
|
| 460 |
+
path = os.path.join(self._data_dir, f"{name}.pt")
|
| 461 |
+
load_lines.append(f"{name} = torch.load({repr(path)})")
|
| 462 |
+
torch.save(wrapped_value.value, path)
|
| 463 |
+
|
| 464 |
+
elif wrapped_value.serialization == Serialization.TORCH_JIT:
|
| 465 |
+
path = os.path.join(self._data_dir, f"{name}.pt")
|
| 466 |
+
load_lines.append(f"{name} = torch.jit.load({repr(path)})")
|
| 467 |
+
with open(path, "wb") as f:
|
| 468 |
+
torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call]
|
| 469 |
+
|
| 470 |
+
else:
|
| 471 |
+
raise NotImplementedError(
|
| 472 |
+
f"Unknown serialization method: {wrapped_value.serialization}")
|
| 473 |
+
|
| 474 |
+
return "\n".join(load_lines)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class _ValgrindWrapper:
|
| 478 |
+
def __init__(self) -> None:
|
| 479 |
+
self._bindings_module: Optional[CallgrindModuleType] = None
|
| 480 |
+
valgrind_symbols = (
|
| 481 |
+
"_valgrind_supported_platform",
|
| 482 |
+
"_valgrind_toggle",
|
| 483 |
+
"_valgrind_toggle_and_dump_stats",
|
| 484 |
+
)
|
| 485 |
+
if all(hasattr(torch._C, symbol) for symbol in valgrind_symbols):
|
| 486 |
+
self._supported_platform: bool = torch._C._valgrind_supported_platform()
|
| 487 |
+
|
| 488 |
+
else:
|
| 489 |
+
print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.")
|
| 490 |
+
self._bindings_module = cpp_jit.get_compat_bindings()
|
| 491 |
+
assert all(hasattr(self._bindings_module, symbol) for symbol in valgrind_symbols)
|
| 492 |
+
self._supported_platform = self._bindings_module._valgrind_supported_platform()
|
| 493 |
+
|
| 494 |
+
self._commands_available: Dict[str, bool] = {}
|
| 495 |
+
if self._supported_platform:
|
| 496 |
+
# Only bother checking on supported platforms.
|
| 497 |
+
for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"):
|
| 498 |
+
self._commands_available[cmd] = not subprocess.run(
|
| 499 |
+
["which", cmd],
|
| 500 |
+
capture_output=True,
|
| 501 |
+
check=False,
|
| 502 |
+
).returncode
|
| 503 |
+
|
| 504 |
+
self._build_type: Optional[str] = None
|
| 505 |
+
build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show()) # type: ignore[no-untyped-call]
|
| 506 |
+
if build_search is not None:
|
| 507 |
+
self._build_type = build_search.groups()[0].split(",")[0]
|
| 508 |
+
|
| 509 |
+
def _validate(self) -> None:
|
| 510 |
+
if not self._supported_platform:
|
| 511 |
+
raise OSError("Valgrind is not supported on this platform.")
|
| 512 |
+
|
| 513 |
+
missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available]
|
| 514 |
+
if missing_cmds:
|
| 515 |
+
raise OSError("Missing: " + ", ".join(missing_cmds))
|
| 516 |
+
|
| 517 |
+
def collect_callgrind(
|
| 518 |
+
self,
|
| 519 |
+
task_spec: common.TaskSpec,
|
| 520 |
+
globals: Dict[str, Any],
|
| 521 |
+
*,
|
| 522 |
+
number: int,
|
| 523 |
+
repeats: int,
|
| 524 |
+
collect_baseline: bool,
|
| 525 |
+
is_python: bool,
|
| 526 |
+
retain_out_file: bool,
|
| 527 |
+
) -> Tuple[CallgrindStats, ...]:
|
| 528 |
+
"""Collect stats, and attach a reference run which can be used to filter interpreter overhead."""
|
| 529 |
+
self._validate()
|
| 530 |
+
assert is_python or not collect_baseline
|
| 531 |
+
|
| 532 |
+
*task_stats, baseline_stats = self._invoke(
|
| 533 |
+
task_spec=task_spec,
|
| 534 |
+
globals=globals,
|
| 535 |
+
number=number,
|
| 536 |
+
repeats=repeats,
|
| 537 |
+
collect_baseline=collect_baseline,
|
| 538 |
+
is_python=is_python,
|
| 539 |
+
retain_out_file=retain_out_file,
|
| 540 |
+
)
|
| 541 |
+
assert len(task_stats) == repeats
|
| 542 |
+
|
| 543 |
+
return tuple(
|
| 544 |
+
CallgrindStats(
|
| 545 |
+
task_spec=task_spec,
|
| 546 |
+
number_per_run=number,
|
| 547 |
+
built_with_debug_symbols=self._build_type == "RelWithDebInfo",
|
| 548 |
+
baseline_inclusive_stats=baseline_stats[0],
|
| 549 |
+
baseline_exclusive_stats=baseline_stats[1],
|
| 550 |
+
stmt_inclusive_stats=stmt_inclusive_stats,
|
| 551 |
+
stmt_exclusive_stats=stmt_exclusive_stats,
|
| 552 |
+
stmt_callgrind_out=out_contents,
|
| 553 |
+
)
|
| 554 |
+
for stmt_inclusive_stats, stmt_exclusive_stats, out_contents in task_stats
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
def _invoke(
|
| 558 |
+
self,
|
| 559 |
+
*,
|
| 560 |
+
task_spec: common.TaskSpec,
|
| 561 |
+
globals: Dict[str, Any],
|
| 562 |
+
number: int,
|
| 563 |
+
repeats: int,
|
| 564 |
+
collect_baseline: bool,
|
| 565 |
+
is_python: bool,
|
| 566 |
+
retain_out_file: bool,
|
| 567 |
+
) -> Tuple[Tuple[FunctionCounts, FunctionCounts, Optional[str]], ...]:
|
| 568 |
+
"""Core invocation method for Callgrind collection.
|
| 569 |
+
|
| 570 |
+
Valgrind operates by effectively replacing the CPU with an emulated
|
| 571 |
+
version which allows it to instrument any code at the cost of severe
|
| 572 |
+
performance degradation. This has the practical effect that in order
|
| 573 |
+
to collect Callgrind statistics, a new process has to be created
|
| 574 |
+
running under `valgrind`. The steps for this process are:
|
| 575 |
+
|
| 576 |
+
1) Create a scratch directory.
|
| 577 |
+
2) Codegen a run script. (_ValgrindWrapper._construct_script)
|
| 578 |
+
Inside the run script:
|
| 579 |
+
* Validate that Python and torch match the parent process
|
| 580 |
+
* Validate that it is indeed running under valgrind
|
| 581 |
+
* Execute `setup` and warm up `stmt`
|
| 582 |
+
* Begin collecting stats
|
| 583 |
+
* Run the `stmt` loop
|
| 584 |
+
* Stop collecting stats
|
| 585 |
+
3) Parse the run results.
|
| 586 |
+
4) Cleanup the scratch directory.
|
| 587 |
+
"""
|
| 588 |
+
working_dir = common._make_temp_dir(prefix="callgrind")
|
| 589 |
+
data_dir = os.path.join(working_dir, "data")
|
| 590 |
+
script_file = os.path.join(working_dir, "timer_callgrind.py")
|
| 591 |
+
callgrind_out = os.path.join(working_dir, "callgrind.out")
|
| 592 |
+
error_log = os.path.join(working_dir, "error.txt")
|
| 593 |
+
stat_log = os.path.join(working_dir, "callgrind_stat.txt")
|
| 594 |
+
stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log")
|
| 595 |
+
|
| 596 |
+
def run(args: List[str], **kwargs: Any) -> Tuple[CompletedProcessType, str]:
|
| 597 |
+
# https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/
|
| 598 |
+
f_stdout_stderr = open(stdout_stderr_log, "wb")
|
| 599 |
+
try:
|
| 600 |
+
invocation = subprocess.run(
|
| 601 |
+
args,
|
| 602 |
+
stdout=f_stdout_stderr,
|
| 603 |
+
stderr=subprocess.STDOUT,
|
| 604 |
+
**kwargs,
|
| 605 |
+
)
|
| 606 |
+
with open(stdout_stderr_log) as f:
|
| 607 |
+
return invocation, f.read()
|
| 608 |
+
finally:
|
| 609 |
+
f_stdout_stderr.close()
|
| 610 |
+
|
| 611 |
+
try:
|
| 612 |
+
if is_python:
|
| 613 |
+
if self._bindings_module is not None:
|
| 614 |
+
shutil.copy(
|
| 615 |
+
self._bindings_module.__file__,
|
| 616 |
+
os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1])
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
script_file = os.path.join(working_dir, "timer_callgrind.py")
|
| 620 |
+
with open(script_file, "w") as f:
|
| 621 |
+
f.write(self._construct_script(
|
| 622 |
+
task_spec,
|
| 623 |
+
globals=GlobalsBridge(globals, data_dir),
|
| 624 |
+
number=number,
|
| 625 |
+
repeats=repeats,
|
| 626 |
+
collect_baseline=collect_baseline,
|
| 627 |
+
error_log=error_log,
|
| 628 |
+
stat_log=stat_log,
|
| 629 |
+
bindings=self._bindings_module))
|
| 630 |
+
|
| 631 |
+
run_loop_cmd = ["python", script_file]
|
| 632 |
+
else:
|
| 633 |
+
assert not collect_baseline
|
| 634 |
+
run_loop_exec = cpp_jit.compile_callgrind_template(
|
| 635 |
+
stmt=task_spec.stmt,
|
| 636 |
+
setup=task_spec.setup,
|
| 637 |
+
global_setup=task_spec.global_setup,
|
| 638 |
+
)
|
| 639 |
+
run_loop_cmd = [
|
| 640 |
+
run_loop_exec,
|
| 641 |
+
"--number", str(number),
|
| 642 |
+
"--number-warmup", str(min(number, 10)),
|
| 643 |
+
"--repeats", str(repeats),
|
| 644 |
+
"--number-threads", str(task_spec.num_threads),
|
| 645 |
+
]
|
| 646 |
+
|
| 647 |
+
valgrind_invocation, valgrind_invocation_output = run([
|
| 648 |
+
"valgrind",
|
| 649 |
+
"--tool=callgrind",
|
| 650 |
+
f"--callgrind-out-file={callgrind_out}",
|
| 651 |
+
"--dump-line=yes",
|
| 652 |
+
"--dump-instr=yes",
|
| 653 |
+
"--instr-atstart=yes",
|
| 654 |
+
"--collect-atstart=no",
|
| 655 |
+
] + run_loop_cmd)
|
| 656 |
+
|
| 657 |
+
if valgrind_invocation.returncode:
|
| 658 |
+
error_report = ""
|
| 659 |
+
if os.path.exists(error_log):
|
| 660 |
+
with open(error_log) as f:
|
| 661 |
+
error_report = f.read()
|
| 662 |
+
if not error_report:
|
| 663 |
+
error_report = "Unknown error.\n" + valgrind_invocation_output
|
| 664 |
+
|
| 665 |
+
raise OSError(f"Failed to collect callgrind profile:\n{error_report}")
|
| 666 |
+
|
| 667 |
+
def parse_output(fpath: str, inclusive: bool) -> FunctionCounts:
|
| 668 |
+
annotate_invocation, annotate_invocation_output = run([
|
| 669 |
+
"callgrind_annotate",
|
| 670 |
+
f"--inclusive={'yes' if inclusive else 'no'}",
|
| 671 |
+
"--threshold=100",
|
| 672 |
+
"--show-percs=no",
|
| 673 |
+
fpath
|
| 674 |
+
], check=True)
|
| 675 |
+
|
| 676 |
+
total_pattern = re.compile(r"^([0-9,]+)\s+PROGRAM TOTALS")
|
| 677 |
+
begin_pattern = re.compile(r"Ir\s+file:function")
|
| 678 |
+
function_pattern = re.compile(r"^\s*([0-9,]+)\s+(.+:.+)$")
|
| 679 |
+
|
| 680 |
+
class ScanState(enum.Enum):
|
| 681 |
+
SCANNING_FOR_TOTAL = 0
|
| 682 |
+
SCANNING_FOR_START = 1
|
| 683 |
+
PARSING = 2
|
| 684 |
+
|
| 685 |
+
scan_state = ScanState.SCANNING_FOR_TOTAL
|
| 686 |
+
fn_counts = []
|
| 687 |
+
for l in annotate_invocation_output.splitlines(keepends=False):
|
| 688 |
+
if scan_state == ScanState.SCANNING_FOR_TOTAL:
|
| 689 |
+
total_match = total_pattern.match(l)
|
| 690 |
+
if total_match:
|
| 691 |
+
program_totals = int(total_match.groups()[0].replace(",", ""))
|
| 692 |
+
scan_state = ScanState.SCANNING_FOR_START
|
| 693 |
+
|
| 694 |
+
elif scan_state == ScanState.SCANNING_FOR_START:
|
| 695 |
+
if begin_pattern.match(l):
|
| 696 |
+
scan_state = ScanState.PARSING
|
| 697 |
+
|
| 698 |
+
else:
|
| 699 |
+
assert scan_state == ScanState.PARSING
|
| 700 |
+
fn_match = function_pattern.match(l)
|
| 701 |
+
if fn_match:
|
| 702 |
+
ir_str, file_function = fn_match.groups()
|
| 703 |
+
ir = int(ir_str.replace(",", ""))
|
| 704 |
+
if ir == program_totals: # type: ignore[possibly-undefined]
|
| 705 |
+
# Callgrind includes some top level red herring symbols when
|
| 706 |
+
# a program dumps multiple profiles.
|
| 707 |
+
continue
|
| 708 |
+
fn_counts.append(FunctionCount(ir, file_function))
|
| 709 |
+
|
| 710 |
+
elif re.match(r"-+", l):
|
| 711 |
+
# Ignore heading separator lines.
|
| 712 |
+
continue
|
| 713 |
+
|
| 714 |
+
else:
|
| 715 |
+
break
|
| 716 |
+
|
| 717 |
+
assert scan_state == ScanState.PARSING, f"Failed to parse {fpath}"
|
| 718 |
+
return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive)
|
| 719 |
+
|
| 720 |
+
def read_results(i: int) -> Tuple[FunctionCounts, FunctionCounts, Optional[str]]:
|
| 721 |
+
if i == repeats and not collect_baseline:
|
| 722 |
+
# Null baseline.
|
| 723 |
+
return (
|
| 724 |
+
FunctionCounts((), inclusive=True),
|
| 725 |
+
FunctionCounts((), inclusive=False),
|
| 726 |
+
None,
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
fpath = f"{callgrind_out}.{i + 1}" # Callgrind one-indexes files.
|
| 730 |
+
callgrind_out_contents: Optional[str] = None
|
| 731 |
+
if retain_out_file:
|
| 732 |
+
with open(fpath) as f:
|
| 733 |
+
callgrind_out_contents = f.read()
|
| 734 |
+
|
| 735 |
+
return (
|
| 736 |
+
parse_output(fpath, inclusive=True),
|
| 737 |
+
parse_output(fpath, inclusive=False),
|
| 738 |
+
callgrind_out_contents
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
return tuple(read_results(i) for i in range(repeats + 1))
|
| 742 |
+
finally:
|
| 743 |
+
shutil.rmtree(working_dir)
|
| 744 |
+
|
| 745 |
+
@staticmethod
|
| 746 |
+
def _construct_script(
|
| 747 |
+
task_spec: common.TaskSpec,
|
| 748 |
+
globals: GlobalsBridge,
|
| 749 |
+
*,
|
| 750 |
+
number: int,
|
| 751 |
+
repeats: int,
|
| 752 |
+
collect_baseline: bool,
|
| 753 |
+
error_log: str,
|
| 754 |
+
stat_log: str,
|
| 755 |
+
bindings: Optional[CallgrindModuleType],
|
| 756 |
+
) -> str:
|
| 757 |
+
def block_stmt(stmt: str, indent: int = 0) -> str:
|
| 758 |
+
"""Partially unroll benchmark loop.
|
| 759 |
+
|
| 760 |
+
The naive template looks something like:
|
| 761 |
+
"for _ in range({number}): {stmt}"
|
| 762 |
+
|
| 763 |
+
However a loop in Python is surprisingly expensive, and significantly
|
| 764 |
+
increases the number of background Python instructions. So instead we
|
| 765 |
+
partially unroll the loops, with a block size of 100 chosen to keep
|
| 766 |
+
the instruction overhead from `range` low while also not ballooning
|
| 767 |
+
the size of the generated file.
|
| 768 |
+
"""
|
| 769 |
+
block_size = 100
|
| 770 |
+
loop_count = number // block_size
|
| 771 |
+
if loop_count == 1:
|
| 772 |
+
# There is no point in having `for _ in range(1): ...` rather
|
| 773 |
+
# than just `...`, and this lets us save shave a few background
|
| 774 |
+
# instructions.
|
| 775 |
+
loop_count = 0
|
| 776 |
+
remainder = number - block_size * loop_count
|
| 777 |
+
blocked_stmt = ""
|
| 778 |
+
|
| 779 |
+
if loop_count:
|
| 780 |
+
unrolled_stmts = textwrap.indent("\n".join([stmt] * block_size), " " * 4)
|
| 781 |
+
blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n"
|
| 782 |
+
|
| 783 |
+
if remainder:
|
| 784 |
+
blocked_stmt += "\n".join([stmt] * remainder)
|
| 785 |
+
|
| 786 |
+
return textwrap.indent(blocked_stmt, " " * indent)
|
| 787 |
+
|
| 788 |
+
pass_baseline = (
|
| 789 |
+
"callgrind_bindings._valgrind_toggle()\n"
|
| 790 |
+
f"{block_stmt('pass')}\n"
|
| 791 |
+
"callgrind_bindings._valgrind_toggle_and_dump_stats()"
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
return textwrap.dedent(r"""
|
| 795 |
+
import gc
|
| 796 |
+
import os
|
| 797 |
+
import pickle
|
| 798 |
+
import subprocess
|
| 799 |
+
import sys
|
| 800 |
+
import time
|
| 801 |
+
|
| 802 |
+
# Mitigate https://github.com/pytorch/pytorch/issues/37377
|
| 803 |
+
# which can sometimes cause the subprocess call to fail.
|
| 804 |
+
import numpy as np
|
| 805 |
+
|
| 806 |
+
import torch
|
| 807 |
+
torch.set_num_threads({num_threads})
|
| 808 |
+
|
| 809 |
+
{bindings_import}
|
| 810 |
+
|
| 811 |
+
PID = os.getpid()
|
| 812 |
+
|
| 813 |
+
def log_failure(msg):
|
| 814 |
+
with open({error_log_repr}, "wt") as f:
|
| 815 |
+
f.write(msg)
|
| 816 |
+
sys.exit(1)
|
| 817 |
+
|
| 818 |
+
def check_result(completed_process):
|
| 819 |
+
if completed_process.returncode:
|
| 820 |
+
log_failure(f"Command failed: {{' '.join(completed_process.args)}}")
|
| 821 |
+
return completed_process
|
| 822 |
+
|
| 823 |
+
# =============================================================================
|
| 824 |
+
# == Check that subprocess matches parent =====================================
|
| 825 |
+
# =============================================================================
|
| 826 |
+
if os.path.realpath(sys.executable) != "{parent_interpreter}":
|
| 827 |
+
log_failure(
|
| 828 |
+
"Interpreter mismatch:\n"
|
| 829 |
+
f" {{os.path.realpath(sys.executable)}}\n vs.\n {parent_interpreter}"
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
if torch.__file__ != "{torch_file}":
|
| 833 |
+
log_failure(
|
| 834 |
+
"PyTorch does not match expected file:\n"
|
| 835 |
+
f" {{torch.__file__}}\n vs.\n {torch_file}"
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
# =============================================================================
|
| 839 |
+
# == User specified setup =====================================================
|
| 840 |
+
# =============================================================================
|
| 841 |
+
# Load serialized globals
|
| 842 |
+
{load_globals}
|
| 843 |
+
|
| 844 |
+
# User setup str
|
| 845 |
+
{setup}
|
| 846 |
+
|
| 847 |
+
for _ in range({warmup_number}):
|
| 848 |
+
{indented_stmt}
|
| 849 |
+
|
| 850 |
+
# =============================================================================
|
| 851 |
+
# == Callgrind management =====================================================
|
| 852 |
+
# =============================================================================
|
| 853 |
+
with open("{stat_log}", "wb") as stat_file:
|
| 854 |
+
# If many instances of callgrind are running at once, the output of
|
| 855 |
+
# `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE`
|
| 856 |
+
# to deadlock. So instead we use a file.
|
| 857 |
+
callgrind_stat = check_result(subprocess.run(
|
| 858 |
+
["callgrind_control", "--stat"],
|
| 859 |
+
stdout=stat_file,
|
| 860 |
+
stderr=subprocess.STDOUT,
|
| 861 |
+
))
|
| 862 |
+
|
| 863 |
+
with open("{stat_log}", "rt") as stat_file:
|
| 864 |
+
stat_lines = stat_file.read().splitlines()
|
| 865 |
+
|
| 866 |
+
if f"PID {{PID}}: python {{__file__}}" not in stat_lines:
|
| 867 |
+
log_failure("Process does not appear to be running callgrind.")
|
| 868 |
+
|
| 869 |
+
gc.collect()
|
| 870 |
+
time.sleep(0.01)
|
| 871 |
+
|
| 872 |
+
# =============================================================================
|
| 873 |
+
# == User code block ==========================================================
|
| 874 |
+
# =============================================================================
|
| 875 |
+
for _ in range({repeats}):
|
| 876 |
+
callgrind_bindings._valgrind_toggle()
|
| 877 |
+
{blocked_stmt}
|
| 878 |
+
callgrind_bindings._valgrind_toggle_and_dump_stats()
|
| 879 |
+
gc.collect()
|
| 880 |
+
|
| 881 |
+
{baseline}
|
| 882 |
+
""").strip().format(
|
| 883 |
+
indented_stmt=textwrap.indent(task_spec.stmt, " " * 4),
|
| 884 |
+
blocked_stmt=block_stmt(task_spec.stmt, indent=4),
|
| 885 |
+
baseline=(pass_baseline if collect_baseline else ""),
|
| 886 |
+
number=number,
|
| 887 |
+
repeats=repeats,
|
| 888 |
+
load_globals=globals.construct(),
|
| 889 |
+
setup=task_spec.setup,
|
| 890 |
+
warmup_number=min(number, 10),
|
| 891 |
+
num_threads=task_spec.num_threads,
|
| 892 |
+
error_log_repr=repr(error_log),
|
| 893 |
+
stat_log=stat_log,
|
| 894 |
+
parent_interpreter=os.path.realpath(sys.executable),
|
| 895 |
+
torch_file=torch.__file__,
|
| 896 |
+
bindings_import=(
|
| 897 |
+
"import torch._C as callgrind_bindings" if bindings is None
|
| 898 |
+
else f"import {bindings.__name__} as callgrind_bindings"),
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
CALLGRIND_SINGLETON: Optional[_ValgrindWrapper] = None
|
| 903 |
+
def wrapper_singleton() -> _ValgrindWrapper:
|
| 904 |
+
global CALLGRIND_SINGLETON
|
| 905 |
+
if CALLGRIND_SINGLETON is None:
|
| 906 |
+
CALLGRIND_SINGLETON = _ValgrindWrapper()
|
| 907 |
+
return CALLGRIND_SINGLETON
|
.venv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/Lib/site-packages/torch/utils/bottleneck/__init__.py
ADDED
|
File without changes
|
.venv/Lib/site-packages/torch/utils/bottleneck/__main__.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import argparse
|
| 3 |
+
import cProfile
|
| 4 |
+
import pstats
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
from typing import Dict
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.autograd import profiler
|
| 11 |
+
from torch.utils.collect_env import get_env_info
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def redirect_argv(new_argv):
|
| 15 |
+
sys.argv[:] = new_argv[:]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def compiled_with_cuda(sysinfo):
|
| 19 |
+
if sysinfo.cuda_compiled_version:
|
| 20 |
+
return f'compiled w/ CUDA {sysinfo.cuda_compiled_version}'
|
| 21 |
+
return 'not compiled w/ CUDA'
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
env_summary = """
|
| 25 |
+
--------------------------------------------------------------------------------
|
| 26 |
+
Environment Summary
|
| 27 |
+
--------------------------------------------------------------------------------
|
| 28 |
+
PyTorch {pytorch_version}{debug_str} {cuda_compiled}
|
| 29 |
+
Running with Python {py_version} and {cuda_runtime}
|
| 30 |
+
|
| 31 |
+
`{pip_version} list` truncated output:
|
| 32 |
+
{pip_list_output}
|
| 33 |
+
""".strip()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def run_env_analysis():
|
| 37 |
+
print('Running environment analysis...')
|
| 38 |
+
info = get_env_info()
|
| 39 |
+
|
| 40 |
+
result: Dict[str, str] = {}
|
| 41 |
+
|
| 42 |
+
debug_str = ''
|
| 43 |
+
if info.is_debug_build:
|
| 44 |
+
debug_str = ' DEBUG'
|
| 45 |
+
|
| 46 |
+
cuda_avail = ''
|
| 47 |
+
if info.is_cuda_available:
|
| 48 |
+
cuda = info.cuda_runtime_version
|
| 49 |
+
if cuda is not None:
|
| 50 |
+
cuda_avail = 'CUDA ' + cuda
|
| 51 |
+
else:
|
| 52 |
+
cuda = 'CUDA unavailable'
|
| 53 |
+
|
| 54 |
+
pip_version = info.pip_version
|
| 55 |
+
pip_list_output = info.pip_packages
|
| 56 |
+
if pip_list_output is None:
|
| 57 |
+
pip_list_output = 'Unable to fetch'
|
| 58 |
+
|
| 59 |
+
result = {
|
| 60 |
+
'debug_str': debug_str,
|
| 61 |
+
'pytorch_version': info.torch_version,
|
| 62 |
+
'cuda_compiled': compiled_with_cuda(info),
|
| 63 |
+
'py_version': f'{sys.version_info[0]}.{sys.version_info[1]}',
|
| 64 |
+
'cuda_runtime': cuda_avail,
|
| 65 |
+
'pip_version': pip_version,
|
| 66 |
+
'pip_list_output': pip_list_output,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
return env_summary.format(**result)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def run_cprofile(code, globs, launch_blocking=False):
|
| 73 |
+
print('Running your script with cProfile')
|
| 74 |
+
prof = cProfile.Profile()
|
| 75 |
+
prof.enable()
|
| 76 |
+
exec(code, globs, None)
|
| 77 |
+
prof.disable()
|
| 78 |
+
return prof
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
cprof_summary = """
|
| 82 |
+
--------------------------------------------------------------------------------
|
| 83 |
+
cProfile output
|
| 84 |
+
--------------------------------------------------------------------------------
|
| 85 |
+
""".strip()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def print_cprofile_summary(prof, sortby='tottime', topk=15):
|
| 89 |
+
print(cprof_summary)
|
| 90 |
+
cprofile_stats = pstats.Stats(prof).sort_stats(sortby)
|
| 91 |
+
cprofile_stats.print_stats(topk)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def run_autograd_prof(code, globs):
|
| 95 |
+
def run_prof(use_cuda=False):
|
| 96 |
+
with profiler.profile(use_cuda=use_cuda) as prof:
|
| 97 |
+
exec(code, globs, None)
|
| 98 |
+
return prof
|
| 99 |
+
|
| 100 |
+
print('Running your script with the autograd profiler...')
|
| 101 |
+
result = [run_prof(use_cuda=False)]
|
| 102 |
+
if torch.cuda.is_available():
|
| 103 |
+
result.append(run_prof(use_cuda=True))
|
| 104 |
+
else:
|
| 105 |
+
result.append(None)
|
| 106 |
+
|
| 107 |
+
return result
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
autograd_prof_summary = """
|
| 111 |
+
--------------------------------------------------------------------------------
|
| 112 |
+
autograd profiler output ({mode} mode)
|
| 113 |
+
--------------------------------------------------------------------------------
|
| 114 |
+
{description}
|
| 115 |
+
{cuda_warning}
|
| 116 |
+
{output}
|
| 117 |
+
""".strip()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def print_autograd_prof_summary(prof, mode, sortby='cpu_time', topk=15):
|
| 121 |
+
valid_sortby = ['cpu_time', 'cuda_time', 'cpu_time_total', 'cuda_time_total', 'count']
|
| 122 |
+
if sortby not in valid_sortby:
|
| 123 |
+
warn = ('WARNING: invalid sorting option for autograd profiler results: {}\n'
|
| 124 |
+
'Expected `cpu_time`, `cpu_time_total`, or `count`. '
|
| 125 |
+
'Defaulting to `cpu_time`.')
|
| 126 |
+
print(warn.format(sortby))
|
| 127 |
+
sortby = 'cpu_time'
|
| 128 |
+
|
| 129 |
+
if mode == 'CUDA':
|
| 130 |
+
cuda_warning = ('\n\tBecause the autograd profiler uses the CUDA event API,\n'
|
| 131 |
+
'\tthe CUDA time column reports approximately max(cuda_time, cpu_time).\n'
|
| 132 |
+
'\tPlease ignore this output if your code does not use CUDA.\n')
|
| 133 |
+
else:
|
| 134 |
+
cuda_warning = ''
|
| 135 |
+
|
| 136 |
+
sorted_events = sorted(prof.function_events,
|
| 137 |
+
key=lambda x: getattr(x, sortby), reverse=True)
|
| 138 |
+
topk_events = sorted_events[:topk]
|
| 139 |
+
|
| 140 |
+
result = {
|
| 141 |
+
'mode': mode,
|
| 142 |
+
'description': f'top {topk} events sorted by {sortby}',
|
| 143 |
+
'output': torch.autograd.profiler_util._build_table(topk_events),
|
| 144 |
+
'cuda_warning': cuda_warning
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
print(autograd_prof_summary.format(**result))
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
descript = """
|
| 151 |
+
`bottleneck` is a tool that can be used as an initial step for debugging
|
| 152 |
+
bottlenecks in your program.
|
| 153 |
+
|
| 154 |
+
It summarizes runs of your script with the Python profiler and PyTorch\'s
|
| 155 |
+
autograd profiler. Because your script will be profiled, please ensure that it
|
| 156 |
+
exits in a finite amount of time.
|
| 157 |
+
|
| 158 |
+
For more complicated uses of the profilers, please see
|
| 159 |
+
https://docs.python.org/3/library/profile.html and
|
| 160 |
+
https://pytorch.org/docs/main/autograd.html#profiler for more information.
|
| 161 |
+
""".strip()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def parse_args():
|
| 165 |
+
parser = argparse.ArgumentParser(description=descript)
|
| 166 |
+
parser.add_argument('scriptfile', type=str,
|
| 167 |
+
help='Path to the script to be run. '
|
| 168 |
+
'Usually run with `python path/to/script`.')
|
| 169 |
+
parser.add_argument('args', type=str, nargs=argparse.REMAINDER,
|
| 170 |
+
help='Command-line arguments to be passed to the script.')
|
| 171 |
+
return parser.parse_args()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def cpu_time_total(autograd_prof):
|
| 175 |
+
return sum(event.cpu_time_total for event in autograd_prof.function_events)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def main():
|
| 179 |
+
args = parse_args()
|
| 180 |
+
|
| 181 |
+
# Customizable constants.
|
| 182 |
+
scriptfile = args.scriptfile
|
| 183 |
+
scriptargs = [] if args.args is None else args.args
|
| 184 |
+
scriptargs.insert(0, scriptfile)
|
| 185 |
+
cprofile_sortby = 'tottime'
|
| 186 |
+
cprofile_topk = 15
|
| 187 |
+
autograd_prof_sortby = 'cpu_time_total'
|
| 188 |
+
autograd_prof_topk = 15
|
| 189 |
+
|
| 190 |
+
redirect_argv(scriptargs)
|
| 191 |
+
|
| 192 |
+
sys.path.insert(0, os.path.dirname(scriptfile))
|
| 193 |
+
with open(scriptfile, 'rb') as stream:
|
| 194 |
+
code = compile(stream.read(), scriptfile, 'exec')
|
| 195 |
+
globs = {
|
| 196 |
+
'__file__': scriptfile,
|
| 197 |
+
'__name__': '__main__',
|
| 198 |
+
'__package__': None,
|
| 199 |
+
'__cached__': None,
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
print(descript)
|
| 203 |
+
|
| 204 |
+
env_summary = run_env_analysis()
|
| 205 |
+
|
| 206 |
+
if torch.cuda.is_available():
|
| 207 |
+
torch.cuda.init()
|
| 208 |
+
cprofile_prof = run_cprofile(code, globs)
|
| 209 |
+
autograd_prof_cpu, autograd_prof_cuda = run_autograd_prof(code, globs)
|
| 210 |
+
|
| 211 |
+
print(env_summary)
|
| 212 |
+
print_cprofile_summary(cprofile_prof, cprofile_sortby, cprofile_topk)
|
| 213 |
+
|
| 214 |
+
if not torch.cuda.is_available():
|
| 215 |
+
print_autograd_prof_summary(autograd_prof_cpu, 'CPU', autograd_prof_sortby, autograd_prof_topk)
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
# Print both the result of the CPU-mode and CUDA-mode autograd profilers
|
| 219 |
+
# if their execution times are very different.
|
| 220 |
+
cuda_prof_exec_time = cpu_time_total(autograd_prof_cuda)
|
| 221 |
+
if len(autograd_prof_cpu.function_events) > 0:
|
| 222 |
+
cpu_prof_exec_time = cpu_time_total(autograd_prof_cpu)
|
| 223 |
+
pct_diff = (cuda_prof_exec_time - cpu_prof_exec_time) / cuda_prof_exec_time
|
| 224 |
+
if abs(pct_diff) > 0.05:
|
| 225 |
+
print_autograd_prof_summary(autograd_prof_cpu, 'CPU', autograd_prof_sortby, autograd_prof_topk)
|
| 226 |
+
|
| 227 |
+
print_autograd_prof_summary(autograd_prof_cuda, 'CUDA', autograd_prof_sortby, autograd_prof_topk)
|
| 228 |
+
|
| 229 |
+
if __name__ == '__main__':
|
| 230 |
+
main()
|
.venv/Lib/site-packages/torch/utils/data/__init__.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data.dataloader import (
|
| 2 |
+
_DatasetKind,
|
| 3 |
+
DataLoader,
|
| 4 |
+
default_collate,
|
| 5 |
+
default_convert,
|
| 6 |
+
get_worker_info,
|
| 7 |
+
)
|
| 8 |
+
from torch.utils.data.datapipes._decorator import (
|
| 9 |
+
argument_validation,
|
| 10 |
+
functional_datapipe,
|
| 11 |
+
guaranteed_datapipes_determinism,
|
| 12 |
+
non_deterministic,
|
| 13 |
+
runtime_validation,
|
| 14 |
+
runtime_validation_disabled,
|
| 15 |
+
)
|
| 16 |
+
from torch.utils.data.datapipes.datapipe import (
|
| 17 |
+
DataChunk,
|
| 18 |
+
DFIterDataPipe,
|
| 19 |
+
IterDataPipe,
|
| 20 |
+
MapDataPipe,
|
| 21 |
+
)
|
| 22 |
+
from torch.utils.data.dataset import (
|
| 23 |
+
ChainDataset,
|
| 24 |
+
ConcatDataset,
|
| 25 |
+
Dataset,
|
| 26 |
+
IterableDataset,
|
| 27 |
+
random_split,
|
| 28 |
+
StackDataset,
|
| 29 |
+
Subset,
|
| 30 |
+
TensorDataset,
|
| 31 |
+
)
|
| 32 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 33 |
+
from torch.utils.data.sampler import (
|
| 34 |
+
BatchSampler,
|
| 35 |
+
RandomSampler,
|
| 36 |
+
Sampler,
|
| 37 |
+
SequentialSampler,
|
| 38 |
+
SubsetRandomSampler,
|
| 39 |
+
WeightedRandomSampler,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
__all__ = [
|
| 44 |
+
"BatchSampler",
|
| 45 |
+
"ChainDataset",
|
| 46 |
+
"ConcatDataset",
|
| 47 |
+
"DFIterDataPipe",
|
| 48 |
+
"DataChunk",
|
| 49 |
+
"DataLoader",
|
| 50 |
+
"Dataset",
|
| 51 |
+
"DistributedSampler",
|
| 52 |
+
"IterDataPipe",
|
| 53 |
+
"IterableDataset",
|
| 54 |
+
"MapDataPipe",
|
| 55 |
+
"RandomSampler",
|
| 56 |
+
"Sampler",
|
| 57 |
+
"SequentialSampler",
|
| 58 |
+
"StackDataset",
|
| 59 |
+
"Subset",
|
| 60 |
+
"SubsetRandomSampler",
|
| 61 |
+
"TensorDataset",
|
| 62 |
+
"WeightedRandomSampler",
|
| 63 |
+
"_DatasetKind",
|
| 64 |
+
"argument_validation",
|
| 65 |
+
"default_collate",
|
| 66 |
+
"default_convert",
|
| 67 |
+
"functional_datapipe",
|
| 68 |
+
"get_worker_info",
|
| 69 |
+
"guaranteed_datapipes_determinism",
|
| 70 |
+
"non_deterministic",
|
| 71 |
+
"random_split",
|
| 72 |
+
"runtime_validation",
|
| 73 |
+
"runtime_validation_disabled",
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
# Please keep this list sorted
|
| 77 |
+
assert __all__ == sorted(__all__)
|
.venv/Lib/site-packages/torch/utils/data/_utils/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py.
|
| 3 |
+
|
| 4 |
+
A lot of multiprocessing is used in data loading, which only supports running
|
| 5 |
+
functions defined in global environment (py2 can't serialize static methods).
|
| 6 |
+
Therefore, for code tidiness we put these functions into different files in this
|
| 7 |
+
folder.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import atexit
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
# old private location of the ExceptionWrapper that some users rely on:
|
| 14 |
+
from torch._utils import ExceptionWrapper
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
IS_WINDOWS = sys.platform == "win32"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
MP_STATUS_CHECK_INTERVAL = 5.0
|
| 21 |
+
r"""Interval (in seconds) to check status of processes to avoid hanging in
|
| 22 |
+
multiprocessing data loading. This is mainly used in getting data from
|
| 23 |
+
another process, in which case we need to periodically check whether the
|
| 24 |
+
sender is alive to prevent hanging."""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
python_exit_status = False
|
| 28 |
+
r"""Whether Python is shutting down. This flag is guaranteed to be set before
|
| 29 |
+
the Python core library resources are freed, but Python may already be exiting
|
| 30 |
+
for some time when this is set.
|
| 31 |
+
|
| 32 |
+
Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
|
| 33 |
+
hook in Python 3.7 multiprocessing library:
|
| 34 |
+
https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
import numpy
|
| 40 |
+
|
| 41 |
+
HAS_NUMPY = True
|
| 42 |
+
except ModuleNotFoundError:
|
| 43 |
+
HAS_NUMPY = False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _set_python_exit_flag():
|
| 47 |
+
global python_exit_status
|
| 48 |
+
python_exit_status = True
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
atexit.register(_set_python_exit_flag)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
from . import collate, fetch, pin_memory, signal_handling, worker
|
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-39.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-39.pyc
ADDED
|
Binary file (2.28 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-39.pyc
ADDED
|
Binary file (3.25 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-39.pyc
ADDED
|
Binary file (2.63 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-39.pyc
ADDED
|
Binary file (7.87 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/_utils/collate.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
| 3 |
+
|
| 4 |
+
These methods are used to collate samples fetched from dataset into Tensor(s).
|
| 5 |
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
| 6 |
+
static methods.
|
| 7 |
+
|
| 8 |
+
`default_collate` and `default_convert` are exposed to users via 'dataloader.py'.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import collections
|
| 12 |
+
import contextlib
|
| 13 |
+
import copy
|
| 14 |
+
import re
|
| 15 |
+
from typing import Callable, Dict, Optional, Tuple, Type, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
np_str_obj_array_pattern = re.compile(r"[SaUO]")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def default_convert(data):
|
| 24 |
+
r"""
|
| 25 |
+
Convert each NumPy array element into a :class:`torch.Tensor`.
|
| 26 |
+
|
| 27 |
+
If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`.
|
| 28 |
+
If the input is not an NumPy array, it is left unchanged.
|
| 29 |
+
This is used as the default function for collation when both `batch_sampler` and `batch_size`
|
| 30 |
+
are NOT defined in :class:`~torch.utils.data.DataLoader`.
|
| 31 |
+
|
| 32 |
+
The general input type to output type mapping is similar to that
|
| 33 |
+
of :func:`~torch.utils.data.default_collate`. See the description there for more details.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
data: a single data point to be converted
|
| 37 |
+
|
| 38 |
+
Examples:
|
| 39 |
+
>>> # xdoctest: +SKIP
|
| 40 |
+
>>> # Example with `int`
|
| 41 |
+
>>> default_convert(0)
|
| 42 |
+
0
|
| 43 |
+
>>> # Example with NumPy array
|
| 44 |
+
>>> default_convert(np.array([0, 1]))
|
| 45 |
+
tensor([0, 1])
|
| 46 |
+
>>> # Example with NamedTuple
|
| 47 |
+
>>> Point = namedtuple('Point', ['x', 'y'])
|
| 48 |
+
>>> default_convert(Point(0, 0))
|
| 49 |
+
Point(x=0, y=0)
|
| 50 |
+
>>> default_convert(Point(np.array(0), np.array(0)))
|
| 51 |
+
Point(x=tensor(0), y=tensor(0))
|
| 52 |
+
>>> # Example with List
|
| 53 |
+
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
|
| 54 |
+
[tensor([0, 1]), tensor([2, 3])]
|
| 55 |
+
"""
|
| 56 |
+
elem_type = type(data)
|
| 57 |
+
if isinstance(data, torch.Tensor):
|
| 58 |
+
return data
|
| 59 |
+
elif (
|
| 60 |
+
elem_type.__module__ == "numpy"
|
| 61 |
+
and elem_type.__name__ != "str_"
|
| 62 |
+
and elem_type.__name__ != "string_"
|
| 63 |
+
):
|
| 64 |
+
# array of string classes and object
|
| 65 |
+
if (
|
| 66 |
+
elem_type.__name__ == "ndarray"
|
| 67 |
+
and np_str_obj_array_pattern.search(data.dtype.str) is not None
|
| 68 |
+
):
|
| 69 |
+
return data
|
| 70 |
+
return torch.as_tensor(data)
|
| 71 |
+
elif isinstance(data, collections.abc.Mapping):
|
| 72 |
+
try:
|
| 73 |
+
if isinstance(data, collections.abc.MutableMapping):
|
| 74 |
+
# The mapping type may have extra properties, so we can't just
|
| 75 |
+
# use `type(data)(...)` to create the new mapping.
|
| 76 |
+
# Create a clone and update it if the mapping type is mutable.
|
| 77 |
+
clone = copy.copy(data)
|
| 78 |
+
clone.update({key: default_convert(data[key]) for key in data})
|
| 79 |
+
return clone
|
| 80 |
+
else:
|
| 81 |
+
return elem_type({key: default_convert(data[key]) for key in data})
|
| 82 |
+
except TypeError:
|
| 83 |
+
# The mapping type may not support `copy()` / `update(mapping)`
|
| 84 |
+
# or `__init__(iterable)`.
|
| 85 |
+
return {key: default_convert(data[key]) for key in data}
|
| 86 |
+
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
|
| 87 |
+
return elem_type(*(default_convert(d) for d in data))
|
| 88 |
+
elif isinstance(data, tuple):
|
| 89 |
+
return [default_convert(d) for d in data] # Backwards compatibility.
|
| 90 |
+
elif isinstance(data, collections.abc.Sequence) and not isinstance(
|
| 91 |
+
data, (str, bytes)
|
| 92 |
+
):
|
| 93 |
+
try:
|
| 94 |
+
if isinstance(data, collections.abc.MutableSequence):
|
| 95 |
+
# The sequence type may have extra properties, so we can't just
|
| 96 |
+
# use `type(data)(...)` to create the new sequence.
|
| 97 |
+
# Create a clone and update it if the sequence type is mutable.
|
| 98 |
+
clone = copy.copy(data) # type: ignore[arg-type]
|
| 99 |
+
for i, d in enumerate(data):
|
| 100 |
+
clone[i] = default_convert(d)
|
| 101 |
+
return clone
|
| 102 |
+
else:
|
| 103 |
+
return elem_type([default_convert(d) for d in data])
|
| 104 |
+
except TypeError:
|
| 105 |
+
# The sequence type may not support `copy()` / `__setitem__(index, item)`
|
| 106 |
+
# or `__init__(iterable)` (e.g., `range`).
|
| 107 |
+
return [default_convert(d) for d in data]
|
| 108 |
+
else:
|
| 109 |
+
return data
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
default_collate_err_msg_format = (
|
| 113 |
+
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
| 114 |
+
"dicts or lists; found {}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def collate(
|
| 119 |
+
batch,
|
| 120 |
+
*,
|
| 121 |
+
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
|
| 122 |
+
):
|
| 123 |
+
r"""
|
| 124 |
+
General collate function that handles collection type of element within each batch.
|
| 125 |
+
|
| 126 |
+
The function also opens function registry to deal with specific element types. `default_collate_fn_map`
|
| 127 |
+
provides default collate functions for tensors, numpy arrays, numbers and strings.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
batch: a single batch to be collated
|
| 131 |
+
collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.
|
| 132 |
+
If the element type isn't present in this dictionary,
|
| 133 |
+
this function will go through each key of the dictionary in the insertion order to
|
| 134 |
+
invoke the corresponding collate function if the element type is a subclass of the key.
|
| 135 |
+
|
| 136 |
+
Examples:
|
| 137 |
+
>>> def collate_tensor_fn(batch, *, collate_fn_map):
|
| 138 |
+
... # Extend this function to handle batch of tensors
|
| 139 |
+
... return torch.stack(batch, 0)
|
| 140 |
+
>>> def custom_collate(batch):
|
| 141 |
+
... collate_map = {torch.Tensor: collate_tensor_fn}
|
| 142 |
+
... return collate(batch, collate_fn_map=collate_map)
|
| 143 |
+
>>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
|
| 144 |
+
>>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
|
| 145 |
+
|
| 146 |
+
Note:
|
| 147 |
+
Each collate function requires a positional argument for batch and a keyword argument
|
| 148 |
+
for the dictionary of collate functions as `collate_fn_map`.
|
| 149 |
+
"""
|
| 150 |
+
elem = batch[0]
|
| 151 |
+
elem_type = type(elem)
|
| 152 |
+
|
| 153 |
+
if collate_fn_map is not None:
|
| 154 |
+
if elem_type in collate_fn_map:
|
| 155 |
+
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
|
| 156 |
+
|
| 157 |
+
for collate_type in collate_fn_map:
|
| 158 |
+
if isinstance(elem, collate_type):
|
| 159 |
+
return collate_fn_map[collate_type](
|
| 160 |
+
batch, collate_fn_map=collate_fn_map
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if isinstance(elem, collections.abc.Mapping):
|
| 164 |
+
try:
|
| 165 |
+
if isinstance(elem, collections.abc.MutableMapping):
|
| 166 |
+
# The mapping type may have extra properties, so we can't just
|
| 167 |
+
# use `type(data)(...)` to create the new mapping.
|
| 168 |
+
# Create a clone and update it if the mapping type is mutable.
|
| 169 |
+
clone = copy.copy(elem)
|
| 170 |
+
clone.update(
|
| 171 |
+
{
|
| 172 |
+
key: collate(
|
| 173 |
+
[d[key] for d in batch], collate_fn_map=collate_fn_map
|
| 174 |
+
)
|
| 175 |
+
for key in elem
|
| 176 |
+
}
|
| 177 |
+
)
|
| 178 |
+
return clone
|
| 179 |
+
else:
|
| 180 |
+
return elem_type(
|
| 181 |
+
{
|
| 182 |
+
key: collate(
|
| 183 |
+
[d[key] for d in batch], collate_fn_map=collate_fn_map
|
| 184 |
+
)
|
| 185 |
+
for key in elem
|
| 186 |
+
}
|
| 187 |
+
)
|
| 188 |
+
except TypeError:
|
| 189 |
+
# The mapping type may not support `copy()` / `update(mapping)`
|
| 190 |
+
# or `__init__(iterable)`.
|
| 191 |
+
return {
|
| 192 |
+
key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map)
|
| 193 |
+
for key in elem
|
| 194 |
+
}
|
| 195 |
+
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
|
| 196 |
+
return elem_type(
|
| 197 |
+
*(
|
| 198 |
+
collate(samples, collate_fn_map=collate_fn_map)
|
| 199 |
+
for samples in zip(*batch)
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
elif isinstance(elem, collections.abc.Sequence):
|
| 203 |
+
# check to make sure that the elements in batch have consistent size
|
| 204 |
+
it = iter(batch)
|
| 205 |
+
elem_size = len(next(it))
|
| 206 |
+
if not all(len(elem) == elem_size for elem in it):
|
| 207 |
+
raise RuntimeError("each element in list of batch should be of equal size")
|
| 208 |
+
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
|
| 209 |
+
|
| 210 |
+
if isinstance(elem, tuple):
|
| 211 |
+
return [
|
| 212 |
+
collate(samples, collate_fn_map=collate_fn_map)
|
| 213 |
+
for samples in transposed
|
| 214 |
+
] # Backwards compatibility.
|
| 215 |
+
else:
|
| 216 |
+
try:
|
| 217 |
+
if isinstance(elem, collections.abc.MutableSequence):
|
| 218 |
+
# The sequence type may have extra properties, so we can't just
|
| 219 |
+
# use `type(data)(...)` to create the new sequence.
|
| 220 |
+
# Create a clone and update it if the sequence type is mutable.
|
| 221 |
+
clone = copy.copy(elem) # type: ignore[arg-type]
|
| 222 |
+
for i, samples in enumerate(transposed):
|
| 223 |
+
clone[i] = collate(samples, collate_fn_map=collate_fn_map)
|
| 224 |
+
return clone
|
| 225 |
+
else:
|
| 226 |
+
return elem_type(
|
| 227 |
+
[
|
| 228 |
+
collate(samples, collate_fn_map=collate_fn_map)
|
| 229 |
+
for samples in transposed
|
| 230 |
+
]
|
| 231 |
+
)
|
| 232 |
+
except TypeError:
|
| 233 |
+
# The sequence type may not support `copy()` / `__setitem__(index, item)`
|
| 234 |
+
# or `__init__(iterable)` (e.g., `range`).
|
| 235 |
+
return [
|
| 236 |
+
collate(samples, collate_fn_map=collate_fn_map)
|
| 237 |
+
for samples in transposed
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def collate_tensor_fn(
|
| 244 |
+
batch,
|
| 245 |
+
*,
|
| 246 |
+
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
|
| 247 |
+
):
|
| 248 |
+
elem = batch[0]
|
| 249 |
+
out = None
|
| 250 |
+
if elem.is_nested:
|
| 251 |
+
raise RuntimeError(
|
| 252 |
+
"Batches of nested tensors are not currently supported by the default collate_fn; "
|
| 253 |
+
"please provide a custom collate_fn to handle them appropriately."
|
| 254 |
+
)
|
| 255 |
+
if elem.layout in {
|
| 256 |
+
torch.sparse_coo,
|
| 257 |
+
torch.sparse_csr,
|
| 258 |
+
torch.sparse_bsr,
|
| 259 |
+
torch.sparse_csc,
|
| 260 |
+
torch.sparse_bsc,
|
| 261 |
+
}:
|
| 262 |
+
raise RuntimeError(
|
| 263 |
+
"Batches of sparse tensors are not currently supported by the default collate_fn; "
|
| 264 |
+
"please provide a custom collate_fn to handle them appropriately."
|
| 265 |
+
)
|
| 266 |
+
if torch.utils.data.get_worker_info() is not None:
|
| 267 |
+
# If we're in a background process, concatenate directly into a
|
| 268 |
+
# shared memory tensor to avoid an extra copy
|
| 269 |
+
numel = sum(x.numel() for x in batch)
|
| 270 |
+
storage = elem._typed_storage()._new_shared(numel, device=elem.device)
|
| 271 |
+
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
|
| 272 |
+
return torch.stack(batch, 0, out=out)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def collate_numpy_array_fn(
|
| 276 |
+
batch,
|
| 277 |
+
*,
|
| 278 |
+
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
|
| 279 |
+
):
|
| 280 |
+
elem = batch[0]
|
| 281 |
+
# array of string classes and object
|
| 282 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
| 283 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
| 284 |
+
|
| 285 |
+
return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def collate_numpy_scalar_fn(
|
| 289 |
+
batch,
|
| 290 |
+
*,
|
| 291 |
+
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
|
| 292 |
+
):
|
| 293 |
+
return torch.as_tensor(batch)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def collate_float_fn(
|
| 297 |
+
batch,
|
| 298 |
+
*,
|
| 299 |
+
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
|
| 300 |
+
):
|
| 301 |
+
return torch.tensor(batch, dtype=torch.float64)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def collate_int_fn(
|
| 305 |
+
batch,
|
| 306 |
+
*,
|
| 307 |
+
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
|
| 308 |
+
):
|
| 309 |
+
return torch.tensor(batch)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def collate_str_fn(
|
| 313 |
+
batch,
|
| 314 |
+
*,
|
| 315 |
+
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
|
| 316 |
+
):
|
| 317 |
+
return batch
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {
|
| 321 |
+
torch.Tensor: collate_tensor_fn
|
| 322 |
+
}
|
| 323 |
+
with contextlib.suppress(ImportError):
|
| 324 |
+
import numpy as np
|
| 325 |
+
|
| 326 |
+
# For both ndarray and memmap (subclass of ndarray)
|
| 327 |
+
default_collate_fn_map[np.ndarray] = collate_numpy_array_fn
|
| 328 |
+
# See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html
|
| 329 |
+
# Skip string scalars
|
| 330 |
+
default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
|
| 331 |
+
default_collate_fn_map[float] = collate_float_fn
|
| 332 |
+
default_collate_fn_map[int] = collate_int_fn
|
| 333 |
+
default_collate_fn_map[str] = collate_str_fn
|
| 334 |
+
default_collate_fn_map[bytes] = collate_str_fn
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def default_collate(batch):
|
| 338 |
+
r"""
|
| 339 |
+
Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
|
| 340 |
+
|
| 341 |
+
The exact output type can be a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
|
| 342 |
+
Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
|
| 343 |
+
This is used as the default function for collation when
|
| 344 |
+
`batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
|
| 345 |
+
|
| 346 |
+
Here is the general input type (based on the type of the element within the batch) to output type mapping:
|
| 347 |
+
|
| 348 |
+
* :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
|
| 349 |
+
* NumPy Arrays -> :class:`torch.Tensor`
|
| 350 |
+
* `float` -> :class:`torch.Tensor`
|
| 351 |
+
* `int` -> :class:`torch.Tensor`
|
| 352 |
+
* `str` -> `str` (unchanged)
|
| 353 |
+
* `bytes` -> `bytes` (unchanged)
|
| 354 |
+
* `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
|
| 355 |
+
* `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]),
|
| 356 |
+
default_collate([V2_1, V2_2, ...]), ...]`
|
| 357 |
+
* `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]),
|
| 358 |
+
default_collate([V2_1, V2_2, ...]), ...]`
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
batch: a single batch to be collated
|
| 362 |
+
|
| 363 |
+
Examples:
|
| 364 |
+
>>> # xdoctest: +SKIP
|
| 365 |
+
>>> # Example with a batch of `int`s:
|
| 366 |
+
>>> default_collate([0, 1, 2, 3])
|
| 367 |
+
tensor([0, 1, 2, 3])
|
| 368 |
+
>>> # Example with a batch of `str`s:
|
| 369 |
+
>>> default_collate(['a', 'b', 'c'])
|
| 370 |
+
['a', 'b', 'c']
|
| 371 |
+
>>> # Example with `Map` inside the batch:
|
| 372 |
+
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
|
| 373 |
+
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
|
| 374 |
+
>>> # Example with `NamedTuple` inside the batch:
|
| 375 |
+
>>> Point = namedtuple('Point', ['x', 'y'])
|
| 376 |
+
>>> default_collate([Point(0, 0), Point(1, 1)])
|
| 377 |
+
Point(x=tensor([0, 1]), y=tensor([0, 1]))
|
| 378 |
+
>>> # Example with `Tuple` inside the batch:
|
| 379 |
+
>>> default_collate([(0, 1), (2, 3)])
|
| 380 |
+
[tensor([0, 2]), tensor([1, 3])]
|
| 381 |
+
>>> # Example with `List` inside the batch:
|
| 382 |
+
>>> default_collate([[0, 1], [2, 3]])
|
| 383 |
+
[tensor([0, 2]), tensor([1, 3])]
|
| 384 |
+
>>> # Two options to extend `default_collate` to handle specific type
|
| 385 |
+
>>> # Option 1: Write custom collate function and invoke `default_collate`
|
| 386 |
+
>>> def custom_collate(batch):
|
| 387 |
+
... elem = batch[0]
|
| 388 |
+
... if isinstance(elem, CustomType): # Some custom condition
|
| 389 |
+
... return ...
|
| 390 |
+
... else: # Fall back to `default_collate`
|
| 391 |
+
... return default_collate(batch)
|
| 392 |
+
>>> # Option 2: In-place modify `default_collate_fn_map`
|
| 393 |
+
>>> def collate_customtype_fn(batch, *, collate_fn_map=None):
|
| 394 |
+
... return ...
|
| 395 |
+
>>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
|
| 396 |
+
>>> default_collate(batch) # Handle `CustomType` automatically
|
| 397 |
+
"""
|
| 398 |
+
return collate(batch, collate_fn_map=default_collate_fn_map)
|
.venv/Lib/site-packages/torch/utils/data/_utils/fetch.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset.
|
| 3 |
+
|
| 4 |
+
This logic is shared in both single- and multi-processing data loading.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class _BaseDatasetFetcher:
|
| 9 |
+
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
| 10 |
+
self.dataset = dataset
|
| 11 |
+
self.auto_collation = auto_collation
|
| 12 |
+
self.collate_fn = collate_fn
|
| 13 |
+
self.drop_last = drop_last
|
| 14 |
+
|
| 15 |
+
def fetch(self, possibly_batched_index):
|
| 16 |
+
raise NotImplementedError
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _IterableDatasetFetcher(_BaseDatasetFetcher):
|
| 20 |
+
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
| 21 |
+
super().__init__(dataset, auto_collation, collate_fn, drop_last)
|
| 22 |
+
self.dataset_iter = iter(dataset)
|
| 23 |
+
self.ended = False
|
| 24 |
+
|
| 25 |
+
def fetch(self, possibly_batched_index):
|
| 26 |
+
if self.ended:
|
| 27 |
+
raise StopIteration
|
| 28 |
+
|
| 29 |
+
if self.auto_collation:
|
| 30 |
+
data = []
|
| 31 |
+
for _ in possibly_batched_index:
|
| 32 |
+
try:
|
| 33 |
+
data.append(next(self.dataset_iter))
|
| 34 |
+
except StopIteration:
|
| 35 |
+
self.ended = True
|
| 36 |
+
break
|
| 37 |
+
if len(data) == 0 or (
|
| 38 |
+
self.drop_last and len(data) < len(possibly_batched_index)
|
| 39 |
+
):
|
| 40 |
+
raise StopIteration
|
| 41 |
+
else:
|
| 42 |
+
data = next(self.dataset_iter)
|
| 43 |
+
return self.collate_fn(data)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
| 47 |
+
def fetch(self, possibly_batched_index):
|
| 48 |
+
if self.auto_collation:
|
| 49 |
+
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
| 50 |
+
data = self.dataset.__getitems__(possibly_batched_index)
|
| 51 |
+
else:
|
| 52 |
+
data = [self.dataset[idx] for idx in possibly_batched_index]
|
| 53 |
+
else:
|
| 54 |
+
data = self.dataset[possibly_batched_index]
|
| 55 |
+
return self.collate_fn(data)
|
.venv/Lib/site-packages/torch/utils/data/_utils/pin_memory.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory.
|
| 3 |
+
|
| 4 |
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
| 5 |
+
static methods.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import collections
|
| 9 |
+
import copy
|
| 10 |
+
import queue
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch._utils import ExceptionWrapper
|
| 14 |
+
|
| 15 |
+
from . import MP_STATUS_CHECK_INTERVAL
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
|
| 19 |
+
# This setting is thread local, and prevents the copy in pin_memory from
|
| 20 |
+
# consuming all CPU cores.
|
| 21 |
+
torch.set_num_threads(1)
|
| 22 |
+
|
| 23 |
+
torch.multiprocessing._set_thread_name("pt_data_pin")
|
| 24 |
+
|
| 25 |
+
if device == "cuda":
|
| 26 |
+
torch.cuda.set_device(device_id)
|
| 27 |
+
elif device == "xpu":
|
| 28 |
+
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
|
| 29 |
+
elif device == torch._C._get_privateuse1_backend_name():
|
| 30 |
+
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
|
| 31 |
+
custom_device_mod.set_device(device_id)
|
| 32 |
+
|
| 33 |
+
def do_one_step():
|
| 34 |
+
try:
|
| 35 |
+
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
| 36 |
+
except queue.Empty:
|
| 37 |
+
return
|
| 38 |
+
idx, data = r
|
| 39 |
+
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
|
| 40 |
+
try:
|
| 41 |
+
data = pin_memory(data, device)
|
| 42 |
+
except Exception:
|
| 43 |
+
data = ExceptionWrapper(
|
| 44 |
+
where=f"in pin memory thread for device {device_id}"
|
| 45 |
+
)
|
| 46 |
+
r = (idx, data)
|
| 47 |
+
while not done_event.is_set():
|
| 48 |
+
try:
|
| 49 |
+
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
|
| 50 |
+
break
|
| 51 |
+
except queue.Full:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
| 55 |
+
# logic of this function.
|
| 56 |
+
while not done_event.is_set():
|
| 57 |
+
# Make sure that we don't preserve any object from one iteration
|
| 58 |
+
# to the next
|
| 59 |
+
do_one_step()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def pin_memory(data, device=None):
|
| 63 |
+
if isinstance(data, torch.Tensor):
|
| 64 |
+
return data.pin_memory(device)
|
| 65 |
+
elif isinstance(data, (str, bytes)):
|
| 66 |
+
return data
|
| 67 |
+
elif isinstance(data, collections.abc.Mapping):
|
| 68 |
+
try:
|
| 69 |
+
if isinstance(data, collections.abc.MutableMapping):
|
| 70 |
+
# The sequence type may have extra properties, so we can't just
|
| 71 |
+
# use `type(data)(...)` to create the new sequence.
|
| 72 |
+
# Create a clone and update it if the sequence type is mutable.
|
| 73 |
+
clone = copy.copy(data)
|
| 74 |
+
clone.update(
|
| 75 |
+
{k: pin_memory(sample, device) for k, sample in data.items()}
|
| 76 |
+
)
|
| 77 |
+
return clone
|
| 78 |
+
else:
|
| 79 |
+
return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg]
|
| 80 |
+
except TypeError:
|
| 81 |
+
# The mapping type may not support `copy()` / `update(mapping)`
|
| 82 |
+
# or `__init__(iterable)`.
|
| 83 |
+
return {k: pin_memory(sample, device) for k, sample in data.items()}
|
| 84 |
+
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
|
| 85 |
+
return type(data)(*(pin_memory(sample, device) for sample in data))
|
| 86 |
+
elif isinstance(data, tuple):
|
| 87 |
+
return [
|
| 88 |
+
pin_memory(sample, device) for sample in data
|
| 89 |
+
] # Backwards compatibility.
|
| 90 |
+
elif isinstance(data, collections.abc.Sequence):
|
| 91 |
+
try:
|
| 92 |
+
if isinstance(data, collections.abc.MutableSequence):
|
| 93 |
+
# The sequence type may have extra properties, so we can't just
|
| 94 |
+
# use `type(data)(...)` to create the new sequence.
|
| 95 |
+
# Create a clone and update it if the sequence type is mutable.
|
| 96 |
+
clone = copy.copy(data) # type: ignore[arg-type]
|
| 97 |
+
for i, item in enumerate(data):
|
| 98 |
+
clone[i] = pin_memory(item, device)
|
| 99 |
+
return clone
|
| 100 |
+
return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg]
|
| 101 |
+
except TypeError:
|
| 102 |
+
# The sequence type may not support `copy()` / `__setitem__(index, item)`
|
| 103 |
+
# or `__init__(iterable)` (e.g., `range`).
|
| 104 |
+
return [pin_memory(sample, device) for sample in data]
|
| 105 |
+
elif hasattr(data, "pin_memory"):
|
| 106 |
+
return data.pin_memory()
|
| 107 |
+
else:
|
| 108 |
+
return data
|
.venv/Lib/site-packages/torch/utils/data/_utils/signal_handling.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""Signal handling for multiprocessing data loading.
|
| 3 |
+
|
| 4 |
+
NOTE [ Signal handling in multiprocessing data loading ]
|
| 5 |
+
|
| 6 |
+
In cases like DataLoader, if a worker process dies due to bus error/segfault
|
| 7 |
+
or just hang, the main process will hang waiting for data. This is difficult
|
| 8 |
+
to avoid on PyTorch side as it can be caused by limited shm, or other
|
| 9 |
+
libraries users call in the workers. In this file and `DataLoader.cpp`, we make
|
| 10 |
+
our best effort to provide some error message to users when such unfortunate
|
| 11 |
+
events happen.
|
| 12 |
+
|
| 13 |
+
When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
|
| 14 |
+
defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
|
| 15 |
+
via `_set_worker_pids`.
|
| 16 |
+
|
| 17 |
+
When an error happens in a worker process, the main process received a SIGCHLD,
|
| 18 |
+
and Python will eventually call the handler registered below
|
| 19 |
+
(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails`
|
| 20 |
+
call checks all registered worker pids and raise proper error message to
|
| 21 |
+
prevent main process from hanging waiting for data from worker.
|
| 22 |
+
|
| 23 |
+
Additionally, at the beginning of each worker's `_utils.worker._worker_loop`,
|
| 24 |
+
`_set_worker_signal_handlers` is called to register critical signal handlers
|
| 25 |
+
(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error
|
| 26 |
+
message to stderr before triggering the default handler. So a message will also
|
| 27 |
+
be printed from the worker process when it is killed by such signals.
|
| 28 |
+
|
| 29 |
+
See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of
|
| 30 |
+
this signal handling design and other mechanism we implement to make our
|
| 31 |
+
multiprocessing data loading robust to errors.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import signal
|
| 35 |
+
import threading
|
| 36 |
+
|
| 37 |
+
# Some of the following imported functions are not used in this file, but are to
|
| 38 |
+
# be used `_utils.signal_handling.XXXXX`.
|
| 39 |
+
from torch._C import ( # noqa: F401
|
| 40 |
+
_error_if_any_worker_fails,
|
| 41 |
+
_remove_worker_pids,
|
| 42 |
+
_set_worker_pids,
|
| 43 |
+
_set_worker_signal_handlers,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
from . import IS_WINDOWS
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
_SIGCHLD_handler_set = False
|
| 50 |
+
r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
|
| 51 |
+
handler needs to be set for all DataLoaders in a process."""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _set_SIGCHLD_handler():
|
| 55 |
+
# Windows doesn't support SIGCHLD handler
|
| 56 |
+
if IS_WINDOWS:
|
| 57 |
+
return
|
| 58 |
+
# can't set signal in child threads
|
| 59 |
+
if not isinstance(threading.current_thread(), threading._MainThread): # type: ignore[attr-defined]
|
| 60 |
+
return
|
| 61 |
+
global _SIGCHLD_handler_set
|
| 62 |
+
if _SIGCHLD_handler_set:
|
| 63 |
+
return
|
| 64 |
+
previous_handler = signal.getsignal(signal.SIGCHLD)
|
| 65 |
+
if not callable(previous_handler):
|
| 66 |
+
# This doesn't catch default handler, but SIGCHLD default handler is a
|
| 67 |
+
# no-op.
|
| 68 |
+
previous_handler = None
|
| 69 |
+
|
| 70 |
+
def handler(signum, frame):
|
| 71 |
+
# This following call uses `waitid` with WNOHANG from C side. Therefore,
|
| 72 |
+
# Python can still get and update the process status successfully.
|
| 73 |
+
_error_if_any_worker_fails()
|
| 74 |
+
if previous_handler is not None:
|
| 75 |
+
assert callable(previous_handler)
|
| 76 |
+
previous_handler(signum, frame)
|
| 77 |
+
|
| 78 |
+
signal.signal(signal.SIGCHLD, handler)
|
| 79 |
+
_SIGCHLD_handler_set = True
|
.venv/Lib/site-packages/torch/utils/data/_utils/worker.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
| 3 |
+
|
| 4 |
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
| 5 |
+
static methods.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import queue
|
| 10 |
+
import random
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional, TYPE_CHECKING, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch._utils import ExceptionWrapper
|
| 16 |
+
|
| 17 |
+
from . import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from torch.utils.data import Dataset
|
| 22 |
+
|
| 23 |
+
if IS_WINDOWS:
|
| 24 |
+
import ctypes
|
| 25 |
+
from ctypes.wintypes import BOOL, DWORD, HANDLE
|
| 26 |
+
|
| 27 |
+
# On Windows, the parent ID of the worker process remains unchanged when the manager process
|
| 28 |
+
# is gone, and the only way to check it through OS is to let the worker have a process handle
|
| 29 |
+
# of the manager and ask if the process status has changed.
|
| 30 |
+
class ManagerWatchdog:
|
| 31 |
+
def __init__(self) -> None:
|
| 32 |
+
self.manager_pid = os.getppid()
|
| 33 |
+
|
| 34 |
+
# mypy cannot detect this code is windows only
|
| 35 |
+
self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
|
| 36 |
+
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
|
| 37 |
+
self.kernel32.OpenProcess.restype = HANDLE
|
| 38 |
+
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
|
| 39 |
+
self.kernel32.WaitForSingleObject.restype = DWORD
|
| 40 |
+
|
| 41 |
+
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
|
| 42 |
+
SYNCHRONIZE = 0x00100000
|
| 43 |
+
self.manager_handle = self.kernel32.OpenProcess(
|
| 44 |
+
SYNCHRONIZE, 0, self.manager_pid
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if not self.manager_handle:
|
| 48 |
+
raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
|
| 49 |
+
|
| 50 |
+
self.manager_dead = False
|
| 51 |
+
|
| 52 |
+
def is_alive(self):
|
| 53 |
+
if not self.manager_dead:
|
| 54 |
+
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
|
| 55 |
+
self.manager_dead = (
|
| 56 |
+
self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
|
| 57 |
+
)
|
| 58 |
+
return not self.manager_dead
|
| 59 |
+
|
| 60 |
+
else:
|
| 61 |
+
|
| 62 |
+
class ManagerWatchdog: # type: ignore[no-redef]
|
| 63 |
+
def __init__(self) -> None:
|
| 64 |
+
self.manager_pid = os.getppid()
|
| 65 |
+
self.manager_dead = False
|
| 66 |
+
|
| 67 |
+
def is_alive(self):
|
| 68 |
+
if not self.manager_dead:
|
| 69 |
+
self.manager_dead = os.getppid() != self.manager_pid
|
| 70 |
+
return not self.manager_dead
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
_worker_info: Optional["WorkerInfo"] = None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class WorkerInfo:
|
| 77 |
+
id: int
|
| 78 |
+
num_workers: int
|
| 79 |
+
seed: int
|
| 80 |
+
dataset: "Dataset"
|
| 81 |
+
__initialized = False
|
| 82 |
+
|
| 83 |
+
def __init__(self, **kwargs):
|
| 84 |
+
for k, v in kwargs.items():
|
| 85 |
+
setattr(self, k, v)
|
| 86 |
+
self.__keys = tuple(kwargs.keys())
|
| 87 |
+
self.__initialized = True
|
| 88 |
+
|
| 89 |
+
def __setattr__(self, key, val):
|
| 90 |
+
if self.__initialized:
|
| 91 |
+
raise RuntimeError(
|
| 92 |
+
f"Cannot assign attributes to {self.__class__.__name__} objects"
|
| 93 |
+
)
|
| 94 |
+
return super().__setattr__(key, val)
|
| 95 |
+
|
| 96 |
+
def __repr__(self):
|
| 97 |
+
items = []
|
| 98 |
+
for k in self.__keys:
|
| 99 |
+
items.append(f"{k}={getattr(self, k)}")
|
| 100 |
+
return f"{self.__class__.__name__}({', '.join(items)})"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_worker_info() -> Optional[WorkerInfo]:
|
| 104 |
+
r"""Returns the information about the current
|
| 105 |
+
:class:`~torch.utils.data.DataLoader` iterator worker process.
|
| 106 |
+
|
| 107 |
+
When called in a worker, this returns an object guaranteed to have the
|
| 108 |
+
following attributes:
|
| 109 |
+
|
| 110 |
+
* :attr:`id`: the current worker id.
|
| 111 |
+
* :attr:`num_workers`: the total number of workers.
|
| 112 |
+
* :attr:`seed`: the random seed set for the current worker. This value is
|
| 113 |
+
determined by main process RNG and the worker id. See
|
| 114 |
+
:class:`~torch.utils.data.DataLoader`'s documentation for more details.
|
| 115 |
+
* :attr:`dataset`: the copy of the dataset object in **this** process. Note
|
| 116 |
+
that this will be a different object in a different process than the one
|
| 117 |
+
in the main process.
|
| 118 |
+
|
| 119 |
+
When called in the main process, this returns ``None``.
|
| 120 |
+
|
| 121 |
+
.. note::
|
| 122 |
+
When used in a :attr:`worker_init_fn` passed over to
|
| 123 |
+
:class:`~torch.utils.data.DataLoader`, this method can be useful to
|
| 124 |
+
set up each worker process differently, for instance, using ``worker_id``
|
| 125 |
+
to configure the ``dataset`` object to only read a specific fraction of a
|
| 126 |
+
sharded dataset, or use ``seed`` to seed other libraries used in dataset
|
| 127 |
+
code.
|
| 128 |
+
"""
|
| 129 |
+
return _worker_info
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
r"""Dummy class used to signal the end of an IterableDataset"""
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclass(frozen=True)
|
| 136 |
+
class _IterableDatasetStopIteration:
|
| 137 |
+
worker_id: int
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
r"""Dummy class used to resume the fetching when worker reuse is enabled"""
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@dataclass(frozen=True)
|
| 144 |
+
class _ResumeIteration:
|
| 145 |
+
seed: Optional[int] = None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# The function `_generate_state` is adapted from `numpy.random.SeedSequence`
|
| 149 |
+
# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
|
| 150 |
+
# It's MIT licensed, here is the copyright:
|
| 151 |
+
|
| 152 |
+
# Copyright (c) 2015 Melissa E. O'Neill
|
| 153 |
+
# Copyright (c) 2019 NumPy Developers
|
| 154 |
+
#
|
| 155 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 156 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 157 |
+
# in the Software without restriction, including without limitation the rights
|
| 158 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 159 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 160 |
+
# furnished to do so, subject to the following conditions:
|
| 161 |
+
#
|
| 162 |
+
# The above copyright notice and this permission notice shall be included in
|
| 163 |
+
# all copies or substantial portions of the Software.
|
| 164 |
+
#
|
| 165 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 166 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 167 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 168 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 169 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 170 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 171 |
+
# SOFTWARE.
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# This function generates an array of int32 as the seed for
|
| 175 |
+
# `numpy.random`, in order to prevent state collision due to same
|
| 176 |
+
# seed and algorithm for `numpy.random` and `random` modules.
|
| 177 |
+
# TODO: Implement `SeedSequence` like object for `torch.random`
|
| 178 |
+
def _generate_state(base_seed, worker_id):
|
| 179 |
+
INIT_A = 0x43B0D7E5
|
| 180 |
+
MULT_A = 0x931E8875
|
| 181 |
+
INIT_B = 0x8B51F9DD
|
| 182 |
+
MULT_B = 0x58F38DED
|
| 183 |
+
MIX_MULT_L = 0xCA01F9DD
|
| 184 |
+
MIX_MULT_R = 0x4973F715
|
| 185 |
+
XSHIFT = 4 * 8 // 2
|
| 186 |
+
MASK32 = 0xFFFFFFFF
|
| 187 |
+
|
| 188 |
+
entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
|
| 189 |
+
pool = [0] * 4
|
| 190 |
+
|
| 191 |
+
hash_const_A = INIT_A
|
| 192 |
+
|
| 193 |
+
def hash(value):
|
| 194 |
+
nonlocal hash_const_A
|
| 195 |
+
value = (value ^ hash_const_A) & MASK32
|
| 196 |
+
hash_const_A = (hash_const_A * MULT_A) & MASK32
|
| 197 |
+
value = (value * hash_const_A) & MASK32
|
| 198 |
+
value = (value ^ (value >> XSHIFT)) & MASK32
|
| 199 |
+
return value
|
| 200 |
+
|
| 201 |
+
def mix(x, y):
|
| 202 |
+
result_x = (MIX_MULT_L * x) & MASK32
|
| 203 |
+
result_y = (MIX_MULT_R * y) & MASK32
|
| 204 |
+
result = (result_x - result_y) & MASK32
|
| 205 |
+
result = (result ^ (result >> XSHIFT)) & MASK32
|
| 206 |
+
return result
|
| 207 |
+
|
| 208 |
+
# Add in the entropy to the pool.
|
| 209 |
+
for i in range(len(pool)):
|
| 210 |
+
pool[i] = hash(entropy[i])
|
| 211 |
+
|
| 212 |
+
# Mix all bits together so late bits can affect earlier bits.
|
| 213 |
+
for i_src in range(len(pool)):
|
| 214 |
+
for i_dst in range(len(pool)):
|
| 215 |
+
if i_src != i_dst:
|
| 216 |
+
pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
|
| 217 |
+
|
| 218 |
+
hash_const_B = INIT_B
|
| 219 |
+
state = []
|
| 220 |
+
for i_dst in range(4):
|
| 221 |
+
data_val = pool[i_dst]
|
| 222 |
+
data_val = (data_val ^ hash_const_B) & MASK32
|
| 223 |
+
hash_const_B = (hash_const_B * MULT_B) & MASK32
|
| 224 |
+
data_val = (data_val * hash_const_B) & MASK32
|
| 225 |
+
data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
|
| 226 |
+
state.append(data_val)
|
| 227 |
+
return state
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _worker_loop(
|
| 231 |
+
dataset_kind,
|
| 232 |
+
dataset,
|
| 233 |
+
index_queue,
|
| 234 |
+
data_queue,
|
| 235 |
+
done_event,
|
| 236 |
+
auto_collation,
|
| 237 |
+
collate_fn,
|
| 238 |
+
drop_last,
|
| 239 |
+
base_seed,
|
| 240 |
+
init_fn,
|
| 241 |
+
worker_id,
|
| 242 |
+
num_workers,
|
| 243 |
+
persistent_workers,
|
| 244 |
+
shared_seed,
|
| 245 |
+
):
|
| 246 |
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
| 247 |
+
# logic of this function.
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
# Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
| 251 |
+
# module's handlers are executed after Python returns from C low-level
|
| 252 |
+
# handlers, likely when the same fatal signal had already happened
|
| 253 |
+
# again.
|
| 254 |
+
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
|
| 255 |
+
signal_handling._set_worker_signal_handlers()
|
| 256 |
+
|
| 257 |
+
torch.multiprocessing._set_thread_name("pt_data_worker")
|
| 258 |
+
|
| 259 |
+
torch.set_num_threads(1)
|
| 260 |
+
seed = base_seed + worker_id
|
| 261 |
+
random.seed(seed)
|
| 262 |
+
torch.manual_seed(seed)
|
| 263 |
+
if HAS_NUMPY:
|
| 264 |
+
np_seed = _generate_state(base_seed, worker_id)
|
| 265 |
+
import numpy as np
|
| 266 |
+
|
| 267 |
+
np.random.seed(np_seed)
|
| 268 |
+
|
| 269 |
+
from torch.utils.data import IterDataPipe
|
| 270 |
+
from torch.utils.data.graph_settings import apply_random_seed
|
| 271 |
+
|
| 272 |
+
shared_rng = torch.Generator()
|
| 273 |
+
if isinstance(dataset, IterDataPipe):
|
| 274 |
+
assert shared_seed is not None
|
| 275 |
+
shared_rng.manual_seed(shared_seed)
|
| 276 |
+
dataset = apply_random_seed(dataset, shared_rng)
|
| 277 |
+
|
| 278 |
+
global _worker_info
|
| 279 |
+
_worker_info = WorkerInfo(
|
| 280 |
+
id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
from torch.utils.data import _DatasetKind
|
| 284 |
+
|
| 285 |
+
init_exception = None
|
| 286 |
+
|
| 287 |
+
try:
|
| 288 |
+
if init_fn is not None:
|
| 289 |
+
init_fn(worker_id)
|
| 290 |
+
|
| 291 |
+
fetcher = _DatasetKind.create_fetcher(
|
| 292 |
+
dataset_kind, dataset, auto_collation, collate_fn, drop_last
|
| 293 |
+
)
|
| 294 |
+
except Exception:
|
| 295 |
+
init_exception = ExceptionWrapper(
|
| 296 |
+
where=f"in DataLoader worker process {worker_id}"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# When using Iterable mode, some worker can exit earlier than others due
|
| 300 |
+
# to the IterableDataset behaving differently for different workers.
|
| 301 |
+
# When such things happen, an `_IterableDatasetStopIteration` object is
|
| 302 |
+
# sent over to the main process with the ID of this worker, so that the
|
| 303 |
+
# main process won't send more tasks to this worker, and will send
|
| 304 |
+
# `None` to this worker to properly exit it.
|
| 305 |
+
#
|
| 306 |
+
# Note that we cannot set `done_event` from a worker as it is shared
|
| 307 |
+
# among all processes. Instead, we set the `iteration_end` flag to
|
| 308 |
+
# signify that the iterator is exhausted. When either `done_event` or
|
| 309 |
+
# `iteration_end` is set, we skip all processing step and just wait for
|
| 310 |
+
# `None`.
|
| 311 |
+
iteration_end = False
|
| 312 |
+
|
| 313 |
+
watchdog = ManagerWatchdog()
|
| 314 |
+
|
| 315 |
+
while watchdog.is_alive():
|
| 316 |
+
try:
|
| 317 |
+
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
| 318 |
+
except queue.Empty:
|
| 319 |
+
continue
|
| 320 |
+
if isinstance(r, _ResumeIteration):
|
| 321 |
+
# Acknowledge the main process
|
| 322 |
+
data_queue.put((r, None))
|
| 323 |
+
iteration_end = False
|
| 324 |
+
|
| 325 |
+
if isinstance(dataset, IterDataPipe):
|
| 326 |
+
assert r.seed is not None
|
| 327 |
+
shared_rng.manual_seed(r.seed)
|
| 328 |
+
dataset = apply_random_seed(dataset, shared_rng)
|
| 329 |
+
|
| 330 |
+
# Recreate the fetcher for worker-reuse policy
|
| 331 |
+
fetcher = _DatasetKind.create_fetcher(
|
| 332 |
+
dataset_kind, dataset, auto_collation, collate_fn, drop_last
|
| 333 |
+
)
|
| 334 |
+
continue
|
| 335 |
+
elif r is None:
|
| 336 |
+
# Received the final signal
|
| 337 |
+
assert done_event.is_set() or iteration_end
|
| 338 |
+
break
|
| 339 |
+
elif done_event.is_set() or iteration_end:
|
| 340 |
+
# `done_event` is set. But I haven't received the final signal
|
| 341 |
+
# (None) yet. I will keep continuing until get it, and skip the
|
| 342 |
+
# processing steps.
|
| 343 |
+
continue
|
| 344 |
+
idx, index = r
|
| 345 |
+
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
|
| 346 |
+
if init_exception is not None:
|
| 347 |
+
data = init_exception
|
| 348 |
+
init_exception = None
|
| 349 |
+
else:
|
| 350 |
+
try:
|
| 351 |
+
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
|
| 352 |
+
except Exception as e:
|
| 353 |
+
if (
|
| 354 |
+
isinstance(e, StopIteration)
|
| 355 |
+
and dataset_kind == _DatasetKind.Iterable
|
| 356 |
+
):
|
| 357 |
+
data = _IterableDatasetStopIteration(worker_id)
|
| 358 |
+
# Set `iteration_end`
|
| 359 |
+
# (1) to save future `next(...)` calls, and
|
| 360 |
+
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
|
| 361 |
+
iteration_end = True
|
| 362 |
+
else:
|
| 363 |
+
# It is important that we don't store exc_info in a variable.
|
| 364 |
+
# `ExceptionWrapper` does the correct thing.
|
| 365 |
+
# See NOTE [ Python Traceback Reference Cycle Problem ]
|
| 366 |
+
data = ExceptionWrapper(
|
| 367 |
+
where=f"in DataLoader worker process {worker_id}"
|
| 368 |
+
)
|
| 369 |
+
data_queue.put((idx, data))
|
| 370 |
+
del data, idx, index, r # save memory
|
| 371 |
+
except KeyboardInterrupt:
|
| 372 |
+
# Main process will raise KeyboardInterrupt anyways.
|
| 373 |
+
pass
|
| 374 |
+
if done_event.is_set():
|
| 375 |
+
data_queue.cancel_join_thread()
|
| 376 |
+
data_queue.close()
|
.venv/Lib/site-packages/torch/utils/data/backward_compatibility.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing_extensions import deprecated as _deprecated
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@_deprecated(
|
| 6 |
+
"Usage of `backward_compatibility.worker_init_fn` is deprecated "
|
| 7 |
+
"as `DataLoader` automatically applies sharding in every worker",
|
| 8 |
+
category=FutureWarning,
|
| 9 |
+
)
|
| 10 |
+
def worker_init_fn(worker_id):
|
| 11 |
+
pass
|
.venv/Lib/site-packages/torch/utils/data/dataloader.py
ADDED
|
@@ -0,0 +1,1604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter.
|
| 3 |
+
|
| 4 |
+
To support these two classes, in `./_utils` we define many utility methods and
|
| 5 |
+
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
| 6 |
+
in `./_utils/worker.py`.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import functools
|
| 10 |
+
import itertools
|
| 11 |
+
import logging
|
| 12 |
+
import multiprocessing as python_multiprocessing
|
| 13 |
+
import os
|
| 14 |
+
import queue
|
| 15 |
+
import threading
|
| 16 |
+
import warnings
|
| 17 |
+
from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.distributed as dist
|
| 21 |
+
import torch.utils.data.graph_settings
|
| 22 |
+
from torch._utils import ExceptionWrapper
|
| 23 |
+
from torch.utils.data import _utils
|
| 24 |
+
from torch.utils.data.datapipes.datapipe import (
|
| 25 |
+
_IterDataPipeSerializationWrapper,
|
| 26 |
+
_MapDataPipeSerializationWrapper,
|
| 27 |
+
IterDataPipe,
|
| 28 |
+
MapDataPipe,
|
| 29 |
+
)
|
| 30 |
+
from torch.utils.data.dataset import Dataset, IterableDataset
|
| 31 |
+
from torch.utils.data.sampler import (
|
| 32 |
+
BatchSampler,
|
| 33 |
+
RandomSampler,
|
| 34 |
+
Sampler,
|
| 35 |
+
SequentialSampler,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
__all__ = [
|
| 40 |
+
"DataLoader",
|
| 41 |
+
"get_worker_info",
|
| 42 |
+
"default_collate",
|
| 43 |
+
"default_convert",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
_T = TypeVar("_T")
|
| 48 |
+
_T_co = TypeVar("_T_co", covariant=True)
|
| 49 |
+
_worker_init_fn_t = Callable[[int], None]
|
| 50 |
+
|
| 51 |
+
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
|
| 52 |
+
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
|
| 53 |
+
# See https://github.com/python/mypy/issues/3737.
|
| 54 |
+
_collate_fn_t = Callable[[List[_T]], Any]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# These functions used to be defined in this file. However, it was moved to
|
| 58 |
+
# _utils/collate.py. Although it is rather hard to access this from user land
|
| 59 |
+
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
|
| 60 |
+
# probably is user code out there using it. This aliasing maintains BC in this
|
| 61 |
+
# aspect.
|
| 62 |
+
default_collate: _collate_fn_t = _utils.collate.default_collate
|
| 63 |
+
default_convert = _utils.collate.default_convert
|
| 64 |
+
|
| 65 |
+
get_worker_info = _utils.worker.get_worker_info
|
| 66 |
+
|
| 67 |
+
logger = logging.getLogger(__name__)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class _DatasetKind:
|
| 71 |
+
Map = 0
|
| 72 |
+
Iterable = 1
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
|
| 76 |
+
if kind == _DatasetKind.Map:
|
| 77 |
+
return _utils.fetch._MapDatasetFetcher(
|
| 78 |
+
dataset, auto_collation, collate_fn, drop_last
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
return _utils.fetch._IterableDatasetFetcher(
|
| 82 |
+
dataset, auto_collation, collate_fn, drop_last
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class _InfiniteConstantSampler(Sampler):
|
| 87 |
+
r"""Analogous to ``itertools.repeat(None, None)``.
|
| 88 |
+
|
| 89 |
+
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __iter__(self):
|
| 93 |
+
while True:
|
| 94 |
+
yield None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _get_distributed_settings():
|
| 98 |
+
if dist.is_available() and dist.is_initialized():
|
| 99 |
+
return dist.get_world_size(), dist.get_rank()
|
| 100 |
+
else:
|
| 101 |
+
return 1, 0
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
|
| 105 |
+
global_worker_id = worker_id
|
| 106 |
+
info = torch.utils.data.get_worker_info()
|
| 107 |
+
assert info is not None
|
| 108 |
+
total_workers = info.num_workers
|
| 109 |
+
datapipe = info.dataset
|
| 110 |
+
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
|
| 111 |
+
# To distribute elements across distributed process evenly, we should shard data on distributed
|
| 112 |
+
# processes first then shard on worker processes
|
| 113 |
+
total_workers *= world_size
|
| 114 |
+
global_worker_id = global_worker_id * world_size + rank_id
|
| 115 |
+
# For BC, use default SHARDING_PRIORITIES
|
| 116 |
+
torch.utils.data.graph_settings.apply_sharding(
|
| 117 |
+
datapipe, total_workers, global_worker_id
|
| 118 |
+
)
|
| 119 |
+
if worker_init_fn is not None:
|
| 120 |
+
worker_init_fn(worker_id)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _share_dist_seed(generator, pg):
|
| 124 |
+
_shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
|
| 125 |
+
if isinstance(pg, dist.ProcessGroup):
|
| 126 |
+
dist.broadcast(_shared_seed, src=0, group=pg)
|
| 127 |
+
return _shared_seed.item()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class DataLoader(Generic[_T_co]):
|
| 131 |
+
r"""
|
| 132 |
+
Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.
|
| 133 |
+
|
| 134 |
+
The :class:`~torch.utils.data.DataLoader` supports both map-style and
|
| 135 |
+
iterable-style datasets with single- or multi-process loading, customizing
|
| 136 |
+
loading order and optional automatic batching (collation) and memory pinning.
|
| 137 |
+
|
| 138 |
+
See :py:mod:`torch.utils.data` documentation page for more details.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
dataset (Dataset): dataset from which to load the data.
|
| 142 |
+
batch_size (int, optional): how many samples per batch to load
|
| 143 |
+
(default: ``1``).
|
| 144 |
+
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
| 145 |
+
at every epoch (default: ``False``).
|
| 146 |
+
sampler (Sampler or Iterable, optional): defines the strategy to draw
|
| 147 |
+
samples from the dataset. Can be any ``Iterable`` with ``__len__``
|
| 148 |
+
implemented. If specified, :attr:`shuffle` must not be specified.
|
| 149 |
+
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
|
| 150 |
+
returns a batch of indices at a time. Mutually exclusive with
|
| 151 |
+
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
|
| 152 |
+
and :attr:`drop_last`.
|
| 153 |
+
num_workers (int, optional): how many subprocesses to use for data
|
| 154 |
+
loading. ``0`` means that the data will be loaded in the main process.
|
| 155 |
+
(default: ``0``)
|
| 156 |
+
collate_fn (Callable, optional): merges a list of samples to form a
|
| 157 |
+
mini-batch of Tensor(s). Used when using batched loading from a
|
| 158 |
+
map-style dataset.
|
| 159 |
+
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
|
| 160 |
+
into device/CUDA pinned memory before returning them. If your data elements
|
| 161 |
+
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
|
| 162 |
+
see the example below.
|
| 163 |
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
| 164 |
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
| 165 |
+
the size of dataset is not divisible by the batch size, then the last batch
|
| 166 |
+
will be smaller. (default: ``False``)
|
| 167 |
+
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
| 168 |
+
from workers. Should always be non-negative. (default: ``0``)
|
| 169 |
+
worker_init_fn (Callable, optional): If not ``None``, this will be called on each
|
| 170 |
+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
| 171 |
+
input, after seeding and before data loading. (default: ``None``)
|
| 172 |
+
multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
|
| 173 |
+
``None``, the default `multiprocessing context`_ of your operating system will
|
| 174 |
+
be used. (default: ``None``)
|
| 175 |
+
generator (torch.Generator, optional): If not ``None``, this RNG will be used
|
| 176 |
+
by RandomSampler to generate random indexes and multiprocessing to generate
|
| 177 |
+
``base_seed`` for workers. (default: ``None``)
|
| 178 |
+
prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
|
| 179 |
+
in advance by each worker. ``2`` means there will be a total of
|
| 180 |
+
2 * num_workers batches prefetched across all workers. (default value depends
|
| 181 |
+
on the set value for num_workers. If value of num_workers=0 default is ``None``.
|
| 182 |
+
Otherwise, if value of ``num_workers > 0`` default is ``2``).
|
| 183 |
+
persistent_workers (bool, optional): If ``True``, the data loader will not shut down
|
| 184 |
+
the worker processes after a dataset has been consumed once. This allows to
|
| 185 |
+
maintain the workers `Dataset` instances alive. (default: ``False``)
|
| 186 |
+
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
|
| 187 |
+
``True``.
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
|
| 191 |
+
cannot be an unpicklable object, e.g., a lambda function. See
|
| 192 |
+
:ref:`multiprocessing-best-practices` on more details related
|
| 193 |
+
to multiprocessing in PyTorch.
|
| 194 |
+
|
| 195 |
+
.. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
|
| 196 |
+
When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
|
| 197 |
+
it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
|
| 198 |
+
rounding depending on :attr:`drop_last`, regardless of multi-process loading
|
| 199 |
+
configurations. This represents the best guess PyTorch can make because PyTorch
|
| 200 |
+
trusts user :attr:`dataset` code in correctly handling multi-process
|
| 201 |
+
loading to avoid duplicate data.
|
| 202 |
+
|
| 203 |
+
However, if sharding results in multiple workers having incomplete last batches,
|
| 204 |
+
this estimate can still be inaccurate, because (1) an otherwise complete batch can
|
| 205 |
+
be broken into multiple ones and (2) more than one batch worth of samples can be
|
| 206 |
+
dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
|
| 207 |
+
cases in general.
|
| 208 |
+
|
| 209 |
+
See `Dataset Types`_ for more details on these two types of datasets and how
|
| 210 |
+
:class:`~torch.utils.data.IterableDataset` interacts with
|
| 211 |
+
`Multi-process data loading`_.
|
| 212 |
+
|
| 213 |
+
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
|
| 214 |
+
:ref:`data-loading-randomness` notes for random seed related questions.
|
| 215 |
+
|
| 216 |
+
.. _multiprocessing context:
|
| 217 |
+
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
dataset: Dataset[_T_co]
|
| 221 |
+
batch_size: Optional[int]
|
| 222 |
+
num_workers: int
|
| 223 |
+
pin_memory: bool
|
| 224 |
+
drop_last: bool
|
| 225 |
+
timeout: float
|
| 226 |
+
sampler: Union[Sampler, Iterable]
|
| 227 |
+
pin_memory_device: str
|
| 228 |
+
prefetch_factor: Optional[int]
|
| 229 |
+
_iterator: Optional["_BaseDataLoaderIter"]
|
| 230 |
+
__initialized = False
|
| 231 |
+
|
| 232 |
+
def __init__(
|
| 233 |
+
self,
|
| 234 |
+
dataset: Dataset[_T_co],
|
| 235 |
+
batch_size: Optional[int] = 1,
|
| 236 |
+
shuffle: Optional[bool] = None,
|
| 237 |
+
sampler: Union[Sampler, Iterable, None] = None,
|
| 238 |
+
batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
|
| 239 |
+
num_workers: int = 0,
|
| 240 |
+
collate_fn: Optional[_collate_fn_t] = None,
|
| 241 |
+
pin_memory: bool = False,
|
| 242 |
+
drop_last: bool = False,
|
| 243 |
+
timeout: float = 0,
|
| 244 |
+
worker_init_fn: Optional[_worker_init_fn_t] = None,
|
| 245 |
+
multiprocessing_context=None,
|
| 246 |
+
generator=None,
|
| 247 |
+
*,
|
| 248 |
+
prefetch_factor: Optional[int] = None,
|
| 249 |
+
persistent_workers: bool = False,
|
| 250 |
+
pin_memory_device: str = "",
|
| 251 |
+
):
|
| 252 |
+
torch._C._log_api_usage_once("python.data_loader")
|
| 253 |
+
|
| 254 |
+
if num_workers < 0:
|
| 255 |
+
raise ValueError(
|
| 256 |
+
"num_workers option should be non-negative; "
|
| 257 |
+
"use num_workers=0 to disable multiprocessing."
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
if timeout < 0:
|
| 261 |
+
raise ValueError("timeout option should be non-negative")
|
| 262 |
+
|
| 263 |
+
if num_workers == 0 and prefetch_factor is not None:
|
| 264 |
+
raise ValueError(
|
| 265 |
+
"prefetch_factor option could only be specified in multiprocessing."
|
| 266 |
+
"let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
|
| 267 |
+
)
|
| 268 |
+
elif num_workers > 0 and prefetch_factor is None:
|
| 269 |
+
prefetch_factor = 2
|
| 270 |
+
elif prefetch_factor is not None and prefetch_factor < 0:
|
| 271 |
+
raise ValueError("prefetch_factor option should be non-negative")
|
| 272 |
+
|
| 273 |
+
if persistent_workers and num_workers == 0:
|
| 274 |
+
raise ValueError("persistent_workers option needs num_workers > 0")
|
| 275 |
+
|
| 276 |
+
self.dataset = dataset
|
| 277 |
+
self.num_workers = num_workers
|
| 278 |
+
self.prefetch_factor = prefetch_factor
|
| 279 |
+
self.pin_memory = pin_memory
|
| 280 |
+
self.pin_memory_device = pin_memory_device
|
| 281 |
+
self.timeout = timeout
|
| 282 |
+
self.worker_init_fn = worker_init_fn
|
| 283 |
+
self.multiprocessing_context = multiprocessing_context
|
| 284 |
+
|
| 285 |
+
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
| 286 |
+
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
|
| 287 |
+
if isinstance(self.dataset, IterDataPipe):
|
| 288 |
+
self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
|
| 289 |
+
elif isinstance(self.dataset, MapDataPipe):
|
| 290 |
+
self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
|
| 291 |
+
|
| 292 |
+
# Arg-check dataset related before checking samplers because we want to
|
| 293 |
+
# tell users that iterable-style datasets are incompatible with custom
|
| 294 |
+
# samplers first, so that they don't learn that this combo doesn't work
|
| 295 |
+
# after spending time fixing the custom sampler errors.
|
| 296 |
+
if isinstance(dataset, IterableDataset):
|
| 297 |
+
self._dataset_kind = _DatasetKind.Iterable
|
| 298 |
+
# NOTE [ Custom Samplers and IterableDataset ]
|
| 299 |
+
#
|
| 300 |
+
# `IterableDataset` does not support custom `batch_sampler` or
|
| 301 |
+
# `sampler` since the key is irrelevant (unless we support
|
| 302 |
+
# generator-style dataset one day...).
|
| 303 |
+
#
|
| 304 |
+
# For `sampler`, we always create a dummy sampler. This is an
|
| 305 |
+
# infinite sampler even when the dataset may have an implemented
|
| 306 |
+
# finite `__len__` because in multi-process data loading, naive
|
| 307 |
+
# settings will return duplicated data (which may be desired), and
|
| 308 |
+
# thus using a sampler with length matching that of dataset will
|
| 309 |
+
# cause data lost (you may have duplicates of the first couple
|
| 310 |
+
# batches, but never see anything afterwards). Therefore,
|
| 311 |
+
# `Iterabledataset` always uses an infinite sampler, an instance of
|
| 312 |
+
# `_InfiniteConstantSampler` defined above.
|
| 313 |
+
#
|
| 314 |
+
# A custom `batch_sampler` essentially only controls the batch size.
|
| 315 |
+
# However, it is unclear how useful it would be since an iterable-style
|
| 316 |
+
# dataset can handle that within itself. Moreover, it is pointless
|
| 317 |
+
# in multi-process data loading as the assignment order of batches
|
| 318 |
+
# to workers is an implementation detail so users can not control
|
| 319 |
+
# how to batchify each worker's iterable. Thus, we disable this
|
| 320 |
+
# option. If this turns out to be useful in future, we can re-enable
|
| 321 |
+
# this, and support custom samplers that specify the assignments to
|
| 322 |
+
# specific workers.
|
| 323 |
+
if isinstance(dataset, IterDataPipe):
|
| 324 |
+
if shuffle is not None:
|
| 325 |
+
dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
|
| 326 |
+
dataset, shuffle=shuffle
|
| 327 |
+
)
|
| 328 |
+
# We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
|
| 329 |
+
elif shuffle not in {False, None}:
|
| 330 |
+
raise ValueError(
|
| 331 |
+
f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}"
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if sampler is not None:
|
| 335 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
| 336 |
+
raise ValueError(
|
| 337 |
+
f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}"
|
| 338 |
+
)
|
| 339 |
+
elif batch_sampler is not None:
|
| 340 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
| 341 |
+
raise ValueError(
|
| 342 |
+
"DataLoader with IterableDataset: expected unspecified "
|
| 343 |
+
f"batch_sampler option, but got batch_sampler={batch_sampler}"
|
| 344 |
+
)
|
| 345 |
+
else:
|
| 346 |
+
shuffle = bool(shuffle)
|
| 347 |
+
self._dataset_kind = _DatasetKind.Map
|
| 348 |
+
|
| 349 |
+
if sampler is not None and shuffle:
|
| 350 |
+
raise ValueError("sampler option is mutually exclusive with " "shuffle")
|
| 351 |
+
|
| 352 |
+
if batch_sampler is not None:
|
| 353 |
+
# auto_collation with custom batch_sampler
|
| 354 |
+
if batch_size != 1 or shuffle or sampler is not None or drop_last:
|
| 355 |
+
raise ValueError(
|
| 356 |
+
"batch_sampler option is mutually exclusive "
|
| 357 |
+
"with batch_size, shuffle, sampler, and "
|
| 358 |
+
"drop_last"
|
| 359 |
+
)
|
| 360 |
+
batch_size = None
|
| 361 |
+
drop_last = False
|
| 362 |
+
elif batch_size is None:
|
| 363 |
+
# no auto_collation
|
| 364 |
+
if drop_last:
|
| 365 |
+
raise ValueError(
|
| 366 |
+
"batch_size=None option disables auto-batching "
|
| 367 |
+
"and is mutually exclusive with drop_last"
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if sampler is None: # give default samplers
|
| 371 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
| 372 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
| 373 |
+
sampler = _InfiniteConstantSampler()
|
| 374 |
+
else: # map-style
|
| 375 |
+
if shuffle:
|
| 376 |
+
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
|
| 377 |
+
else:
|
| 378 |
+
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
|
| 379 |
+
|
| 380 |
+
if batch_size is not None and batch_sampler is None:
|
| 381 |
+
# auto_collation without custom batch_sampler
|
| 382 |
+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
| 383 |
+
|
| 384 |
+
self.batch_size = batch_size
|
| 385 |
+
self.drop_last = drop_last
|
| 386 |
+
self.sampler = sampler
|
| 387 |
+
self.batch_sampler = batch_sampler
|
| 388 |
+
self.generator = generator
|
| 389 |
+
|
| 390 |
+
if collate_fn is None:
|
| 391 |
+
if self._auto_collation:
|
| 392 |
+
collate_fn = _utils.collate.default_collate
|
| 393 |
+
else:
|
| 394 |
+
collate_fn = _utils.collate.default_convert
|
| 395 |
+
|
| 396 |
+
self.collate_fn = collate_fn
|
| 397 |
+
self.persistent_workers = persistent_workers
|
| 398 |
+
|
| 399 |
+
self.__initialized = True
|
| 400 |
+
self._IterableDataset_len_called = (
|
| 401 |
+
None # See NOTE [ IterableDataset and __len__ ]
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
self._iterator = None
|
| 405 |
+
|
| 406 |
+
self.check_worker_number_rationality()
|
| 407 |
+
|
| 408 |
+
torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
|
| 409 |
+
|
| 410 |
+
def _get_iterator(self) -> "_BaseDataLoaderIter":
|
| 411 |
+
if self.num_workers == 0:
|
| 412 |
+
return _SingleProcessDataLoaderIter(self)
|
| 413 |
+
else:
|
| 414 |
+
self.check_worker_number_rationality()
|
| 415 |
+
return _MultiProcessingDataLoaderIter(self)
|
| 416 |
+
|
| 417 |
+
@property
|
| 418 |
+
def multiprocessing_context(self):
|
| 419 |
+
return self.__multiprocessing_context
|
| 420 |
+
|
| 421 |
+
@multiprocessing_context.setter
|
| 422 |
+
def multiprocessing_context(self, multiprocessing_context):
|
| 423 |
+
if multiprocessing_context is not None:
|
| 424 |
+
if self.num_workers > 0:
|
| 425 |
+
if isinstance(multiprocessing_context, str):
|
| 426 |
+
valid_start_methods = torch.multiprocessing.get_all_start_methods()
|
| 427 |
+
if multiprocessing_context not in valid_start_methods:
|
| 428 |
+
raise ValueError(
|
| 429 |
+
"multiprocessing_context option "
|
| 430 |
+
f"should specify a valid start method in {valid_start_methods!r}, but got "
|
| 431 |
+
f"multiprocessing_context={multiprocessing_context!r}"
|
| 432 |
+
)
|
| 433 |
+
multiprocessing_context = torch.multiprocessing.get_context(
|
| 434 |
+
multiprocessing_context
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
if not isinstance(
|
| 438 |
+
multiprocessing_context, python_multiprocessing.context.BaseContext
|
| 439 |
+
):
|
| 440 |
+
raise TypeError(
|
| 441 |
+
"multiprocessing_context option should be a valid context "
|
| 442 |
+
"object or a string specifying the start method, but got "
|
| 443 |
+
f"multiprocessing_context={multiprocessing_context}"
|
| 444 |
+
)
|
| 445 |
+
else:
|
| 446 |
+
raise ValueError(
|
| 447 |
+
"multiprocessing_context can only be used with "
|
| 448 |
+
"multi-process loading (num_workers > 0), but got "
|
| 449 |
+
f"num_workers={self.num_workers}"
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
self.__multiprocessing_context = multiprocessing_context
|
| 453 |
+
|
| 454 |
+
def __setattr__(self, attr, val):
|
| 455 |
+
if self.__initialized and attr in (
|
| 456 |
+
"batch_size",
|
| 457 |
+
"batch_sampler",
|
| 458 |
+
"sampler",
|
| 459 |
+
"drop_last",
|
| 460 |
+
"dataset",
|
| 461 |
+
"persistent_workers",
|
| 462 |
+
):
|
| 463 |
+
raise ValueError(
|
| 464 |
+
f"{attr} attribute should not be set after {self.__class__.__name__} is initialized"
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
super().__setattr__(attr, val)
|
| 468 |
+
|
| 469 |
+
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
|
| 470 |
+
# since '_BaseDataLoaderIter' references 'DataLoader'.
|
| 471 |
+
def __iter__(self) -> "_BaseDataLoaderIter":
|
| 472 |
+
# When using a single worker the returned iterator should be
|
| 473 |
+
# created everytime to avoid resetting its state
|
| 474 |
+
# However, in the case of a multiple workers iterator
|
| 475 |
+
# the iterator is only created once in the lifetime of the
|
| 476 |
+
# DataLoader object so that workers can be reused
|
| 477 |
+
if self.persistent_workers and self.num_workers > 0:
|
| 478 |
+
if self._iterator is None:
|
| 479 |
+
self._iterator = self._get_iterator()
|
| 480 |
+
else:
|
| 481 |
+
self._iterator._reset(self)
|
| 482 |
+
return self._iterator
|
| 483 |
+
else:
|
| 484 |
+
return self._get_iterator()
|
| 485 |
+
|
| 486 |
+
@property
|
| 487 |
+
def _auto_collation(self):
|
| 488 |
+
return self.batch_sampler is not None
|
| 489 |
+
|
| 490 |
+
@property
|
| 491 |
+
def _index_sampler(self):
|
| 492 |
+
# The actual sampler used for generating indices for `_DatasetFetcher`
|
| 493 |
+
# (see _utils/fetch.py) to read data at each time. This would be
|
| 494 |
+
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
|
| 495 |
+
# We can't change `.sampler` and `.batch_sampler` attributes for BC
|
| 496 |
+
# reasons.
|
| 497 |
+
if self._auto_collation:
|
| 498 |
+
return self.batch_sampler
|
| 499 |
+
else:
|
| 500 |
+
return self.sampler
|
| 501 |
+
|
| 502 |
+
def __len__(self) -> int:
|
| 503 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
| 504 |
+
# NOTE [ IterableDataset and __len__ ]
|
| 505 |
+
#
|
| 506 |
+
# For `IterableDataset`, `__len__` could be inaccurate when one naively
|
| 507 |
+
# does multi-processing data loading, since the samples will be duplicated.
|
| 508 |
+
# However, no real use case should be actually using that behavior, so
|
| 509 |
+
# it should count as a user error. We should generally trust user
|
| 510 |
+
# code to do the proper thing (e.g., configure each replica differently
|
| 511 |
+
# in `__iter__`), and give us the correct `__len__` if they choose to
|
| 512 |
+
# implement it (this will still throw if the dataset does not implement
|
| 513 |
+
# a `__len__`).
|
| 514 |
+
#
|
| 515 |
+
# To provide a further warning, we track if `__len__` was called on the
|
| 516 |
+
# `DataLoader`, save the returned value in `self._len_called`, and warn
|
| 517 |
+
# if the iterator ends up yielding more than this number of samples.
|
| 518 |
+
|
| 519 |
+
# Cannot statically verify that dataset is Sized
|
| 520 |
+
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
|
| 521 |
+
if (
|
| 522 |
+
self.batch_size is not None
|
| 523 |
+
): # IterableDataset doesn't allow custom sampler or batch_sampler
|
| 524 |
+
from math import ceil
|
| 525 |
+
|
| 526 |
+
if self.drop_last:
|
| 527 |
+
length = length // self.batch_size
|
| 528 |
+
else:
|
| 529 |
+
length = ceil(length / self.batch_size)
|
| 530 |
+
return length
|
| 531 |
+
else:
|
| 532 |
+
return len(self._index_sampler)
|
| 533 |
+
|
| 534 |
+
def check_worker_number_rationality(self):
|
| 535 |
+
# This function check whether the dataloader's worker number is rational based on
|
| 536 |
+
# current system's resource. Current rule is that if the number of workers this
|
| 537 |
+
# Dataloader will create is bigger than the number of logical cpus that is allowed to
|
| 538 |
+
# use, than we will pop up a warning to let user pay attention.
|
| 539 |
+
#
|
| 540 |
+
# eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
|
| 541 |
+
# threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
|
| 542 |
+
# DataLoader process can use half of them which is 32, then the rational max number of
|
| 543 |
+
# worker that initiated from this process is 32.
|
| 544 |
+
# Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
|
| 545 |
+
# So the warning message is triggered to notify the user to lower the worker number if
|
| 546 |
+
# necessary.
|
| 547 |
+
#
|
| 548 |
+
#
|
| 549 |
+
# [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
|
| 550 |
+
# available (available in most of Linux system, but not OSX and Windows).
|
| 551 |
+
# When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
|
| 552 |
+
# it doesn't repect cpuset.
|
| 553 |
+
# We don't take threading into account since each worker process is single threaded
|
| 554 |
+
# at this time.
|
| 555 |
+
#
|
| 556 |
+
# We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
|
| 557 |
+
# other than `torch.set_num_threads` to 1 in the worker process, if the passing
|
| 558 |
+
# in functions use 3rd party modules that rely on those threading flags to determine
|
| 559 |
+
# how many thread to create (eg. numpy, etc), then it is caller's responsibility to
|
| 560 |
+
# set those flags correctly.
|
| 561 |
+
def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
|
| 562 |
+
suggested_max_worker_msg = (
|
| 563 |
+
(
|
| 564 |
+
(
|
| 565 |
+
"Our suggested max number of worker in current system is {}{}, which is smaller "
|
| 566 |
+
"than what this DataLoader is going to create."
|
| 567 |
+
).format(
|
| 568 |
+
num_worker_suggest,
|
| 569 |
+
(
|
| 570 |
+
""
|
| 571 |
+
if cpuset_checked
|
| 572 |
+
else " (`cpuset` is not taken into account)"
|
| 573 |
+
),
|
| 574 |
+
)
|
| 575 |
+
)
|
| 576 |
+
if num_worker_suggest is not None
|
| 577 |
+
else (
|
| 578 |
+
"DataLoader is not able to compute a suggested max number of worker in current system."
|
| 579 |
+
)
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
warn_msg = (
|
| 583 |
+
f"This DataLoader will create {num_worker_created} worker processes in total. {suggested_max_worker_msg} "
|
| 584 |
+
"Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
|
| 585 |
+
"lower the worker number to avoid potential slowness/freeze if necessary."
|
| 586 |
+
)
|
| 587 |
+
return warn_msg
|
| 588 |
+
|
| 589 |
+
if not self.num_workers or self.num_workers == 0:
|
| 590 |
+
return
|
| 591 |
+
|
| 592 |
+
# try to compute a suggested max number of worker based on system's resource
|
| 593 |
+
max_num_worker_suggest = None
|
| 594 |
+
cpuset_checked = False
|
| 595 |
+
if hasattr(os, "sched_getaffinity"):
|
| 596 |
+
try:
|
| 597 |
+
max_num_worker_suggest = len(os.sched_getaffinity(0))
|
| 598 |
+
cpuset_checked = True
|
| 599 |
+
except Exception:
|
| 600 |
+
pass
|
| 601 |
+
if max_num_worker_suggest is None:
|
| 602 |
+
# os.cpu_count() could return Optional[int]
|
| 603 |
+
# get cpu count first and check None in order to satisfy mypy check
|
| 604 |
+
cpu_count = os.cpu_count()
|
| 605 |
+
if cpu_count is not None:
|
| 606 |
+
max_num_worker_suggest = cpu_count
|
| 607 |
+
|
| 608 |
+
if max_num_worker_suggest is None:
|
| 609 |
+
warnings.warn(
|
| 610 |
+
_create_warning_msg(
|
| 611 |
+
max_num_worker_suggest, self.num_workers, cpuset_checked
|
| 612 |
+
)
|
| 613 |
+
)
|
| 614 |
+
return
|
| 615 |
+
|
| 616 |
+
if self.num_workers > max_num_worker_suggest:
|
| 617 |
+
warnings.warn(
|
| 618 |
+
_create_warning_msg(
|
| 619 |
+
max_num_worker_suggest, self.num_workers, cpuset_checked
|
| 620 |
+
)
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
class _BaseDataLoaderIter:
|
| 625 |
+
def __init__(self, loader: DataLoader) -> None:
|
| 626 |
+
self._dataset = loader.dataset
|
| 627 |
+
self._shared_seed = None
|
| 628 |
+
self._pg = None
|
| 629 |
+
if isinstance(self._dataset, IterDataPipe):
|
| 630 |
+
if dist.is_available() and dist.is_initialized():
|
| 631 |
+
self._pg = dist.new_group(backend="gloo")
|
| 632 |
+
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
|
| 633 |
+
shared_rng = torch.Generator()
|
| 634 |
+
shared_rng.manual_seed(self._shared_seed)
|
| 635 |
+
self._dataset = torch.utils.data.graph_settings.apply_random_seed(
|
| 636 |
+
self._dataset, shared_rng
|
| 637 |
+
)
|
| 638 |
+
self._dataset_kind = loader._dataset_kind
|
| 639 |
+
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
| 640 |
+
self._auto_collation = loader._auto_collation
|
| 641 |
+
self._drop_last = loader.drop_last
|
| 642 |
+
self._index_sampler = loader._index_sampler
|
| 643 |
+
self._num_workers = loader.num_workers
|
| 644 |
+
ws, rank = _get_distributed_settings()
|
| 645 |
+
self._world_size = ws
|
| 646 |
+
self._rank = rank
|
| 647 |
+
# for other backends, pin_memory_device need to set. if not set
|
| 648 |
+
# default behaviour is CUDA device. if pin_memory_device is selected
|
| 649 |
+
# and pin_memory is not set, the default behaviour false.
|
| 650 |
+
if len(loader.pin_memory_device) == 0:
|
| 651 |
+
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
|
| 652 |
+
self._pin_memory_device = None
|
| 653 |
+
else:
|
| 654 |
+
if not loader.pin_memory:
|
| 655 |
+
warn_msg = (
|
| 656 |
+
"pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
|
| 657 |
+
"please set pin_memory to true, if you need to use the device pin memory"
|
| 658 |
+
)
|
| 659 |
+
warnings.warn(warn_msg)
|
| 660 |
+
|
| 661 |
+
self._pin_memory = loader.pin_memory
|
| 662 |
+
self._pin_memory_device = loader.pin_memory_device
|
| 663 |
+
self._timeout = loader.timeout
|
| 664 |
+
self._collate_fn = loader.collate_fn
|
| 665 |
+
self._sampler_iter = iter(self._index_sampler)
|
| 666 |
+
self._base_seed = (
|
| 667 |
+
torch.empty((), dtype=torch.int64)
|
| 668 |
+
.random_(generator=loader.generator)
|
| 669 |
+
.item()
|
| 670 |
+
)
|
| 671 |
+
self._persistent_workers = loader.persistent_workers
|
| 672 |
+
self._num_yielded = 0
|
| 673 |
+
self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"
|
| 674 |
+
|
| 675 |
+
def __iter__(self) -> "_BaseDataLoaderIter":
|
| 676 |
+
return self
|
| 677 |
+
|
| 678 |
+
def _reset(self, loader, first_iter=False):
|
| 679 |
+
self._sampler_iter = iter(self._index_sampler)
|
| 680 |
+
self._num_yielded = 0
|
| 681 |
+
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
| 682 |
+
if isinstance(self._dataset, IterDataPipe):
|
| 683 |
+
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
|
| 684 |
+
shared_rng = torch.Generator()
|
| 685 |
+
shared_rng.manual_seed(self._shared_seed)
|
| 686 |
+
self._dataset = torch.utils.data.graph_settings.apply_random_seed(
|
| 687 |
+
self._dataset, shared_rng
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
def _next_index(self):
|
| 691 |
+
return next(self._sampler_iter) # may raise StopIteration
|
| 692 |
+
|
| 693 |
+
def _next_data(self):
|
| 694 |
+
raise NotImplementedError
|
| 695 |
+
|
| 696 |
+
def __next__(self) -> Any:
|
| 697 |
+
with torch.autograd.profiler.record_function(self._profile_name):
|
| 698 |
+
if self._sampler_iter is None:
|
| 699 |
+
# TODO(https://github.com/pytorch/pytorch/issues/76750)
|
| 700 |
+
self._reset() # type: ignore[call-arg]
|
| 701 |
+
data = self._next_data()
|
| 702 |
+
self._num_yielded += 1
|
| 703 |
+
if (
|
| 704 |
+
self._dataset_kind == _DatasetKind.Iterable
|
| 705 |
+
and self._IterableDataset_len_called is not None
|
| 706 |
+
and self._num_yielded > self._IterableDataset_len_called
|
| 707 |
+
):
|
| 708 |
+
warn_msg = (
|
| 709 |
+
f"Length of IterableDataset {self._dataset} was reported to be {self._IterableDataset_len_called}"
|
| 710 |
+
f"(when accessing len(dataloader)), but {self._num_yielded} samples have been fetched. "
|
| 711 |
+
)
|
| 712 |
+
if self._num_workers > 0:
|
| 713 |
+
warn_msg += (
|
| 714 |
+
"For multiprocessing data-loading, this could be caused by not properly configuring the "
|
| 715 |
+
"IterableDataset replica at each worker. Please see "
|
| 716 |
+
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
|
| 717 |
+
)
|
| 718 |
+
warnings.warn(warn_msg)
|
| 719 |
+
return data
|
| 720 |
+
|
| 721 |
+
def __len__(self) -> int:
|
| 722 |
+
return len(self._index_sampler)
|
| 723 |
+
|
| 724 |
+
def __getstate__(self):
|
| 725 |
+
# TODO: add limited pickling support for sharing an iterator
|
| 726 |
+
# across multiple threads for HOGWILD.
|
| 727 |
+
# Probably the best way to do this is by moving the sample pushing
|
| 728 |
+
# to a separate thread and then just sharing the data queue
|
| 729 |
+
# but signalling the end is tricky without a non-blocking API
|
| 730 |
+
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
|
| 734 |
+
def __init__(self, loader):
|
| 735 |
+
super().__init__(loader)
|
| 736 |
+
assert self._timeout == 0
|
| 737 |
+
assert self._num_workers == 0
|
| 738 |
+
|
| 739 |
+
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
| 740 |
+
# Taking care of distributed sharding
|
| 741 |
+
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
| 742 |
+
# For BC, use default SHARDING_PRIORITIES
|
| 743 |
+
torch.utils.data.graph_settings.apply_sharding(
|
| 744 |
+
self._dataset, self._world_size, self._rank
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
self._dataset_fetcher = _DatasetKind.create_fetcher(
|
| 748 |
+
self._dataset_kind,
|
| 749 |
+
self._dataset,
|
| 750 |
+
self._auto_collation,
|
| 751 |
+
self._collate_fn,
|
| 752 |
+
self._drop_last,
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
def _next_data(self):
|
| 756 |
+
index = self._next_index() # may raise StopIteration
|
| 757 |
+
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
|
| 758 |
+
if self._pin_memory:
|
| 759 |
+
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
|
| 760 |
+
return data
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
| 764 |
+
r"""Iterates once over the DataLoader's dataset, as specified by the sampler."""
|
| 765 |
+
|
| 766 |
+
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
|
| 767 |
+
#
|
| 768 |
+
# Preliminary:
|
| 769 |
+
#
|
| 770 |
+
# Our data model looks like this (queues are indicated with curly brackets):
|
| 771 |
+
#
|
| 772 |
+
# main process ||
|
| 773 |
+
# | ||
|
| 774 |
+
# {index_queue} ||
|
| 775 |
+
# | ||
|
| 776 |
+
# worker processes || DATA
|
| 777 |
+
# | ||
|
| 778 |
+
# {worker_result_queue} || FLOW
|
| 779 |
+
# | ||
|
| 780 |
+
# pin_memory_thread of main process || DIRECTION
|
| 781 |
+
# | ||
|
| 782 |
+
# {data_queue} ||
|
| 783 |
+
# | ||
|
| 784 |
+
# data output \/
|
| 785 |
+
#
|
| 786 |
+
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
|
| 787 |
+
# `pin_memory=False`.
|
| 788 |
+
#
|
| 789 |
+
#
|
| 790 |
+
# Terminating multiprocessing logic requires very careful design. In
|
| 791 |
+
# particular, we need to make sure that
|
| 792 |
+
#
|
| 793 |
+
# 1. The iterator gracefully exits the workers when its last reference is
|
| 794 |
+
# gone or it is depleted.
|
| 795 |
+
#
|
| 796 |
+
# In this case, the workers should be gracefully exited because the
|
| 797 |
+
# main process may still need to continue to run, and we want cleaning
|
| 798 |
+
# up code in the workers to be executed (e.g., releasing GPU memory).
|
| 799 |
+
# Naturally, we implement the shutdown logic in `__del__` of
|
| 800 |
+
# DataLoaderIterator.
|
| 801 |
+
#
|
| 802 |
+
# We delay the discussion on the logic in this case until later.
|
| 803 |
+
#
|
| 804 |
+
# 2. The iterator exits the workers when the loader process and/or worker
|
| 805 |
+
# processes exits normally or with error.
|
| 806 |
+
#
|
| 807 |
+
# We set all workers and `pin_memory_thread` to have `daemon=True`.
|
| 808 |
+
#
|
| 809 |
+
# You may ask, why can't we make the workers non-daemonic, and
|
| 810 |
+
# gracefully exit using the same logic as we have in `__del__` when the
|
| 811 |
+
# iterator gets deleted (see 1 above)?
|
| 812 |
+
#
|
| 813 |
+
# First of all, `__del__` is **not** guaranteed to be called when
|
| 814 |
+
# interpreter exits. Even if it is called, by the time it executes,
|
| 815 |
+
# many Python core library resources may already be freed, and even
|
| 816 |
+
# simple things like acquiring an internal lock of a queue may hang.
|
| 817 |
+
# Therefore, in this case, we actually need to prevent `__del__` from
|
| 818 |
+
# being executed, and rely on the automatic termination of daemonic
|
| 819 |
+
# children.
|
| 820 |
+
#
|
| 821 |
+
# Thus, we register an `atexit` hook that sets a global flag
|
| 822 |
+
# `_utils.python_exit_status`. Since `atexit` hooks are executed in the
|
| 823 |
+
# reverse order of registration, we are guaranteed that this flag is
|
| 824 |
+
# set before library resources we use are freed (which, at least in
|
| 825 |
+
# CPython, is done via an `atexit` handler defined in
|
| 826 |
+
# `multiprocessing/util.py`
|
| 827 |
+
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
|
| 828 |
+
# registered when an object requiring this mechanism is first
|
| 829 |
+
# created, e.g., `mp.Queue`
|
| 830 |
+
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
|
| 831 |
+
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
|
| 832 |
+
# )
|
| 833 |
+
#
|
| 834 |
+
# So in `__del__`, we check if `_utils.python_exit_status` is set or
|
| 835 |
+
# `None` (freed), and perform no-op if so.
|
| 836 |
+
#
|
| 837 |
+
# However, simply letting library clean-up codes run can also be bad,
|
| 838 |
+
# because such codes (i.e., `multiprocessing.util._exit_function()`)
|
| 839 |
+
# include join putting threads for `mp.Queue`, which can be blocking.
|
| 840 |
+
# Hence, the main process putting threads are called with
|
| 841 |
+
# `cancel_join_thread` at creation. See later section
|
| 842 |
+
# [ 3b. A process won't hang when putting into a queue; ]
|
| 843 |
+
# for more details.
|
| 844 |
+
#
|
| 845 |
+
# Here are two example cases where library clean-up codes can run
|
| 846 |
+
# before `__del__` is called:
|
| 847 |
+
#
|
| 848 |
+
# 1. If we hold onto a reference to the iterator, it more often
|
| 849 |
+
# than not tries to do `multiprocessing` library cleaning before
|
| 850 |
+
# clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
|
| 851 |
+
# and thus prevents our cleaning-up code to run first.
|
| 852 |
+
#
|
| 853 |
+
# 2. A similar issue araises when a `DataLoader` is used in a subprocess.
|
| 854 |
+
# When a process ends, it shuts the all its daemonic children
|
| 855 |
+
# down with a SIGTERM (instead of joining them without a timeout).
|
| 856 |
+
# Simiarly for threads, but by a different mechanism. This fact,
|
| 857 |
+
# together with a few implementation details of multiprocessing, forces
|
| 858 |
+
# us to make workers daemonic. All of our problems arise when a
|
| 859 |
+
# DataLoader is used in a subprocess, and are caused by multiprocessing
|
| 860 |
+
# code which looks more or less like this:
|
| 861 |
+
#
|
| 862 |
+
# try:
|
| 863 |
+
# your_function_using_a_dataloader()
|
| 864 |
+
# finally:
|
| 865 |
+
# multiprocessing.util._exit_function()
|
| 866 |
+
#
|
| 867 |
+
# The joining/termination mentioned above happens inside
|
| 868 |
+
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
|
| 869 |
+
# throws, the stack trace stored in the exception will prevent the
|
| 870 |
+
# frame which uses `DataLoaderIter` to be freed. If the frame has any
|
| 871 |
+
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
|
| 872 |
+
# its `__del__`, which starts the shutdown procedure, will not be
|
| 873 |
+
# called. That, in turn, means that workers aren't notified. Attempting
|
| 874 |
+
# to join in `_exit_function` will then result in a hang.
|
| 875 |
+
#
|
| 876 |
+
# For context, `_exit_function` is also registered as an `atexit` call.
|
| 877 |
+
# So it is unclear to me (@ssnl) why this is needed in a finally block.
|
| 878 |
+
# The code dates back to 2008 and there is no comment on the original
|
| 879 |
+
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
|
| 880 |
+
# the finally block and the `atexit` registration) that explains this.
|
| 881 |
+
#
|
| 882 |
+
#
|
| 883 |
+
# Finally, another choice is to just shutdown workers with logic in 1
|
| 884 |
+
# above whenever we see an error in `next`. This isn't ideal because
|
| 885 |
+
# a. It prevents users from using try-catch to resume data loading.
|
| 886 |
+
# b. It doesn't prevent hanging if users have references to the
|
| 887 |
+
# iterator.
|
| 888 |
+
#
|
| 889 |
+
# 3. All processes exit if any of them die unexpectedly by fatal signals.
|
| 890 |
+
#
|
| 891 |
+
# As shown above, the workers are set as daemonic children of the main
|
| 892 |
+
# process. However, automatic cleaning-up of such child processes only
|
| 893 |
+
# happens if the parent process exits gracefully (e.g., not via fatal
|
| 894 |
+
# signals like SIGKILL). So we must ensure that each process will exit
|
| 895 |
+
# even the process that should send/receive data to/from it were
|
| 896 |
+
# killed, i.e.,
|
| 897 |
+
#
|
| 898 |
+
# a. A process won't hang when getting from a queue.
|
| 899 |
+
#
|
| 900 |
+
# Even with carefully designed data dependencies (i.e., a `put()`
|
| 901 |
+
# always corresponding to a `get()`), hanging on `get()` can still
|
| 902 |
+
# happen when data in queue is corrupted (e.g., due to
|
| 903 |
+
# `cancel_join_thread` or unexpected exit).
|
| 904 |
+
#
|
| 905 |
+
# For child exit, we set a timeout whenever we try to get data
|
| 906 |
+
# from `data_queue`, and check the workers' status on each timeout
|
| 907 |
+
# and error.
|
| 908 |
+
# See `_DataLoaderiter._get_batch()` and
|
| 909 |
+
# `_DataLoaderiter._try_get_data()` for details.
|
| 910 |
+
#
|
| 911 |
+
# Additionally, for child exit on non-Windows platforms, we also
|
| 912 |
+
# register a SIGCHLD handler (which is supported on Windows) on
|
| 913 |
+
# the main process, which checks if any of the workers fail in the
|
| 914 |
+
# (Python) handler. This is more efficient and faster in detecting
|
| 915 |
+
# worker failures, compared to only using the above mechanism.
|
| 916 |
+
# See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
|
| 917 |
+
#
|
| 918 |
+
# For `.get()` calls where the sender(s) is not the workers, we
|
| 919 |
+
# guard them with timeouts, and check the status of the sender
|
| 920 |
+
# when timeout happens:
|
| 921 |
+
# + in the workers, the `_utils.worker.ManagerWatchdog` class
|
| 922 |
+
# checks the status of the main process.
|
| 923 |
+
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
|
| 924 |
+
# check `pin_memory_thread` status periodically until `.get()`
|
| 925 |
+
# returns or see that `pin_memory_thread` died.
|
| 926 |
+
#
|
| 927 |
+
# b. A process won't hang when putting into a queue;
|
| 928 |
+
#
|
| 929 |
+
# We use `mp.Queue` which has a separate background thread to put
|
| 930 |
+
# objects from an unbounded buffer array. The background thread is
|
| 931 |
+
# daemonic and usually automatically joined when the process
|
| 932 |
+
# *exits*.
|
| 933 |
+
#
|
| 934 |
+
# In case that the receiver has ended abruptly while
|
| 935 |
+
# reading from the pipe, the join will hang forever. The usual
|
| 936 |
+
# solution for this in Python is calling `q.cancel_join_thread`,
|
| 937 |
+
# which prevents automatically joining it when finalizing
|
| 938 |
+
# (exiting).
|
| 939 |
+
#
|
| 940 |
+
# Nonetheless, `cancel_join_thread` must only be called when the
|
| 941 |
+
# queue is **not** going to be read from or write into by another
|
| 942 |
+
# process, because it may hold onto a lock or leave corrupted data
|
| 943 |
+
# in the queue, leading other readers/writers to hang.
|
| 944 |
+
#
|
| 945 |
+
# Hence,
|
| 946 |
+
# + For worker processes, we only do so (for their output
|
| 947 |
+
# queues, i.e., `worker_result_queue`) before exiting.
|
| 948 |
+
# + For `pin_memory_thread`, its output queue `data_queue` is a
|
| 949 |
+
# `queue.Queue` that does blocking `put` if the queue is full.
|
| 950 |
+
# So there is no above problem, but as a result, in
|
| 951 |
+
# `_pin_memory_loop`, we do need to wrap the `put` in a loop
|
| 952 |
+
# that breaks not only upon success, but also when the main
|
| 953 |
+
# process stops reading, i.e., is shutting down.
|
| 954 |
+
# + For loader process, we `cancel_join_thread()` for all
|
| 955 |
+
# `_index_queues` because the whole purpose of workers and
|
| 956 |
+
# `pin_memory_thread` is to serve the loader process. If
|
| 957 |
+
# loader process is already exiting, we don't really care if
|
| 958 |
+
# the queues are corrupted.
|
| 959 |
+
#
|
| 960 |
+
#
|
| 961 |
+
# Now let's get back to 1:
|
| 962 |
+
# how we gracefully exit the workers when the last reference to the
|
| 963 |
+
# iterator is gone.
|
| 964 |
+
#
|
| 965 |
+
# To achieve this, we implement the following logic along with the design
|
| 966 |
+
# choices mentioned above:
|
| 967 |
+
#
|
| 968 |
+
# `workers_done_event`:
|
| 969 |
+
# A `multiprocessing.Event` shared among the main process and all worker
|
| 970 |
+
# processes. This is used to signal the workers that the iterator is
|
| 971 |
+
# shutting down. After it is set, they will not send processed data to
|
| 972 |
+
# queues anymore, and only wait for the final `None` before exiting.
|
| 973 |
+
# `done_event` isn't strictly needed. I.e., we can just check for `None`
|
| 974 |
+
# from the input queue, but it allows us to skip wasting resources
|
| 975 |
+
# processing data if we are already shutting down.
|
| 976 |
+
#
|
| 977 |
+
# `pin_memory_thread_done_event`:
|
| 978 |
+
# A `threading.Event` for a similar purpose to that of
|
| 979 |
+
# `workers_done_event`, but is for the `pin_memory_thread`. The reason
|
| 980 |
+
# that separate events are needed is that `pin_memory_thread` reads from
|
| 981 |
+
# the output queue of the workers. But the workers, upon seeing that
|
| 982 |
+
# `workers_done_event` is set, only wants to see the final `None`, and is
|
| 983 |
+
# not required to flush all data in the output queue (e.g., it may call
|
| 984 |
+
# `cancel_join_thread` on that queue if its `IterableDataset` iterator
|
| 985 |
+
# happens to exhaust coincidentally, which is out of the control of the
|
| 986 |
+
# main process). Thus, since we will exit `pin_memory_thread` before the
|
| 987 |
+
# workers (see below), two separete events are used.
|
| 988 |
+
#
|
| 989 |
+
# NOTE: In short, the protocol is that the main process will set these
|
| 990 |
+
# `done_event`s and then the corresponding processes/threads a `None`,
|
| 991 |
+
# and that they may exit at any time after receiving the `None`.
|
| 992 |
+
#
|
| 993 |
+
# NOTE: Using `None` as the final signal is valid, since normal data will
|
| 994 |
+
# always be a 2-tuple with the 1st element being the index of the data
|
| 995 |
+
# transferred (different from dataset index/key), and the 2nd being
|
| 996 |
+
# either the dataset key or the data sample (depending on which part
|
| 997 |
+
# of the data model the queue is at).
|
| 998 |
+
#
|
| 999 |
+
# [ worker processes ]
|
| 1000 |
+
# While loader process is alive:
|
| 1001 |
+
# Get from `index_queue`.
|
| 1002 |
+
# If get anything else,
|
| 1003 |
+
# Check `workers_done_event`.
|
| 1004 |
+
# If set, continue to next iteration
|
| 1005 |
+
# i.e., keep getting until see the `None`, then exit.
|
| 1006 |
+
# Otherwise, process data:
|
| 1007 |
+
# If is fetching from an `IterableDataset` and the iterator
|
| 1008 |
+
# is exhausted, send an `_IterableDatasetStopIteration`
|
| 1009 |
+
# object to signal iteration end. The main process, upon
|
| 1010 |
+
# receiving such an object, will send `None` to this
|
| 1011 |
+
# worker and not use the corresponding `index_queue`
|
| 1012 |
+
# anymore.
|
| 1013 |
+
# If timed out,
|
| 1014 |
+
# No matter `workers_done_event` is set (still need to see `None`)
|
| 1015 |
+
# or not, must continue to next iteration.
|
| 1016 |
+
# (outside loop)
|
| 1017 |
+
# If `workers_done_event` is set, (this can be False with `IterableDataset`)
|
| 1018 |
+
# `data_queue.cancel_join_thread()`. (Everything is ending here:
|
| 1019 |
+
# main process won't read from it;
|
| 1020 |
+
# other workers will also call
|
| 1021 |
+
# `cancel_join_thread`.)
|
| 1022 |
+
#
|
| 1023 |
+
# [ pin_memory_thread ]
|
| 1024 |
+
# # No need to check main thread. If this thread is alive, the main loader
|
| 1025 |
+
# # thread must be alive, because this thread is set as daemonic.
|
| 1026 |
+
# While `pin_memory_thread_done_event` is not set:
|
| 1027 |
+
# Get from `worker_result_queue`.
|
| 1028 |
+
# If timed out, continue to get in the next iteration.
|
| 1029 |
+
# Otherwise, process data.
|
| 1030 |
+
# While `pin_memory_thread_done_event` is not set:
|
| 1031 |
+
# Put processed data to `data_queue` (a `queue.Queue` with blocking put)
|
| 1032 |
+
# If timed out, continue to put in the next iteration.
|
| 1033 |
+
# Otherwise, break, i.e., continuing to the out loop.
|
| 1034 |
+
#
|
| 1035 |
+
# NOTE: we don't check the status of the main thread because
|
| 1036 |
+
# 1. if the process is killed by fatal signal, `pin_memory_thread`
|
| 1037 |
+
# ends.
|
| 1038 |
+
# 2. in other cases, either the cleaning-up in __del__ or the
|
| 1039 |
+
# automatic exit of daemonic thread will take care of it.
|
| 1040 |
+
# This won't busy-wait either because `.get(timeout)` does not
|
| 1041 |
+
# busy-wait.
|
| 1042 |
+
#
|
| 1043 |
+
# [ main process ]
|
| 1044 |
+
# In the DataLoader Iter's `__del__`
|
| 1045 |
+
# b. Exit `pin_memory_thread`
|
| 1046 |
+
# i. Set `pin_memory_thread_done_event`.
|
| 1047 |
+
# ii Put `None` in `worker_result_queue`.
|
| 1048 |
+
# iii. Join the `pin_memory_thread`.
|
| 1049 |
+
# iv. `worker_result_queue.cancel_join_thread()`.
|
| 1050 |
+
#
|
| 1051 |
+
# c. Exit the workers.
|
| 1052 |
+
# i. Set `workers_done_event`.
|
| 1053 |
+
# ii. Put `None` in each worker's `index_queue`.
|
| 1054 |
+
# iii. Join the workers.
|
| 1055 |
+
# iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
|
| 1056 |
+
#
|
| 1057 |
+
# NOTE: (c) is better placed after (b) because it may leave corrupted
|
| 1058 |
+
# data in `worker_result_queue`, which `pin_memory_thread`
|
| 1059 |
+
# reads from, in which case the `pin_memory_thread` can only
|
| 1060 |
+
# happen at timing out, which is slow. Nonetheless, same thing
|
| 1061 |
+
# happens if a worker is killed by signal at unfortunate times,
|
| 1062 |
+
# but in other cases, we are better off having a non-corrupted
|
| 1063 |
+
# `worker_result_queue` for `pin_memory_thread`.
|
| 1064 |
+
#
|
| 1065 |
+
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
|
| 1066 |
+
# can be omitted
|
| 1067 |
+
#
|
| 1068 |
+
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
|
| 1069 |
+
# `None` from `index_queue`, but it allows us to skip wasting resources
|
| 1070 |
+
# processing indices already in `index_queue` if we are already shutting
|
| 1071 |
+
# down.
|
| 1072 |
+
|
| 1073 |
+
def __init__(self, loader):
|
| 1074 |
+
super().__init__(loader)
|
| 1075 |
+
|
| 1076 |
+
self._prefetch_factor = loader.prefetch_factor
|
| 1077 |
+
|
| 1078 |
+
assert self._num_workers > 0
|
| 1079 |
+
assert self._prefetch_factor > 0
|
| 1080 |
+
|
| 1081 |
+
if loader.multiprocessing_context is None:
|
| 1082 |
+
multiprocessing_context = torch.multiprocessing
|
| 1083 |
+
else:
|
| 1084 |
+
multiprocessing_context = loader.multiprocessing_context
|
| 1085 |
+
|
| 1086 |
+
self._worker_init_fn = loader.worker_init_fn
|
| 1087 |
+
|
| 1088 |
+
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
| 1089 |
+
# Additional worker init function will take care of sharding in MP and Distributed
|
| 1090 |
+
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
| 1091 |
+
self._worker_init_fn = functools.partial(
|
| 1092 |
+
_sharding_worker_init_fn,
|
| 1093 |
+
self._worker_init_fn,
|
| 1094 |
+
self._world_size,
|
| 1095 |
+
self._rank,
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
+
# No certainty which module multiprocessing_context is
|
| 1099 |
+
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
| 1100 |
+
self._worker_pids_set = False
|
| 1101 |
+
self._shutdown = False
|
| 1102 |
+
self._workers_done_event = multiprocessing_context.Event()
|
| 1103 |
+
|
| 1104 |
+
self._index_queues = []
|
| 1105 |
+
self._workers = []
|
| 1106 |
+
for i in range(self._num_workers):
|
| 1107 |
+
# No certainty which module multiprocessing_context is
|
| 1108 |
+
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
| 1109 |
+
# Need to `cancel_join_thread` here!
|
| 1110 |
+
# See sections (2) and (3b) above.
|
| 1111 |
+
index_queue.cancel_join_thread()
|
| 1112 |
+
w = multiprocessing_context.Process(
|
| 1113 |
+
target=_utils.worker._worker_loop,
|
| 1114 |
+
args=(
|
| 1115 |
+
self._dataset_kind,
|
| 1116 |
+
self._dataset,
|
| 1117 |
+
index_queue,
|
| 1118 |
+
self._worker_result_queue,
|
| 1119 |
+
self._workers_done_event,
|
| 1120 |
+
self._auto_collation,
|
| 1121 |
+
self._collate_fn,
|
| 1122 |
+
self._drop_last,
|
| 1123 |
+
self._base_seed,
|
| 1124 |
+
self._worker_init_fn,
|
| 1125 |
+
i,
|
| 1126 |
+
self._num_workers,
|
| 1127 |
+
self._persistent_workers,
|
| 1128 |
+
self._shared_seed,
|
| 1129 |
+
),
|
| 1130 |
+
)
|
| 1131 |
+
w.daemon = True
|
| 1132 |
+
# NB: Process.start() actually take some time as it needs to
|
| 1133 |
+
# start a process and pass the arguments over via a pipe.
|
| 1134 |
+
# Therefore, we only add a worker to self._workers list after
|
| 1135 |
+
# it started, so that we do not call .join() if program dies
|
| 1136 |
+
# before it starts, and __del__ tries to join but will get:
|
| 1137 |
+
# AssertionError: can only join a started process.
|
| 1138 |
+
w.start()
|
| 1139 |
+
self._index_queues.append(index_queue)
|
| 1140 |
+
self._workers.append(w)
|
| 1141 |
+
|
| 1142 |
+
if self._pin_memory:
|
| 1143 |
+
self._pin_memory_thread_done_event = threading.Event()
|
| 1144 |
+
|
| 1145 |
+
# Queue is not type-annotated
|
| 1146 |
+
self._data_queue = queue.Queue() # type: ignore[var-annotated]
|
| 1147 |
+
if self._pin_memory_device == "xpu":
|
| 1148 |
+
current_device = torch.xpu.current_device() # type: ignore[attr-defined]
|
| 1149 |
+
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
|
| 1150 |
+
custom_device_mod = getattr(
|
| 1151 |
+
torch, torch._C._get_privateuse1_backend_name()
|
| 1152 |
+
)
|
| 1153 |
+
current_device = custom_device_mod.current_device()
|
| 1154 |
+
else:
|
| 1155 |
+
current_device = torch.cuda.current_device() # choose cuda for default
|
| 1156 |
+
pin_memory_thread = threading.Thread(
|
| 1157 |
+
target=_utils.pin_memory._pin_memory_loop,
|
| 1158 |
+
args=(
|
| 1159 |
+
self._worker_result_queue,
|
| 1160 |
+
self._data_queue,
|
| 1161 |
+
current_device,
|
| 1162 |
+
self._pin_memory_thread_done_event,
|
| 1163 |
+
self._pin_memory_device,
|
| 1164 |
+
),
|
| 1165 |
+
)
|
| 1166 |
+
pin_memory_thread.daemon = True
|
| 1167 |
+
pin_memory_thread.start()
|
| 1168 |
+
# Similar to workers (see comment above), we only register
|
| 1169 |
+
# pin_memory_thread once it is started.
|
| 1170 |
+
self._pin_memory_thread = pin_memory_thread
|
| 1171 |
+
else:
|
| 1172 |
+
self._data_queue = self._worker_result_queue # type: ignore[assignment]
|
| 1173 |
+
|
| 1174 |
+
# In some rare cases, persistent workers (daemonic processes)
|
| 1175 |
+
# would be terminated before `__del__` of iterator is invoked
|
| 1176 |
+
# when main process exits
|
| 1177 |
+
# It would cause failure when pin_memory_thread tries to read
|
| 1178 |
+
# corrupted data from worker_result_queue
|
| 1179 |
+
# atexit is used to shutdown thread and child processes in the
|
| 1180 |
+
# right sequence before main process exits
|
| 1181 |
+
if self._persistent_workers and self._pin_memory:
|
| 1182 |
+
import atexit
|
| 1183 |
+
|
| 1184 |
+
for w in self._workers:
|
| 1185 |
+
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
| 1186 |
+
|
| 1187 |
+
# .pid can be None only before process is spawned (not the case, so ignore)
|
| 1188 |
+
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
|
| 1189 |
+
_utils.signal_handling._set_SIGCHLD_handler()
|
| 1190 |
+
self._worker_pids_set = True
|
| 1191 |
+
self._reset(loader, first_iter=True)
|
| 1192 |
+
|
| 1193 |
+
def _reset(self, loader, first_iter=False):
|
| 1194 |
+
super()._reset(loader, first_iter)
|
| 1195 |
+
self._send_idx = 0 # idx of the next task to be sent to workers
|
| 1196 |
+
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
|
| 1197 |
+
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
|
| 1198 |
+
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
|
| 1199 |
+
# \ (worker_id, data) if data is already fetched (out-of-order)
|
| 1200 |
+
self._task_info = {}
|
| 1201 |
+
self._tasks_outstanding = (
|
| 1202 |
+
0 # always equal to count(v for v in task_info.values() if len(v) == 1)
|
| 1203 |
+
)
|
| 1204 |
+
# A list of booleans representing whether each worker still has work to
|
| 1205 |
+
# do, i.e., not having exhausted its iterable dataset object. It always
|
| 1206 |
+
# contains all `True`s if not using an iterable-style dataset
|
| 1207 |
+
# (i.e., if kind != Iterable).
|
| 1208 |
+
# Not that this indicates that a worker still has work to do *for this epoch*.
|
| 1209 |
+
# It does not mean that a worker is dead. In case of `_persistent_workers`,
|
| 1210 |
+
# the worker will be reset to available in the next epoch.
|
| 1211 |
+
self._workers_status = [True for i in range(self._num_workers)]
|
| 1212 |
+
# Reset the worker queue cycle so it resumes next epoch at worker 0
|
| 1213 |
+
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
|
| 1214 |
+
# We resume the prefetching in case it was enabled
|
| 1215 |
+
if not first_iter:
|
| 1216 |
+
for idx in range(self._num_workers):
|
| 1217 |
+
self._index_queues[idx].put(
|
| 1218 |
+
_utils.worker._ResumeIteration(self._shared_seed)
|
| 1219 |
+
)
|
| 1220 |
+
resume_iteration_cnt = self._num_workers
|
| 1221 |
+
while resume_iteration_cnt > 0:
|
| 1222 |
+
return_idx, return_data = self._get_data()
|
| 1223 |
+
if isinstance(return_idx, _utils.worker._ResumeIteration):
|
| 1224 |
+
assert return_data is None
|
| 1225 |
+
resume_iteration_cnt -= 1
|
| 1226 |
+
# prime the prefetch loop
|
| 1227 |
+
for _ in range(self._prefetch_factor * self._num_workers):
|
| 1228 |
+
self._try_put_index()
|
| 1229 |
+
|
| 1230 |
+
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
|
| 1231 |
+
# Tries to fetch data from `self._data_queue` once for a given timeout.
|
| 1232 |
+
# This can also be used as inner loop of fetching without timeout, with
|
| 1233 |
+
# the sender status as the loop condition.
|
| 1234 |
+
#
|
| 1235 |
+
# This raises a `RuntimeError` if any worker died expectedly. This error
|
| 1236 |
+
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
|
| 1237 |
+
# (only for non-Windows platforms), or the manual check below on errors
|
| 1238 |
+
# and timeouts.
|
| 1239 |
+
#
|
| 1240 |
+
# Returns a 2-tuple:
|
| 1241 |
+
# (bool: whether successfully get data, any: data if successful else None)
|
| 1242 |
+
try:
|
| 1243 |
+
data = self._data_queue.get(timeout=timeout)
|
| 1244 |
+
return (True, data)
|
| 1245 |
+
except Exception as e:
|
| 1246 |
+
# At timeout and error, we manually check whether any worker has
|
| 1247 |
+
# failed. Note that this is the only mechanism for Windows to detect
|
| 1248 |
+
# worker failures.
|
| 1249 |
+
failed_workers = []
|
| 1250 |
+
for worker_id, w in enumerate(self._workers):
|
| 1251 |
+
if self._workers_status[worker_id] and not w.is_alive():
|
| 1252 |
+
failed_workers.append(w)
|
| 1253 |
+
self._mark_worker_as_unavailable(worker_id)
|
| 1254 |
+
if len(failed_workers) > 0:
|
| 1255 |
+
pids_str = ", ".join(str(w.pid) for w in failed_workers)
|
| 1256 |
+
raise RuntimeError(
|
| 1257 |
+
f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly"
|
| 1258 |
+
) from e
|
| 1259 |
+
if isinstance(e, queue.Empty):
|
| 1260 |
+
return (False, None)
|
| 1261 |
+
|
| 1262 |
+
import errno
|
| 1263 |
+
import tempfile
|
| 1264 |
+
|
| 1265 |
+
try:
|
| 1266 |
+
# Raise an exception if we are this close to the FDs limit.
|
| 1267 |
+
# Apparently, trying to open only one file is not a sufficient
|
| 1268 |
+
# test.
|
| 1269 |
+
# See NOTE [ DataLoader on Linux and open files limit ]
|
| 1270 |
+
fds_limit_margin = 10
|
| 1271 |
+
fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
|
| 1272 |
+
except OSError as e:
|
| 1273 |
+
if e.errno == errno.EMFILE:
|
| 1274 |
+
raise RuntimeError(
|
| 1275 |
+
"Too many open files. Communication with the"
|
| 1276 |
+
" workers is no longer possible. Please increase the"
|
| 1277 |
+
" limit using `ulimit -n` in the shell or change the"
|
| 1278 |
+
" sharing strategy by calling"
|
| 1279 |
+
" `torch.multiprocessing.set_sharing_strategy('file_system')`"
|
| 1280 |
+
" at the beginning of your code"
|
| 1281 |
+
) from None
|
| 1282 |
+
raise
|
| 1283 |
+
|
| 1284 |
+
# NOTE [ DataLoader on Linux and open files limit ]
|
| 1285 |
+
#
|
| 1286 |
+
# On Linux when DataLoader is used with multiprocessing we pass the data between
|
| 1287 |
+
# the root process and the workers through SHM files. We remove those files from
|
| 1288 |
+
# the filesystem as soon as they are created and keep them alive by
|
| 1289 |
+
# passing around their file descriptors through AF_UNIX sockets. (See
|
| 1290 |
+
# docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
|
| 1291 |
+
# the wiki (https://github.com/pytorch/pytorch/wiki).)
|
| 1292 |
+
#
|
| 1293 |
+
# This sometimes leads us to exceeding the open files limit. When that happens,
|
| 1294 |
+
# and the offending file descriptor is coming over a socket, the `socket` Python
|
| 1295 |
+
# package silently strips the file descriptor from the message, setting only the
|
| 1296 |
+
# `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
|
| 1297 |
+
# it _indicates that some control data were discarded due to lack of space in
|
| 1298 |
+
# the buffer for ancillary data_). This might reflect the C implementation of
|
| 1299 |
+
# AF_UNIX sockets.
|
| 1300 |
+
#
|
| 1301 |
+
# This behaviour can be reproduced with the script and instructions at the
|
| 1302 |
+
# bottom of this note.
|
| 1303 |
+
#
|
| 1304 |
+
# When that happens, the standard Python `multiprocessing` (and not
|
| 1305 |
+
# `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
|
| 1306 |
+
#
|
| 1307 |
+
# Sometimes, instead of the FD being stripped, you may get an `OSError:
|
| 1308 |
+
# Too many open files`, both in the script below and in DataLoader. However,
|
| 1309 |
+
# this is rare and seems to be nondeterministic.
|
| 1310 |
+
#
|
| 1311 |
+
#
|
| 1312 |
+
# #!/usr/bin/env python3
|
| 1313 |
+
# import sys
|
| 1314 |
+
# import socket
|
| 1315 |
+
# import os
|
| 1316 |
+
# import array
|
| 1317 |
+
# import shutil
|
| 1318 |
+
# import socket
|
| 1319 |
+
#
|
| 1320 |
+
#
|
| 1321 |
+
# if len(sys.argv) != 4:
|
| 1322 |
+
# print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
|
| 1323 |
+
# sys.exit(1)
|
| 1324 |
+
#
|
| 1325 |
+
# if __name__ == '__main__':
|
| 1326 |
+
# dirname = sys.argv[1]
|
| 1327 |
+
# sock_path = dirname + "/sock"
|
| 1328 |
+
# iterations = int(sys.argv[2])
|
| 1329 |
+
# def dummy_path(i):
|
| 1330 |
+
# return dirname + "/" + str(i) + ".dummy"
|
| 1331 |
+
#
|
| 1332 |
+
#
|
| 1333 |
+
# if sys.argv[3] == 'send':
|
| 1334 |
+
# while not os.path.exists(sock_path):
|
| 1335 |
+
# pass
|
| 1336 |
+
# client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
| 1337 |
+
# client.connect(sock_path)
|
| 1338 |
+
# for i in range(iterations):
|
| 1339 |
+
# fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
|
| 1340 |
+
# ancdata = array.array('i', [fd])
|
| 1341 |
+
# msg = bytes([i % 256])
|
| 1342 |
+
# print("Sending fd ", fd, " (iteration #", i, ")")
|
| 1343 |
+
# client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
|
| 1344 |
+
#
|
| 1345 |
+
#
|
| 1346 |
+
# else:
|
| 1347 |
+
# assert sys.argv[3] == 'recv'
|
| 1348 |
+
#
|
| 1349 |
+
# if os.path.exists(dirname):
|
| 1350 |
+
# raise Exception("Directory exists")
|
| 1351 |
+
#
|
| 1352 |
+
# os.mkdir(dirname)
|
| 1353 |
+
#
|
| 1354 |
+
# print("Opening socket...")
|
| 1355 |
+
# server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
| 1356 |
+
# server.bind(sock_path)
|
| 1357 |
+
#
|
| 1358 |
+
# print("Listening...")
|
| 1359 |
+
# for i in range(iterations):
|
| 1360 |
+
# a = array.array('i')
|
| 1361 |
+
# msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
|
| 1362 |
+
# assert(len(ancdata) == 1)
|
| 1363 |
+
# cmsg_level, cmsg_type, cmsg_data = ancdata[0]
|
| 1364 |
+
# a.frombytes(cmsg_data)
|
| 1365 |
+
# print("Received fd ", a[0], " (iteration #", i, ")")
|
| 1366 |
+
#
|
| 1367 |
+
# shutil.rmtree(dirname)
|
| 1368 |
+
#
|
| 1369 |
+
# Steps to reproduce:
|
| 1370 |
+
#
|
| 1371 |
+
# 1. Run two shells and set lower file descriptor limit in the receiving one:
|
| 1372 |
+
# (shell1) ulimit -n 1020
|
| 1373 |
+
# (shell2) ulimit -n 1022
|
| 1374 |
+
#
|
| 1375 |
+
# 2. Run the script above with the `recv` option in the first shell
|
| 1376 |
+
# (shell1) ./test_socket.py sock_tmp 1017 recv
|
| 1377 |
+
#
|
| 1378 |
+
# 3. Run the script with the `send` option in the second shell:
|
| 1379 |
+
# (shell2) ./test_socket.py sock_tmp 1017 send
|
| 1380 |
+
|
| 1381 |
+
def _get_data(self):
|
| 1382 |
+
# Fetches data from `self._data_queue`.
|
| 1383 |
+
#
|
| 1384 |
+
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
|
| 1385 |
+
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
|
| 1386 |
+
# in a loop. This is the only mechanism to detect worker failures for
|
| 1387 |
+
# Windows. For other platforms, a SIGCHLD handler is also used for
|
| 1388 |
+
# worker failure detection.
|
| 1389 |
+
#
|
| 1390 |
+
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
|
| 1391 |
+
# died at timeouts.
|
| 1392 |
+
if self._timeout > 0:
|
| 1393 |
+
success, data = self._try_get_data(self._timeout)
|
| 1394 |
+
if success:
|
| 1395 |
+
return data
|
| 1396 |
+
else:
|
| 1397 |
+
raise RuntimeError(
|
| 1398 |
+
f"DataLoader timed out after {self._timeout} seconds"
|
| 1399 |
+
)
|
| 1400 |
+
elif self._pin_memory:
|
| 1401 |
+
while self._pin_memory_thread.is_alive():
|
| 1402 |
+
success, data = self._try_get_data()
|
| 1403 |
+
if success:
|
| 1404 |
+
return data
|
| 1405 |
+
else:
|
| 1406 |
+
# while condition is false, i.e., pin_memory_thread died.
|
| 1407 |
+
raise RuntimeError("Pin memory thread exited unexpectedly")
|
| 1408 |
+
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
|
| 1409 |
+
# need to call `.task_done()` because we don't use `.join()`.
|
| 1410 |
+
else:
|
| 1411 |
+
while True:
|
| 1412 |
+
success, data = self._try_get_data()
|
| 1413 |
+
if success:
|
| 1414 |
+
return data
|
| 1415 |
+
|
| 1416 |
+
def _next_data(self):
|
| 1417 |
+
while True:
|
| 1418 |
+
# If the worker responsible for `self._rcvd_idx` has already ended
|
| 1419 |
+
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
|
| 1420 |
+
# we try to advance `self._rcvd_idx` to find the next valid index.
|
| 1421 |
+
#
|
| 1422 |
+
# This part needs to run in the loop because both the `self._get_data()`
|
| 1423 |
+
# call and `_IterableDatasetStopIteration` check below can mark
|
| 1424 |
+
# extra worker(s) as dead.
|
| 1425 |
+
while self._rcvd_idx < self._send_idx:
|
| 1426 |
+
info = self._task_info[self._rcvd_idx]
|
| 1427 |
+
worker_id = info[0]
|
| 1428 |
+
if (
|
| 1429 |
+
len(info) == 2 or self._workers_status[worker_id]
|
| 1430 |
+
): # has data or is still active
|
| 1431 |
+
break
|
| 1432 |
+
del self._task_info[self._rcvd_idx]
|
| 1433 |
+
self._rcvd_idx += 1
|
| 1434 |
+
else:
|
| 1435 |
+
# no valid `self._rcvd_idx` is found (i.e., didn't break)
|
| 1436 |
+
if not self._persistent_workers:
|
| 1437 |
+
self._shutdown_workers()
|
| 1438 |
+
raise StopIteration
|
| 1439 |
+
|
| 1440 |
+
# Now `self._rcvd_idx` is the batch index we want to fetch
|
| 1441 |
+
|
| 1442 |
+
# Check if the next sample has already been generated
|
| 1443 |
+
if len(self._task_info[self._rcvd_idx]) == 2:
|
| 1444 |
+
data = self._task_info.pop(self._rcvd_idx)[1]
|
| 1445 |
+
return self._process_data(data)
|
| 1446 |
+
|
| 1447 |
+
assert not self._shutdown and self._tasks_outstanding > 0
|
| 1448 |
+
idx, data = self._get_data()
|
| 1449 |
+
self._tasks_outstanding -= 1
|
| 1450 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
| 1451 |
+
# Check for _IterableDatasetStopIteration
|
| 1452 |
+
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
|
| 1453 |
+
if self._persistent_workers:
|
| 1454 |
+
self._workers_status[data.worker_id] = False
|
| 1455 |
+
else:
|
| 1456 |
+
self._mark_worker_as_unavailable(data.worker_id)
|
| 1457 |
+
self._try_put_index()
|
| 1458 |
+
continue
|
| 1459 |
+
|
| 1460 |
+
if idx != self._rcvd_idx:
|
| 1461 |
+
# store out-of-order samples
|
| 1462 |
+
self._task_info[idx] += (data,)
|
| 1463 |
+
else:
|
| 1464 |
+
del self._task_info[idx]
|
| 1465 |
+
return self._process_data(data)
|
| 1466 |
+
|
| 1467 |
+
def _try_put_index(self):
|
| 1468 |
+
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
|
| 1469 |
+
|
| 1470 |
+
try:
|
| 1471 |
+
index = self._next_index()
|
| 1472 |
+
except StopIteration:
|
| 1473 |
+
return
|
| 1474 |
+
for _ in range(self._num_workers): # find the next active worker, if any
|
| 1475 |
+
worker_queue_idx = next(self._worker_queue_idx_cycle)
|
| 1476 |
+
if self._workers_status[worker_queue_idx]:
|
| 1477 |
+
break
|
| 1478 |
+
else:
|
| 1479 |
+
# not found (i.e., didn't break)
|
| 1480 |
+
return
|
| 1481 |
+
|
| 1482 |
+
self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined]
|
| 1483 |
+
self._task_info[self._send_idx] = (worker_queue_idx,)
|
| 1484 |
+
self._tasks_outstanding += 1
|
| 1485 |
+
self._send_idx += 1
|
| 1486 |
+
|
| 1487 |
+
def _process_data(self, data):
|
| 1488 |
+
self._rcvd_idx += 1
|
| 1489 |
+
self._try_put_index()
|
| 1490 |
+
if isinstance(data, ExceptionWrapper):
|
| 1491 |
+
data.reraise()
|
| 1492 |
+
return data
|
| 1493 |
+
|
| 1494 |
+
def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
|
| 1495 |
+
# Mark a worker as having finished its work e.g., due to
|
| 1496 |
+
# exhausting an `IterableDataset`. This should be used only when this
|
| 1497 |
+
# `_MultiProcessingDataLoaderIter` is going to continue running.
|
| 1498 |
+
|
| 1499 |
+
assert self._workers_status[worker_id] or (
|
| 1500 |
+
self._persistent_workers and shutdown
|
| 1501 |
+
)
|
| 1502 |
+
|
| 1503 |
+
# Signal termination to that specific worker.
|
| 1504 |
+
q = self._index_queues[worker_id]
|
| 1505 |
+
# Indicate that no more data will be put on this queue by the current
|
| 1506 |
+
# process.
|
| 1507 |
+
q.put(None)
|
| 1508 |
+
|
| 1509 |
+
# Note that we don't actually join the worker here, nor do we remove the
|
| 1510 |
+
# worker's pid from C side struct because (1) joining may be slow, and
|
| 1511 |
+
# (2) since we don't join, the worker may still raise error, and we
|
| 1512 |
+
# prefer capturing those, rather than ignoring them, even though they
|
| 1513 |
+
# are raised after the worker has finished its job.
|
| 1514 |
+
# Joinning is deferred to `_shutdown_workers`, which it is called when
|
| 1515 |
+
# all workers finish their jobs (e.g., `IterableDataset` replicas) or
|
| 1516 |
+
# when this iterator is garbage collected.
|
| 1517 |
+
|
| 1518 |
+
self._workers_status[worker_id] = False
|
| 1519 |
+
|
| 1520 |
+
assert self._workers_done_event.is_set() == shutdown
|
| 1521 |
+
|
| 1522 |
+
def _shutdown_workers(self):
|
| 1523 |
+
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
|
| 1524 |
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
|
| 1525 |
+
# the logic of this function.
|
| 1526 |
+
if (
|
| 1527 |
+
_utils is None
|
| 1528 |
+
or _utils.python_exit_status is True
|
| 1529 |
+
or _utils.python_exit_status is None
|
| 1530 |
+
):
|
| 1531 |
+
# See (2) of the note. If Python is shutting down, do no-op.
|
| 1532 |
+
return
|
| 1533 |
+
# Normal exit when last reference is gone / iterator is depleted.
|
| 1534 |
+
# See (1) and the second half of the note.
|
| 1535 |
+
if not self._shutdown:
|
| 1536 |
+
self._shutdown = True
|
| 1537 |
+
try:
|
| 1538 |
+
# Normal exit when last reference is gone / iterator is depleted.
|
| 1539 |
+
# See (1) and the second half of the note.
|
| 1540 |
+
|
| 1541 |
+
# Exit `pin_memory_thread` first because exiting workers may leave
|
| 1542 |
+
# corrupted data in `worker_result_queue` which `pin_memory_thread`
|
| 1543 |
+
# reads from.
|
| 1544 |
+
if hasattr(self, "_pin_memory_thread"):
|
| 1545 |
+
# Use hasattr in case error happens before we set the attribute.
|
| 1546 |
+
self._pin_memory_thread_done_event.set()
|
| 1547 |
+
# Send something to pin_memory_thread in case it is waiting
|
| 1548 |
+
# so that it can wake up and check `pin_memory_thread_done_event`
|
| 1549 |
+
self._worker_result_queue.put((None, None))
|
| 1550 |
+
self._pin_memory_thread.join()
|
| 1551 |
+
self._worker_result_queue.cancel_join_thread()
|
| 1552 |
+
self._worker_result_queue.close()
|
| 1553 |
+
|
| 1554 |
+
# Exit workers now.
|
| 1555 |
+
self._workers_done_event.set()
|
| 1556 |
+
for worker_id in range(len(self._workers)):
|
| 1557 |
+
# Get number of workers from `len(self._workers)` instead of
|
| 1558 |
+
# `self._num_workers` in case we error before starting all
|
| 1559 |
+
# workers.
|
| 1560 |
+
# If we are using workers_status with persistent_workers
|
| 1561 |
+
# we have to shut it down because the worker is paused
|
| 1562 |
+
if self._persistent_workers or self._workers_status[worker_id]:
|
| 1563 |
+
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
| 1564 |
+
for w in self._workers:
|
| 1565 |
+
# We should be able to join here, but in case anything went
|
| 1566 |
+
# wrong, we set a timeout and if the workers fail to join,
|
| 1567 |
+
# they are killed in the `finally` block.
|
| 1568 |
+
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
| 1569 |
+
for q in self._index_queues:
|
| 1570 |
+
q.cancel_join_thread()
|
| 1571 |
+
q.close()
|
| 1572 |
+
finally:
|
| 1573 |
+
# Even though all this function does is putting into queues that
|
| 1574 |
+
# we have called `cancel_join_thread` on, weird things can
|
| 1575 |
+
# happen when a worker is killed by a signal, e.g., hanging in
|
| 1576 |
+
# `Event.set()`. So we need to guard this with SIGCHLD handler,
|
| 1577 |
+
# and remove pids from the C side data structure only at the
|
| 1578 |
+
# end.
|
| 1579 |
+
#
|
| 1580 |
+
# FIXME: Unfortunately, for Windows, we are missing a worker
|
| 1581 |
+
# error detection mechanism here in this function, as it
|
| 1582 |
+
# doesn't provide a SIGCHLD handler.
|
| 1583 |
+
if self._worker_pids_set:
|
| 1584 |
+
_utils.signal_handling._remove_worker_pids(id(self))
|
| 1585 |
+
self._worker_pids_set = False
|
| 1586 |
+
for w in self._workers:
|
| 1587 |
+
if w.is_alive():
|
| 1588 |
+
# Existing mechanisms try to make the workers exit
|
| 1589 |
+
# peacefully, but in case that we unfortunately reach
|
| 1590 |
+
# here, which we shouldn't, (e.g., pytorch/pytorch#39570),
|
| 1591 |
+
# we kill the worker.
|
| 1592 |
+
w.terminate()
|
| 1593 |
+
|
| 1594 |
+
# staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
|
| 1595 |
+
@staticmethod
|
| 1596 |
+
def _clean_up_worker(w):
|
| 1597 |
+
try:
|
| 1598 |
+
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
| 1599 |
+
finally:
|
| 1600 |
+
if w.is_alive():
|
| 1601 |
+
w.terminate()
|
| 1602 |
+
|
| 1603 |
+
def __del__(self):
|
| 1604 |
+
self._shutdown_workers()
|
.venv/Lib/site-packages/torch/utils/data/datapipes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data.datapipes import dataframe as dataframe, iter as iter, map as map
|
.venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (287 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-39.pyc
ADDED
|
Binary file (6.2 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-39.pyc
ADDED
|
Binary file (8.45 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-39.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-39.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/datapipes/_decorator.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import inspect
|
| 3 |
+
from functools import wraps
|
| 4 |
+
from typing import Any, Callable, get_type_hints, Optional, Type, Union
|
| 5 |
+
|
| 6 |
+
from torch.utils.data.datapipes._typing import _DataPipeMeta
|
| 7 |
+
from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
######################################################
|
| 11 |
+
# Functional API
|
| 12 |
+
######################################################
|
| 13 |
+
class functional_datapipe:
|
| 14 |
+
name: str
|
| 15 |
+
|
| 16 |
+
def __init__(self, name: str, enable_df_api_tracing=False) -> None:
|
| 17 |
+
"""
|
| 18 |
+
Define a functional datapipe.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
enable_df_api_tracing - if set, any returned DataPipe would accept
|
| 22 |
+
DataFrames API in tracing mode.
|
| 23 |
+
"""
|
| 24 |
+
self.name = name
|
| 25 |
+
self.enable_df_api_tracing = enable_df_api_tracing
|
| 26 |
+
|
| 27 |
+
def __call__(self, cls):
|
| 28 |
+
if issubclass(cls, IterDataPipe):
|
| 29 |
+
if isinstance(cls, Type): # type: ignore[arg-type]
|
| 30 |
+
if not isinstance(cls, _DataPipeMeta):
|
| 31 |
+
raise TypeError(
|
| 32 |
+
"`functional_datapipe` can only decorate IterDataPipe"
|
| 33 |
+
)
|
| 34 |
+
# with non_deterministic decorator
|
| 35 |
+
else:
|
| 36 |
+
if not isinstance(cls, non_deterministic) and not (
|
| 37 |
+
hasattr(cls, "__self__")
|
| 38 |
+
and isinstance(cls.__self__, non_deterministic)
|
| 39 |
+
):
|
| 40 |
+
raise TypeError(
|
| 41 |
+
"`functional_datapipe` can only decorate IterDataPipe"
|
| 42 |
+
)
|
| 43 |
+
IterDataPipe.register_datapipe_as_function(
|
| 44 |
+
self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing
|
| 45 |
+
)
|
| 46 |
+
elif issubclass(cls, MapDataPipe):
|
| 47 |
+
MapDataPipe.register_datapipe_as_function(self.name, cls)
|
| 48 |
+
|
| 49 |
+
return cls
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
######################################################
|
| 53 |
+
# Determinism
|
| 54 |
+
######################################################
|
| 55 |
+
_determinism: bool = False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class guaranteed_datapipes_determinism:
|
| 59 |
+
prev: bool
|
| 60 |
+
|
| 61 |
+
def __init__(self) -> None:
|
| 62 |
+
global _determinism
|
| 63 |
+
self.prev = _determinism
|
| 64 |
+
_determinism = True
|
| 65 |
+
|
| 66 |
+
def __enter__(self) -> None:
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 70 |
+
global _determinism
|
| 71 |
+
_determinism = self.prev
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class non_deterministic:
|
| 75 |
+
cls: Optional[Type[IterDataPipe]] = None
|
| 76 |
+
# TODO: Lambda for picking
|
| 77 |
+
deterministic_fn: Callable[[], bool]
|
| 78 |
+
|
| 79 |
+
def __init__(self, arg: Union[Type[IterDataPipe], Callable[[], bool]]) -> None:
|
| 80 |
+
# 1. Decorator doesn't have any argument
|
| 81 |
+
if isinstance(arg, Type): # type: ignore[arg-type]
|
| 82 |
+
if not issubclass(arg, IterDataPipe): # type: ignore[arg-type]
|
| 83 |
+
raise TypeError(
|
| 84 |
+
"Only `IterDataPipe` can be decorated with `non_deterministic`"
|
| 85 |
+
f", but {arg.__name__} is found"
|
| 86 |
+
)
|
| 87 |
+
self.cls = arg # type: ignore[assignment]
|
| 88 |
+
# 2. Decorator has an argument of a function
|
| 89 |
+
# This class should behave differently given different inputs. Use this
|
| 90 |
+
# function to verify the determinism for each instance.
|
| 91 |
+
# When the function returns True, the instance is non-deterministic. Otherwise,
|
| 92 |
+
# the instance is a deterministic DataPipe.
|
| 93 |
+
elif isinstance(arg, Callable): # type:ignore[arg-type]
|
| 94 |
+
self.deterministic_fn = arg # type: ignore[assignment, misc]
|
| 95 |
+
else:
|
| 96 |
+
raise TypeError(f"{arg} can not be decorated by non_deterministic")
|
| 97 |
+
|
| 98 |
+
def __call__(self, *args, **kwargs):
|
| 99 |
+
global _determinism
|
| 100 |
+
# Decorate IterDataPipe
|
| 101 |
+
if self.cls is not None:
|
| 102 |
+
if _determinism:
|
| 103 |
+
raise TypeError(
|
| 104 |
+
f"{self.cls.__name__} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. "
|
| 105 |
+
"You can turn off determinism for this DataPipe if that is acceptable "
|
| 106 |
+
"for your application"
|
| 107 |
+
)
|
| 108 |
+
return self.cls(*args, **kwargs) # type: ignore[call-arg]
|
| 109 |
+
|
| 110 |
+
# Decorate with a functional argument
|
| 111 |
+
if not (
|
| 112 |
+
isinstance(args[0], type)
|
| 113 |
+
and issubclass(args[0], IterDataPipe) # type: ignore[arg-type]
|
| 114 |
+
):
|
| 115 |
+
raise TypeError(
|
| 116 |
+
f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found"
|
| 117 |
+
)
|
| 118 |
+
self.cls = args[0]
|
| 119 |
+
return self.deterministic_wrapper_fn
|
| 120 |
+
|
| 121 |
+
def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe:
|
| 122 |
+
res = self.deterministic_fn(*args, **kwargs) # type: ignore[call-arg, misc]
|
| 123 |
+
if not isinstance(res, bool):
|
| 124 |
+
raise TypeError(
|
| 125 |
+
"deterministic_fn of `non_deterministic` decorator is required "
|
| 126 |
+
f"to return a boolean value, but {type(res)} is found"
|
| 127 |
+
)
|
| 128 |
+
global _determinism
|
| 129 |
+
if _determinism and res:
|
| 130 |
+
raise TypeError(
|
| 131 |
+
f"{self.cls.__name__} is non-deterministic with the inputs, but you set " # type: ignore[union-attr]
|
| 132 |
+
"'guaranteed_datapipes_determinism'. You can turn off determinism "
|
| 133 |
+
"for this DataPipe if that is acceptable for your application"
|
| 134 |
+
)
|
| 135 |
+
return self.cls(*args, **kwargs) # type: ignore[call-arg, misc]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
######################################################
|
| 139 |
+
# Type validation
|
| 140 |
+
######################################################
|
| 141 |
+
# Validate each argument of DataPipe with hint as a subtype of the hint.
|
| 142 |
+
def argument_validation(f):
|
| 143 |
+
signature = inspect.signature(f)
|
| 144 |
+
hints = get_type_hints(f)
|
| 145 |
+
|
| 146 |
+
@wraps(f)
|
| 147 |
+
def wrapper(*args, **kwargs):
|
| 148 |
+
bound = signature.bind(*args, **kwargs)
|
| 149 |
+
for argument_name, value in bound.arguments.items():
|
| 150 |
+
if argument_name in hints and isinstance(
|
| 151 |
+
hints[argument_name], _DataPipeMeta
|
| 152 |
+
):
|
| 153 |
+
hint = hints[argument_name]
|
| 154 |
+
if not isinstance(value, IterDataPipe):
|
| 155 |
+
raise TypeError(
|
| 156 |
+
f"Expected argument '{argument_name}' as a IterDataPipe, but found {type(value)}"
|
| 157 |
+
)
|
| 158 |
+
if not value.type.issubtype(hint.type):
|
| 159 |
+
raise TypeError(
|
| 160 |
+
f"Expected type of argument '{argument_name}' as a subtype of "
|
| 161 |
+
f"hint {hint.type}, but found {value.type}"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return f(*args, **kwargs)
|
| 165 |
+
|
| 166 |
+
return wrapper
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# Default value is True
|
| 170 |
+
_runtime_validation_enabled: bool = True
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class runtime_validation_disabled:
|
| 174 |
+
prev: bool
|
| 175 |
+
|
| 176 |
+
def __init__(self) -> None:
|
| 177 |
+
global _runtime_validation_enabled
|
| 178 |
+
self.prev = _runtime_validation_enabled
|
| 179 |
+
_runtime_validation_enabled = False
|
| 180 |
+
|
| 181 |
+
def __enter__(self) -> None:
|
| 182 |
+
pass
|
| 183 |
+
|
| 184 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 185 |
+
global _runtime_validation_enabled
|
| 186 |
+
_runtime_validation_enabled = self.prev
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# Runtime checking
|
| 190 |
+
# Validate output data is subtype of return hint
|
| 191 |
+
def runtime_validation(f):
|
| 192 |
+
# TODO:
|
| 193 |
+
# Can be extended to validate '__getitem__' and nonblocking
|
| 194 |
+
if f.__name__ != "__iter__":
|
| 195 |
+
raise TypeError(
|
| 196 |
+
f"Can not decorate function {f.__name__} with 'runtime_validation'"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
@wraps(f)
|
| 200 |
+
def wrapper(self):
|
| 201 |
+
global _runtime_validation_enabled
|
| 202 |
+
if not _runtime_validation_enabled:
|
| 203 |
+
yield from f(self)
|
| 204 |
+
else:
|
| 205 |
+
it = f(self)
|
| 206 |
+
for d in it:
|
| 207 |
+
if not self.type.issubtype_of_instance(d):
|
| 208 |
+
raise RuntimeError(
|
| 209 |
+
f"Expected an instance as subtype of {self.type}, but found {d}({type(d)})"
|
| 210 |
+
)
|
| 211 |
+
yield d
|
| 212 |
+
|
| 213 |
+
return wrapper
|
.venv/Lib/site-packages/torch/utils/data/datapipes/_hook_iterator.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import inspect
|
| 4 |
+
from enum import Enum
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class _SnapshotState(Enum):
|
| 10 |
+
r"""
|
| 11 |
+
These are the snapshotting-related states that IterDataPipes can be in.
|
| 12 |
+
|
| 13 |
+
`NotStarted` - allows you to restore a snapshot and create an iterator with reset
|
| 14 |
+
`Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe
|
| 15 |
+
`Iterating` - can restore, will reset if you create a new iterator
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
NotStarted = 0
|
| 19 |
+
Restored = 1
|
| 20 |
+
Iterating = 2
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _simplify_obj_name(obj) -> str:
|
| 24 |
+
"""Simplify the display strings of objects for the purpose of rendering within DataPipe error messages."""
|
| 25 |
+
if inspect.isfunction(obj):
|
| 26 |
+
return obj.__name__
|
| 27 |
+
else:
|
| 28 |
+
return repr(obj)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _strip_datapipe_from_name(name: str) -> str:
|
| 32 |
+
return name.replace("IterDataPipe", "").replace("MapDataPipe", "")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _generate_input_args_string(obj):
|
| 36 |
+
"""Generate a string for the input arguments of an object."""
|
| 37 |
+
signature = inspect.signature(obj.__class__)
|
| 38 |
+
input_param_names = set(signature.parameters.keys())
|
| 39 |
+
result = []
|
| 40 |
+
for name, value in inspect.getmembers(obj):
|
| 41 |
+
if name in input_param_names:
|
| 42 |
+
result.append((name, _simplify_obj_name(value)))
|
| 43 |
+
return ", ".join([f"{name}={value}" for name, value in result])
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _generate_iterdatapipe_msg(datapipe, simplify_dp_name: bool = False):
|
| 47 |
+
output_string = (
|
| 48 |
+
f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
|
| 49 |
+
)
|
| 50 |
+
if simplify_dp_name:
|
| 51 |
+
output_string = _strip_datapipe_from_name(output_string)
|
| 52 |
+
return output_string
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _gen_invalid_iterdatapipe_msg(datapipe):
|
| 56 |
+
return (
|
| 57 |
+
"This iterator has been invalidated because another iterator has been created "
|
| 58 |
+
f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n"
|
| 59 |
+
"This may be caused multiple references to the same IterDataPipe. We recommend "
|
| 60 |
+
"using `.fork()` if that is necessary."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
_feedback_msg = (
|
| 65 |
+
"\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free "
|
| 66 |
+
"to comment on this issue: https://github.com/pytorch/data/issues/45."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None:
|
| 71 |
+
r"""
|
| 72 |
+
Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception.
|
| 73 |
+
|
| 74 |
+
In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well.
|
| 75 |
+
"""
|
| 76 |
+
if next_method_exists:
|
| 77 |
+
# This is the case where `IterDataPipe` has both `__iter__` and `__next__`.
|
| 78 |
+
# The `_valid_iterator_id` should either be never set (`None`), or set by at most one
|
| 79 |
+
# iterator (`0`). Otherwise, it means there are multiple iterators.
|
| 80 |
+
if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0:
|
| 81 |
+
extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method"
|
| 82 |
+
raise RuntimeError(
|
| 83 |
+
_gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg
|
| 84 |
+
)
|
| 85 |
+
elif (
|
| 86 |
+
hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True
|
| 87 |
+
):
|
| 88 |
+
if hasattr(datapipe, "_check_valid_iterator_id"):
|
| 89 |
+
if not datapipe._check_valid_iterator_id(iterator_id):
|
| 90 |
+
raise RuntimeError(
|
| 91 |
+
"This iterator has been invalidated, because a new iterator has been created "
|
| 92 |
+
f"from one of the ChildDataPipes of "
|
| 93 |
+
f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}."
|
| 94 |
+
+ _feedback_msg
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
raise RuntimeError(
|
| 98 |
+
"ChildDataPipe must have method `_check_valid_iterator_id`."
|
| 99 |
+
)
|
| 100 |
+
elif datapipe._valid_iterator_id != iterator_id:
|
| 101 |
+
raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _set_datapipe_valid_iterator_id(datapipe):
|
| 105 |
+
"""Given a DataPipe, updates its valid iterator ID and reset the DataPipe."""
|
| 106 |
+
if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True:
|
| 107 |
+
if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"):
|
| 108 |
+
datapipe._set_main_datapipe_valid_iterator_id() # reset() is called within this method when appropriate
|
| 109 |
+
else:
|
| 110 |
+
raise RuntimeError(
|
| 111 |
+
"ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`."
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
if datapipe._valid_iterator_id is None:
|
| 115 |
+
datapipe._valid_iterator_id = 0
|
| 116 |
+
else:
|
| 117 |
+
datapipe._valid_iterator_id += 1
|
| 118 |
+
datapipe.reset()
|
| 119 |
+
return datapipe._valid_iterator_id
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def hook_iterator(namespace):
|
| 123 |
+
r"""
|
| 124 |
+
Define a hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`.
|
| 125 |
+
|
| 126 |
+
This is done for the purpose of profiling and checking if an iterator is still valid.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def profiler_record_fn_context(datapipe):
|
| 130 |
+
if not hasattr(datapipe, "_profile_name"):
|
| 131 |
+
datapipe._profile_name = _generate_iterdatapipe_msg(
|
| 132 |
+
datapipe, simplify_dp_name=True
|
| 133 |
+
)
|
| 134 |
+
return torch.autograd.profiler.record_function(datapipe._profile_name)
|
| 135 |
+
|
| 136 |
+
class IteratorDecorator:
|
| 137 |
+
r"""
|
| 138 |
+
Wrap the iterator and modifying its `__next__` method.
|
| 139 |
+
|
| 140 |
+
This decorator is applied to DataPipes of which `__iter__` method is NOT a generator function.
|
| 141 |
+
Those `__iter__` method commonly returns `self` but not necessarily.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(self, iterator, datapipe, iterator_id, has_next_method):
|
| 145 |
+
self.iterator = iterator
|
| 146 |
+
self.datapipe = datapipe
|
| 147 |
+
self.iterator_id = iterator_id
|
| 148 |
+
self._profiler_enabled = torch.autograd._profiler_enabled()
|
| 149 |
+
# Check if `__iter__` returns `self` and `DataPipe` has `__next__`
|
| 150 |
+
self.self_and_has_next_method = (
|
| 151 |
+
self.iterator is self.datapipe and has_next_method
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def __iter__(self):
|
| 155 |
+
return self
|
| 156 |
+
|
| 157 |
+
def _get_next(self):
|
| 158 |
+
"""Return next with logic related to iterator validity, profiler, and incrementation of samples yielded."""
|
| 159 |
+
_check_iterator_valid(self.datapipe, self.iterator_id)
|
| 160 |
+
result = next(self.iterator)
|
| 161 |
+
if not self.self_and_has_next_method:
|
| 162 |
+
self.datapipe._number_of_samples_yielded += 1
|
| 163 |
+
return result
|
| 164 |
+
|
| 165 |
+
def __next__(self):
|
| 166 |
+
# TODO: Add try-except to in-place reduce traceback from the Exception
|
| 167 |
+
# See: https://github.com/pytorch/data/issues/284
|
| 168 |
+
if self._profiler_enabled:
|
| 169 |
+
with profiler_record_fn_context(self.datapipe):
|
| 170 |
+
return self._get_next()
|
| 171 |
+
else: # Decided against using `contextlib.nullcontext` for performance reasons
|
| 172 |
+
return self._get_next()
|
| 173 |
+
|
| 174 |
+
def __getattr__(self, name):
|
| 175 |
+
return getattr(self.iterator, name)
|
| 176 |
+
|
| 177 |
+
func = namespace["__iter__"]
|
| 178 |
+
|
| 179 |
+
# ``__iter__`` of IterDataPipe is a generator function
|
| 180 |
+
if inspect.isgeneratorfunction(func):
|
| 181 |
+
|
| 182 |
+
@functools.wraps(func)
|
| 183 |
+
def wrap_generator(*args, **kwargs):
|
| 184 |
+
gen = func(*args, **kwargs)
|
| 185 |
+
datapipe = args[0]
|
| 186 |
+
if datapipe._fast_forward_iterator:
|
| 187 |
+
it = datapipe._fast_forward_iterator
|
| 188 |
+
datapipe._fast_forward_iterator = None
|
| 189 |
+
datapipe._snapshot_state = _SnapshotState.Iterating
|
| 190 |
+
while True:
|
| 191 |
+
try:
|
| 192 |
+
yield next(it)
|
| 193 |
+
except StopIteration:
|
| 194 |
+
return
|
| 195 |
+
iterator_id = _set_datapipe_valid_iterator_id(
|
| 196 |
+
datapipe
|
| 197 |
+
) # This ID is tied to each created iterator
|
| 198 |
+
_profiler_enabled = torch.autograd._profiler_enabled()
|
| 199 |
+
try:
|
| 200 |
+
if _profiler_enabled:
|
| 201 |
+
with profiler_record_fn_context(datapipe):
|
| 202 |
+
response = gen.send(None)
|
| 203 |
+
else:
|
| 204 |
+
response = gen.send(None)
|
| 205 |
+
|
| 206 |
+
while True:
|
| 207 |
+
datapipe._number_of_samples_yielded += 1
|
| 208 |
+
request = yield response
|
| 209 |
+
# Pass through here every time `__next__` is called
|
| 210 |
+
if _profiler_enabled:
|
| 211 |
+
with profiler_record_fn_context(datapipe):
|
| 212 |
+
_check_iterator_valid(datapipe, iterator_id)
|
| 213 |
+
response = gen.send(request)
|
| 214 |
+
else: # Decided against using `contextlib.nullcontext` for performance reasons
|
| 215 |
+
_check_iterator_valid(datapipe, iterator_id)
|
| 216 |
+
response = gen.send(request)
|
| 217 |
+
except StopIteration as e:
|
| 218 |
+
return
|
| 219 |
+
except Exception as e:
|
| 220 |
+
# TODO: Simplify the traceback message to skip over `response = gen.send(None)`
|
| 221 |
+
# Part of https://github.com/pytorch/data/issues/284
|
| 222 |
+
datapipe = args[0]
|
| 223 |
+
msg = "thrown by __iter__ of"
|
| 224 |
+
single_iterator_msg = "single iterator per IterDataPipe constraint"
|
| 225 |
+
if hasattr(e.args, "__len__"):
|
| 226 |
+
full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
|
| 227 |
+
if len(e.args) == 0 or not isinstance(
|
| 228 |
+
e.args[0], str
|
| 229 |
+
): # If an exception message doesn't exist
|
| 230 |
+
e.args = (f"\nThis exception is {full_msg}",)
|
| 231 |
+
elif msg not in e.args[0] and single_iterator_msg not in e.args[0]:
|
| 232 |
+
e.args = (
|
| 233 |
+
e.args[0] + f"\nThis exception is {full_msg}",
|
| 234 |
+
) + e.args[1:]
|
| 235 |
+
raise
|
| 236 |
+
|
| 237 |
+
namespace["__iter__"] = wrap_generator
|
| 238 |
+
else: # ``__iter__`` of IterDataPipe is NOT a generator function
|
| 239 |
+
# IterDataPipe is an iterator with both ``__iter__`` and ``__next__``
|
| 240 |
+
# And ``__iter__`` may or may not return `self`
|
| 241 |
+
if "__next__" in namespace: # If `__next__` exists, put a wrapper around it
|
| 242 |
+
next_func = namespace["__next__"]
|
| 243 |
+
|
| 244 |
+
@functools.wraps(next_func)
|
| 245 |
+
def wrap_next(*args, **kwargs):
|
| 246 |
+
datapipe = args[0]
|
| 247 |
+
if torch.autograd._profiler_enabled():
|
| 248 |
+
with profiler_record_fn_context(datapipe):
|
| 249 |
+
result = next_func(*args, **kwargs)
|
| 250 |
+
else:
|
| 251 |
+
result = next_func(*args, **kwargs)
|
| 252 |
+
datapipe._number_of_samples_yielded += 1
|
| 253 |
+
return result
|
| 254 |
+
|
| 255 |
+
namespace["__next__"] = wrap_next
|
| 256 |
+
|
| 257 |
+
# Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but
|
| 258 |
+
# the user will be violating the iterator protocol. Potential issue:
|
| 259 |
+
# 1. Valid iterator ID may not update or checked properly
|
| 260 |
+
# 2. The number of samples yielded will be miscounted
|
| 261 |
+
|
| 262 |
+
# Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators
|
| 263 |
+
@functools.wraps(func)
|
| 264 |
+
def wrap_iter(*args, **kwargs):
|
| 265 |
+
iter_ret = func(*args, **kwargs)
|
| 266 |
+
datapipe = args[0]
|
| 267 |
+
datapipe._snapshot_state = _SnapshotState.Iterating
|
| 268 |
+
if datapipe._fast_forward_iterator:
|
| 269 |
+
iter_ret = datapipe._fast_forward_iterator
|
| 270 |
+
datapipe._fast_forward_iterator = None
|
| 271 |
+
return iter_ret
|
| 272 |
+
iterator_id = _set_datapipe_valid_iterator_id(
|
| 273 |
+
datapipe
|
| 274 |
+
) # This ID is tied to each created iterator
|
| 275 |
+
return IteratorDecorator(
|
| 276 |
+
iter_ret, datapipe, iterator_id, "__next__" in namespace
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
namespace["__iter__"] = wrap_iter
|
.venv/Lib/site-packages/torch/utils/data/datapipes/_typing.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# Taking reference from official Python typing
|
| 3 |
+
# https://github.com/python/cpython/blob/master/Lib/typing.py
|
| 4 |
+
|
| 5 |
+
import collections
|
| 6 |
+
import functools
|
| 7 |
+
import numbers
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
# Please check [Note: TypeMeta and TypeAlias]
|
| 11 |
+
# In case of metaclass conflict due to ABCMeta or _ProtocolMeta
|
| 12 |
+
# For Python 3.9, only Protocol in typing uses metaclass
|
| 13 |
+
from abc import ABCMeta
|
| 14 |
+
|
| 15 |
+
# TODO: Use TypeAlias when Python 3.6 is deprecated
|
| 16 |
+
from typing import ( # type: ignore[attr-defined]
|
| 17 |
+
_eval_type,
|
| 18 |
+
_GenericAlias,
|
| 19 |
+
_tp_cache,
|
| 20 |
+
_type_check,
|
| 21 |
+
_type_repr,
|
| 22 |
+
Any,
|
| 23 |
+
Dict,
|
| 24 |
+
ForwardRef,
|
| 25 |
+
Generic,
|
| 26 |
+
get_type_hints,
|
| 27 |
+
Iterator,
|
| 28 |
+
List,
|
| 29 |
+
Set,
|
| 30 |
+
Tuple,
|
| 31 |
+
TypeVar,
|
| 32 |
+
Union,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
from torch.utils.data.datapipes._hook_iterator import _SnapshotState, hook_iterator
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class GenericMeta(ABCMeta): # type: ignore[no-redef]
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Integer(numbers.Integral):
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Boolean(numbers.Integral):
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Python 'type' object is not subscriptable
|
| 51 |
+
# Tuple[int, List, dict] -> valid
|
| 52 |
+
# tuple[int, list, dict] -> invalid
|
| 53 |
+
# Map Python 'type' to abstract base class
|
| 54 |
+
TYPE2ABC = {
|
| 55 |
+
bool: Boolean,
|
| 56 |
+
int: Integer,
|
| 57 |
+
float: numbers.Real,
|
| 58 |
+
complex: numbers.Complex,
|
| 59 |
+
dict: Dict,
|
| 60 |
+
list: List,
|
| 61 |
+
set: Set,
|
| 62 |
+
tuple: Tuple,
|
| 63 |
+
None: type(None),
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def issubtype(left, right, recursive=True):
|
| 68 |
+
r"""
|
| 69 |
+
Check if the left-side type is a subtype of the right-side type.
|
| 70 |
+
|
| 71 |
+
If any of type is a composite type like `Union` and `TypeVar` with
|
| 72 |
+
bounds, it would be expanded into a list of types and check all
|
| 73 |
+
of left-side types are subtypes of either one from right-side types.
|
| 74 |
+
"""
|
| 75 |
+
left = TYPE2ABC.get(left, left)
|
| 76 |
+
right = TYPE2ABC.get(right, right)
|
| 77 |
+
|
| 78 |
+
if right is Any or left == right:
|
| 79 |
+
return True
|
| 80 |
+
|
| 81 |
+
if isinstance(right, _GenericAlias):
|
| 82 |
+
if getattr(right, "__origin__", None) is Generic:
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
if right == type(None):
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
# Right-side type
|
| 89 |
+
constraints = _decompose_type(right)
|
| 90 |
+
|
| 91 |
+
if len(constraints) == 0 or Any in constraints:
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
if left is Any:
|
| 95 |
+
return False
|
| 96 |
+
|
| 97 |
+
# Left-side type
|
| 98 |
+
variants = _decompose_type(left)
|
| 99 |
+
|
| 100 |
+
# all() will return True for empty variants
|
| 101 |
+
if len(variants) == 0:
|
| 102 |
+
return False
|
| 103 |
+
|
| 104 |
+
return all(
|
| 105 |
+
_issubtype_with_constraints(variant, constraints, recursive)
|
| 106 |
+
for variant in variants
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _decompose_type(t, to_list=True):
|
| 111 |
+
if isinstance(t, TypeVar):
|
| 112 |
+
if t.__bound__ is not None:
|
| 113 |
+
ts = [t.__bound__]
|
| 114 |
+
else:
|
| 115 |
+
# For T_co, __constraints__ is ()
|
| 116 |
+
ts = list(t.__constraints__)
|
| 117 |
+
elif hasattr(t, "__origin__") and t.__origin__ == Union:
|
| 118 |
+
ts = t.__args__
|
| 119 |
+
else:
|
| 120 |
+
if not to_list:
|
| 121 |
+
return None
|
| 122 |
+
ts = [t]
|
| 123 |
+
# Ignored: Generator has incompatible item type "object"; expected "Type[Any]"
|
| 124 |
+
ts = [TYPE2ABC.get(_t, _t) for _t in ts] # type: ignore[misc]
|
| 125 |
+
return ts
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _issubtype_with_constraints(variant, constraints, recursive=True):
|
| 129 |
+
r"""
|
| 130 |
+
Check if the variant is a subtype of either one from constraints.
|
| 131 |
+
|
| 132 |
+
For composite types like `Union` and `TypeVar` with bounds, they
|
| 133 |
+
would be expanded for testing.
|
| 134 |
+
"""
|
| 135 |
+
if variant in constraints:
|
| 136 |
+
return True
|
| 137 |
+
|
| 138 |
+
# [Note: Subtype for Union and TypeVar]
|
| 139 |
+
# Python typing is able to flatten Union[Union[...]] or Union[TypeVar].
|
| 140 |
+
# But it couldn't flatten the following scenarios:
|
| 141 |
+
# - Union[int, TypeVar[Union[...]]]
|
| 142 |
+
# - TypeVar[TypeVar[...]]
|
| 143 |
+
# So, variant and each constraint may be a TypeVar or a Union.
|
| 144 |
+
# In these cases, all of inner types from the variant are required to be
|
| 145 |
+
# extraced and verified as a subtype of any constraint. And, all of
|
| 146 |
+
# inner types from any constraint being a TypeVar or a Union are
|
| 147 |
+
# also required to be extracted and verified if the variant belongs to
|
| 148 |
+
# any of them.
|
| 149 |
+
|
| 150 |
+
# Variant
|
| 151 |
+
vs = _decompose_type(variant, to_list=False)
|
| 152 |
+
|
| 153 |
+
# Variant is TypeVar or Union
|
| 154 |
+
if vs is not None:
|
| 155 |
+
return all(_issubtype_with_constraints(v, constraints, recursive) for v in vs)
|
| 156 |
+
|
| 157 |
+
# Variant is not TypeVar or Union
|
| 158 |
+
if hasattr(variant, "__origin__") and variant.__origin__ is not None:
|
| 159 |
+
v_origin = variant.__origin__
|
| 160 |
+
# In Python-3.9 typing library untyped generics do not have args
|
| 161 |
+
v_args = getattr(variant, "__args__", None)
|
| 162 |
+
else:
|
| 163 |
+
v_origin = variant
|
| 164 |
+
v_args = None
|
| 165 |
+
|
| 166 |
+
# Constraints
|
| 167 |
+
for constraint in constraints:
|
| 168 |
+
cs = _decompose_type(constraint, to_list=False)
|
| 169 |
+
|
| 170 |
+
# Constraint is TypeVar or Union
|
| 171 |
+
if cs is not None:
|
| 172 |
+
if _issubtype_with_constraints(variant, cs, recursive):
|
| 173 |
+
return True
|
| 174 |
+
# Constraint is not TypeVar or Union
|
| 175 |
+
else:
|
| 176 |
+
# __origin__ can be None for plain list, tuple, ... in Python 3.6
|
| 177 |
+
if hasattr(constraint, "__origin__") and constraint.__origin__ is not None:
|
| 178 |
+
c_origin = constraint.__origin__
|
| 179 |
+
if v_origin == c_origin:
|
| 180 |
+
if not recursive:
|
| 181 |
+
return True
|
| 182 |
+
# In Python-3.9 typing library untyped generics do not have args
|
| 183 |
+
c_args = getattr(constraint, "__args__", None)
|
| 184 |
+
if c_args is None or len(c_args) == 0:
|
| 185 |
+
return True
|
| 186 |
+
if (
|
| 187 |
+
v_args is not None
|
| 188 |
+
and len(v_args) == len(c_args)
|
| 189 |
+
and all(
|
| 190 |
+
issubtype(v_arg, c_arg)
|
| 191 |
+
for v_arg, c_arg in zip(v_args, c_args)
|
| 192 |
+
)
|
| 193 |
+
):
|
| 194 |
+
return True
|
| 195 |
+
# Tuple[int] -> Tuple
|
| 196 |
+
else:
|
| 197 |
+
if v_origin == constraint:
|
| 198 |
+
return True
|
| 199 |
+
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def issubinstance(data, data_type):
|
| 204 |
+
if not issubtype(type(data), data_type, recursive=False):
|
| 205 |
+
return False
|
| 206 |
+
|
| 207 |
+
# In Python-3.9 typing library __args__ attribute is not defined for untyped generics
|
| 208 |
+
dt_args = getattr(data_type, "__args__", None)
|
| 209 |
+
if isinstance(data, tuple):
|
| 210 |
+
if dt_args is None or len(dt_args) == 0:
|
| 211 |
+
return True
|
| 212 |
+
if len(dt_args) != len(data):
|
| 213 |
+
return False
|
| 214 |
+
return all(issubinstance(d, t) for d, t in zip(data, dt_args))
|
| 215 |
+
elif isinstance(data, (list, set)):
|
| 216 |
+
if dt_args is None or len(dt_args) == 0:
|
| 217 |
+
return True
|
| 218 |
+
t = dt_args[0]
|
| 219 |
+
return all(issubinstance(d, t) for d in data)
|
| 220 |
+
elif isinstance(data, dict):
|
| 221 |
+
if dt_args is None or len(dt_args) == 0:
|
| 222 |
+
return True
|
| 223 |
+
kt, vt = dt_args
|
| 224 |
+
return all(
|
| 225 |
+
issubinstance(k, kt) and issubinstance(v, vt) for k, v in data.items()
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
return True
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# [Note: TypeMeta and TypeAlias]
|
| 232 |
+
# In order to keep compatibility for Python 3.6, use Meta for the typing.
|
| 233 |
+
# TODO: When PyTorch drops the support for Python 3.6, it can be converted
|
| 234 |
+
# into the Alias system and using `__class_getitem__` for DataPipe. The
|
| 235 |
+
# typing system will gain benefit of performance and resolving metaclass
|
| 236 |
+
# conflicts as elaborated in https://www.python.org/dev/peps/pep-0560/
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class _DataPipeType:
|
| 240 |
+
r"""Save type annotation in `param`."""
|
| 241 |
+
|
| 242 |
+
def __init__(self, param):
|
| 243 |
+
self.param = param
|
| 244 |
+
|
| 245 |
+
def __repr__(self):
|
| 246 |
+
return _type_repr(self.param)
|
| 247 |
+
|
| 248 |
+
def __eq__(self, other):
|
| 249 |
+
if isinstance(other, _DataPipeType):
|
| 250 |
+
return self.param == other.param
|
| 251 |
+
return NotImplemented
|
| 252 |
+
|
| 253 |
+
def __hash__(self):
|
| 254 |
+
return hash(self.param)
|
| 255 |
+
|
| 256 |
+
def issubtype(self, other):
|
| 257 |
+
if isinstance(other.param, _GenericAlias):
|
| 258 |
+
if getattr(other.param, "__origin__", None) is Generic:
|
| 259 |
+
return True
|
| 260 |
+
if isinstance(other, _DataPipeType):
|
| 261 |
+
return issubtype(self.param, other.param)
|
| 262 |
+
if isinstance(other, type):
|
| 263 |
+
return issubtype(self.param, other)
|
| 264 |
+
raise TypeError(f"Expected '_DataPipeType' or 'type', but found {type(other)}")
|
| 265 |
+
|
| 266 |
+
def issubtype_of_instance(self, other):
|
| 267 |
+
return issubinstance(other, self.param)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# Default type for DataPipe without annotation
|
| 271 |
+
_T_co = TypeVar("_T_co", covariant=True)
|
| 272 |
+
_DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class _DataPipeMeta(GenericMeta):
|
| 276 |
+
r"""
|
| 277 |
+
Metaclass for `DataPipe`.
|
| 278 |
+
|
| 279 |
+
Add `type` attribute and `__init_subclass__` based on the type, and validate the return hint of `__iter__`.
|
| 280 |
+
|
| 281 |
+
Note that there is subclass `_IterDataPipeMeta` specifically for `IterDataPipe`.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
type: _DataPipeType
|
| 285 |
+
|
| 286 |
+
def __new__(cls, name, bases, namespace, **kwargs):
|
| 287 |
+
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
| 288 |
+
|
| 289 |
+
# TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
|
| 290 |
+
cls.__origin__ = None
|
| 291 |
+
if "type" in namespace:
|
| 292 |
+
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
| 293 |
+
|
| 294 |
+
namespace["__type_class__"] = False
|
| 295 |
+
# For plain derived class without annotation
|
| 296 |
+
for base in bases:
|
| 297 |
+
if isinstance(base, _DataPipeMeta):
|
| 298 |
+
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
| 299 |
+
|
| 300 |
+
namespace.update(
|
| 301 |
+
{"type": _DEFAULT_TYPE, "__init_subclass__": _dp_init_subclass}
|
| 302 |
+
)
|
| 303 |
+
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
| 304 |
+
|
| 305 |
+
def __init__(self, name, bases, namespace, **kwargs):
|
| 306 |
+
super().__init__(name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
| 307 |
+
|
| 308 |
+
# TODO: Fix isinstance bug
|
| 309 |
+
@_tp_cache
|
| 310 |
+
def _getitem_(self, params):
|
| 311 |
+
if params is None:
|
| 312 |
+
raise TypeError(f"{self.__name__}[t]: t can not be None")
|
| 313 |
+
if isinstance(params, str):
|
| 314 |
+
params = ForwardRef(params)
|
| 315 |
+
if not isinstance(params, tuple):
|
| 316 |
+
params = (params,)
|
| 317 |
+
|
| 318 |
+
msg = f"{self.__name__}[t]: t must be a type"
|
| 319 |
+
params = tuple(_type_check(p, msg) for p in params)
|
| 320 |
+
|
| 321 |
+
if isinstance(self.type.param, _GenericAlias):
|
| 322 |
+
orig = getattr(self.type.param, "__origin__", None)
|
| 323 |
+
if isinstance(orig, type) and orig is not Generic:
|
| 324 |
+
p = self.type.param[params] # type: ignore[index]
|
| 325 |
+
t = _DataPipeType(p)
|
| 326 |
+
l = len(str(self.type)) + 2
|
| 327 |
+
name = self.__name__[:-l]
|
| 328 |
+
name = name + "[" + str(t) + "]"
|
| 329 |
+
bases = (self,) + self.__bases__
|
| 330 |
+
return self.__class__(
|
| 331 |
+
name,
|
| 332 |
+
bases,
|
| 333 |
+
{
|
| 334 |
+
"__init_subclass__": _dp_init_subclass,
|
| 335 |
+
"type": t,
|
| 336 |
+
"__type_class__": True,
|
| 337 |
+
},
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if len(params) > 1:
|
| 341 |
+
raise TypeError(
|
| 342 |
+
f"Too many parameters for {self} actual {len(params)}, expected 1"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
t = _DataPipeType(params[0])
|
| 346 |
+
|
| 347 |
+
if not t.issubtype(self.type):
|
| 348 |
+
raise TypeError(
|
| 349 |
+
f"Can not subclass a DataPipe[{t}] from DataPipe[{self.type}]"
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Types are equal, fast path for inheritance
|
| 353 |
+
if self.type == t:
|
| 354 |
+
return self
|
| 355 |
+
|
| 356 |
+
name = self.__name__ + "[" + str(t) + "]"
|
| 357 |
+
bases = (self,) + self.__bases__
|
| 358 |
+
|
| 359 |
+
return self.__class__(
|
| 360 |
+
name,
|
| 361 |
+
bases,
|
| 362 |
+
{"__init_subclass__": _dp_init_subclass, "__type_class__": True, "type": t},
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# TODO: Fix isinstance bug
|
| 366 |
+
def _eq_(self, other):
|
| 367 |
+
if not isinstance(other, _DataPipeMeta):
|
| 368 |
+
return NotImplemented
|
| 369 |
+
if self.__origin__ is None or other.__origin__ is None: # type: ignore[has-type]
|
| 370 |
+
return self is other
|
| 371 |
+
return (
|
| 372 |
+
self.__origin__ == other.__origin__ # type: ignore[has-type]
|
| 373 |
+
and self.type == other.type
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# TODO: Fix isinstance bug
|
| 377 |
+
def _hash_(self):
|
| 378 |
+
return hash((self.__name__, self.type))
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class _IterDataPipeMeta(_DataPipeMeta):
|
| 382 |
+
r"""
|
| 383 |
+
Metaclass for `IterDataPipe` and inherits from `_DataPipeMeta`.
|
| 384 |
+
|
| 385 |
+
Add various functions for behaviors specific to `IterDataPipe`.
|
| 386 |
+
"""
|
| 387 |
+
|
| 388 |
+
def __new__(cls, name, bases, namespace, **kwargs):
|
| 389 |
+
if "reset" in namespace:
|
| 390 |
+
reset_func = namespace["reset"]
|
| 391 |
+
|
| 392 |
+
@functools.wraps(reset_func)
|
| 393 |
+
def conditional_reset(*args, **kwargs):
|
| 394 |
+
r"""
|
| 395 |
+
Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating` or `NotStarted`.
|
| 396 |
+
|
| 397 |
+
This allows recently restored DataPipe to preserve its restored state during the initial `__iter__` call.
|
| 398 |
+
"""
|
| 399 |
+
datapipe = args[0]
|
| 400 |
+
if datapipe._snapshot_state in (
|
| 401 |
+
_SnapshotState.Iterating,
|
| 402 |
+
_SnapshotState.NotStarted,
|
| 403 |
+
):
|
| 404 |
+
# Reset `NotStarted` is necessary because the `source_datapipe` of a DataPipe might have
|
| 405 |
+
# already begun iterating.
|
| 406 |
+
datapipe._number_of_samples_yielded = 0
|
| 407 |
+
datapipe._fast_forward_iterator = None
|
| 408 |
+
reset_func(*args, **kwargs)
|
| 409 |
+
datapipe._snapshot_state = _SnapshotState.Iterating
|
| 410 |
+
|
| 411 |
+
namespace["reset"] = conditional_reset
|
| 412 |
+
|
| 413 |
+
if "__iter__" in namespace:
|
| 414 |
+
hook_iterator(namespace)
|
| 415 |
+
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _dp_init_subclass(sub_cls, *args, **kwargs):
|
| 419 |
+
# Add function for datapipe instance to reinforce the type
|
| 420 |
+
sub_cls.reinforce_type = reinforce_type
|
| 421 |
+
|
| 422 |
+
# TODO:
|
| 423 |
+
# - add global switch for type checking at compile-time
|
| 424 |
+
|
| 425 |
+
# Ignore internal type class
|
| 426 |
+
if getattr(sub_cls, "__type_class__", False):
|
| 427 |
+
return
|
| 428 |
+
|
| 429 |
+
# Check if the string type is valid
|
| 430 |
+
if isinstance(sub_cls.type.param, ForwardRef):
|
| 431 |
+
base_globals = sys.modules[sub_cls.__module__].__dict__
|
| 432 |
+
try:
|
| 433 |
+
param = _eval_type(sub_cls.type.param, base_globals, locals())
|
| 434 |
+
sub_cls.type.param = param
|
| 435 |
+
except TypeError as e:
|
| 436 |
+
raise TypeError(
|
| 437 |
+
f"{sub_cls.type.param.__forward_arg__} is not supported by Python typing"
|
| 438 |
+
) from e
|
| 439 |
+
|
| 440 |
+
if "__iter__" in sub_cls.__dict__:
|
| 441 |
+
iter_fn = sub_cls.__dict__["__iter__"]
|
| 442 |
+
hints = get_type_hints(iter_fn)
|
| 443 |
+
if "return" in hints:
|
| 444 |
+
return_hint = hints["return"]
|
| 445 |
+
# Plain Return Hint for Python 3.6
|
| 446 |
+
if return_hint == Iterator:
|
| 447 |
+
return
|
| 448 |
+
if not (
|
| 449 |
+
hasattr(return_hint, "__origin__")
|
| 450 |
+
and (
|
| 451 |
+
return_hint.__origin__ == Iterator
|
| 452 |
+
or return_hint.__origin__ == collections.abc.Iterator
|
| 453 |
+
)
|
| 454 |
+
):
|
| 455 |
+
raise TypeError(
|
| 456 |
+
"Expected 'Iterator' as the return annotation for `__iter__` of {}"
|
| 457 |
+
", but found {}".format(
|
| 458 |
+
sub_cls.__name__, _type_repr(hints["return"])
|
| 459 |
+
)
|
| 460 |
+
)
|
| 461 |
+
data_type = return_hint.__args__[0]
|
| 462 |
+
if not issubtype(data_type, sub_cls.type.param):
|
| 463 |
+
raise TypeError(
|
| 464 |
+
f"Expected return type of '__iter__' as a subtype of {sub_cls.type},"
|
| 465 |
+
f" but found {_type_repr(data_type)} for {sub_cls.__name__}"
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def reinforce_type(self, expected_type):
|
| 470 |
+
r"""
|
| 471 |
+
Reinforce the type for DataPipe instance.
|
| 472 |
+
|
| 473 |
+
And the 'expected_type' is required to be a subtype of the original type
|
| 474 |
+
hint to restrict the type requirement of DataPipe instance.
|
| 475 |
+
"""
|
| 476 |
+
if isinstance(expected_type, tuple):
|
| 477 |
+
expected_type = Tuple[expected_type]
|
| 478 |
+
_type_check(expected_type, msg="'expected_type' must be a type")
|
| 479 |
+
|
| 480 |
+
if not issubtype(expected_type, self.type.param):
|
| 481 |
+
raise TypeError(
|
| 482 |
+
f"Expected 'expected_type' as subtype of {self.type}, but found {_type_repr(expected_type)}"
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
self.type = _DataPipeType(expected_type)
|
| 486 |
+
return self
|
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data.datapipes.dataframe.dataframes import (
|
| 2 |
+
CaptureDataFrame,
|
| 3 |
+
DFIterDataPipe,
|
| 4 |
+
)
|
| 5 |
+
from torch.utils.data.datapipes.dataframe.datapipes import DataFramesAsTuplesPipe
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
__all__ = ["CaptureDataFrame", "DFIterDataPipe", "DataFramesAsTuplesPipe"]
|
| 9 |
+
|
| 10 |
+
# Please keep this list sorted
|
| 11 |
+
assert __all__ == sorted(__all__)
|
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (476 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-39.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-39.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-39.pyc
ADDED
|
Binary file (4.62 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-39.pyc
ADDED
|
Binary file (1.05 kB). View file
|
|
|
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
_pandas: Any = None
|
| 6 |
+
_WITH_PANDAS: Optional[bool] = None
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _try_import_pandas() -> bool:
|
| 10 |
+
try:
|
| 11 |
+
import pandas # type: ignore[import]
|
| 12 |
+
|
| 13 |
+
global _pandas
|
| 14 |
+
_pandas = pandas
|
| 15 |
+
return True
|
| 16 |
+
except ImportError:
|
| 17 |
+
return False
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# pandas used only for prototyping, will be shortly replaced with TorchArrow
|
| 21 |
+
def _with_pandas() -> bool:
|
| 22 |
+
global _WITH_PANDAS
|
| 23 |
+
if _WITH_PANDAS is None:
|
| 24 |
+
_WITH_PANDAS = _try_import_pandas()
|
| 25 |
+
return _WITH_PANDAS
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PandasWrapper:
|
| 29 |
+
@classmethod
|
| 30 |
+
def create_dataframe(cls, data, columns):
|
| 31 |
+
if not _with_pandas():
|
| 32 |
+
raise RuntimeError("DataFrames prototype requires pandas to function")
|
| 33 |
+
return _pandas.DataFrame(data, columns=columns) # type: ignore[union-attr]
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def is_dataframe(cls, data):
|
| 37 |
+
if not _with_pandas():
|
| 38 |
+
return False
|
| 39 |
+
return isinstance(data, _pandas.core.frame.DataFrame) # type: ignore[union-attr]
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def is_column(cls, data):
|
| 43 |
+
if not _with_pandas():
|
| 44 |
+
return False
|
| 45 |
+
return isinstance(data, _pandas.core.series.Series) # type: ignore[union-attr]
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def iterate(cls, data):
|
| 49 |
+
if not _with_pandas():
|
| 50 |
+
raise RuntimeError("DataFrames prototype requires pandas to function")
|
| 51 |
+
yield from data.itertuples(index=False)
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def concat(cls, buffer):
|
| 55 |
+
if not _with_pandas():
|
| 56 |
+
raise RuntimeError("DataFrames prototype requires pandas to function")
|
| 57 |
+
return _pandas.concat(buffer) # type: ignore[union-attr]
|
| 58 |
+
|
| 59 |
+
@classmethod
|
| 60 |
+
def get_item(cls, data, idx):
|
| 61 |
+
if not _with_pandas():
|
| 62 |
+
raise RuntimeError("DataFrames prototype requires pandas to function")
|
| 63 |
+
return data[idx : idx + 1]
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def get_len(cls, df):
|
| 67 |
+
if not _with_pandas():
|
| 68 |
+
raise RuntimeError("DataFrames prototype requires pandas to function")
|
| 69 |
+
return len(df.index)
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def get_columns(cls, df):
|
| 73 |
+
if not _with_pandas():
|
| 74 |
+
raise RuntimeError("DataFrames prototype requires pandas to function")
|
| 75 |
+
return list(df.columns.values.tolist())
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class)
|
| 79 |
+
default_wrapper = PandasWrapper
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_df_wrapper():
|
| 83 |
+
return default_wrapper
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def set_df_wrapper(wrapper):
|
| 87 |
+
global default_wrapper
|
| 88 |
+
default_wrapper = wrapper
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def create_dataframe(data, columns=None):
|
| 92 |
+
wrapper = get_df_wrapper()
|
| 93 |
+
return wrapper.create_dataframe(data, columns)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def is_dataframe(data):
|
| 97 |
+
wrapper = get_df_wrapper()
|
| 98 |
+
return wrapper.is_dataframe(data)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_columns(data):
|
| 102 |
+
wrapper = get_df_wrapper()
|
| 103 |
+
return wrapper.get_columns(data)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def is_column(data):
|
| 107 |
+
wrapper = get_df_wrapper()
|
| 108 |
+
return wrapper.is_column(data)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def concat(buffer):
|
| 112 |
+
wrapper = get_df_wrapper()
|
| 113 |
+
return wrapper.concat(buffer)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def iterate(data):
|
| 117 |
+
wrapper = get_df_wrapper()
|
| 118 |
+
return wrapper.iterate(data)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_item(data, idx):
|
| 122 |
+
wrapper = get_df_wrapper()
|
| 123 |
+
return wrapper.get_item(data, idx)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_len(df):
|
| 127 |
+
wrapper = get_df_wrapper()
|
| 128 |
+
return wrapper.get_len(df)
|
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
from torch.utils.data.datapipes._decorator import functional_datapipe
|
| 5 |
+
from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
|
| 6 |
+
from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# TODO(VitalyFedyunin): Add error when two different traces get combined
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"Capture",
|
| 13 |
+
"CaptureA",
|
| 14 |
+
"CaptureAdd",
|
| 15 |
+
"CaptureCall",
|
| 16 |
+
"CaptureControl",
|
| 17 |
+
"CaptureDataFrame",
|
| 18 |
+
"CaptureDataFrameWithDataPipeOps",
|
| 19 |
+
"CaptureF",
|
| 20 |
+
"CaptureGetAttr",
|
| 21 |
+
"CaptureGetItem",
|
| 22 |
+
"CaptureInitial",
|
| 23 |
+
"CaptureLikeMock",
|
| 24 |
+
"CaptureMul",
|
| 25 |
+
"CaptureSetItem",
|
| 26 |
+
"CaptureSub",
|
| 27 |
+
"CaptureVariable",
|
| 28 |
+
"CaptureVariableAssign",
|
| 29 |
+
"DataFrameTracer",
|
| 30 |
+
"DataFrameTracedOps",
|
| 31 |
+
"disable_capture",
|
| 32 |
+
"get_val",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def disable_capture():
|
| 37 |
+
CaptureControl.disabled = True
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CaptureControl:
|
| 41 |
+
disabled = False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DataFrameTracedOps(DFIterDataPipe):
|
| 45 |
+
def __init__(self, source_datapipe, output_var):
|
| 46 |
+
self.source_datapipe = source_datapipe
|
| 47 |
+
self.output_var = output_var
|
| 48 |
+
|
| 49 |
+
def __iter__(self):
|
| 50 |
+
for item in self.source_datapipe:
|
| 51 |
+
yield self.output_var.apply_ops(item)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions
|
| 55 |
+
DATAPIPES_OPS = [
|
| 56 |
+
"_dataframes_as_tuples",
|
| 57 |
+
"groupby",
|
| 58 |
+
"_dataframes_filter",
|
| 59 |
+
"map",
|
| 60 |
+
"to_datapipe",
|
| 61 |
+
"shuffle",
|
| 62 |
+
"concat",
|
| 63 |
+
"batch",
|
| 64 |
+
"_dataframes_per_row",
|
| 65 |
+
"_dataframes_concat",
|
| 66 |
+
"_dataframes_shuffle",
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
UNIMPLEMENTED_ATTR = ["__deepcopy__", "__setstate__", "is_shardable", "apply_sharding"]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Capture:
|
| 73 |
+
# TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures
|
| 74 |
+
|
| 75 |
+
def __init__(self, schema_df=None):
|
| 76 |
+
self.ctx = {"operations": [], "variables": [], "schema_df": schema_df}
|
| 77 |
+
|
| 78 |
+
def __str__(self):
|
| 79 |
+
return self._ops_str()
|
| 80 |
+
|
| 81 |
+
def _ops_str(self):
|
| 82 |
+
res = ""
|
| 83 |
+
for op in self.ctx["operations"]:
|
| 84 |
+
if len(res) > 0:
|
| 85 |
+
res += "\n"
|
| 86 |
+
res += str(op)
|
| 87 |
+
return res
|
| 88 |
+
|
| 89 |
+
def __getstate__(self):
|
| 90 |
+
# TODO(VitalyFedyunin): Currently can't pickle (why?)
|
| 91 |
+
self.ctx["schema_df"] = None
|
| 92 |
+
for var in self.ctx["variables"]:
|
| 93 |
+
var.calculated_value = None
|
| 94 |
+
state = {}
|
| 95 |
+
for item in self.__dict__:
|
| 96 |
+
state[item] = getattr(self, item)
|
| 97 |
+
return state
|
| 98 |
+
|
| 99 |
+
def __setstate__(self, state):
|
| 100 |
+
for k, v in state.items():
|
| 101 |
+
setattr(self, k, v)
|
| 102 |
+
|
| 103 |
+
def __getattr__(self, attrname):
|
| 104 |
+
if attrname == "kwarg" or attrname == "kwargs":
|
| 105 |
+
raise RuntimeError("no kwargs!")
|
| 106 |
+
if attrname in ["__deepcopy__"]:
|
| 107 |
+
raise AttributeError
|
| 108 |
+
result = CaptureGetAttr(self, attrname, ctx=self.ctx)
|
| 109 |
+
return result
|
| 110 |
+
|
| 111 |
+
def __getitem__(self, key):
|
| 112 |
+
return CaptureGetItem(self, key, ctx=self.ctx)
|
| 113 |
+
|
| 114 |
+
def __setitem__(self, key, value):
|
| 115 |
+
self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
|
| 116 |
+
|
| 117 |
+
def __add__(self, add_val):
|
| 118 |
+
res = CaptureAdd(self, add_val, ctx=self.ctx)
|
| 119 |
+
var = CaptureVariable(res, ctx=self.ctx)
|
| 120 |
+
self.ctx["operations"].append(
|
| 121 |
+
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
| 122 |
+
)
|
| 123 |
+
return var
|
| 124 |
+
|
| 125 |
+
def __sub__(self, add_val):
|
| 126 |
+
res = CaptureSub(self, add_val, ctx=self.ctx)
|
| 127 |
+
var = CaptureVariable(res, ctx=self.ctx)
|
| 128 |
+
self.ctx["operations"].append(
|
| 129 |
+
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
| 130 |
+
)
|
| 131 |
+
return var
|
| 132 |
+
|
| 133 |
+
def __mul__(self, add_val):
|
| 134 |
+
res = CaptureMul(self, add_val, ctx=self.ctx)
|
| 135 |
+
var = CaptureVariable(res, ctx=self.ctx)
|
| 136 |
+
t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
| 137 |
+
self.ctx["operations"].append(t)
|
| 138 |
+
return var
|
| 139 |
+
|
| 140 |
+
def _is_context_empty(self):
|
| 141 |
+
return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
|
| 142 |
+
|
| 143 |
+
def apply_ops_2(self, dataframe):
|
| 144 |
+
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
|
| 145 |
+
self.ctx["variables"][0].calculated_value = dataframe
|
| 146 |
+
for op in self.ctx["operations"]:
|
| 147 |
+
op.execute()
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def columns(self):
|
| 151 |
+
self.apply_ops_2(self.ctx["schema_df"])
|
| 152 |
+
value = self.execute()
|
| 153 |
+
return value.columns
|
| 154 |
+
|
| 155 |
+
# TODO(VitalyFedyunin): Add tests
|
| 156 |
+
# TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture
|
| 157 |
+
|
| 158 |
+
def __call__(self, *args, **kwargs):
|
| 159 |
+
# TODO: Check if args or kwargs have more than one different context
|
| 160 |
+
if self._is_context_empty():
|
| 161 |
+
# TODO: Allow CaptureA to take context from mock
|
| 162 |
+
for arg in args:
|
| 163 |
+
if isinstance(arg, Capture) and not arg._is_context_empty():
|
| 164 |
+
self.ctx = arg.ctx
|
| 165 |
+
break
|
| 166 |
+
if self._is_context_empty():
|
| 167 |
+
for k, v in kwargs.items():
|
| 168 |
+
if isinstance(k, Capture) and not k._is_context_empty():
|
| 169 |
+
self.ctx = k.ctx
|
| 170 |
+
break
|
| 171 |
+
if isinstance(v, Capture) and not v._is_context_empty():
|
| 172 |
+
self.ctx = v.ctx
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
|
| 176 |
+
var = CaptureVariable(None, ctx=self.ctx)
|
| 177 |
+
t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
|
| 178 |
+
self.ctx["operations"].append(t)
|
| 179 |
+
return var
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class CaptureF(Capture):
|
| 183 |
+
def __init__(self, ctx=None, **kwargs):
|
| 184 |
+
if ctx is None:
|
| 185 |
+
self.ctx = {"operations": [], "variables": []}
|
| 186 |
+
else:
|
| 187 |
+
self.ctx = ctx
|
| 188 |
+
self.kwargs = kwargs
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class CaptureA(CaptureF):
|
| 192 |
+
def __str__(self):
|
| 193 |
+
return f"{self.kwargs['name']}"
|
| 194 |
+
|
| 195 |
+
def execute(self):
|
| 196 |
+
value = self.kwargs["real_attribute"]
|
| 197 |
+
return value
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class CaptureLikeMock:
|
| 201 |
+
def __init__(self, name):
|
| 202 |
+
import unittest.mock as mock
|
| 203 |
+
|
| 204 |
+
# TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead.
|
| 205 |
+
get_target, attribute = mock._get_target(name) # type: ignore[attr-defined]
|
| 206 |
+
self.get_target = get_target
|
| 207 |
+
self.attribute = attribute
|
| 208 |
+
self.name = name
|
| 209 |
+
|
| 210 |
+
def __enter__(self):
|
| 211 |
+
self.save = getattr(self.get_target(), self.attribute)
|
| 212 |
+
capt = CaptureA(name=self.name, real_attribute=self.save)
|
| 213 |
+
setattr(self.get_target(), self.attribute, capt)
|
| 214 |
+
|
| 215 |
+
def __exit__(self, *exc_info):
|
| 216 |
+
setattr(self.get_target(), self.attribute, self.save)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class CaptureCall(Capture):
|
| 220 |
+
def __init__(self, callable, ctx=None, **kwargs):
|
| 221 |
+
if ctx is None:
|
| 222 |
+
self.ctx = {"operations": [], "variables": []}
|
| 223 |
+
else:
|
| 224 |
+
self.ctx = ctx
|
| 225 |
+
self.kwargs = kwargs
|
| 226 |
+
self.callable = callable
|
| 227 |
+
|
| 228 |
+
def __str__(self):
|
| 229 |
+
return "{callable}({args},{kwargs})".format(
|
| 230 |
+
callable=self.callable, **self.kwargs
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
def execute(self):
|
| 234 |
+
# TODO: VitalyFedyunin execute kwargs and maybe nested structures
|
| 235 |
+
executed_args = []
|
| 236 |
+
for arg in self.kwargs["args"]:
|
| 237 |
+
if isinstance(arg, Capture):
|
| 238 |
+
executed_args.append(arg.execute())
|
| 239 |
+
else:
|
| 240 |
+
executed_args.append(arg)
|
| 241 |
+
left = get_val(self.callable)
|
| 242 |
+
return left(*executed_args, **self.kwargs["kwargs"])
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class CaptureVariableAssign(CaptureF):
|
| 246 |
+
def __str__(self):
|
| 247 |
+
variable = self.kwargs["variable"]
|
| 248 |
+
value = self.kwargs["value"]
|
| 249 |
+
return f"{variable} = {value}"
|
| 250 |
+
|
| 251 |
+
def execute(self):
|
| 252 |
+
self.kwargs["variable"].calculated_value = self.kwargs["value"].execute()
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class CaptureVariable(Capture):
|
| 256 |
+
# TODO(VitalyFedyunin): This should be atomic and thread safe
|
| 257 |
+
names_idx = 0
|
| 258 |
+
|
| 259 |
+
def __init__(self, value, ctx):
|
| 260 |
+
if CaptureControl.disabled:
|
| 261 |
+
raise RuntimeError("Attempting to create capture variable with capture off")
|
| 262 |
+
self.ctx = ctx
|
| 263 |
+
self.value = value
|
| 264 |
+
self.name = f"var_{CaptureVariable.names_idx}"
|
| 265 |
+
CaptureVariable.names_idx += 1
|
| 266 |
+
self.ctx["variables"].append(self)
|
| 267 |
+
|
| 268 |
+
def __str__(self):
|
| 269 |
+
return self.name
|
| 270 |
+
|
| 271 |
+
def execute(self):
|
| 272 |
+
return self.calculated_value
|
| 273 |
+
|
| 274 |
+
def apply_ops(self, dataframe):
|
| 275 |
+
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
|
| 276 |
+
self.ctx["variables"][0].calculated_value = dataframe
|
| 277 |
+
for op in self.ctx["operations"]:
|
| 278 |
+
op.execute()
|
| 279 |
+
return self.calculated_value
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class CaptureGetItem(Capture):
|
| 283 |
+
def __init__(self, left, key, ctx):
|
| 284 |
+
self.ctx = ctx
|
| 285 |
+
self.left = left
|
| 286 |
+
self.key = key
|
| 287 |
+
|
| 288 |
+
def __str__(self):
|
| 289 |
+
return f"{self.left}[{get_val(self.key)}]"
|
| 290 |
+
|
| 291 |
+
def execute(self):
|
| 292 |
+
left = self.left.execute()
|
| 293 |
+
return left[self.key]
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class CaptureSetItem(Capture):
|
| 297 |
+
def __init__(self, left, key, value, ctx):
|
| 298 |
+
self.ctx = ctx
|
| 299 |
+
self.left = left
|
| 300 |
+
self.key = key
|
| 301 |
+
self.value = value
|
| 302 |
+
|
| 303 |
+
def __str__(self):
|
| 304 |
+
return f"{self.left}[{get_val(self.key)}] = {self.value}"
|
| 305 |
+
|
| 306 |
+
def execute(self):
|
| 307 |
+
left = self.left.execute()
|
| 308 |
+
value = self.value.execute()
|
| 309 |
+
left[self.key] = value
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class CaptureAdd(Capture):
|
| 313 |
+
def __init__(self, left, right, ctx):
|
| 314 |
+
self.ctx = ctx
|
| 315 |
+
self.left = left
|
| 316 |
+
self.right = right
|
| 317 |
+
|
| 318 |
+
def __str__(self):
|
| 319 |
+
return f"{self.left} + {self.right}"
|
| 320 |
+
|
| 321 |
+
def execute(self):
|
| 322 |
+
return get_val(self.left) + get_val(self.right)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class CaptureMul(Capture):
|
| 326 |
+
def __init__(self, left, right, ctx):
|
| 327 |
+
self.ctx = ctx
|
| 328 |
+
self.left = left
|
| 329 |
+
self.right = right
|
| 330 |
+
|
| 331 |
+
def __str__(self):
|
| 332 |
+
return f"{self.left} * {self.right}"
|
| 333 |
+
|
| 334 |
+
def execute(self):
|
| 335 |
+
return get_val(self.left) * get_val(self.right)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class CaptureSub(Capture):
|
| 339 |
+
def __init__(self, left, right, ctx):
|
| 340 |
+
self.ctx = ctx
|
| 341 |
+
self.left = left
|
| 342 |
+
self.right = right
|
| 343 |
+
|
| 344 |
+
def __str__(self):
|
| 345 |
+
return f"{self.left} - {self.right}"
|
| 346 |
+
|
| 347 |
+
def execute(self):
|
| 348 |
+
return get_val(self.left) - get_val(self.right)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class CaptureGetAttr(Capture):
|
| 352 |
+
def __init__(self, src, name, ctx):
|
| 353 |
+
self.ctx = ctx
|
| 354 |
+
self.src = src
|
| 355 |
+
self.name = name
|
| 356 |
+
|
| 357 |
+
def __str__(self):
|
| 358 |
+
return f"{self.src}.{self.name}"
|
| 359 |
+
|
| 360 |
+
def execute(self):
|
| 361 |
+
val = get_val(self.src)
|
| 362 |
+
return getattr(val, self.name)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def get_val(capture):
|
| 366 |
+
if isinstance(capture, Capture):
|
| 367 |
+
return capture.execute()
|
| 368 |
+
elif isinstance(capture, str):
|
| 369 |
+
return f'"{capture}"'
|
| 370 |
+
else:
|
| 371 |
+
return capture
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class CaptureInitial(CaptureVariable):
|
| 375 |
+
def __init__(self, schema_df=None):
|
| 376 |
+
new_ctx: Dict[str, List[Any]] = {
|
| 377 |
+
"operations": [],
|
| 378 |
+
"variables": [],
|
| 379 |
+
"schema_df": schema_df,
|
| 380 |
+
}
|
| 381 |
+
super().__init__(None, new_ctx)
|
| 382 |
+
self.name = f"input_{self.name}"
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class CaptureDataFrame(CaptureInitial):
|
| 386 |
+
pass
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
|
| 390 |
+
def as_datapipe(self):
|
| 391 |
+
return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)
|
| 392 |
+
|
| 393 |
+
def raw_iterator(self):
|
| 394 |
+
return self.as_datapipe().__iter__()
|
| 395 |
+
|
| 396 |
+
def __iter__(self):
|
| 397 |
+
return iter(self._dataframes_as_tuples())
|
| 398 |
+
|
| 399 |
+
def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF):
|
| 400 |
+
dp = self._dataframes_per_row()._dataframes_concat(batch_size)
|
| 401 |
+
dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class)
|
| 402 |
+
dp._dp_contains_dataframe = True
|
| 403 |
+
return dp
|
| 404 |
+
|
| 405 |
+
def groupby(
|
| 406 |
+
self,
|
| 407 |
+
group_key_fn,
|
| 408 |
+
*,
|
| 409 |
+
buffer_size=10000,
|
| 410 |
+
group_size=None,
|
| 411 |
+
guaranteed_group_size=None,
|
| 412 |
+
drop_remaining=False,
|
| 413 |
+
):
|
| 414 |
+
dp = self._dataframes_per_row()
|
| 415 |
+
dp = dp.as_datapipe().groupby(
|
| 416 |
+
group_key_fn,
|
| 417 |
+
buffer_size=buffer_size,
|
| 418 |
+
group_size=group_size,
|
| 419 |
+
guaranteed_group_size=guaranteed_group_size,
|
| 420 |
+
drop_remaining=drop_remaining,
|
| 421 |
+
)
|
| 422 |
+
return dp
|
| 423 |
+
|
| 424 |
+
def shuffle(self, *args, **kwargs):
|
| 425 |
+
return self._dataframes_shuffle(*args, **kwargs)
|
| 426 |
+
|
| 427 |
+
def filter(self, *args, **kwargs):
|
| 428 |
+
return self._dataframes_filter(*args, **kwargs)
|
| 429 |
+
|
| 430 |
+
def collate(self, *args, **kwargs):
|
| 431 |
+
raise RuntimeError("Can't collate unbatched DataFrames stream")
|
| 432 |
+
|
| 433 |
+
def __getattr__(self, attrname): # ?
|
| 434 |
+
if attrname in UNIMPLEMENTED_ATTR:
|
| 435 |
+
raise AttributeError("Attempting to get ", attrname)
|
| 436 |
+
if attrname in DATAPIPES_OPS:
|
| 437 |
+
return (self.as_datapipe()).__getattr__(attrname)
|
| 438 |
+
return super().__getattr__(attrname)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
@functional_datapipe("trace_as_dataframe")
|
| 442 |
+
class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc]
|
| 443 |
+
source_datapipe: Optional[Any] = None
|
| 444 |
+
|
| 445 |
+
# TODO(VitalyFedyunin): Must implement all special functions of datapipes
|
| 446 |
+
|
| 447 |
+
def set_shuffle_settings(self, *args, **kwargs):
|
| 448 |
+
pass
|
| 449 |
+
|
| 450 |
+
def is_shardable(self):
|
| 451 |
+
return False
|
| 452 |
+
|
| 453 |
+
def __init__(self, source_datapipe, schema_df=None):
|
| 454 |
+
self.source_datapipe = source_datapipe
|
| 455 |
+
if schema_df is None:
|
| 456 |
+
schema_df = next(iter(self.source_datapipe))
|
| 457 |
+
super().__init__(schema_df=schema_df)
|
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
from torch.utils.data.datapipes._decorator import functional_datapipe
|
| 5 |
+
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
|
| 6 |
+
from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"ConcatDataFramesPipe",
|
| 11 |
+
"DataFramesAsTuplesPipe",
|
| 12 |
+
"ExampleAggregateAsDataFrames",
|
| 13 |
+
"FilterDataFramesPipe",
|
| 14 |
+
"PerRowDataFramesPipe",
|
| 15 |
+
"ShuffleDataFramesPipe",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@functional_datapipe("_dataframes_as_tuples")
|
| 20 |
+
class DataFramesAsTuplesPipe(IterDataPipe):
|
| 21 |
+
def __init__(self, source_datapipe):
|
| 22 |
+
self.source_datapipe = source_datapipe
|
| 23 |
+
|
| 24 |
+
def __iter__(self):
|
| 25 |
+
for df in self.source_datapipe:
|
| 26 |
+
# for record in df.to_records(index=False):
|
| 27 |
+
yield from df_wrapper.iterate(df)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@functional_datapipe("_dataframes_per_row", enable_df_api_tracing=True)
|
| 31 |
+
class PerRowDataFramesPipe(DFIterDataPipe):
|
| 32 |
+
def __init__(self, source_datapipe):
|
| 33 |
+
self.source_datapipe = source_datapipe
|
| 34 |
+
|
| 35 |
+
def __iter__(self):
|
| 36 |
+
for df in self.source_datapipe:
|
| 37 |
+
# TODO(VitalyFedyunin): Replacing with TorchArrow only API, as we are dropping pandas as followup
|
| 38 |
+
for i in range(len(df)):
|
| 39 |
+
yield df[i : i + 1]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@functional_datapipe("_dataframes_concat", enable_df_api_tracing=True)
|
| 43 |
+
class ConcatDataFramesPipe(DFIterDataPipe):
|
| 44 |
+
def __init__(self, source_datapipe, batch=3):
|
| 45 |
+
self.source_datapipe = source_datapipe
|
| 46 |
+
self.n_batch = batch
|
| 47 |
+
|
| 48 |
+
def __iter__(self):
|
| 49 |
+
buffer = []
|
| 50 |
+
for df in self.source_datapipe:
|
| 51 |
+
buffer.append(df)
|
| 52 |
+
if len(buffer) == self.n_batch:
|
| 53 |
+
yield df_wrapper.concat(buffer)
|
| 54 |
+
buffer = []
|
| 55 |
+
if len(buffer):
|
| 56 |
+
yield df_wrapper.concat(buffer)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@functional_datapipe("_dataframes_shuffle", enable_df_api_tracing=True)
|
| 60 |
+
class ShuffleDataFramesPipe(DFIterDataPipe):
|
| 61 |
+
def __init__(self, source_datapipe):
|
| 62 |
+
self.source_datapipe = source_datapipe
|
| 63 |
+
|
| 64 |
+
def __iter__(self):
|
| 65 |
+
size = None
|
| 66 |
+
all_buffer = []
|
| 67 |
+
for df in self.source_datapipe:
|
| 68 |
+
if size is None:
|
| 69 |
+
size = df_wrapper.get_len(df)
|
| 70 |
+
for i in range(df_wrapper.get_len(df)):
|
| 71 |
+
all_buffer.append(df_wrapper.get_item(df, i))
|
| 72 |
+
random.shuffle(all_buffer)
|
| 73 |
+
buffer = []
|
| 74 |
+
for df in all_buffer:
|
| 75 |
+
buffer.append(df)
|
| 76 |
+
if len(buffer) == size:
|
| 77 |
+
yield df_wrapper.concat(buffer)
|
| 78 |
+
buffer = []
|
| 79 |
+
if len(buffer):
|
| 80 |
+
yield df_wrapper.concat(buffer)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@functional_datapipe("_dataframes_filter", enable_df_api_tracing=True)
|
| 84 |
+
class FilterDataFramesPipe(DFIterDataPipe):
|
| 85 |
+
def __init__(self, source_datapipe, filter_fn):
|
| 86 |
+
self.source_datapipe = source_datapipe
|
| 87 |
+
self.filter_fn = filter_fn
|
| 88 |
+
|
| 89 |
+
def __iter__(self):
|
| 90 |
+
size = None
|
| 91 |
+
all_buffer = []
|
| 92 |
+
filter_res = []
|
| 93 |
+
for df in self.source_datapipe:
|
| 94 |
+
if size is None:
|
| 95 |
+
size = len(df.index)
|
| 96 |
+
for i in range(len(df.index)):
|
| 97 |
+
all_buffer.append(df[i : i + 1])
|
| 98 |
+
filter_res.append(self.filter_fn(df.iloc[i]))
|
| 99 |
+
|
| 100 |
+
buffer = []
|
| 101 |
+
for df, res in zip(all_buffer, filter_res):
|
| 102 |
+
if res:
|
| 103 |
+
buffer.append(df)
|
| 104 |
+
if len(buffer) == size:
|
| 105 |
+
yield df_wrapper.concat(buffer)
|
| 106 |
+
buffer = []
|
| 107 |
+
if len(buffer):
|
| 108 |
+
yield df_wrapper.concat(buffer)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@functional_datapipe("_to_dataframes_pipe", enable_df_api_tracing=True)
|
| 112 |
+
class ExampleAggregateAsDataFrames(DFIterDataPipe):
|
| 113 |
+
def __init__(self, source_datapipe, dataframe_size=10, columns=None):
|
| 114 |
+
self.source_datapipe = source_datapipe
|
| 115 |
+
self.columns = columns
|
| 116 |
+
self.dataframe_size = dataframe_size
|
| 117 |
+
|
| 118 |
+
def _as_list(self, item):
|
| 119 |
+
try:
|
| 120 |
+
return list(item)
|
| 121 |
+
except (
|
| 122 |
+
Exception
|
| 123 |
+
): # TODO(VitalyFedyunin): Replace with better iterable exception
|
| 124 |
+
return [item]
|
| 125 |
+
|
| 126 |
+
def __iter__(self):
|
| 127 |
+
aggregate = []
|
| 128 |
+
for item in self.source_datapipe:
|
| 129 |
+
aggregate.append(self._as_list(item))
|
| 130 |
+
if len(aggregate) == self.dataframe_size:
|
| 131 |
+
yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
|
| 132 |
+
aggregate = []
|
| 133 |
+
if len(aggregate) > 0:
|
| 134 |
+
yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
|
.venv/Lib/site-packages/torch/utils/data/datapipes/dataframe/structures.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
|
| 3 |
+
from torch.utils.data.datapipes.datapipe import DataChunk
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
__all__ = ["DataChunkDF"]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DataChunkDF(DataChunk):
|
| 10 |
+
"""DataChunkDF iterating over individual items inside of DataFrame containers, to access DataFrames user `raw_iterator`."""
|
| 11 |
+
|
| 12 |
+
def __iter__(self):
|
| 13 |
+
for df in self.items:
|
| 14 |
+
yield from df_wrapper.iterate(df)
|
| 15 |
+
|
| 16 |
+
def __len__(self):
|
| 17 |
+
total_len = 0
|
| 18 |
+
for df in self.items:
|
| 19 |
+
total_len += df_wrapper.get_len(df)
|
| 20 |
+
return total_len
|
.venv/Lib/site-packages/torch/utils/data/datapipes/datapipe.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import pickle
|
| 3 |
+
from typing import Callable, Dict, Iterable, Iterator, List, Optional, TypeVar
|
| 4 |
+
|
| 5 |
+
from torch.utils._import_utils import import_dill
|
| 6 |
+
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
|
| 7 |
+
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
|
| 8 |
+
from torch.utils.data.datapipes.utils.common import (
|
| 9 |
+
_deprecation_warning,
|
| 10 |
+
_iter_deprecated_functional_names,
|
| 11 |
+
_map_deprecated_functional_names,
|
| 12 |
+
)
|
| 13 |
+
from torch.utils.data.dataset import Dataset, IterableDataset
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
dill = import_dill()
|
| 17 |
+
HAS_DILL = dill is not None
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"DataChunk",
|
| 21 |
+
"DFIterDataPipe",
|
| 22 |
+
"IterDataPipe",
|
| 23 |
+
"MapDataPipe",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_T = TypeVar("_T")
|
| 28 |
+
_T_co = TypeVar("_T_co", covariant=True)
|
| 29 |
+
|
| 30 |
+
UNTRACABLE_DATAFRAME_PIPES = [
|
| 31 |
+
"batch", # As it returns DataChunks
|
| 32 |
+
"groupby", # As it returns DataChunks
|
| 33 |
+
"_dataframes_as_tuples", # As it unpacks DF
|
| 34 |
+
"trace_as_dataframe", # As it used to mark DF for tracing
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DataChunk(List[_T]):
|
| 39 |
+
def __init__(self, items: Iterable[_T]) -> None:
|
| 40 |
+
items = list(items)
|
| 41 |
+
super().__init__(items)
|
| 42 |
+
self.items = items
|
| 43 |
+
|
| 44 |
+
def as_str(self, indent: str = "") -> str:
|
| 45 |
+
return indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
|
| 46 |
+
|
| 47 |
+
def __iter__(self) -> Iterator[_T]:
|
| 48 |
+
yield from super().__iter__()
|
| 49 |
+
|
| 50 |
+
def raw_iterator(self) -> Iterator[_T]:
|
| 51 |
+
yield from self.items
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
| 55 |
+
r"""
|
| 56 |
+
Iterable-style DataPipe.
|
| 57 |
+
|
| 58 |
+
All DataPipes that represent an iterable of data samples should subclass this.
|
| 59 |
+
This style of DataPipes is particularly useful when data come from a stream, or
|
| 60 |
+
when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its
|
| 61 |
+
elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``.
|
| 62 |
+
|
| 63 |
+
All subclasses should overwrite :meth:`__iter__`, which would return an
|
| 64 |
+
iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its
|
| 65 |
+
method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should
|
| 66 |
+
override ``reset()`` if necessary. The common usages include resetting buffers, pointers,
|
| 67 |
+
and various state variables within the custom ``IterDataPipe``.
|
| 68 |
+
|
| 69 |
+
Note:
|
| 70 |
+
Only `one` iterator can be valid for each ``IterDataPipe`` at a time,
|
| 71 |
+
and the creation a second iterator will invalidate the first one. This constraint is necessary because
|
| 72 |
+
some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators.
|
| 73 |
+
The code example below presents details on how this constraint looks in practice.
|
| 74 |
+
If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_.
|
| 75 |
+
|
| 76 |
+
These DataPipes can be invoked in two ways, using the class constructor or applying their
|
| 77 |
+
functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes).
|
| 78 |
+
You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple
|
| 79 |
+
operations in succession.
|
| 80 |
+
|
| 81 |
+
.. _GitHub IterDataPipe Single Iterator Issue:
|
| 82 |
+
https://github.com/pytorch/data/issues/45
|
| 83 |
+
|
| 84 |
+
Note:
|
| 85 |
+
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
|
| 86 |
+
item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader`
|
| 87 |
+
iterator. When :attr:`num_workers > 0`, each worker process will have a
|
| 88 |
+
different copy of the DataPipe object, so it is often desired to configure
|
| 89 |
+
each copy independently to avoid having duplicate data returned from the
|
| 90 |
+
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
|
| 91 |
+
process, returns information about the worker. It can be used in either the
|
| 92 |
+
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
|
| 93 |
+
:attr:`worker_init_fn` option to modify each copy's behavior.
|
| 94 |
+
|
| 95 |
+
Examples:
|
| 96 |
+
General Usage:
|
| 97 |
+
>>> # xdoctest: +SKIP
|
| 98 |
+
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
| 99 |
+
>>> dp = IterableWrapper(range(10))
|
| 100 |
+
>>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
|
| 101 |
+
>>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended)
|
| 102 |
+
>>> list(map_dp_1)
|
| 103 |
+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 104 |
+
>>> list(map_dp_2)
|
| 105 |
+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 106 |
+
>>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0)
|
| 107 |
+
>>> list(filter_dp)
|
| 108 |
+
[2, 4, 6, 8, 10]
|
| 109 |
+
Single Iterator Constraint Example:
|
| 110 |
+
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
| 111 |
+
>>> source_dp = IterableWrapper(range(10))
|
| 112 |
+
>>> it1 = iter(source_dp)
|
| 113 |
+
>>> list(it1)
|
| 114 |
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
| 115 |
+
>>> it1 = iter(source_dp)
|
| 116 |
+
>>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1`
|
| 117 |
+
>>> next(it2)
|
| 118 |
+
0
|
| 119 |
+
>>> next(it1) # Further usage of `it1` will raise a `RunTimeError`
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
functions: Dict[str, Callable] = {}
|
| 123 |
+
reduce_ex_hook: Optional[Callable] = None
|
| 124 |
+
getstate_hook: Optional[Callable] = None
|
| 125 |
+
str_hook: Optional[Callable] = None
|
| 126 |
+
repr_hook: Optional[Callable] = None
|
| 127 |
+
_valid_iterator_id: Optional[int] = None
|
| 128 |
+
_number_of_samples_yielded: int = 0
|
| 129 |
+
_snapshot_state: _SnapshotState = _SnapshotState.NotStarted
|
| 130 |
+
_fast_forward_iterator: Optional[Iterator] = None
|
| 131 |
+
|
| 132 |
+
def __iter__(self) -> Iterator[_T_co]:
|
| 133 |
+
return self
|
| 134 |
+
|
| 135 |
+
def __getattr__(self, attribute_name):
|
| 136 |
+
if attribute_name in IterDataPipe.functions:
|
| 137 |
+
if attribute_name in _iter_deprecated_functional_names:
|
| 138 |
+
kwargs = _iter_deprecated_functional_names[attribute_name]
|
| 139 |
+
_deprecation_warning(**kwargs)
|
| 140 |
+
f = IterDataPipe.functions[attribute_name]
|
| 141 |
+
function = functools.partial(f, self)
|
| 142 |
+
functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
|
| 143 |
+
return function
|
| 144 |
+
else:
|
| 145 |
+
raise AttributeError(
|
| 146 |
+
f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
@classmethod
|
| 150 |
+
def register_function(cls, function_name, function):
|
| 151 |
+
cls.functions[function_name] = function
|
| 152 |
+
|
| 153 |
+
@classmethod
|
| 154 |
+
def register_datapipe_as_function(
|
| 155 |
+
cls, function_name, cls_to_register, enable_df_api_tracing=False
|
| 156 |
+
):
|
| 157 |
+
if function_name in cls.functions:
|
| 158 |
+
raise Exception( # noqa: TRY002
|
| 159 |
+
f"Unable to add DataPipe function name {function_name} as it is already taken"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
|
| 163 |
+
result_pipe = cls(source_dp, *args, **kwargs)
|
| 164 |
+
if isinstance(result_pipe, IterDataPipe):
|
| 165 |
+
if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
|
| 166 |
+
if function_name not in UNTRACABLE_DATAFRAME_PIPES:
|
| 167 |
+
result_pipe = result_pipe.trace_as_dataframe()
|
| 168 |
+
|
| 169 |
+
return result_pipe
|
| 170 |
+
|
| 171 |
+
function = functools.partial(
|
| 172 |
+
class_function, cls_to_register, enable_df_api_tracing
|
| 173 |
+
)
|
| 174 |
+
functools.update_wrapper(
|
| 175 |
+
wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
|
| 176 |
+
)
|
| 177 |
+
cls.functions[function_name] = function
|
| 178 |
+
|
| 179 |
+
def __getstate__(self):
|
| 180 |
+
"""
|
| 181 |
+
Serialize `lambda` functions when `dill` is available.
|
| 182 |
+
|
| 183 |
+
If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
|
| 184 |
+
`__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
|
| 185 |
+
"""
|
| 186 |
+
state = self.__dict__
|
| 187 |
+
if IterDataPipe.getstate_hook is not None:
|
| 188 |
+
return IterDataPipe.getstate_hook(state)
|
| 189 |
+
return state
|
| 190 |
+
|
| 191 |
+
def __reduce_ex__(self, *args, **kwargs):
|
| 192 |
+
if IterDataPipe.reduce_ex_hook is not None:
|
| 193 |
+
try:
|
| 194 |
+
return IterDataPipe.reduce_ex_hook(self)
|
| 195 |
+
except NotImplementedError:
|
| 196 |
+
pass
|
| 197 |
+
return super().__reduce_ex__(*args, **kwargs)
|
| 198 |
+
|
| 199 |
+
@classmethod
|
| 200 |
+
def set_getstate_hook(cls, hook_fn):
|
| 201 |
+
if IterDataPipe.getstate_hook is not None and hook_fn is not None:
|
| 202 |
+
raise RuntimeError("Attempt to override existing getstate_hook")
|
| 203 |
+
IterDataPipe.getstate_hook = hook_fn
|
| 204 |
+
|
| 205 |
+
@classmethod
|
| 206 |
+
def set_reduce_ex_hook(cls, hook_fn):
|
| 207 |
+
if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None:
|
| 208 |
+
raise RuntimeError("Attempt to override existing reduce_ex_hook")
|
| 209 |
+
IterDataPipe.reduce_ex_hook = hook_fn
|
| 210 |
+
|
| 211 |
+
def __repr__(self):
|
| 212 |
+
if self.repr_hook is not None:
|
| 213 |
+
return self.repr_hook(self)
|
| 214 |
+
# Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
|
| 215 |
+
return str(self.__class__.__qualname__)
|
| 216 |
+
|
| 217 |
+
def __str__(self):
|
| 218 |
+
if self.str_hook is not None:
|
| 219 |
+
return self.str_hook(self)
|
| 220 |
+
# Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
|
| 221 |
+
return str(self.__class__.__qualname__)
|
| 222 |
+
|
| 223 |
+
def __dir__(self):
|
| 224 |
+
# for auto-completion in a REPL (e.g. Jupyter notebook)
|
| 225 |
+
return list(super().__dir__()) + list(self.functions.keys())
|
| 226 |
+
|
| 227 |
+
def reset(self) -> None:
|
| 228 |
+
r"""
|
| 229 |
+
Reset the `IterDataPipe` to the initial state.
|
| 230 |
+
|
| 231 |
+
By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities,
|
| 232 |
+
they may want to override this method with implementations that
|
| 233 |
+
may clear the buffers and reset pointers of the DataPipe.
|
| 234 |
+
The `reset` method is always called when `__iter__` is called as part of `hook_iterator`.
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class DFIterDataPipe(IterDataPipe):
|
| 239 |
+
def _is_dfpipe(self):
|
| 240 |
+
return True
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
|
| 244 |
+
r"""
|
| 245 |
+
Map-style DataPipe.
|
| 246 |
+
|
| 247 |
+
All datasets that represent a map from keys to data samples should subclass this.
|
| 248 |
+
Subclasses should overwrite :meth:`__getitem__`, supporting fetching a
|
| 249 |
+
data sample for a given, unique key. Subclasses can also optionally overwrite
|
| 250 |
+
:meth:`__len__`, which is expected to return the size of the dataset by many
|
| 251 |
+
:class:`~torch.utils.data.Sampler` implementations and the default options
|
| 252 |
+
of :class:`~torch.utils.data.DataLoader`.
|
| 253 |
+
|
| 254 |
+
These DataPipes can be invoked in two ways, using the class constructor or applying their
|
| 255 |
+
functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes).
|
| 256 |
+
|
| 257 |
+
Note:
|
| 258 |
+
:class:`~torch.utils.data.DataLoader` by default constructs an index
|
| 259 |
+
sampler that yields integral indices. To make it work with a map-style
|
| 260 |
+
DataPipe with non-integral indices/keys, a custom sampler must be provided.
|
| 261 |
+
|
| 262 |
+
Example:
|
| 263 |
+
>>> # xdoctest: +SKIP
|
| 264 |
+
>>> from torchdata.datapipes.map import SequenceWrapper, Mapper
|
| 265 |
+
>>> dp = SequenceWrapper(range(10))
|
| 266 |
+
>>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended)
|
| 267 |
+
>>> list(map_dp_1)
|
| 268 |
+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 269 |
+
>>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor
|
| 270 |
+
>>> list(map_dp_2)
|
| 271 |
+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 272 |
+
>>> batch_dp = map_dp_1.batch(batch_size=2)
|
| 273 |
+
>>> list(batch_dp)
|
| 274 |
+
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
functions: Dict[str, Callable] = {}
|
| 278 |
+
reduce_ex_hook: Optional[Callable] = None
|
| 279 |
+
getstate_hook: Optional[Callable] = None
|
| 280 |
+
str_hook: Optional[Callable] = None
|
| 281 |
+
repr_hook: Optional[Callable] = None
|
| 282 |
+
|
| 283 |
+
def __getattr__(self, attribute_name):
|
| 284 |
+
if attribute_name in MapDataPipe.functions:
|
| 285 |
+
if attribute_name in _map_deprecated_functional_names:
|
| 286 |
+
kwargs = _map_deprecated_functional_names[attribute_name]
|
| 287 |
+
_deprecation_warning(**kwargs)
|
| 288 |
+
f = MapDataPipe.functions[attribute_name]
|
| 289 |
+
function = functools.partial(f, self)
|
| 290 |
+
functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
|
| 291 |
+
return function
|
| 292 |
+
else:
|
| 293 |
+
raise AttributeError(
|
| 294 |
+
f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
@classmethod
|
| 298 |
+
def register_function(cls, function_name, function):
|
| 299 |
+
cls.functions[function_name] = function
|
| 300 |
+
|
| 301 |
+
@classmethod
|
| 302 |
+
def register_datapipe_as_function(cls, function_name, cls_to_register):
|
| 303 |
+
if function_name in cls.functions:
|
| 304 |
+
raise Exception( # noqa: TRY002
|
| 305 |
+
f"Unable to add DataPipe function name {function_name} as it is already taken"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def class_function(cls, source_dp, *args, **kwargs):
|
| 309 |
+
result_pipe = cls(source_dp, *args, **kwargs)
|
| 310 |
+
return result_pipe
|
| 311 |
+
|
| 312 |
+
function = functools.partial(class_function, cls_to_register)
|
| 313 |
+
functools.update_wrapper(
|
| 314 |
+
wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
|
| 315 |
+
)
|
| 316 |
+
cls.functions[function_name] = function
|
| 317 |
+
|
| 318 |
+
def __getstate__(self):
|
| 319 |
+
"""
|
| 320 |
+
Serialize `lambda` functions when `dill` is available.
|
| 321 |
+
|
| 322 |
+
If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
|
| 323 |
+
`__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
|
| 324 |
+
"""
|
| 325 |
+
state = self.__dict__
|
| 326 |
+
if MapDataPipe.getstate_hook is not None:
|
| 327 |
+
return MapDataPipe.getstate_hook(state)
|
| 328 |
+
return state
|
| 329 |
+
|
| 330 |
+
def __reduce_ex__(self, *args, **kwargs):
|
| 331 |
+
if MapDataPipe.reduce_ex_hook is not None:
|
| 332 |
+
try:
|
| 333 |
+
return MapDataPipe.reduce_ex_hook(self)
|
| 334 |
+
except NotImplementedError:
|
| 335 |
+
pass
|
| 336 |
+
return super().__reduce_ex__(*args, **kwargs)
|
| 337 |
+
|
| 338 |
+
@classmethod
|
| 339 |
+
def set_getstate_hook(cls, hook_fn):
|
| 340 |
+
if MapDataPipe.getstate_hook is not None and hook_fn is not None:
|
| 341 |
+
raise RuntimeError("Attempt to override existing getstate_hook")
|
| 342 |
+
MapDataPipe.getstate_hook = hook_fn
|
| 343 |
+
|
| 344 |
+
@classmethod
|
| 345 |
+
def set_reduce_ex_hook(cls, hook_fn):
|
| 346 |
+
if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None:
|
| 347 |
+
raise RuntimeError("Attempt to override existing reduce_ex_hook")
|
| 348 |
+
MapDataPipe.reduce_ex_hook = hook_fn
|
| 349 |
+
|
| 350 |
+
def __repr__(self):
|
| 351 |
+
if self.repr_hook is not None:
|
| 352 |
+
return self.repr_hook(self)
|
| 353 |
+
# Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
|
| 354 |
+
return str(self.__class__.__qualname__)
|
| 355 |
+
|
| 356 |
+
def __str__(self):
|
| 357 |
+
if self.str_hook is not None:
|
| 358 |
+
return self.str_hook(self)
|
| 359 |
+
# Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
|
| 360 |
+
return str(self.__class__.__qualname__)
|
| 361 |
+
|
| 362 |
+
def __dir__(self):
|
| 363 |
+
# for auto-completion in a REPL (e.g. Jupyter notebook)
|
| 364 |
+
return list(super().__dir__()) + list(self.functions.keys())
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class _DataPipeSerializationWrapper:
|
| 368 |
+
def __init__(self, datapipe):
|
| 369 |
+
self._datapipe = datapipe
|
| 370 |
+
|
| 371 |
+
def __getstate__(self):
|
| 372 |
+
use_dill = False
|
| 373 |
+
try:
|
| 374 |
+
value = pickle.dumps(self._datapipe)
|
| 375 |
+
except Exception:
|
| 376 |
+
if HAS_DILL:
|
| 377 |
+
value = dill.dumps(self._datapipe)
|
| 378 |
+
use_dill = True
|
| 379 |
+
else:
|
| 380 |
+
raise
|
| 381 |
+
return (value, use_dill)
|
| 382 |
+
|
| 383 |
+
def __setstate__(self, state):
|
| 384 |
+
value, use_dill = state
|
| 385 |
+
if use_dill:
|
| 386 |
+
self._datapipe = dill.loads(value)
|
| 387 |
+
else:
|
| 388 |
+
self._datapipe = pickle.loads(value)
|
| 389 |
+
|
| 390 |
+
def __len__(self):
|
| 391 |
+
try:
|
| 392 |
+
return len(self._datapipe)
|
| 393 |
+
except Exception as e:
|
| 394 |
+
raise TypeError(
|
| 395 |
+
f"{type(self).__name__} instance doesn't have valid length"
|
| 396 |
+
) from e
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
|
| 400 |
+
def __init__(self, datapipe: IterDataPipe[_T_co]):
|
| 401 |
+
super().__init__(datapipe)
|
| 402 |
+
self._datapipe_iter: Optional[Iterator[_T_co]] = None
|
| 403 |
+
|
| 404 |
+
def __iter__(self) -> "_IterDataPipeSerializationWrapper":
|
| 405 |
+
self._datapipe_iter = iter(self._datapipe)
|
| 406 |
+
return self
|
| 407 |
+
|
| 408 |
+
def __next__(self) -> _T_co: # type: ignore[type-var]
|
| 409 |
+
assert self._datapipe_iter is not None
|
| 410 |
+
return next(self._datapipe_iter)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
|
| 414 |
+
def __getitem__(self, idx):
|
| 415 |
+
return self._datapipe[idx]
|
.venv/Lib/site-packages/torch/utils/data/datapipes/datapipe.pyi
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection
|
| 3 |
+
# The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt
|
| 4 |
+
# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other
|
| 5 |
+
# classes/objects here, even though we are not injecting extra code into them at the moment.
|
| 6 |
+
|
| 7 |
+
from typing import (
|
| 8 |
+
Any,
|
| 9 |
+
Callable,
|
| 10 |
+
Dict,
|
| 11 |
+
Iterable,
|
| 12 |
+
Iterator,
|
| 13 |
+
List,
|
| 14 |
+
Literal,
|
| 15 |
+
Optional,
|
| 16 |
+
Type,
|
| 17 |
+
TypeVar,
|
| 18 |
+
Union,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from torch.utils.data import Dataset, default_collate, IterableDataset
|
| 22 |
+
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
|
| 23 |
+
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
|
| 24 |
+
|
| 25 |
+
_T = TypeVar("_T")
|
| 26 |
+
_T_co = TypeVar("_T_co", covariant=True)
|
| 27 |
+
UNTRACABLE_DATAFRAME_PIPES: Any
|
| 28 |
+
|
| 29 |
+
class DataChunk(List[_T]):
|
| 30 |
+
items: List[_T]
|
| 31 |
+
def __init__(self, items: Iterable[_T]) -> None: ...
|
| 32 |
+
def as_str(self, indent: str = "") -> str: ...
|
| 33 |
+
def __iter__(self) -> Iterator[_T]: ...
|
| 34 |
+
def raw_iterator(self) -> Iterator[_T]: ...
|
| 35 |
+
|
| 36 |
+
class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
|
| 37 |
+
functions: Dict[str, Callable] = ...
|
| 38 |
+
reduce_ex_hook: Optional[Callable] = ...
|
| 39 |
+
getstate_hook: Optional[Callable] = ...
|
| 40 |
+
str_hook: Optional[Callable] = ...
|
| 41 |
+
repr_hook: Optional[Callable] = ...
|
| 42 |
+
def __getattr__(self, attribute_name: Any): ...
|
| 43 |
+
@classmethod
|
| 44 |
+
def register_function(cls, function_name: Any, function: Any) -> None: ...
|
| 45 |
+
@classmethod
|
| 46 |
+
def register_datapipe_as_function(
|
| 47 |
+
cls,
|
| 48 |
+
function_name: Any,
|
| 49 |
+
cls_to_register: Any,
|
| 50 |
+
): ...
|
| 51 |
+
def __getstate__(self): ...
|
| 52 |
+
def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
|
| 53 |
+
@classmethod
|
| 54 |
+
def set_getstate_hook(cls, hook_fn: Any) -> None: ...
|
| 55 |
+
@classmethod
|
| 56 |
+
def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
|
| 57 |
+
# Functional form of 'BatcherMapDataPipe'
|
| 58 |
+
def batch(self, batch_size: int, drop_last: bool = False, wrapper_class: Type[DataChunk] = DataChunk) -> MapDataPipe:
|
| 59 |
+
r"""
|
| 60 |
+
Create mini-batches of data (functional name: ``batch``).
|
| 61 |
+
|
| 62 |
+
An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``,
|
| 63 |
+
or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
datapipe: Iterable DataPipe being batched
|
| 67 |
+
batch_size: The size of each batch
|
| 68 |
+
drop_last: Option to drop the last batch if it's not full
|
| 69 |
+
|
| 70 |
+
Example:
|
| 71 |
+
>>> # xdoctest: +SKIP
|
| 72 |
+
>>> from torchdata.datapipes.map import SequenceWrapper
|
| 73 |
+
>>> dp = SequenceWrapper(range(10))
|
| 74 |
+
>>> batch_dp = dp.batch(batch_size=2)
|
| 75 |
+
>>> list(batch_dp)
|
| 76 |
+
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
# Functional form of 'ConcaterMapDataPipe'
|
| 80 |
+
def concat(self, *datapipes: MapDataPipe) -> MapDataPipe:
|
| 81 |
+
r"""
|
| 82 |
+
Concatenate multiple Map DataPipes (functional name: ``concat``).
|
| 83 |
+
|
| 84 |
+
The new index of is the cumulative sum of source DataPipes.
|
| 85 |
+
For example, if there are 2 source DataPipes both with length 5,
|
| 86 |
+
index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
|
| 87 |
+
elements of the first DataPipe, and 5 to 9 would refer to elements
|
| 88 |
+
of the second DataPipe.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
datapipes: Map DataPipes being concatenated
|
| 92 |
+
|
| 93 |
+
Example:
|
| 94 |
+
>>> # xdoctest: +SKIP
|
| 95 |
+
>>> from torchdata.datapipes.map import SequenceWrapper
|
| 96 |
+
>>> dp1 = SequenceWrapper(range(3))
|
| 97 |
+
>>> dp2 = SequenceWrapper(range(3))
|
| 98 |
+
>>> concat_dp = dp1.concat(dp2)
|
| 99 |
+
>>> list(concat_dp)
|
| 100 |
+
[0, 1, 2, 0, 1, 2]
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
# Functional form of 'MapperMapDataPipe'
|
| 104 |
+
def map(self, fn: Callable= ...) -> MapDataPipe:
|
| 105 |
+
r"""
|
| 106 |
+
Apply the input function over each item from the source DataPipe (functional name: ``map``).
|
| 107 |
+
|
| 108 |
+
The function can be any regular Python function or partial object. Lambda
|
| 109 |
+
function is not recommended as it is not supported by pickle.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
datapipe: Source MapDataPipe
|
| 113 |
+
fn: Function being applied to each item
|
| 114 |
+
|
| 115 |
+
Example:
|
| 116 |
+
>>> # xdoctest: +SKIP
|
| 117 |
+
>>> from torchdata.datapipes.map import SequenceWrapper, Mapper
|
| 118 |
+
>>> def add_one(x):
|
| 119 |
+
... return x + 1
|
| 120 |
+
>>> dp = SequenceWrapper(range(10))
|
| 121 |
+
>>> map_dp_1 = dp.map(add_one)
|
| 122 |
+
>>> list(map_dp_1)
|
| 123 |
+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 124 |
+
>>> map_dp_2 = Mapper(dp, lambda x: x + 1)
|
| 125 |
+
>>> list(map_dp_2)
|
| 126 |
+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# Functional form of 'ShufflerIterDataPipe'
|
| 130 |
+
def shuffle(self, *, indices: Optional[List] = None) -> IterDataPipe:
|
| 131 |
+
r"""
|
| 132 |
+
Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``).
|
| 133 |
+
|
| 134 |
+
When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
|
| 135 |
+
set up random seed are different based on :attr:`num_workers`.
|
| 136 |
+
|
| 137 |
+
For single-process mode (:attr:`num_workers == 0`), the random seed is set before
|
| 138 |
+
the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
|
| 139 |
+
mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed
|
| 140 |
+
for each worker process.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
datapipe: MapDataPipe being shuffled
|
| 144 |
+
indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing
|
| 145 |
+
|
| 146 |
+
Example:
|
| 147 |
+
>>> # xdoctest: +SKIP
|
| 148 |
+
>>> from torchdata.datapipes.map import SequenceWrapper
|
| 149 |
+
>>> dp = SequenceWrapper(range(10))
|
| 150 |
+
>>> shuffle_dp = dp.shuffle().set_seed(0)
|
| 151 |
+
>>> list(shuffle_dp)
|
| 152 |
+
[7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
|
| 153 |
+
>>> list(shuffle_dp)
|
| 154 |
+
[6, 1, 9, 5, 2, 4, 7, 3, 8, 0]
|
| 155 |
+
>>> # Reset seed for Shuffler
|
| 156 |
+
>>> shuffle_dp = shuffle_dp.set_seed(0)
|
| 157 |
+
>>> list(shuffle_dp)
|
| 158 |
+
[7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
|
| 159 |
+
|
| 160 |
+
Note:
|
| 161 |
+
Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an
|
| 162 |
+
``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to
|
| 163 |
+
the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order
|
| 164 |
+
of data during data-processing.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
# Functional form of 'ZipperMapDataPipe'
|
| 168 |
+
def zip(self, *datapipes: MapDataPipe[_T_co]) -> MapDataPipe:
|
| 169 |
+
r"""
|
| 170 |
+
Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
|
| 171 |
+
|
| 172 |
+
This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
*datapipes: Map DataPipes being aggregated
|
| 176 |
+
|
| 177 |
+
Example:
|
| 178 |
+
>>> # xdoctest: +SKIP
|
| 179 |
+
>>> from torchdata.datapipes.map import SequenceWrapper
|
| 180 |
+
>>> dp1 = SequenceWrapper(range(3))
|
| 181 |
+
>>> dp2 = SequenceWrapper(range(10, 13))
|
| 182 |
+
>>> zip_dp = dp1.zip(dp2)
|
| 183 |
+
>>> list(zip_dp)
|
| 184 |
+
[(0, 10), (1, 11), (2, 12)]
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
| 189 |
+
functions: Dict[str, Callable] = ...
|
| 190 |
+
reduce_ex_hook: Optional[Callable] = ...
|
| 191 |
+
getstate_hook: Optional[Callable] = ...
|
| 192 |
+
str_hook: Optional[Callable] = ...
|
| 193 |
+
repr_hook: Optional[Callable] = ...
|
| 194 |
+
_number_of_samples_yielded: int = ...
|
| 195 |
+
_snapshot_state: _SnapshotState = _SnapshotState.Iterating # noqa: PYI015
|
| 196 |
+
_fast_forward_iterator: Optional[Iterator] = ...
|
| 197 |
+
def __getattr__(self, attribute_name: Any): ...
|
| 198 |
+
@classmethod
|
| 199 |
+
def register_function(cls, function_name: Any, function: Any) -> None: ...
|
| 200 |
+
@classmethod
|
| 201 |
+
def register_datapipe_as_function(
|
| 202 |
+
cls,
|
| 203 |
+
function_name: Any,
|
| 204 |
+
cls_to_register: Any,
|
| 205 |
+
enable_df_api_tracing: bool = ...,
|
| 206 |
+
): ...
|
| 207 |
+
def __getstate__(self): ...
|
| 208 |
+
def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
|
| 209 |
+
@classmethod
|
| 210 |
+
def set_getstate_hook(cls, hook_fn: Any) -> None: ...
|
| 211 |
+
@classmethod
|
| 212 |
+
def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
|
| 213 |
+
# Functional form of 'BatcherIterDataPipe'
|
| 214 |
+
def batch(self, batch_size: int, drop_last: bool = False, wrapper_class: Type[DataChunk] = DataChunk) -> IterDataPipe:
|
| 215 |
+
r"""
|
| 216 |
+
Creates mini-batches of data (functional name: ``batch``).
|
| 217 |
+
|
| 218 |
+
An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the
|
| 219 |
+
last batch if ``drop_last`` is set to ``False``.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
datapipe: Iterable DataPipe being batched
|
| 223 |
+
batch_size: The size of each batch
|
| 224 |
+
drop_last: Option to drop the last batch if it's not full
|
| 225 |
+
wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding,
|
| 226 |
+
defaults to ``DataChunk``
|
| 227 |
+
|
| 228 |
+
Example:
|
| 229 |
+
>>> # xdoctest: +SKIP
|
| 230 |
+
>>> from torchdata.datapipes.iter import IterableWrapper
|
| 231 |
+
>>> dp = IterableWrapper(range(10))
|
| 232 |
+
>>> dp = dp.batch(batch_size=3, drop_last=True)
|
| 233 |
+
>>> list(dp)
|
| 234 |
+
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
# Functional form of 'CollatorIterDataPipe'
|
| 238 |
+
def collate(self, conversion: Union[Callable[..., Any], Dict[Union[str, Any], Union[Callable, Any]], None] = default_collate, collate_fn: Optional[Callable] = None) -> IterDataPipe:
|
| 239 |
+
r"""
|
| 240 |
+
Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``).
|
| 241 |
+
|
| 242 |
+
By default, it uses :func:`torch.utils.data.default_collate`.
|
| 243 |
+
|
| 244 |
+
.. note::
|
| 245 |
+
While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the
|
| 246 |
+
default behavior and `functools.partial` to specify any additional arguments.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
datapipe: Iterable DataPipe being collated
|
| 250 |
+
collate_fn: Customized collate function to collect and combine data or a batch of data.
|
| 251 |
+
Default function collates to Tensor(s) based on data type.
|
| 252 |
+
|
| 253 |
+
Example:
|
| 254 |
+
>>> # xdoctest: +SKIP
|
| 255 |
+
>>> # Convert integer data to float Tensor
|
| 256 |
+
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
|
| 257 |
+
... def __init__(self, start, end):
|
| 258 |
+
... super(MyIterDataPipe).__init__()
|
| 259 |
+
... assert end > start, "this example code only works with end >= start"
|
| 260 |
+
... self.start = start
|
| 261 |
+
... self.end = end
|
| 262 |
+
...
|
| 263 |
+
... def __iter__(self):
|
| 264 |
+
... return iter(range(self.start, self.end))
|
| 265 |
+
...
|
| 266 |
+
... def __len__(self):
|
| 267 |
+
... return self.end - self.start
|
| 268 |
+
...
|
| 269 |
+
>>> ds = MyIterDataPipe(start=3, end=7)
|
| 270 |
+
>>> print(list(ds))
|
| 271 |
+
[3, 4, 5, 6]
|
| 272 |
+
>>> def collate_fn(batch):
|
| 273 |
+
... return torch.tensor(batch, dtype=torch.float)
|
| 274 |
+
...
|
| 275 |
+
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
|
| 276 |
+
>>> print(list(collated_ds))
|
| 277 |
+
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
# Functional form of 'ConcaterIterDataPipe'
|
| 281 |
+
def concat(self, *datapipes: IterDataPipe) -> IterDataPipe:
|
| 282 |
+
r"""
|
| 283 |
+
Concatenates multiple Iterable DataPipes (functional name: ``concat``).
|
| 284 |
+
|
| 285 |
+
The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
datapipes: Iterable DataPipes being concatenated
|
| 289 |
+
|
| 290 |
+
Example:
|
| 291 |
+
>>> # xdoctest: +REQUIRES(module:torchdata)
|
| 292 |
+
>>> import random
|
| 293 |
+
>>> from torchdata.datapipes.iter import IterableWrapper
|
| 294 |
+
>>> dp1 = IterableWrapper(range(3))
|
| 295 |
+
>>> dp2 = IterableWrapper(range(5))
|
| 296 |
+
>>> list(dp1.concat(dp2))
|
| 297 |
+
[0, 1, 2, 0, 1, 2, 3, 4]
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
# Functional form of 'DemultiplexerIterDataPipe'
|
| 301 |
+
def demux(self, num_instances: int, classifier_fn: Callable[[_T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000) -> List[IterDataPipe]:
|
| 302 |
+
r"""
|
| 303 |
+
Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``).
|
| 304 |
+
|
| 305 |
+
A list of the child DataPipes is returned from this operation.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
datapipe: Iterable DataPipe being filtered
|
| 309 |
+
num_instances: number of instances of the DataPipe to create
|
| 310 |
+
classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None``
|
| 311 |
+
drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None``
|
| 312 |
+
buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
|
| 313 |
+
DataPipes while waiting for their values to be yielded.
|
| 314 |
+
Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
|
| 315 |
+
|
| 316 |
+
Examples:
|
| 317 |
+
>>> # xdoctest: +REQUIRES(module:torchdata)
|
| 318 |
+
>>> from torchdata.datapipes.iter import IterableWrapper
|
| 319 |
+
>>> def odd_or_even(n):
|
| 320 |
+
... return n % 2
|
| 321 |
+
>>> source_dp = IterableWrapper(range(5))
|
| 322 |
+
>>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even)
|
| 323 |
+
>>> list(dp1)
|
| 324 |
+
[0, 2, 4]
|
| 325 |
+
>>> list(dp2)
|
| 326 |
+
[1, 3]
|
| 327 |
+
>>> # It can also filter out any element that gets `None` from the `classifier_fn`
|
| 328 |
+
>>> def odd_or_even_no_zero(n):
|
| 329 |
+
... return n % 2 if n != 0 else None
|
| 330 |
+
>>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
|
| 331 |
+
>>> list(dp1)
|
| 332 |
+
[2, 4]
|
| 333 |
+
>>> list(dp2)
|
| 334 |
+
[1, 3]
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
# Functional form of 'FilterIterDataPipe'
|
| 338 |
+
def filter(self, filter_fn: Callable, input_col=None) -> IterDataPipe:
|
| 339 |
+
r"""
|
| 340 |
+
Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``).
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
datapipe: Iterable DataPipe being filtered
|
| 344 |
+
filter_fn: Customized function mapping an element to a boolean.
|
| 345 |
+
input_col: Index or indices of data which ``filter_fn`` is applied, such as:
|
| 346 |
+
|
| 347 |
+
- ``None`` as default to apply ``filter_fn`` to the data directly.
|
| 348 |
+
- Integer(s) is used for list/tuple.
|
| 349 |
+
- Key(s) is used for dict.
|
| 350 |
+
|
| 351 |
+
Example:
|
| 352 |
+
>>> # xdoctest: +SKIP
|
| 353 |
+
>>> from torchdata.datapipes.iter import IterableWrapper
|
| 354 |
+
>>> def is_even(n):
|
| 355 |
+
... return n % 2 == 0
|
| 356 |
+
>>> dp = IterableWrapper(range(5))
|
| 357 |
+
>>> filter_dp = dp.filter(filter_fn=is_even)
|
| 358 |
+
>>> list(filter_dp)
|
| 359 |
+
[0, 2, 4]
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
# Functional form of 'ForkerIterDataPipe'
|
| 363 |
+
def fork(self, num_instances: int, buffer_size: int = 1000, copy: Optional[Literal["shallow", "deep"]] = None) -> List[IterDataPipe]:
|
| 364 |
+
r"""
|
| 365 |
+
Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``).
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
datapipe: Iterable DataPipe being copied
|
| 369 |
+
num_instances: number of instances of the datapipe to create
|
| 370 |
+
buffer_size: this restricts how far ahead the leading child DataPipe
|
| 371 |
+
can read relative to the slowest child DataPipe.
|
| 372 |
+
Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
|
| 373 |
+
copy: copy strategy to use for items yielded by each branch. Supported
|
| 374 |
+
options are ``None`` for no copying, ``"shallow"`` for shallow object
|
| 375 |
+
copies, and ``"deep"`` for deep object copies. Defaults to ``None``.
|
| 376 |
+
|
| 377 |
+
Note:
|
| 378 |
+
All branches of the forked pipeline return the identical object unless
|
| 379 |
+
the copy parameter is supplied. If the object is mutable or contains
|
| 380 |
+
mutable objects, changing them in one branch will affect all others.
|
| 381 |
+
|
| 382 |
+
Example:
|
| 383 |
+
>>> # xdoctest: +REQUIRES(module:torchdata)
|
| 384 |
+
>>> from torchdata.datapipes.iter import IterableWrapper
|
| 385 |
+
>>> source_dp = IterableWrapper(range(5))
|
| 386 |
+
>>> dp1, dp2 = source_dp.fork(num_instances=2)
|
| 387 |
+
>>> list(dp1)
|
| 388 |
+
[0, 1, 2, 3, 4]
|
| 389 |
+
>>> list(dp2)
|
| 390 |
+
[0, 1, 2, 3, 4]
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
# Functional form of 'GrouperIterDataPipe'
|
| 394 |
+
def groupby(self, group_key_fn: Callable[[_T_co], Any], *, keep_key: bool = False, buffer_size: int = 10000, group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False) -> IterDataPipe:
|
| 395 |
+
r"""
|
| 396 |
+
Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``.
|
| 397 |
+
|
| 398 |
+
(functional name: ``groupby``).
|
| 399 |
+
|
| 400 |
+
The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group
|
| 401 |
+
will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full,
|
| 402 |
+
the DataPipe will yield the largest batch with the same key, provided that its size is larger
|
| 403 |
+
than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``.
|
| 404 |
+
|
| 405 |
+
After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity
|
| 406 |
+
will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
datapipe: Iterable datapipe to be grouped
|
| 410 |
+
group_key_fn: Function used to generate group key from the data of the source datapipe
|
| 411 |
+
keep_key: Option to yield the matching key along with the items in a tuple,
|
| 412 |
+
resulting in `(key, [items])` otherwise returning [items]
|
| 413 |
+
buffer_size: The size of buffer for ungrouped data
|
| 414 |
+
group_size: The max size of each group, a batch is yielded as soon as it reaches this size
|
| 415 |
+
guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full
|
| 416 |
+
drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer
|
| 417 |
+
when the buffer is full
|
| 418 |
+
|
| 419 |
+
Example:
|
| 420 |
+
>>> import os
|
| 421 |
+
>>> # xdoctest: +SKIP
|
| 422 |
+
>>> from torchdata.datapipes.iter import IterableWrapper
|
| 423 |
+
>>> def group_fn(file):
|
| 424 |
+
... return os.path.basename(file).split(".")[0]
|
| 425 |
+
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
|
| 426 |
+
>>> dp0 = source_dp.groupby(group_key_fn=group_fn)
|
| 427 |
+
>>> list(dp0)
|
| 428 |
+
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
|
| 429 |
+
>>> # A group is yielded as soon as its size equals to `group_size`
|
| 430 |
+
>>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
|
| 431 |
+
>>> list(dp1)
|
| 432 |
+
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
|
| 433 |
+
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
|
| 434 |
+
>>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
|
| 435 |
+
>>> list(dp2)
|
| 436 |
+
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
|
| 437 |
+
"""
|
| 438 |
+
|
| 439 |
+
# Functional form of 'FileListerIterDataPipe'
|
| 440 |
+
def list_files(self, masks: Union[str, List[str]] = "", *, recursive: bool = False, abspath: bool = False, non_deterministic: bool = False, length: int = -1) -> IterDataPipe:
|
| 441 |
+
r"""
|
| 442 |
+
Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory.
|
| 443 |
+
|
| 444 |
+
Multiple root directories can be provided (functional name: ``list_files``).
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
root: Root directory or a sequence of root directories
|
| 448 |
+
masks: Unix style filter string or string list for filtering file name(s)
|
| 449 |
+
recursive: Whether to return pathname from nested directories or not
|
| 450 |
+
abspath: Whether to return relative pathname or absolute pathname
|
| 451 |
+
non_deterministic: Whether to return pathname in sorted order or not.
|
| 452 |
+
If ``False``, the results yielded from each root directory will be sorted
|
| 453 |
+
length: Nominal length of the datapipe
|
| 454 |
+
|
| 455 |
+
Example:
|
| 456 |
+
>>> # xdoctest: +SKIP
|
| 457 |
+
>>> from torchdata.datapipes.iter import FileLister
|
| 458 |
+
>>> dp = FileLister(root=".", recursive=True)
|
| 459 |
+
>>> list(dp)
|
| 460 |
+
['example.py', './data/data.tar']
|
| 461 |
+
"""
|
| 462 |
+
|
| 463 |
+
# Functional form of 'MapperIterDataPipe'
|
| 464 |
+
def map(self, fn: Callable, input_col=None, output_col=None) -> IterDataPipe:
|
| 465 |
+
r"""
|
| 466 |
+
Applies a function over each item from the source DataPipe (functional name: ``map``).
|
| 467 |
+
|
| 468 |
+
The function can be any regular Python function or partial object. Lambda
|
| 469 |
+
function is not recommended as it is not supported by pickle.
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
datapipe: Source Iterable DataPipe
|
| 473 |
+
fn: Function being applied over each item
|
| 474 |
+
input_col: Index or indices of data which ``fn`` is applied, such as:
|
| 475 |
+
|
| 476 |
+
- ``None`` as default to apply ``fn`` to the data directly.
|
| 477 |
+
- Integer(s) is used for list/tuple.
|
| 478 |
+
- Key(s) is used for dict.
|
| 479 |
+
|
| 480 |
+
output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
|
| 481 |
+
only when ``input_col`` is not ``None``
|
| 482 |
+
|
| 483 |
+
- ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
|
| 484 |
+
multiple indices, the left-most one is used, and other indices will be removed.
|
| 485 |
+
- Integer is used for list/tuple. ``-1`` represents to append result at the end.
|
| 486 |
+
- Key is used for dict. New key is acceptable.
|
| 487 |
+
|
| 488 |
+
Example:
|
| 489 |
+
>>> # xdoctest: +SKIP
|
| 490 |
+
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
|
| 491 |
+
>>> def add_one(x):
|
| 492 |
+
... return x + 1
|
| 493 |
+
>>> dp = IterableWrapper(range(10))
|
| 494 |
+
>>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred
|
| 495 |
+
>>> list(map_dp_1)
|
| 496 |
+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 497 |
+
>>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
|
| 498 |
+
>>> # Use `functools.partial` or explicitly define the function instead
|
| 499 |
+
>>> map_dp_2 = Mapper(dp, lambda x: x + 1)
|
| 500 |
+
>>> list(map_dp_2)
|
| 501 |
+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 502 |
+
"""
|
| 503 |
+
|
| 504 |
+
# Functional form of 'MultiplexerIterDataPipe'
|
| 505 |
+
def mux(self, *datapipes) -> IterDataPipe:
|
| 506 |
+
r"""
|
| 507 |
+
Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``).
|
| 508 |
+
|
| 509 |
+
As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
|
| 510 |
+
and so on. It ends when the shortest input DataPipe is exhausted.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted
|
| 514 |
+
|
| 515 |
+
Example:
|
| 516 |
+
>>> # xdoctest: +REQUIRES(module:torchdata)
|
| 517 |
+
>>> from torchdata.datapipes.iter import IterableWrapper
|
| 518 |
+
>>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
|
| 519 |
+
>>> list(dp1.mux(dp2, dp3))
|
| 520 |
+
[0, 10, 20, 1, 11, 21, 2, 12, 22]
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
# Functional form of 'FileOpenerIterDataPipe'
|
| 524 |
+
def open_files(self, mode: str = "r", encoding: Optional[str] = None, length: int = -1) -> IterDataPipe:
|
| 525 |
+
r"""
|
| 526 |
+
Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``).
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
datapipe: Iterable datapipe that provides pathnames
|
| 530 |
+
mode: An optional string that specifies the mode in which
|
| 531 |
+
the file is opened by ``open()``. It defaults to ``r``, other options are
|
| 532 |
+
``b`` for reading in binary mode and ``t`` for text mode.
|
| 533 |
+
encoding: An optional string that specifies the encoding of the
|
| 534 |
+
underlying file. It defaults to ``None`` to match the default encoding of ``open``.
|
| 535 |
+
length: Nominal length of the datapipe
|
| 536 |
+
|
| 537 |
+
Note:
|
| 538 |
+
The opened file handles will be closed by Python's GC periodically. Users can choose
|
| 539 |
+
to close them explicitly.
|
| 540 |
+
|
| 541 |
+
Example:
|
| 542 |
+
>>> # xdoctest: +SKIP
|
| 543 |
+
>>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
|
| 544 |
+
>>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
|
| 545 |
+
>>> dp = FileOpener(dp)
|
| 546 |
+
>>> dp = StreamReader(dp)
|
| 547 |
+
>>> list(dp)
|
| 548 |
+
[('./abc.txt', 'abc')]
|
| 549 |
+
"""
|
| 550 |
+
|
| 551 |
+
# Functional form of 'StreamReaderIterDataPipe'
|
| 552 |
+
def read_from_stream(self, chunk=None) -> IterDataPipe:
|
| 553 |
+
r"""
|
| 554 |
+
Given IO streams and their label names, yield bytes with label name as tuple.
|
| 555 |
+
|
| 556 |
+
(functional name: ``read_from_stream``).
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
datapipe: Iterable DataPipe provides label/URL and byte stream
|
| 560 |
+
chunk: Number of bytes to be read from stream per iteration.
|
| 561 |
+
If ``None``, all bytes will be read until the EOF.
|
| 562 |
+
|
| 563 |
+
Example:
|
| 564 |
+
>>> # xdoctest: +SKIP
|
| 565 |
+
>>> from torchdata.datapipes.iter import IterableWrapper, StreamReader
|
| 566 |
+
>>> from io import StringIO
|
| 567 |
+
>>> dp = IterableWrapper([("alphabet", StringIO("abcde"))])
|
| 568 |
+
>>> list(StreamReader(dp, chunk=1))
|
| 569 |
+
[('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')]
|
| 570 |
+
"""
|
| 571 |
+
|
| 572 |
+
# Functional form of 'RoutedDecoderIterDataPipe'
|
| 573 |
+
def routed_decode(self, *handlers: Callable, key_fn: Callable= ...) -> IterDataPipe:
|
| 574 |
+
r"""
|
| 575 |
+
Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple.
|
| 576 |
+
|
| 577 |
+
(functional name: ``routed_decode``)
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
datapipe: Iterable datapipe that provides pathname and binary stream in tuples
|
| 581 |
+
handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder
|
| 582 |
+
handlers will be set as default. If multiple handles are provided, the priority
|
| 583 |
+
order follows the order of handlers (the first handler has the top priority)
|
| 584 |
+
key_fn: Function for decoder to extract key from pathname to dispatch handlers.
|
| 585 |
+
Default is set to extract file extension from pathname
|
| 586 |
+
|
| 587 |
+
Note:
|
| 588 |
+
When ``key_fn`` is specified returning anything other than extension, the default
|
| 589 |
+
handler will not work and users need to specify custom handler. Custom handler
|
| 590 |
+
could use regex to determine the eligibility to handle data.
|
| 591 |
+
"""
|
| 592 |
+
|
| 593 |
+
# Functional form of 'ShardingFilterIterDataPipe'
|
| 594 |
+
def sharding_filter(self, sharding_group_filter=None) -> IterDataPipe:
|
| 595 |
+
r"""
|
| 596 |
+
Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``).
|
| 597 |
+
|
| 598 |
+
After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the
|
| 599 |
+
original DataPipe, where `n` equals to the number of instances.
|
| 600 |
+
|
| 601 |
+
Args:
|
| 602 |
+
source_datapipe: Iterable DataPipe that will be sharded
|
| 603 |
+
"""
|
| 604 |
+
|
| 605 |
+
# Functional form of 'ShufflerIterDataPipe'
|
| 606 |
+
def shuffle(self, *, buffer_size: int = 10000, unbatch_level: int = 0) -> IterDataPipe:
|
| 607 |
+
r"""
|
| 608 |
+
Shuffle the input DataPipe with a buffer (functional name: ``shuffle``).
|
| 609 |
+
|
| 610 |
+
The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then,
|
| 611 |
+
each item will be yielded from the buffer by reservoir sampling via iterator.
|
| 612 |
+
|
| 613 |
+
``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the
|
| 614 |
+
datapipe is not shuffled. In order to fully shuffle all elements from datapipe,
|
| 615 |
+
``buffer_size`` is required to be greater than or equal to the size of datapipe.
|
| 616 |
+
|
| 617 |
+
When it is used with :class:`torch.utils.data.DataLoader`, the methods to
|
| 618 |
+
set up random seed are different based on :attr:`num_workers`.
|
| 619 |
+
|
| 620 |
+
For single-process mode (:attr:`num_workers == 0`), the random seed is set before
|
| 621 |
+
the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
|
| 622 |
+
mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
|
| 623 |
+
for each worker process.
|
| 624 |
+
|
| 625 |
+
Args:
|
| 626 |
+
datapipe: The IterDataPipe being shuffled
|
| 627 |
+
buffer_size: The buffer size for shuffling (default to ``10000``)
|
| 628 |
+
unbatch_level: Specifies if it is necessary to unbatch source data before
|
| 629 |
+
applying the shuffle
|
| 630 |
+
|
| 631 |
+
Example:
|
| 632 |
+
>>> # xdoctest: +SKIP
|
| 633 |
+
>>> from torchdata.datapipes.iter import IterableWrapper
|
| 634 |
+
>>> dp = IterableWrapper(range(10))
|
| 635 |
+
>>> shuffle_dp = dp.shuffle()
|
| 636 |
+
>>> list(shuffle_dp)
|
| 637 |
+
[0, 4, 1, 6, 3, 2, 9, 5, 7, 8]
|
| 638 |
+
"""
|
| 639 |
+
|
| 640 |
+
# Functional form of 'UnBatcherIterDataPipe'
|
| 641 |
+
def unbatch(self, unbatch_level: int = 1) -> IterDataPipe:
|
| 642 |
+
r"""
|
| 643 |
+
Undos batching of data (functional name: ``unbatch``).
|
| 644 |
+
|
| 645 |
+
In other words, it flattens the data up to the specified level within a batched DataPipe.
|
| 646 |
+
|
| 647 |
+
Args:
|
| 648 |
+
datapipe: Iterable DataPipe being un-batched
|
| 649 |
+
unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``,
|
| 650 |
+
it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe.
|
| 651 |
+
|
| 652 |
+
Example:
|
| 653 |
+
>>> # xdoctest: +SKIP
|
| 654 |
+
>>> from torchdata.datapipes.iter import IterableWrapper
|
| 655 |
+
>>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]])
|
| 656 |
+
>>> dp1 = source_dp.unbatch()
|
| 657 |
+
>>> list(dp1)
|
| 658 |
+
[[0, 1], [2], [3, 4], [5], [6]]
|
| 659 |
+
>>> dp2 = source_dp.unbatch(unbatch_level=2)
|
| 660 |
+
>>> list(dp2)
|
| 661 |
+
[0, 1, 2, 3, 4, 5, 6]
|
| 662 |
+
"""
|
| 663 |
+
|
| 664 |
+
# Functional form of 'ZipperIterDataPipe'
|
| 665 |
+
def zip(self, *datapipes: IterDataPipe) -> IterDataPipe:
|
| 666 |
+
r"""
|
| 667 |
+
Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
|
| 668 |
+
|
| 669 |
+
The output is stopped as soon as the shortest input DataPipe is exhausted.
|
| 670 |
+
|
| 671 |
+
Args:
|
| 672 |
+
*datapipes: Iterable DataPipes being aggregated
|
| 673 |
+
|
| 674 |
+
Example:
|
| 675 |
+
>>> # xdoctest: +REQUIRES(module:torchdata)
|
| 676 |
+
>>> from torchdata.datapipes.iter import IterableWrapper
|
| 677 |
+
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
|
| 678 |
+
>>> list(dp1.zip(dp2, dp3))
|
| 679 |
+
[(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
|
| 680 |
+
"""
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
class DFIterDataPipe(IterDataPipe):
|
| 684 |
+
def _is_dfpipe(self): ...
|
| 685 |
+
def __iter__(self): ...
|
| 686 |
+
|
| 687 |
+
class _DataPipeSerializationWrapper:
|
| 688 |
+
def __init__(self, datapipe): ...
|
| 689 |
+
def __getstate__(self): ...
|
| 690 |
+
def __setstate__(self, state): ...
|
| 691 |
+
def __len__(self): ...
|
| 692 |
+
|
| 693 |
+
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
|
| 694 |
+
def __iter__(self): ...
|
| 695 |
+
|
| 696 |
+
class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
|
| 697 |
+
def __getitem__(self, idx): ...
|
.venv/Lib/site-packages/torch/utils/data/datapipes/gen_pyi.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import os
|
| 3 |
+
import pathlib
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from typing import Any, Dict, List, Set, Tuple, Union
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def materialize_lines(lines: List[str], indentation: int) -> str:
|
| 9 |
+
output = ""
|
| 10 |
+
new_line_with_indent = "\n" + " " * indentation
|
| 11 |
+
for i, line in enumerate(lines):
|
| 12 |
+
if i != 0:
|
| 13 |
+
output += new_line_with_indent
|
| 14 |
+
output += line.replace("\n", new_line_with_indent)
|
| 15 |
+
return output
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def gen_from_template(
|
| 19 |
+
dir: str,
|
| 20 |
+
template_name: str,
|
| 21 |
+
output_name: str,
|
| 22 |
+
replacements: List[Tuple[str, Any, int]],
|
| 23 |
+
):
|
| 24 |
+
template_path = os.path.join(dir, template_name)
|
| 25 |
+
output_path = os.path.join(dir, output_name)
|
| 26 |
+
|
| 27 |
+
with open(template_path) as f:
|
| 28 |
+
content = f.read()
|
| 29 |
+
for placeholder, lines, indentation in replacements:
|
| 30 |
+
with open(output_path, "w") as f:
|
| 31 |
+
content = content.replace(
|
| 32 |
+
placeholder, materialize_lines(lines, indentation)
|
| 33 |
+
)
|
| 34 |
+
f.write(content)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def find_file_paths(dir_paths: List[str], files_to_exclude: Set[str]) -> Set[str]:
|
| 38 |
+
"""
|
| 39 |
+
When given a path to a directory, returns the paths to the relevant files within it.
|
| 40 |
+
|
| 41 |
+
This function does NOT recursive traverse to subdirectories.
|
| 42 |
+
"""
|
| 43 |
+
paths: Set[str] = set()
|
| 44 |
+
for dir_path in dir_paths:
|
| 45 |
+
all_files = os.listdir(dir_path)
|
| 46 |
+
python_files = {fname for fname in all_files if ".py" == fname[-3:]}
|
| 47 |
+
filter_files = {
|
| 48 |
+
fname for fname in python_files if fname not in files_to_exclude
|
| 49 |
+
}
|
| 50 |
+
paths.update({os.path.join(dir_path, fname) for fname in filter_files})
|
| 51 |
+
return paths
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def extract_method_name(line: str) -> str:
|
| 55 |
+
"""Extract method name from decorator in the form of "@functional_datapipe({method_name})"."""
|
| 56 |
+
if '("' in line:
|
| 57 |
+
start_token, end_token = '("', '")'
|
| 58 |
+
elif "('" in line:
|
| 59 |
+
start_token, end_token = "('", "')"
|
| 60 |
+
else:
|
| 61 |
+
raise RuntimeError(
|
| 62 |
+
f"Unable to find appropriate method name within line:\n{line}"
|
| 63 |
+
)
|
| 64 |
+
start, end = line.find(start_token) + len(start_token), line.find(end_token)
|
| 65 |
+
return line[start:end]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_class_name(line: str) -> str:
|
| 69 |
+
"""Extract class name from class definition in the form of "class {CLASS_NAME}({Type}):"."""
|
| 70 |
+
start_token = "class "
|
| 71 |
+
end_token = "("
|
| 72 |
+
start, end = line.find(start_token) + len(start_token), line.find(end_token)
|
| 73 |
+
return line[start:end]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def parse_datapipe_file(
|
| 77 |
+
file_path: str,
|
| 78 |
+
) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
|
| 79 |
+
"""Given a path to file, parses the file and returns a dictionary of method names to function signatures."""
|
| 80 |
+
method_to_signature, method_to_class_name, special_output_type = {}, {}, set()
|
| 81 |
+
doc_string_dict = defaultdict(list)
|
| 82 |
+
with open(file_path) as f:
|
| 83 |
+
open_paren_count = 0
|
| 84 |
+
method_name, class_name, signature = "", "", ""
|
| 85 |
+
skip = False
|
| 86 |
+
for line in f:
|
| 87 |
+
if line.count('"""') % 2 == 1:
|
| 88 |
+
skip = not skip
|
| 89 |
+
if skip or '"""' in line: # Saving docstrings
|
| 90 |
+
doc_string_dict[method_name].append(line)
|
| 91 |
+
continue
|
| 92 |
+
if "@functional_datapipe" in line:
|
| 93 |
+
method_name = extract_method_name(line)
|
| 94 |
+
doc_string_dict[method_name] = []
|
| 95 |
+
continue
|
| 96 |
+
if method_name and "class " in line:
|
| 97 |
+
class_name = extract_class_name(line)
|
| 98 |
+
continue
|
| 99 |
+
if method_name and ("def __init__(" in line or "def __new__(" in line):
|
| 100 |
+
if "def __new__(" in line:
|
| 101 |
+
special_output_type.add(method_name)
|
| 102 |
+
open_paren_count += 1
|
| 103 |
+
start = line.find("(") + len("(")
|
| 104 |
+
line = line[start:]
|
| 105 |
+
if open_paren_count > 0:
|
| 106 |
+
open_paren_count += line.count("(")
|
| 107 |
+
open_paren_count -= line.count(")")
|
| 108 |
+
if open_paren_count == 0:
|
| 109 |
+
end = line.rfind(")")
|
| 110 |
+
signature += line[:end]
|
| 111 |
+
method_to_signature[method_name] = process_signature(signature)
|
| 112 |
+
method_to_class_name[method_name] = class_name
|
| 113 |
+
method_name, class_name, signature = "", "", ""
|
| 114 |
+
elif open_paren_count < 0:
|
| 115 |
+
raise RuntimeError(
|
| 116 |
+
"open parenthesis count < 0. This shouldn't be possible."
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
signature += line.strip("\n").strip(" ")
|
| 120 |
+
return (
|
| 121 |
+
method_to_signature,
|
| 122 |
+
method_to_class_name,
|
| 123 |
+
special_output_type,
|
| 124 |
+
doc_string_dict,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def parse_datapipe_files(
|
| 129 |
+
file_paths: Set[str],
|
| 130 |
+
) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
|
| 131 |
+
(
|
| 132 |
+
methods_and_signatures,
|
| 133 |
+
methods_and_class_names,
|
| 134 |
+
methods_with_special_output_types,
|
| 135 |
+
) = ({}, {}, set())
|
| 136 |
+
methods_and_doc_strings = {}
|
| 137 |
+
for path in file_paths:
|
| 138 |
+
(
|
| 139 |
+
method_to_signature,
|
| 140 |
+
method_to_class_name,
|
| 141 |
+
methods_needing_special_output_types,
|
| 142 |
+
doc_string_dict,
|
| 143 |
+
) = parse_datapipe_file(path)
|
| 144 |
+
methods_and_signatures.update(method_to_signature)
|
| 145 |
+
methods_and_class_names.update(method_to_class_name)
|
| 146 |
+
methods_with_special_output_types.update(methods_needing_special_output_types)
|
| 147 |
+
methods_and_doc_strings.update(doc_string_dict)
|
| 148 |
+
return (
|
| 149 |
+
methods_and_signatures,
|
| 150 |
+
methods_and_class_names,
|
| 151 |
+
methods_with_special_output_types,
|
| 152 |
+
methods_and_doc_strings,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def split_outside_bracket(line: str, delimiter: str = ",") -> List[str]:
|
| 157 |
+
"""Given a line of text, split it on comma unless the comma is within a bracket '[]'."""
|
| 158 |
+
bracket_count = 0
|
| 159 |
+
curr_token = ""
|
| 160 |
+
res = []
|
| 161 |
+
for char in line:
|
| 162 |
+
if char == "[":
|
| 163 |
+
bracket_count += 1
|
| 164 |
+
elif char == "]":
|
| 165 |
+
bracket_count -= 1
|
| 166 |
+
elif char == delimiter and bracket_count == 0:
|
| 167 |
+
res.append(curr_token)
|
| 168 |
+
curr_token = ""
|
| 169 |
+
continue
|
| 170 |
+
curr_token += char
|
| 171 |
+
res.append(curr_token)
|
| 172 |
+
return res
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def process_signature(line: str) -> str:
|
| 176 |
+
"""
|
| 177 |
+
Clean up a given raw function signature.
|
| 178 |
+
|
| 179 |
+
This includes removing the self-referential datapipe argument, default
|
| 180 |
+
arguments of input functions, newlines, and spaces.
|
| 181 |
+
"""
|
| 182 |
+
tokens: List[str] = split_outside_bracket(line)
|
| 183 |
+
for i, token in enumerate(tokens):
|
| 184 |
+
tokens[i] = token.strip(" ")
|
| 185 |
+
if token == "cls":
|
| 186 |
+
tokens[i] = "self"
|
| 187 |
+
elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"):
|
| 188 |
+
# Remove the datapipe after 'self' or 'cls' unless it has '*'
|
| 189 |
+
tokens[i] = ""
|
| 190 |
+
elif "Callable =" in token: # Remove default argument if it is a function
|
| 191 |
+
head, default_arg = token.rsplit("=", 2)
|
| 192 |
+
tokens[i] = head.strip(" ") + "= ..."
|
| 193 |
+
tokens = [t for t in tokens if t != ""]
|
| 194 |
+
line = ", ".join(tokens)
|
| 195 |
+
return line
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def get_method_definitions(
|
| 199 |
+
file_path: Union[str, List[str]],
|
| 200 |
+
files_to_exclude: Set[str],
|
| 201 |
+
deprecated_files: Set[str],
|
| 202 |
+
default_output_type: str,
|
| 203 |
+
method_to_special_output_type: Dict[str, str],
|
| 204 |
+
root: str = "",
|
| 205 |
+
) -> List[str]:
|
| 206 |
+
"""
|
| 207 |
+
#.pyi generation for functional DataPipes Process.
|
| 208 |
+
|
| 209 |
+
# 1. Find files that we want to process (exclude the ones who don't)
|
| 210 |
+
# 2. Parse method name and signature
|
| 211 |
+
# 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
|
| 212 |
+
"""
|
| 213 |
+
if root == "":
|
| 214 |
+
root = str(pathlib.Path(__file__).parent.resolve())
|
| 215 |
+
file_path = [file_path] if isinstance(file_path, str) else file_path
|
| 216 |
+
file_path = [os.path.join(root, path) for path in file_path]
|
| 217 |
+
file_paths = find_file_paths(
|
| 218 |
+
file_path, files_to_exclude=files_to_exclude.union(deprecated_files)
|
| 219 |
+
)
|
| 220 |
+
(
|
| 221 |
+
methods_and_signatures,
|
| 222 |
+
methods_and_class_names,
|
| 223 |
+
methods_w_special_output_types,
|
| 224 |
+
methods_and_doc_strings,
|
| 225 |
+
) = parse_datapipe_files(file_paths)
|
| 226 |
+
|
| 227 |
+
for fn_name in method_to_special_output_type:
|
| 228 |
+
if fn_name not in methods_w_special_output_types:
|
| 229 |
+
methods_w_special_output_types.add(fn_name)
|
| 230 |
+
|
| 231 |
+
method_definitions = []
|
| 232 |
+
for method_name, arguments in methods_and_signatures.items():
|
| 233 |
+
class_name = methods_and_class_names[method_name]
|
| 234 |
+
if method_name in methods_w_special_output_types:
|
| 235 |
+
output_type = method_to_special_output_type[method_name]
|
| 236 |
+
else:
|
| 237 |
+
output_type = default_output_type
|
| 238 |
+
doc_string = "".join(methods_and_doc_strings[method_name])
|
| 239 |
+
if doc_string == "":
|
| 240 |
+
doc_string = " ...\n"
|
| 241 |
+
method_definitions.append(
|
| 242 |
+
f"# Functional form of '{class_name}'\n"
|
| 243 |
+
f"def {method_name}({arguments}) -> {output_type}:\n"
|
| 244 |
+
f"{doc_string}"
|
| 245 |
+
)
|
| 246 |
+
method_definitions.sort(
|
| 247 |
+
key=lambda s: s.split("\n")[1]
|
| 248 |
+
) # sorting based on method_name
|
| 249 |
+
|
| 250 |
+
return method_definitions
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# Defined outside of main() so they can be imported by TorchData
|
| 254 |
+
iterDP_file_path: str = "iter"
|
| 255 |
+
iterDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
|
| 256 |
+
iterDP_deprecated_files: Set[str] = set()
|
| 257 |
+
iterDP_method_to_special_output_type: Dict[str, str] = {
|
| 258 |
+
"demux": "List[IterDataPipe]",
|
| 259 |
+
"fork": "List[IterDataPipe]",
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
mapDP_file_path: str = "map"
|
| 263 |
+
mapDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
|
| 264 |
+
mapDP_deprecated_files: Set[str] = set()
|
| 265 |
+
mapDP_method_to_special_output_type: Dict[str, str] = {"shuffle": "IterDataPipe"}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def main() -> None:
|
| 269 |
+
"""
|
| 270 |
+
# Inject file into template datapipe.pyi.in.
|
| 271 |
+
|
| 272 |
+
TODO: The current implementation of this script only generates interfaces for built-in methods. To generate
|
| 273 |
+
interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`.
|
| 274 |
+
"""
|
| 275 |
+
iter_method_definitions = get_method_definitions(
|
| 276 |
+
iterDP_file_path,
|
| 277 |
+
iterDP_files_to_exclude,
|
| 278 |
+
iterDP_deprecated_files,
|
| 279 |
+
"IterDataPipe",
|
| 280 |
+
iterDP_method_to_special_output_type,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
map_method_definitions = get_method_definitions(
|
| 284 |
+
mapDP_file_path,
|
| 285 |
+
mapDP_files_to_exclude,
|
| 286 |
+
mapDP_deprecated_files,
|
| 287 |
+
"MapDataPipe",
|
| 288 |
+
mapDP_method_to_special_output_type,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
path = pathlib.Path(__file__).parent.resolve()
|
| 292 |
+
replacements = [
|
| 293 |
+
("${IterDataPipeMethods}", iter_method_definitions, 4),
|
| 294 |
+
("${MapDataPipeMethods}", map_method_definitions, 4),
|
| 295 |
+
]
|
| 296 |
+
gen_from_template(
|
| 297 |
+
dir=str(path),
|
| 298 |
+
template_name="datapipe.pyi.in",
|
| 299 |
+
output_name="datapipe.pyi",
|
| 300 |
+
replacements=replacements,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
if __name__ == "__main__":
|
| 305 |
+
main()
|
.venv/Lib/site-packages/torch/utils/data/datapipes/iter/__init__.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data.datapipes.iter.callable import (
|
| 2 |
+
CollatorIterDataPipe as Collator,
|
| 3 |
+
MapperIterDataPipe as Mapper,
|
| 4 |
+
)
|
| 5 |
+
from torch.utils.data.datapipes.iter.combinatorics import (
|
| 6 |
+
SamplerIterDataPipe as Sampler,
|
| 7 |
+
ShufflerIterDataPipe as Shuffler,
|
| 8 |
+
)
|
| 9 |
+
from torch.utils.data.datapipes.iter.combining import (
|
| 10 |
+
ConcaterIterDataPipe as Concater,
|
| 11 |
+
DemultiplexerIterDataPipe as Demultiplexer,
|
| 12 |
+
ForkerIterDataPipe as Forker,
|
| 13 |
+
MultiplexerIterDataPipe as Multiplexer,
|
| 14 |
+
ZipperIterDataPipe as Zipper,
|
| 15 |
+
)
|
| 16 |
+
from torch.utils.data.datapipes.iter.filelister import (
|
| 17 |
+
FileListerIterDataPipe as FileLister,
|
| 18 |
+
)
|
| 19 |
+
from torch.utils.data.datapipes.iter.fileopener import (
|
| 20 |
+
FileOpenerIterDataPipe as FileOpener,
|
| 21 |
+
)
|
| 22 |
+
from torch.utils.data.datapipes.iter.grouping import (
|
| 23 |
+
BatcherIterDataPipe as Batcher,
|
| 24 |
+
GrouperIterDataPipe as Grouper,
|
| 25 |
+
UnBatcherIterDataPipe as UnBatcher,
|
| 26 |
+
)
|
| 27 |
+
from torch.utils.data.datapipes.iter.routeddecoder import (
|
| 28 |
+
RoutedDecoderIterDataPipe as RoutedDecoder,
|
| 29 |
+
)
|
| 30 |
+
from torch.utils.data.datapipes.iter.selecting import FilterIterDataPipe as Filter
|
| 31 |
+
from torch.utils.data.datapipes.iter.sharding import (
|
| 32 |
+
ShardingFilterIterDataPipe as ShardingFilter,
|
| 33 |
+
)
|
| 34 |
+
from torch.utils.data.datapipes.iter.streamreader import (
|
| 35 |
+
StreamReaderIterDataPipe as StreamReader,
|
| 36 |
+
)
|
| 37 |
+
from torch.utils.data.datapipes.iter.utils import (
|
| 38 |
+
IterableWrapperIterDataPipe as IterableWrapper,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
__all__ = [
|
| 43 |
+
"Batcher",
|
| 44 |
+
"Collator",
|
| 45 |
+
"Concater",
|
| 46 |
+
"Demultiplexer",
|
| 47 |
+
"FileLister",
|
| 48 |
+
"FileOpener",
|
| 49 |
+
"Filter",
|
| 50 |
+
"Forker",
|
| 51 |
+
"Grouper",
|
| 52 |
+
"IterableWrapper",
|
| 53 |
+
"Mapper",
|
| 54 |
+
"Multiplexer",
|
| 55 |
+
"RoutedDecoder",
|
| 56 |
+
"Sampler",
|
| 57 |
+
"ShardingFilter",
|
| 58 |
+
"Shuffler",
|
| 59 |
+
"StreamReader",
|
| 60 |
+
"UnBatcher",
|
| 61 |
+
"Zipper",
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
# Please keep this list sorted
|
| 65 |
+
assert __all__ == sorted(__all__)
|
.venv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-39.pyc
ADDED
|
Binary file (7.84 kB). View file
|
|
|