Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h +57 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h +255 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h +63 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h +82 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h +65 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets.h +9 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/base.h +104 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h +529 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/map.h +118 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h +48 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/shared.h +83 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h +70 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h +38 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h +87 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/queue.h +84 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h +113 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/example.h +55 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/iterator.h +178 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers.h +9 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/base.h +47 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h +21 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h +139 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/random.h +54 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h +50 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h +28 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/stream.h +63 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms.h +7 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/base.h +53 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/collate.h +35 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h +56 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/stack.h +49 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h +77 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/worker_exception.h +38 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h +363 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/static.h +65 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h +372 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h +133 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h +125 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h +105 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h +262 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h +274 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h +94 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h +148 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h +169 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h +388 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h +190 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/fold.h +87 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h +153 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/loss.h +805 -0
- .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h +198 -0
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/dataloader/stateful.h>
|
| 4 |
+
#include <torch/data/dataloader/stateless.h>
|
| 5 |
+
|
| 6 |
+
#include <torch/csrc/utils/variadic.h>
|
| 7 |
+
|
| 8 |
+
#include <c10/util/Exception.h>
|
| 9 |
+
|
| 10 |
+
#include <cstddef>
|
| 11 |
+
#include <memory>
|
| 12 |
+
#include <type_traits>
|
| 13 |
+
#include <utility>
|
| 14 |
+
|
| 15 |
+
namespace torch {
|
| 16 |
+
namespace data {
|
| 17 |
+
|
| 18 |
+
/// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and
|
| 19 |
+
/// some `options`.
|
| 20 |
+
template <typename Dataset, typename Sampler>
|
| 21 |
+
std::enable_if_t<
|
| 22 |
+
!Dataset::is_stateful,
|
| 23 |
+
std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
|
| 24 |
+
make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) {
|
| 25 |
+
return std::make_unique<StatelessDataLoader<Dataset, Sampler>>(
|
| 26 |
+
std::move(dataset), std::move(sampler), std::move(options));
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
/// Creates a `DataLoader` instance for a stateless `dataset` and some
|
| 30 |
+
/// `options`. A sampler (by default a `RandomSampler`) will be constructed from
|
| 31 |
+
/// the size of the dataset.
|
| 32 |
+
template <typename Sampler = samplers::RandomSampler, typename Dataset>
|
| 33 |
+
std::enable_if_t<
|
| 34 |
+
!Dataset::is_stateful && std::is_constructible_v<Sampler, size_t>,
|
| 35 |
+
std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
|
| 36 |
+
make_data_loader(
|
| 37 |
+
Dataset dataset,
|
| 38 |
+
DataLoaderOptions options = DataLoaderOptions()) {
|
| 39 |
+
const std::optional<size_t> size = dataset.size();
|
| 40 |
+
TORCH_CHECK(
|
| 41 |
+
size.has_value(),
|
| 42 |
+
"Expected the dataset to be sized in "
|
| 43 |
+
"order to construct the Sampler");
|
| 44 |
+
return make_data_loader(
|
| 45 |
+
std::move(dataset), Sampler(*size), std::move(options));
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
/// Creates a `DataLoader` for a stateful `dataset` and some `options`.
|
| 49 |
+
template <typename Dataset, typename = std::enable_if_t<Dataset::is_stateful>>
|
| 50 |
+
std::unique_ptr<StatefulDataLoader<Dataset>> make_data_loader(
|
| 51 |
+
Dataset dataset,
|
| 52 |
+
DataLoaderOptions options = DataLoaderOptions()) {
|
| 53 |
+
return std::make_unique<StatefulDataLoader<Dataset>>(
|
| 54 |
+
std::move(dataset), std::move(options));
|
| 55 |
+
}
|
| 56 |
+
} // namespace data
|
| 57 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/dataloader_options.h>
|
| 4 |
+
#include <torch/data/detail/data_shuttle.h>
|
| 5 |
+
#include <torch/data/detail/sequencers.h>
|
| 6 |
+
#include <torch/data/iterator.h>
|
| 7 |
+
#include <torch/data/samplers/random.h>
|
| 8 |
+
#include <torch/data/worker_exception.h>
|
| 9 |
+
#include <torch/types.h>
|
| 10 |
+
|
| 11 |
+
#include <torch/csrc/utils/variadic.h>
|
| 12 |
+
|
| 13 |
+
#include <c10/util/Exception.h>
|
| 14 |
+
#include <c10/util/irange.h>
|
| 15 |
+
|
| 16 |
+
#include <cstddef>
|
| 17 |
+
#include <exception>
|
| 18 |
+
#include <memory>
|
| 19 |
+
#include <thread>
|
| 20 |
+
#include <type_traits>
|
| 21 |
+
#include <utility>
|
| 22 |
+
#include <vector>
|
| 23 |
+
|
| 24 |
+
namespace torch {
|
| 25 |
+
namespace data {
|
| 26 |
+
template <typename Dataset, typename Batch, typename BatchRequest>
|
| 27 |
+
class DataLoaderBase {
|
| 28 |
+
public:
|
| 29 |
+
using BatchType = Batch;
|
| 30 |
+
using BatchRequestType = BatchRequest;
|
| 31 |
+
|
| 32 |
+
/// Constructs a new DataLoader from a `dataset` to sample from, `options`
|
| 33 |
+
/// to configure the DataLoader with, and a `sampler` that specifies the
|
| 34 |
+
/// sampling strategy.
|
| 35 |
+
DataLoaderBase(
|
| 36 |
+
DataLoaderOptions options,
|
| 37 |
+
std::unique_ptr<Dataset> main_thread_dataset = nullptr)
|
| 38 |
+
: options_(std::move(options)),
|
| 39 |
+
main_thread_dataset_(std::move(main_thread_dataset)),
|
| 40 |
+
sequencer_(new_sequencer()) {}
|
| 41 |
+
|
| 42 |
+
// NOLINTNEXTLINE(bugprone-exception-escape)
|
| 43 |
+
virtual ~DataLoaderBase() {
|
| 44 |
+
join();
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
/// Returns an iterator into the DataLoader. The lifetime of the iterator is
|
| 48 |
+
/// bound to the DataLoader. In C++ standards language, the category of the
|
| 49 |
+
/// iterator is `OutputIterator`. See
|
| 50 |
+
/// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this
|
| 51 |
+
/// means. In short: you may increment the iterator and dereference it, but
|
| 52 |
+
/// cannot go back, or step forward more than one position at a time. When the
|
| 53 |
+
/// DataLoader is exhausted, it will compare equal with the special
|
| 54 |
+
/// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you
|
| 55 |
+
/// should only use range-for loops to loop over the DataLoader, but
|
| 56 |
+
/// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(),
|
| 57 |
+
/// output_iterator)` are supported too.
|
| 58 |
+
Iterator<Batch> begin() {
|
| 59 |
+
TORCH_CHECK(
|
| 60 |
+
shuttle_.in_flight_jobs() == 0,
|
| 61 |
+
"Attempted to get a new DataLoader iterator "
|
| 62 |
+
"while another iterator is not yet exhausted");
|
| 63 |
+
reset();
|
| 64 |
+
return Iterator<Batch>(std::make_unique<detail::ValidIterator<Batch>>(
|
| 65 |
+
[this] { return this->next(); }));
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
/// Returns a special "sentinel" iterator that compares equal with a
|
| 69 |
+
/// non-sentinel iterator once the DataLoader is exhausted.
|
| 70 |
+
Iterator<Batch> end() {
|
| 71 |
+
return Iterator<Batch>(std::make_unique<detail::SentinelIterator<Batch>>());
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
/// Joins the DataLoader's worker threads and drains internal queues.
|
| 75 |
+
/// This function may only be invoked from the main thread (in which the
|
| 76 |
+
/// DataLoader lives).
|
| 77 |
+
void join() {
|
| 78 |
+
if (joined_) {
|
| 79 |
+
return;
|
| 80 |
+
}
|
| 81 |
+
shuttle_.drain();
|
| 82 |
+
// Send one 'quit' message per worker. Since a worker dies (exits its
|
| 83 |
+
// thread) after receiving this message, each `QuitWorker()` message will be
|
| 84 |
+
// read by exactly one worker.
|
| 85 |
+
for (const auto w : c10::irange(options_.workers)) {
|
| 86 |
+
(void)w; // Suppress unused variable warning
|
| 87 |
+
push_job(QuitWorker());
|
| 88 |
+
}
|
| 89 |
+
for (auto& worker : workers_) {
|
| 90 |
+
worker.join();
|
| 91 |
+
}
|
| 92 |
+
joined_ = true;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
/// Returns the options with which the DataLoader was configured.
|
| 96 |
+
const FullDataLoaderOptions& options() const noexcept {
|
| 97 |
+
return options_;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
protected:
|
| 101 |
+
/// Simple mix-in to give something a sequence number.
|
| 102 |
+
struct Sequenced {
|
| 103 |
+
Sequenced() = default;
|
| 104 |
+
Sequenced(size_t sqn) : sequence_number(sqn) {}
|
| 105 |
+
size_t sequence_number;
|
| 106 |
+
};
|
| 107 |
+
|
| 108 |
+
struct QuitWorker {};
|
| 109 |
+
|
| 110 |
+
/// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a
|
| 111 |
+
/// `QuitWorker` object, to indicate the worker should shut down.
|
| 112 |
+
struct Job : Sequenced {
|
| 113 |
+
Job() = default;
|
| 114 |
+
Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
|
| 115 |
+
Job(BatchRequest&& i, size_t sqn)
|
| 116 |
+
: Sequenced(sqn), batch_request(std::move(i)) {}
|
| 117 |
+
std::optional<QuitWorker> quit;
|
| 118 |
+
std::optional<BatchRequest> batch_request;
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
/// The finished result of a job.
|
| 122 |
+
struct Result : Sequenced {
|
| 123 |
+
Result() = default;
|
| 124 |
+
Result(std::optional<Batch>&& b, size_t sqn)
|
| 125 |
+
: Sequenced(sqn), batch(std::move(b)) {}
|
| 126 |
+
Result(std::exception_ptr exception, size_t sqn)
|
| 127 |
+
: Sequenced(sqn), exception(std::move(exception)) {}
|
| 128 |
+
std::optional<Batch> batch;
|
| 129 |
+
std::exception_ptr exception;
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
/// Subclass hook for getting the next batch request. The stateless case will
|
| 133 |
+
/// ask the sampler for a new batch request (e.g. a vector of indices), while
|
| 134 |
+
/// the stateful one will simply return the batch size.
|
| 135 |
+
virtual std::optional<BatchRequestType> get_batch_request() = 0;
|
| 136 |
+
|
| 137 |
+
/// Resets the internal state of the DataLoader, optionally pre-fetching
|
| 138 |
+
/// new jobs.
|
| 139 |
+
virtual void reset() {
|
| 140 |
+
shuttle_.drain();
|
| 141 |
+
sequence_number_ = 0;
|
| 142 |
+
sequencer_ = new_sequencer();
|
| 143 |
+
prefetch();
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
/// Schedules `requested_jobs` many new batches to be fetched. The actual
|
| 147 |
+
/// number of jobs scheduled may be less if the DataLoader exhausts.
|
| 148 |
+
void prefetch(size_t requested_jobs) {
|
| 149 |
+
for (const auto r : c10::irange(requested_jobs)) {
|
| 150 |
+
(void)r; // Suppress unused variable
|
| 151 |
+
if (auto batch_request = get_batch_request()) {
|
| 152 |
+
this->push_job(std::move(*batch_request));
|
| 153 |
+
} else {
|
| 154 |
+
break;
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
/// Schedules the maximum number of jobs (based on the `max_jobs` option).
|
| 160 |
+
void prefetch() {
|
| 161 |
+
prefetch(options_.max_jobs);
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
/// Returns the next batch of data, or an empty `optional` if the DataLoader
|
| 165 |
+
/// is exhausted. This operation will block until a batch is available if one
|
| 166 |
+
/// is still expected.
|
| 167 |
+
std::optional<BatchType> next() {
|
| 168 |
+
if (options_.workers > 0) {
|
| 169 |
+
while (std::optional<Result> result = this->pop_result()) {
|
| 170 |
+
if (result->exception) {
|
| 171 |
+
throw WorkerException(result->exception);
|
| 172 |
+
} else if (result->batch) {
|
| 173 |
+
prefetch(1);
|
| 174 |
+
return std::move(result->batch);
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
} else if (auto batch_request = get_batch_request()) {
|
| 178 |
+
return this->main_thread_dataset_->get_batch(std::move(*batch_request));
|
| 179 |
+
}
|
| 180 |
+
return nullopt;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
/// The function that worker threads run.
|
| 184 |
+
void worker_thread(Dataset& dataset) {
|
| 185 |
+
while (true) {
|
| 186 |
+
auto job = shuttle_.pop_job();
|
| 187 |
+
if (job.quit) {
|
| 188 |
+
break;
|
| 189 |
+
}
|
| 190 |
+
try {
|
| 191 |
+
auto batch = dataset.get_batch(std::move(*job.batch_request));
|
| 192 |
+
shuttle_.push_result({std::move(batch), job.sequence_number});
|
| 193 |
+
} catch (...) {
|
| 194 |
+
shuttle_.push_result({std::current_exception(), job.sequence_number});
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
/// Convenience method that calls `shuttle_.push_job()` with the next sequence
|
| 200 |
+
/// number.
|
| 201 |
+
template <typename T>
|
| 202 |
+
void push_job(T value) {
|
| 203 |
+
shuttle_.push_job({std::move(value), sequence_number_++});
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
/// Convenience method that gets the next result from the sequencer.
|
| 207 |
+
std::optional<Result> pop_result() {
|
| 208 |
+
return sequencer_->next(
|
| 209 |
+
[this] { return this->shuttle_.pop_result(this->options_.timeout); });
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
/// Convenience method that creates a new sequencer based on the
|
| 213 |
+
/// `enforce_ordering` option.
|
| 214 |
+
std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
|
| 215 |
+
if (options_.enforce_ordering) {
|
| 216 |
+
return std::make_unique<detail::sequencers::OrderedSequencer<Result>>(
|
| 217 |
+
options_.max_jobs);
|
| 218 |
+
}
|
| 219 |
+
return std::make_unique<detail::sequencers::NoSequencer<Result>>();
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
/// The options the DataLoader was configured with.
|
| 223 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 224 |
+
const FullDataLoaderOptions options_;
|
| 225 |
+
|
| 226 |
+
/// The dataset for the main thread, only has a value if the number of
|
| 227 |
+
/// worker threads was configured as zero, meaning the main thread has to do
|
| 228 |
+
/// all the work (synchronously). NOTE: Really want this to be on the heap
|
| 229 |
+
/// when empty, therefore `unique_ptr` and not `optional`.
|
| 230 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 231 |
+
std::unique_ptr<Dataset> main_thread_dataset_;
|
| 232 |
+
|
| 233 |
+
/// The sequence number for the *next* batch to be retrieved from the
|
| 234 |
+
/// dataset.
|
| 235 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 236 |
+
size_t sequence_number_ = 0;
|
| 237 |
+
|
| 238 |
+
/// The worker threads, running the `worker_thread()` method.
|
| 239 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 240 |
+
std::vector<std::thread> workers_;
|
| 241 |
+
|
| 242 |
+
/// The `DataShuttle` which takes care of the life cycle of a job.
|
| 243 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 244 |
+
detail::DataShuttle<Job, Result> shuttle_;
|
| 245 |
+
|
| 246 |
+
/// The `Sequencer`, which handles optional ordering of batches.
|
| 247 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 248 |
+
std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;
|
| 249 |
+
|
| 250 |
+
/// True if the DataLoader has joined its worker threads.
|
| 251 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 252 |
+
bool joined_ = false;
|
| 253 |
+
};
|
| 254 |
+
} // namespace data
|
| 255 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
#include <torch/data/dataloader/base.h>
|
| 5 |
+
|
| 6 |
+
#include <cstddef>
|
| 7 |
+
#include <thread>
|
| 8 |
+
#include <utility>
|
| 9 |
+
|
| 10 |
+
namespace torch {
|
| 11 |
+
namespace data {
|
| 12 |
+
|
| 13 |
+
/// A dataloader for stateful datasets.
|
| 14 |
+
///
|
| 15 |
+
/// A dataloader for stateful datatasets differs from one for stateless
|
| 16 |
+
/// datasets one in that the dataset is shared among worker threads, and that
|
| 17 |
+
/// this dataset is itself responsible for producing batches rather than
|
| 18 |
+
/// depending on a sampler. The statefulness here actually refers to the
|
| 19 |
+
/// dataset. The StatefulDataLoader simply alters the data loading algorithm to
|
| 20 |
+
/// accommodate the stateful, shared nature of the dataset. Note that the
|
| 21 |
+
/// dataset must be thread safe if more than one worker thread is used.
|
| 22 |
+
///
|
| 23 |
+
/// A stateful dataloader is created by calling `make_data_loader` with a
|
| 24 |
+
/// stateful dataset.
|
| 25 |
+
template <typename Dataset>
|
| 26 |
+
class StatefulDataLoader : public DataLoaderBase<
|
| 27 |
+
Dataset,
|
| 28 |
+
typename Dataset::BatchType::value_type,
|
| 29 |
+
typename Dataset::BatchRequestType> {
|
| 30 |
+
public:
|
| 31 |
+
using super = DataLoaderBase<
|
| 32 |
+
Dataset,
|
| 33 |
+
typename Dataset::BatchType::value_type,
|
| 34 |
+
typename Dataset::BatchRequestType>;
|
| 35 |
+
using typename super::BatchRequestType;
|
| 36 |
+
|
| 37 |
+
/// Constructs the `StatefulDataLoader` from a `dataset` and some `options`.
|
| 38 |
+
StatefulDataLoader(Dataset dataset, DataLoaderOptions options)
|
| 39 |
+
: super(options, std::make_unique<Dataset>(std::move(dataset))) {
|
| 40 |
+
for ([[maybe_unused]] const auto _ : c10::irange(this->options_.workers)) {
|
| 41 |
+
// As opposed to the stateless case, here all worker threads access the
|
| 42 |
+
// same underlying dataset.
|
| 43 |
+
this->workers_.emplace_back(
|
| 44 |
+
[this] { this->worker_thread(*this->main_thread_dataset_); });
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
private:
|
| 49 |
+
/// Resets the internal state of the dataloader and the dataset.
|
| 50 |
+
void reset() override {
|
| 51 |
+
this->main_thread_dataset_->reset();
|
| 52 |
+
// Call the base class method last because it calls `prefetch()`
|
| 53 |
+
super::reset();
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
/// For stateful datasets, the batch request is always the batch size. The
|
| 57 |
+
/// dataset is responsible for determining what goes into the batch next.
|
| 58 |
+
std::optional<BatchRequestType> get_batch_request() override {
|
| 59 |
+
return this->options_.batch_size;
|
| 60 |
+
}
|
| 61 |
+
};
|
| 62 |
+
} // namespace data
|
| 63 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/dataloader/base.h>
|
| 4 |
+
#include <torch/data/worker_exception.h>
|
| 5 |
+
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
#include <c10/util/irange.h>
|
| 8 |
+
|
| 9 |
+
#include <cstddef>
|
| 10 |
+
#include <thread>
|
| 11 |
+
#include <utility>
|
| 12 |
+
|
| 13 |
+
namespace torch {
|
| 14 |
+
namespace data {
|
| 15 |
+
|
| 16 |
+
/// A dataloader for stateless datasets.
|
| 17 |
+
///
|
| 18 |
+
/// This dataloader follows the traditional PyTorch dataloader design, whereby a
|
| 19 |
+
/// (posssibly) stateful sampler produces *batch requests* for a stateless
|
| 20 |
+
/// dataset, which acts as a simple batch request to batch mapping. The batch
|
| 21 |
+
/// request will often be an array of indices, and if the dataset is a simple
|
| 22 |
+
/// image dataset, the dataset would produce the images at those indices.
|
| 23 |
+
template <typename Dataset, typename Sampler>
|
| 24 |
+
class StatelessDataLoader : public DataLoaderBase<
|
| 25 |
+
Dataset,
|
| 26 |
+
typename Dataset::BatchType,
|
| 27 |
+
typename Sampler::BatchRequestType> {
|
| 28 |
+
public:
|
| 29 |
+
using super = DataLoaderBase<
|
| 30 |
+
Dataset,
|
| 31 |
+
typename Dataset::BatchType,
|
| 32 |
+
typename Sampler::BatchRequestType>;
|
| 33 |
+
using typename super::BatchRequestType;
|
| 34 |
+
|
| 35 |
+
/// Constructs the `StatelessDataLoader` from a `dataset`, a `sampler` and
|
| 36 |
+
/// some `options`.
|
| 37 |
+
StatelessDataLoader(
|
| 38 |
+
Dataset dataset,
|
| 39 |
+
Sampler sampler,
|
| 40 |
+
DataLoaderOptions options)
|
| 41 |
+
: super(std::move(options)), sampler_(std::move(sampler)) {
|
| 42 |
+
for (const auto w : c10::irange(this->options_.workers)) {
|
| 43 |
+
// Here we copy the dataset into the worker thread closure. Each worker
|
| 44 |
+
// has its own copy of the dataset. This means the dataset must be
|
| 45 |
+
// trivially copiable, or else we don't expect more than one worker to
|
| 46 |
+
// be in use.
|
| 47 |
+
(void)w; // Suppress unused variable warning
|
| 48 |
+
this->workers_.emplace_back(
|
| 49 |
+
[this, dataset]() mutable { this->worker_thread(dataset); });
|
| 50 |
+
}
|
| 51 |
+
if (this->options_.workers == 0) {
|
| 52 |
+
this->main_thread_dataset_ =
|
| 53 |
+
std::make_unique<Dataset>(std::move(dataset));
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
private:
|
| 58 |
+
/// Resets the internal state of the dataloader and the sampler.
|
| 59 |
+
void reset() override {
|
| 60 |
+
sampler_.reset();
|
| 61 |
+
// Call the base class method last because it calls `prefetch()`
|
| 62 |
+
super::reset();
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
/// Queries the sampler for the next batch request (possibly progressing its
|
| 66 |
+
/// internal state).
|
| 67 |
+
std::optional<BatchRequestType> get_batch_request() override {
|
| 68 |
+
auto indices = sampler_.next(this->options_.batch_size);
|
| 69 |
+
if (!indices ||
|
| 70 |
+
(indices->size() < this->options_.batch_size &&
|
| 71 |
+
this->options_.drop_last)) {
|
| 72 |
+
return nullopt;
|
| 73 |
+
}
|
| 74 |
+
AT_ASSERT(indices->size() > 0);
|
| 75 |
+
return indices;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
/// The `Sampler` used to produce batch requests.
|
| 79 |
+
Sampler sampler_;
|
| 80 |
+
};
|
| 81 |
+
} // namespace data
|
| 82 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/arg.h>
|
| 4 |
+
#include <torch/types.h>
|
| 5 |
+
|
| 6 |
+
#include <chrono>
|
| 7 |
+
#include <cstddef>
|
| 8 |
+
|
| 9 |
+
namespace torch {
|
| 10 |
+
namespace data {
|
| 11 |
+
|
| 12 |
+
/// Options to configure a `DataLoader`.
|
| 13 |
+
struct DataLoaderOptions {
|
| 14 |
+
DataLoaderOptions() = default;
|
| 15 |
+
/* implicit */ DataLoaderOptions(size_t batch_size)
|
| 16 |
+
: batch_size_(batch_size) {}
|
| 17 |
+
|
| 18 |
+
/// The size of each batch to fetch.
|
| 19 |
+
TORCH_ARG(size_t, batch_size) = 1;
|
| 20 |
+
|
| 21 |
+
/// The number of worker threads to launch. If zero, the main thread will
|
| 22 |
+
/// synchronously perform the data loading.
|
| 23 |
+
TORCH_ARG(size_t, workers) = 0;
|
| 24 |
+
|
| 25 |
+
/// The maximum number of jobs to enqueue for fetching by worker threads.
|
| 26 |
+
/// Defaults to two times the number of worker threads.
|
| 27 |
+
TORCH_ARG(std::optional<size_t>, max_jobs);
|
| 28 |
+
|
| 29 |
+
/// An optional limit on the time to wait for the next batch.
|
| 30 |
+
TORCH_ARG(std::optional<std::chrono::milliseconds>, timeout);
|
| 31 |
+
|
| 32 |
+
/// Whether to enforce ordering of batches when multiple are loaded
|
| 33 |
+
/// asynchronously by worker threads. Set to `false` for better performance if
|
| 34 |
+
/// you do not care about determinism.
|
| 35 |
+
TORCH_ARG(bool, enforce_ordering) = true;
|
| 36 |
+
|
| 37 |
+
/// Whether to omit the last batch if it contains less than `batch_size`
|
| 38 |
+
/// examples.
|
| 39 |
+
TORCH_ARG(bool, drop_last) = false;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
/// Like `DataLoaderOptions`, but without any unconfigured state.
|
| 43 |
+
/// `DataLoaderOptions` has some options that depend on other options
|
| 44 |
+
/// (`max_jobs` => `2 * workers`). In the spirit of properly using the C++ type
|
| 45 |
+
/// system, `DataLoaderOptions` allows only setting values. To access values,
|
| 46 |
+
/// you must create a `FullDataLoaderOptions` from a `DataLoaderOptions`
|
| 47 |
+
/// instance, which will do any necessary coalescing.
|
| 48 |
+
struct FullDataLoaderOptions {
|
| 49 |
+
explicit FullDataLoaderOptions(DataLoaderOptions options)
|
| 50 |
+
: batch_size(options.batch_size()),
|
| 51 |
+
workers(options.workers()),
|
| 52 |
+
max_jobs(options.max_jobs().value_or(2 * workers)),
|
| 53 |
+
timeout(options.timeout()),
|
| 54 |
+
enforce_ordering(options.enforce_ordering()),
|
| 55 |
+
drop_last(options.drop_last()) {}
|
| 56 |
+
|
| 57 |
+
size_t batch_size;
|
| 58 |
+
size_t workers;
|
| 59 |
+
size_t max_jobs;
|
| 60 |
+
std::optional<std::chrono::milliseconds> timeout;
|
| 61 |
+
bool enforce_ordering;
|
| 62 |
+
bool drop_last;
|
| 63 |
+
};
|
| 64 |
+
} // namespace data
|
| 65 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets.h
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/datasets/base.h>
|
| 4 |
+
#include <torch/data/datasets/chunk.h>
|
| 5 |
+
#include <torch/data/datasets/map.h>
|
| 6 |
+
#include <torch/data/datasets/mnist.h>
|
| 7 |
+
#include <torch/data/datasets/shared.h>
|
| 8 |
+
#include <torch/data/datasets/stateful.h>
|
| 9 |
+
#include <torch/data/datasets/tensor.h>
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/base.h
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/example.h>
|
| 4 |
+
#include <torch/types.h>
|
| 5 |
+
|
| 6 |
+
#include <c10/util/ArrayRef.h>
|
| 7 |
+
|
| 8 |
+
#include <cstddef>
|
| 9 |
+
#include <cstdint>
|
| 10 |
+
#include <type_traits>
|
| 11 |
+
#include <utility>
|
| 12 |
+
#include <vector>
|
| 13 |
+
|
| 14 |
+
namespace torch {
|
| 15 |
+
namespace data {
|
| 16 |
+
namespace datasets {
|
| 17 |
+
template <typename S, typename T>
|
| 18 |
+
class MapDataset;
|
| 19 |
+
template <typename D, typename T>
|
| 20 |
+
MapDataset<D, T> map(D, T); // NOLINT
|
| 21 |
+
} // namespace datasets
|
| 22 |
+
} // namespace data
|
| 23 |
+
} // namespace torch
|
| 24 |
+
|
| 25 |
+
namespace torch {
|
| 26 |
+
namespace data {
|
| 27 |
+
namespace datasets {
|
| 28 |
+
namespace detail {
|
| 29 |
+
template <typename T>
|
| 30 |
+
struct is_optional : std::false_type {};
|
| 31 |
+
template <typename T>
|
| 32 |
+
struct is_optional<std::optional<T>> : std::true_type {};
|
| 33 |
+
} // namespace detail
|
| 34 |
+
|
| 35 |
+
/// A dataset that can yield data only in batches.
|
| 36 |
+
template <
|
| 37 |
+
typename Self,
|
| 38 |
+
typename Batch = std::vector<Example<>>,
|
| 39 |
+
typename BatchRequest = ArrayRef<size_t>>
|
| 40 |
+
class BatchDataset {
|
| 41 |
+
public:
|
| 42 |
+
using SelfType = Self;
|
| 43 |
+
using BatchType = Batch;
|
| 44 |
+
using BatchRequestType = BatchRequest;
|
| 45 |
+
constexpr static bool is_stateful = detail::is_optional<BatchType>::value;
|
| 46 |
+
|
| 47 |
+
virtual ~BatchDataset() = default;
|
| 48 |
+
|
| 49 |
+
/// Returns a batch of data given an index.
|
| 50 |
+
virtual Batch get_batch(BatchRequest request) = 0;
|
| 51 |
+
|
| 52 |
+
/// Returns the size of the dataset, or an empty std::optional if it is
|
| 53 |
+
/// unsized.
|
| 54 |
+
virtual std::optional<size_t> size() const = 0;
|
| 55 |
+
|
| 56 |
+
/// Creates a `MapDataset` that applies the given `transform` to this dataset.
|
| 57 |
+
template <typename TransformType>
|
| 58 |
+
MapDataset<Self, TransformType> map(TransformType transform) & {
|
| 59 |
+
return datasets::map(static_cast<Self&>(*this), std::move(transform));
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
/// Creates a `MapDataset` that applies the given `transform` to this dataset.
|
| 63 |
+
template <typename TransformType>
|
| 64 |
+
MapDataset<Self, TransformType> map(TransformType transform) && {
|
| 65 |
+
return datasets::map(
|
| 66 |
+
std::move(static_cast<Self&>(*this)), std::move(transform));
|
| 67 |
+
}
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
/// A dataset that can yield data in batches, or as individual examples.
|
| 71 |
+
///
|
| 72 |
+
/// A `Dataset` is a `BatchDataset`, because it supports random access and
|
| 73 |
+
/// therefore batched access is implemented (by default) by calling the random
|
| 74 |
+
/// access indexing function for each index in the requested batch of indices.
|
| 75 |
+
/// This can be customized.
|
| 76 |
+
template <typename Self, typename SingleExample = Example<>>
|
| 77 |
+
class Dataset : public BatchDataset<Self, std::vector<SingleExample>> {
|
| 78 |
+
public:
|
| 79 |
+
using ExampleType = SingleExample;
|
| 80 |
+
|
| 81 |
+
/// Returns the example at the given index.
|
| 82 |
+
virtual ExampleType get(size_t index) = 0;
|
| 83 |
+
|
| 84 |
+
/// Returns a batch of data.
|
| 85 |
+
/// The default implementation calls `get()` for every requested index
|
| 86 |
+
/// in the batch.
|
| 87 |
+
std::vector<ExampleType> get_batch(ArrayRef<size_t> indices) override {
|
| 88 |
+
std::vector<ExampleType> batch;
|
| 89 |
+
batch.reserve(indices.size());
|
| 90 |
+
for (const auto i : indices) {
|
| 91 |
+
batch.push_back(get(i));
|
| 92 |
+
}
|
| 93 |
+
return batch;
|
| 94 |
+
}
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
/// A `StreamDataset` represents a dataset that is a potentially infinite
|
| 98 |
+
/// stream. It takes as batch index only a number, which is the batch size, and
|
| 99 |
+
/// yields that many elements from the stream.
|
| 100 |
+
template <typename Self, typename Batch = std::vector<Example<>>>
|
| 101 |
+
using StreamDataset = BatchDataset<Self, Batch, /*BatchRequest=*/size_t>;
|
| 102 |
+
} // namespace datasets
|
| 103 |
+
} // namespace data
|
| 104 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
#include <torch/arg.h>
|
| 5 |
+
#include <torch/data/datasets/stateful.h>
|
| 6 |
+
#include <torch/data/samplers.h>
|
| 7 |
+
#include <queue>
|
| 8 |
+
#include <thread>
|
| 9 |
+
|
| 10 |
+
#include <torch/serialize.h>
|
| 11 |
+
|
| 12 |
+
namespace torch {
|
| 13 |
+
namespace data {
|
| 14 |
+
namespace datasets {
|
| 15 |
+
|
| 16 |
+
/// Interface for chunk reader, which performs data chunking and reading of
|
| 17 |
+
/// entire chunks.
|
| 18 |
+
///
|
| 19 |
+
/// A chunk could be an entire file, such as an audio data file or an image,
|
| 20 |
+
/// or part of a file in the case of a large text-file split based on seek
|
| 21 |
+
/// positions.
|
| 22 |
+
template <
|
| 23 |
+
typename ExampleType_,
|
| 24 |
+
typename ChunkType_ = std::vector<ExampleType_>>
|
| 25 |
+
class ChunkDataReader {
|
| 26 |
+
public:
|
| 27 |
+
virtual ~ChunkDataReader() = default;
|
| 28 |
+
|
| 29 |
+
using ChunkType = ChunkType_;
|
| 30 |
+
using ExampleType = ExampleType_;
|
| 31 |
+
|
| 32 |
+
/// Read an entire chunk.
|
| 33 |
+
virtual ChunkType read_chunk(size_t chunk_index) = 0;
|
| 34 |
+
|
| 35 |
+
/// Returns the number of chunks available in this reader.
|
| 36 |
+
virtual size_t chunk_count() = 0;
|
| 37 |
+
|
| 38 |
+
/// This will clear any internal state associate with this reader.
|
| 39 |
+
virtual void reset() = 0;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
namespace detail {
|
| 43 |
+
/// BatchDataBuffer manages a queue of UnwrappedBatchData. After a new chunk is
|
| 44 |
+
/// loaded, BatchDataBuffer splits it into small batches and push them into the
|
| 45 |
+
/// queue. When get_batch is called from data loader, it pops cached batches and
|
| 46 |
+
/// return. If the cache is empty, it either waits to load more chunks or return
|
| 47 |
+
/// null if all chunks are loaded.
|
| 48 |
+
template <
|
| 49 |
+
typename UnwrappedBatch,
|
| 50 |
+
typename ExampleSampler = samplers::RandomSampler>
|
| 51 |
+
class BatchDataBuffer {
|
| 52 |
+
public:
|
| 53 |
+
using UnwrappedBatchType = UnwrappedBatch;
|
| 54 |
+
using BatchType = torch::optional<UnwrappedBatchType>;
|
| 55 |
+
using BatchRequestType = typename ExampleSampler::BatchRequestType;
|
| 56 |
+
|
| 57 |
+
BatchDataBuffer(
|
| 58 |
+
size_t batch_size,
|
| 59 |
+
ExampleSampler& example_sampler,
|
| 60 |
+
size_t queue_capacity)
|
| 61 |
+
: batch_size_(batch_size),
|
| 62 |
+
example_sampler_(example_sampler),
|
| 63 |
+
queue_capacity_(queue_capacity) {}
|
| 64 |
+
|
| 65 |
+
/// Return batch data from the queue. Called from the ChunkDataset main
|
| 66 |
+
/// thread.
|
| 67 |
+
BatchType get_batch() {
|
| 68 |
+
std::unique_lock<std::mutex> lock(queue_mutex_);
|
| 69 |
+
cv_read_.wait(lock, [this] {
|
| 70 |
+
// wait till there is available data in the queue or if all chunks are
|
| 71 |
+
// loaded (i.e. the dataset is exhausted for this epoch)
|
| 72 |
+
return (
|
| 73 |
+
this->total_example_count_in_queue_ >= batch_size_ || this->stop_);
|
| 74 |
+
});
|
| 75 |
+
if (batch_queue_.empty()) {
|
| 76 |
+
AT_ASSERT(stop_);
|
| 77 |
+
// All batches have been retrieved. Return an empty batch.
|
| 78 |
+
return nullopt;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
UnwrappedBatchData batch = std::move(batch_queue_.front());
|
| 82 |
+
batch_queue_.pop();
|
| 83 |
+
if (batch.exception) {
|
| 84 |
+
throw WorkerException(batch.exception);
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
total_example_count_in_queue_ -= batch.batch_data.size();
|
| 88 |
+
lock.unlock();
|
| 89 |
+
cv_write_.notify_all();
|
| 90 |
+
|
| 91 |
+
return batch.batch_data;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
/// Push preloaded chunks to batch queue. Called from the ChunkDataset worker
|
| 95 |
+
/// threads.
|
| 96 |
+
void add_chunk_data(UnwrappedBatchType data) {
|
| 97 |
+
std::unique_lock<std::mutex> lock(queue_mutex_);
|
| 98 |
+
cv_write_.wait(lock, [this] {
|
| 99 |
+
// stop loading if we have preloaded enough data.
|
| 100 |
+
return this->total_example_count_in_queue_ < this->queue_capacity_ ||
|
| 101 |
+
this->stop_;
|
| 102 |
+
});
|
| 103 |
+
if (stop_) {
|
| 104 |
+
// When stop_ is true, it means no further chunk loading is necessary.
|
| 105 |
+
// Return without any further processing.
|
| 106 |
+
return;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
auto data_size = data.size();
|
| 110 |
+
auto remaining_size = data_size;
|
| 111 |
+
example_sampler_.reset(data_size);
|
| 112 |
+
|
| 113 |
+
auto fill_batch = [&](size_t example_count, UnwrappedBatchType& batch) {
|
| 114 |
+
auto batch_example_indices = this->example_sampler_.next(example_count);
|
| 115 |
+
AT_ASSERT(
|
| 116 |
+
batch_example_indices &&
|
| 117 |
+
batch_example_indices.value().size() == example_count);
|
| 118 |
+
BatchRequestType& indices = batch_example_indices.value();
|
| 119 |
+
for (size_t i : indices) {
|
| 120 |
+
TORCH_CHECK(i < data_size, "Index out of range");
|
| 121 |
+
batch.emplace_back(std::move(data[i]));
|
| 122 |
+
}
|
| 123 |
+
remaining_size -= example_count;
|
| 124 |
+
};
|
| 125 |
+
|
| 126 |
+
if (!batch_queue_.empty()) {
|
| 127 |
+
// if the queue has existing data, and the last batch doesn't have enough
|
| 128 |
+
// examples to fill a batch_size batch, add more example to this batch
|
| 129 |
+
// first.
|
| 130 |
+
auto& batch = batch_queue_.back();
|
| 131 |
+
size_t current_count = batch.batch_data.size();
|
| 132 |
+
if (current_count < batch_size_) {
|
| 133 |
+
auto example_count =
|
| 134 |
+
std::min(remaining_size, batch_size_ - current_count);
|
| 135 |
+
fill_batch(example_count, batch.batch_data);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// If we still have data remaining after filling the last pushed batch, add
|
| 140 |
+
// them to the queue too.
|
| 141 |
+
// NOLINTNEXTLINE(bugprone-infinite-loop)
|
| 142 |
+
while (remaining_size > 0) {
|
| 143 |
+
UnwrappedBatchType current_batch;
|
| 144 |
+
|
| 145 |
+
// Allocate the batch memory ahead of time.
|
| 146 |
+
current_batch.reserve(batch_size_);
|
| 147 |
+
|
| 148 |
+
auto example_count = std::min(remaining_size, batch_size_);
|
| 149 |
+
fill_batch(example_count, current_batch);
|
| 150 |
+
batch_queue_.emplace(std::move(current_batch));
|
| 151 |
+
}
|
| 152 |
+
total_example_count_in_queue_ += data_size;
|
| 153 |
+
lock.unlock();
|
| 154 |
+
cv_read_.notify_all();
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
/// Push exceptions thrown during preloading into batch queue. Called from
|
| 158 |
+
/// the ChunkDataset worker threads.
|
| 159 |
+
void add_chunk_data(std::exception_ptr e_ptr) {
|
| 160 |
+
std::unique_lock<std::mutex> lock(queue_mutex_);
|
| 161 |
+
cv_write_.wait(lock, [this] {
|
| 162 |
+
// stop loading if we have preloaded enough data.
|
| 163 |
+
return (
|
| 164 |
+
this->total_example_count_in_queue_ < this->queue_capacity_ ||
|
| 165 |
+
this->stop_);
|
| 166 |
+
});
|
| 167 |
+
if (stop_) {
|
| 168 |
+
// When stop_ is true, it means this current thread needs to be tore down,
|
| 169 |
+
// the batch buffer will be discarded, so no need to enqueue any new
|
| 170 |
+
// exceptions.
|
| 171 |
+
return;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
batch_queue_.emplace(e_ptr);
|
| 175 |
+
lock.unlock();
|
| 176 |
+
cv_read_.notify_all();
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
void stop() {
|
| 180 |
+
{
|
| 181 |
+
// Hold the lock before changing stop_ to prevent a race condition which
|
| 182 |
+
// can cause a deadlock. To be more specific, conditional variable
|
| 183 |
+
// cv_write_ waits on predicate stop_ in add_chunk_data(). The wait
|
| 184 |
+
// happens in two steps: 1) while still holding the lock, check if
|
| 185 |
+
// predicate is true; 2) if it is true, proceeds, otherwise, release the
|
| 186 |
+
// lock and wait until notified. Without holding a lock, cv_write_'s
|
| 187 |
+
// notification can happen in between step 1) and 2). In that case, as
|
| 188 |
+
// cv_write_ is not in waiting status yet, so the notification is lost and
|
| 189 |
+
// cv_write_ will sleep forever. By taking a lock before changing
|
| 190 |
+
// predicate stop_, it is ensured updating and evaluating stop_ always
|
| 191 |
+
// happen in a synchronized way
|
| 192 |
+
std::lock_guard<std::mutex> lock(queue_mutex_);
|
| 193 |
+
stop_ = true;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
// notify all writers, wake them from wait to exit current method.
|
| 197 |
+
cv_write_.notify_all();
|
| 198 |
+
// notify all readers too.
|
| 199 |
+
cv_read_.notify_all();
|
| 200 |
+
}
|
| 201 |
+
/// The batch size is needed to create batches from the chunk data. Similar to
|
| 202 |
+
/// regular dataloader where the batches are created with prefetches,
|
| 203 |
+
/// BatchDataBuffer perform the batch creation using the provided batch size.
|
| 204 |
+
size_t batch_size_ = 0;
|
| 205 |
+
|
| 206 |
+
/// count of total example stored in the queue
|
| 207 |
+
size_t total_example_count_in_queue_ = 0;
|
| 208 |
+
|
| 209 |
+
/// struct that contains a raw unwrapped batch unit. An unwrapped batch unit
|
| 210 |
+
/// is the raw data without 'optional' wrapper. It can be a collection of
|
| 211 |
+
/// images, utterances, e.t.c.
|
| 212 |
+
struct UnwrappedBatchData {
|
| 213 |
+
explicit UnwrappedBatchData(UnwrappedBatchType data)
|
| 214 |
+
: batch_data(std::move(data)) {}
|
| 215 |
+
|
| 216 |
+
// NOLINTNEXTLINE(modernize-pass-by-value)
|
| 217 |
+
explicit UnwrappedBatchData(std::exception_ptr e) : exception(e) {}
|
| 218 |
+
|
| 219 |
+
/// batch data to return
|
| 220 |
+
UnwrappedBatchType batch_data;
|
| 221 |
+
|
| 222 |
+
/// exception pointer which captures any abnormal exceptions while creating
|
| 223 |
+
/// the batch.
|
| 224 |
+
std::exception_ptr exception;
|
| 225 |
+
};
|
| 226 |
+
|
| 227 |
+
/// local cache to store example batches from loaded chunk
|
| 228 |
+
std::queue<UnwrappedBatchData> batch_queue_;
|
| 229 |
+
|
| 230 |
+
// sync batch_queue_ update.
|
| 231 |
+
std::mutex queue_mutex_;
|
| 232 |
+
|
| 233 |
+
std::condition_variable cv_read_;
|
| 234 |
+
std::condition_variable cv_write_;
|
| 235 |
+
|
| 236 |
+
ExampleSampler& example_sampler_;
|
| 237 |
+
|
| 238 |
+
// configurable maximun number of elements the queue can hold at one time.
|
| 239 |
+
size_t queue_capacity_;
|
| 240 |
+
|
| 241 |
+
// When set to true, it wakes the writer threads from the wait and exit
|
| 242 |
+
// current function call. This is needed when ChunkDataSet.Reset is called
|
| 243 |
+
// while the previous epoch is not exhausted yet. When ChunkDataset is waiting
|
| 244 |
+
// its preloader to finish previous work before tearing down the thread, the
|
| 245 |
+
// preloader could be still waiting for the conditional variable, thus cause
|
| 246 |
+
// the program to hang. This boolean is used to break this waiting condition.
|
| 247 |
+
bool stop_ = false;
|
| 248 |
+
};
|
| 249 |
+
} // namespace detail
|
| 250 |
+
|
| 251 |
+
/// Options to configure a `ChunkDataset`.
|
| 252 |
+
struct ChunkDatasetOptions {
|
| 253 |
+
ChunkDatasetOptions() = delete;
|
| 254 |
+
ChunkDatasetOptions(
|
| 255 |
+
size_t preloader_count,
|
| 256 |
+
size_t batch_size,
|
| 257 |
+
size_t cache_size = 2048,
|
| 258 |
+
size_t cross_chunk_shuffle_count = 1)
|
| 259 |
+
: preloader_count_(preloader_count),
|
| 260 |
+
batch_size_(batch_size),
|
| 261 |
+
cache_size_(cache_size),
|
| 262 |
+
cross_chunk_shuffle_count_(cross_chunk_shuffle_count) {
|
| 263 |
+
TORCH_CHECK(
|
| 264 |
+
preloader_count_ > 0,
|
| 265 |
+
"Preloader count is 0. At least one preloader needs to be specified.");
|
| 266 |
+
TORCH_CHECK(
|
| 267 |
+
batch_size_ > 0,
|
| 268 |
+
"Batch size is 0. A positive batch size needs to be specified.");
|
| 269 |
+
TORCH_CHECK(
|
| 270 |
+
cache_size_ > 0,
|
| 271 |
+
"Cache size is 0. A positive cache size needs to be specified.");
|
| 272 |
+
TORCH_CHECK(
|
| 273 |
+
cache_size_ >= batch_size_,
|
| 274 |
+
"Cache size is less than batch size. Cache needs to be large enough to "
|
| 275 |
+
"hold at least one batch.");
|
| 276 |
+
TORCH_CHECK(
|
| 277 |
+
cross_chunk_shuffle_count_ > 0,
|
| 278 |
+
"cross_chunk_shuffle_count needs to be greater than 0.");
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
/// The number of worker thread to preload chunk data.
|
| 282 |
+
TORCH_ARG(size_t, preloader_count);
|
| 283 |
+
|
| 284 |
+
/// The size of each batch.
|
| 285 |
+
TORCH_ARG(size_t, batch_size);
|
| 286 |
+
|
| 287 |
+
/// The capacity of the queue for batch caching.
|
| 288 |
+
TORCH_ARG(size_t, cache_size) = 2048;
|
| 289 |
+
|
| 290 |
+
// The number of chunks to perfrom cross-chunk shuffling. Default to 1 meaning
|
| 291 |
+
// no cross-chunk shuffling. When it is equal to n (n > 1), n random
|
| 292 |
+
// chunks will be loaded at once and example shuffling will be performed
|
| 293 |
+
// across all those n chunks.
|
| 294 |
+
// Note: Usually the default config (1 chunk shuffle + example shuffle) is
|
| 295 |
+
// good enough to generate random distributed data. Use this parameter only if
|
| 296 |
+
// you know cross-shuffle is needed in your case. Also there is a performance
|
| 297 |
+
// penalty when this value is greater than 1, as we need to do extra merge
|
| 298 |
+
// between multiple chunks before performing example sampling.
|
| 299 |
+
TORCH_ARG(size_t, cross_chunk_shuffle_count) = 1;
|
| 300 |
+
};
|
| 301 |
+
|
| 302 |
+
/// A stateful dataset that support hierarchical sampling and prefetching of
|
| 303 |
+
/// entre chunks.
|
| 304 |
+
///
|
| 305 |
+
/// Unlike regular dataset, chunk dataset require two samplers to operate and
|
| 306 |
+
/// keeps an internal state. `ChunkSampler` selects, which chunk to load next,
|
| 307 |
+
/// while the `ExampleSampler` determins the order of Examples that are returned
|
| 308 |
+
/// in each `get_batch` call. The hierarchical sampling approach used here is
|
| 309 |
+
/// inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf
|
| 310 |
+
template <
|
| 311 |
+
typename ChunkReader,
|
| 312 |
+
typename ChunkSampler = samplers::RandomSampler,
|
| 313 |
+
typename ExampleSampler = samplers::RandomSampler>
|
| 314 |
+
class ChunkDataset final
|
| 315 |
+
: public StatefulDataset<
|
| 316 |
+
ChunkDataset<ChunkReader, ChunkSampler, ExampleSampler>,
|
| 317 |
+
typename ChunkReader::BatchType,
|
| 318 |
+
size_t> {
|
| 319 |
+
public:
|
| 320 |
+
using BatchType = torch::optional<typename ChunkReader::BatchType>;
|
| 321 |
+
using UnwrappedBatchType = typename ChunkReader::BatchType;
|
| 322 |
+
using BatchRequestType = size_t;
|
| 323 |
+
using ChunkSamplerType = ChunkSampler;
|
| 324 |
+
using ExampleSamplerType = ExampleSampler;
|
| 325 |
+
|
| 326 |
+
ChunkDataset(
|
| 327 |
+
ChunkReader chunk_reader,
|
| 328 |
+
ChunkSampler chunk_sampler,
|
| 329 |
+
ExampleSampler example_sampler,
|
| 330 |
+
ChunkDatasetOptions options,
|
| 331 |
+
std::function<void(UnwrappedBatchType&)> preprocessing_policy =
|
| 332 |
+
std::function<void(UnwrappedBatchType&)>())
|
| 333 |
+
: chunk_reader_(std::move(chunk_reader)),
|
| 334 |
+
chunk_sampler_(std::move(chunk_sampler)),
|
| 335 |
+
example_sampler_(std::move(example_sampler)),
|
| 336 |
+
options_(std::move(options)),
|
| 337 |
+
preprocessing_policy_(std::move(preprocessing_policy)),
|
| 338 |
+
quit_worker_(false),
|
| 339 |
+
running_preloaders_(0),
|
| 340 |
+
load_checkpoint_(false) {}
|
| 341 |
+
|
| 342 |
+
~ChunkDataset() override {
|
| 343 |
+
// stop batch buffer first.
|
| 344 |
+
if (batch_buffer_) {
|
| 345 |
+
batch_buffer_->stop();
|
| 346 |
+
}
|
| 347 |
+
free_workers();
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
/// Default get_batch method of BatchDataset. This method returns
|
| 351 |
+
/// Example batches created from the preloaded chunks. The implemenation
|
| 352 |
+
/// is dataset agnostic and does not need overriding in different chunk
|
| 353 |
+
/// datasets.
|
| 354 |
+
BatchType get_batch(size_t batch_size) override {
|
| 355 |
+
TORCH_CHECK(
|
| 356 |
+
batch_buffer_ != nullptr,
|
| 357 |
+
"Dataset needs to call reset() before calling get_batch().");
|
| 358 |
+
|
| 359 |
+
TORCH_CHECK(
|
| 360 |
+
batch_size == options_.batch_size(),
|
| 361 |
+
"The requested batch size does not match with the initialized batch size.\n"
|
| 362 |
+
" The requested batch size is ",
|
| 363 |
+
batch_size,
|
| 364 |
+
", while the dataset is created with batch size equal to ",
|
| 365 |
+
options_.batch_size());
|
| 366 |
+
return batch_buffer_->get_batch();
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
/// Helper method around get_batch as `batch_size` is not strictly necessary
|
| 370 |
+
BatchType get_batch() {
|
| 371 |
+
return get_batch(options_.batch_size());
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
/// This will clear any internal state and starts the internal prefetching
|
| 375 |
+
/// mechanism for the chunk dataset.
|
| 376 |
+
void reset() override {
|
| 377 |
+
// We need this to support partial data reads via dataloader iterator.
|
| 378 |
+
if (batch_buffer_) {
|
| 379 |
+
batch_buffer_->stop();
|
| 380 |
+
}
|
| 381 |
+
// free workers from previous reset if there is any.
|
| 382 |
+
free_workers();
|
| 383 |
+
preload_threads_.clear();
|
| 384 |
+
|
| 385 |
+
if (!load_checkpoint_) {
|
| 386 |
+
chunk_reader_.reset();
|
| 387 |
+
chunk_sampler_.reset(chunk_reader_.chunk_count());
|
| 388 |
+
load_checkpoint_ = false;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
// Throw out any existing cached batch in the buffer and re-creates a new
|
| 392 |
+
// chunk buffer.
|
| 393 |
+
batch_buffer_ = std::make_unique<
|
| 394 |
+
detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>(
|
| 395 |
+
options_.batch_size(), example_sampler_, options_.cache_size());
|
| 396 |
+
|
| 397 |
+
// create new workers for this new epoch.
|
| 398 |
+
quit_worker_ = false;
|
| 399 |
+
|
| 400 |
+
AT_ASSERT(running_preloaders_ == 0);
|
| 401 |
+
running_preloaders_ = options_.preloader_count();
|
| 402 |
+
for (const auto i : c10::irange(options_.preloader_count())) {
|
| 403 |
+
preload_threads_.emplace_back([this, i]() { this->preloader(i); });
|
| 404 |
+
}
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
/// size is not used for chunk dataset.
|
| 408 |
+
std::optional<size_t> size() const override {
|
| 409 |
+
return torch::nullopt;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
// provide a references to chunk sampler. Used mainly in distributed data
|
| 413 |
+
// loading to set the epoch number for the sampler.
|
| 414 |
+
ChunkSamplerType& chunk_sampler() {
|
| 415 |
+
return chunk_sampler_;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
void save(serialize::OutputArchive& archive) const override {
|
| 419 |
+
std::lock_guard<std::mutex> lock(chunk_index_guard_);
|
| 420 |
+
chunk_sampler_.save(archive);
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
void load(serialize::InputArchive& archive) override {
|
| 424 |
+
std::lock_guard<std::mutex> lock(chunk_index_guard_);
|
| 425 |
+
chunk_sampler_.load(archive);
|
| 426 |
+
load_checkpoint_ = true;
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
private:
|
| 430 |
+
/// running on worker thread to preload chunk data.
|
| 431 |
+
void preloader(size_t id) {
|
| 432 |
+
while (!quit_worker_.load()) {
|
| 433 |
+
try {
|
| 434 |
+
std::vector<size_t> chunk_idx;
|
| 435 |
+
{
|
| 436 |
+
std::lock_guard<std::mutex> lock(chunk_index_guard_);
|
| 437 |
+
if (auto chunk_sampler_result = chunk_sampler_.next(
|
| 438 |
+
this->options_.cross_chunk_shuffle_count())) {
|
| 439 |
+
chunk_idx = chunk_sampler_result.value();
|
| 440 |
+
} else {
|
| 441 |
+
break;
|
| 442 |
+
}
|
| 443 |
+
}
|
| 444 |
+
UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_idx[0]);
|
| 445 |
+
for (const auto i : c10::irange(1, chunk_idx.size())) {
|
| 446 |
+
auto chunk_data = chunk_reader_.read_chunk(chunk_idx[i]);
|
| 447 |
+
std::move(
|
| 448 |
+
chunk_data.begin(), chunk_data.end(), std::back_inserter(data));
|
| 449 |
+
}
|
| 450 |
+
if (preprocessing_policy_) {
|
| 451 |
+
preprocessing_policy_(data);
|
| 452 |
+
}
|
| 453 |
+
if (!data.empty()) { // skip empty chunks.
|
| 454 |
+
batch_buffer_->add_chunk_data(std::move(data));
|
| 455 |
+
}
|
| 456 |
+
} catch (...) {
|
| 457 |
+
batch_buffer_->add_chunk_data(std::current_exception());
|
| 458 |
+
}
|
| 459 |
+
}
|
| 460 |
+
AT_ASSERT(running_preloaders_.load() > 0);
|
| 461 |
+
--running_preloaders_;
|
| 462 |
+
if (running_preloaders_.load() == 0) {
|
| 463 |
+
// all preloaders are completed, so we can notify the batch_buffer.
|
| 464 |
+
batch_buffer_->stop();
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
/// Block the current thread until the workers finish execution and exit.
|
| 469 |
+
void free_workers() {
|
| 470 |
+
if (!quit_worker_.load()) {
|
| 471 |
+
quit_worker_ = true;
|
| 472 |
+
for (auto& worker_thread : preload_threads_) {
|
| 473 |
+
worker_thread.join();
|
| 474 |
+
}
|
| 475 |
+
}
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
private:
|
| 479 |
+
// Templated class that defines what is a chunk and how to read chunk data.
|
| 480 |
+
// When a chunk is returned by chunk_reader_, ChunkDataset split it into
|
| 481 |
+
// batches and caches them in batch_buffer_.
|
| 482 |
+
ChunkReader chunk_reader_;
|
| 483 |
+
|
| 484 |
+
// chunk sampler to shuffle different chunks
|
| 485 |
+
ChunkSamplerType chunk_sampler_;
|
| 486 |
+
|
| 487 |
+
// example sampler to shuffle examples in a specific chunk
|
| 488 |
+
ExampleSamplerType example_sampler_;
|
| 489 |
+
|
| 490 |
+
// batch data buffer which holds chunk data from preloading thread.
|
| 491 |
+
std::shared_ptr<
|
| 492 |
+
detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>
|
| 493 |
+
batch_buffer_;
|
| 494 |
+
|
| 495 |
+
// worker thread pool
|
| 496 |
+
std::vector<std::thread> preload_threads_;
|
| 497 |
+
|
| 498 |
+
/// The options the Dataset was configured with.
|
| 499 |
+
const ChunkDatasetOptions options_;
|
| 500 |
+
|
| 501 |
+
// function pointer wrapper to apply custom processing over chunk data. This
|
| 502 |
+
// is considered an advanced parameter for developers who want to apply a
|
| 503 |
+
// pre-process to the chunk data before sampling into minibatch.
|
| 504 |
+
// Different than the collate function, this policy is applied on the chunk
|
| 505 |
+
// level, instead of minibatch level. When a chunk of data is loaded (multiple
|
| 506 |
+
// chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is
|
| 507 |
+
// applied to the full loaded data. It is useful if developers want to
|
| 508 |
+
// perform pre-processing (like bucketing) to the chunk data before
|
| 509 |
+
// example sampler samples the data. By default it's an empty pointer and no
|
| 510 |
+
// action will be taken.
|
| 511 |
+
std::function<void(UnwrappedBatchType&)> preprocessing_policy_;
|
| 512 |
+
|
| 513 |
+
// indicate whether the worker thread can be teared down
|
| 514 |
+
std::atomic<bool> quit_worker_;
|
| 515 |
+
|
| 516 |
+
// keep track of running preloaders to notify batch buffer. A value 0
|
| 517 |
+
// indicates that the chunk loading is completed.
|
| 518 |
+
std::atomic<size_t> running_preloaders_;
|
| 519 |
+
|
| 520 |
+
// mutex to synchronize chunk sampler next() call.
|
| 521 |
+
mutable std::mutex chunk_index_guard_;
|
| 522 |
+
|
| 523 |
+
// boolean value to indicate whether we need to load the checkpoint for
|
| 524 |
+
// chunk_sampler_.
|
| 525 |
+
bool load_checkpoint_;
|
| 526 |
+
};
|
| 527 |
+
} // namespace datasets
|
| 528 |
+
} // namespace data
|
| 529 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/map.h
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/datasets/base.h>
|
| 4 |
+
#include <torch/types.h>
|
| 5 |
+
|
| 6 |
+
#include <c10/util/ArrayRef.h>
|
| 7 |
+
|
| 8 |
+
#include <cstddef>
|
| 9 |
+
#include <type_traits>
|
| 10 |
+
#include <utility>
|
| 11 |
+
|
| 12 |
+
namespace torch {
|
| 13 |
+
namespace data {
|
| 14 |
+
namespace datasets {
|
| 15 |
+
namespace detail {
|
| 16 |
+
template <bool C, typename T>
|
| 17 |
+
using optional_if_t = typename std::conditional<C, torch::optional<T>, T>::type;
|
| 18 |
+
} // namespace detail
|
| 19 |
+
|
| 20 |
+
/// A `MapDataset` is a dataset that applies a transform to a source dataset.
|
| 21 |
+
template <typename SourceDataset, typename AppliedTransform>
|
| 22 |
+
class MapDataset : public BatchDataset<
|
| 23 |
+
MapDataset<SourceDataset, AppliedTransform>,
|
| 24 |
+
detail::optional_if_t<
|
| 25 |
+
SourceDataset::is_stateful,
|
| 26 |
+
typename AppliedTransform::OutputBatchType>,
|
| 27 |
+
typename SourceDataset::BatchRequestType> {
|
| 28 |
+
public:
|
| 29 |
+
using DatasetType = SourceDataset;
|
| 30 |
+
using TransformType = AppliedTransform;
|
| 31 |
+
using BatchRequestType = typename SourceDataset::BatchRequestType;
|
| 32 |
+
using OutputBatchType = detail::optional_if_t<
|
| 33 |
+
SourceDataset::is_stateful,
|
| 34 |
+
typename AppliedTransform::OutputBatchType>;
|
| 35 |
+
|
| 36 |
+
MapDataset(DatasetType dataset, TransformType transform)
|
| 37 |
+
: dataset_(std::move(dataset)), transform_(std::move(transform)) {}
|
| 38 |
+
|
| 39 |
+
/// Gets a batch from the source dataset and applies the transform to it,
|
| 40 |
+
/// returning the result.
|
| 41 |
+
OutputBatchType get_batch(BatchRequestType indices) override {
|
| 42 |
+
return get_batch_impl(std::move(indices));
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
/// Returns the size of the source dataset.
|
| 46 |
+
// NOLINTNEXTLINE(bugprone-exception-escape)
|
| 47 |
+
std::optional<size_t> size() const noexcept override {
|
| 48 |
+
return dataset_.size();
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
/// Calls `reset()` on the underlying dataset.
|
| 52 |
+
/// NOTE: Stateless datasets do not have a reset() method, so a call to this
|
| 53 |
+
/// method will only compile for stateful datasets (which have a reset()
|
| 54 |
+
/// method).
|
| 55 |
+
void reset() {
|
| 56 |
+
dataset_.reset();
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
/// Returns the underlying dataset.
|
| 60 |
+
const SourceDataset& dataset() noexcept {
|
| 61 |
+
return dataset_;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
/// Returns the transform being applied.
|
| 65 |
+
const AppliedTransform& transform() noexcept {
|
| 66 |
+
return transform_;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
private:
|
| 70 |
+
/// The implementation of `get_batch()` for the stateless case, which simply
|
| 71 |
+
/// applies the transform to the output of `get_batch()` from the dataset.
|
| 72 |
+
template <
|
| 73 |
+
typename D = SourceDataset,
|
| 74 |
+
typename = std::enable_if_t<!D::is_stateful>>
|
| 75 |
+
OutputBatchType get_batch_impl(BatchRequestType indices) {
|
| 76 |
+
return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
/// The implementation of `get_batch()` for the stateful case. Here, we follow
|
| 80 |
+
/// the semantics of `Optional.map()` in many functional languages, which
|
| 81 |
+
/// applies a transformation to the optional's content when the optional
|
| 82 |
+
/// contains a value, and returns a new optional (of a different type) if the
|
| 83 |
+
/// original optional returned by `get_batch()` was empty.
|
| 84 |
+
template <typename D = SourceDataset>
|
| 85 |
+
std::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
|
| 86 |
+
BatchRequestType indices) {
|
| 87 |
+
if (auto batch = dataset_.get_batch(std::move(indices))) {
|
| 88 |
+
return transform_.apply_batch(std::move(*batch));
|
| 89 |
+
}
|
| 90 |
+
return nullopt;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
/// The underlying dataset being transformed.
|
| 94 |
+
SourceDataset dataset_;
|
| 95 |
+
|
| 96 |
+
// The transformation that is applied to batches received from the dataset.
|
| 97 |
+
AppliedTransform transform_;
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
/// Creates a `MapDataset` with the given dataset and transform.
|
| 101 |
+
template <typename DatasetType, typename TransformType>
|
| 102 |
+
MapDataset<DatasetType, TransformType> map(
|
| 103 |
+
DatasetType dataset,
|
| 104 |
+
TransformType transform) {
|
| 105 |
+
static_assert(
|
| 106 |
+
std::is_same<
|
| 107 |
+
typename std::conditional<
|
| 108 |
+
DatasetType::is_stateful,
|
| 109 |
+
typename DatasetType::BatchType::value_type,
|
| 110 |
+
typename DatasetType::BatchType>::type,
|
| 111 |
+
typename TransformType::InputBatchType>::value,
|
| 112 |
+
"BatchType type of dataset does not match input type of transform");
|
| 113 |
+
return {std::move(dataset), std::move(transform)};
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
} // namespace datasets
|
| 117 |
+
} // namespace data
|
| 118 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/datasets/base.h>
|
| 4 |
+
#include <torch/data/example.h>
|
| 5 |
+
#include <torch/types.h>
|
| 6 |
+
|
| 7 |
+
#include <torch/csrc/Export.h>
|
| 8 |
+
|
| 9 |
+
#include <cstddef>
|
| 10 |
+
#include <string>
|
| 11 |
+
|
| 12 |
+
namespace torch {
|
| 13 |
+
namespace data {
|
| 14 |
+
namespace datasets {
|
| 15 |
+
/// The MNIST dataset.
|
| 16 |
+
class TORCH_API MNIST : public Dataset<MNIST> {
|
| 17 |
+
public:
|
| 18 |
+
/// The mode in which the dataset is loaded.
|
| 19 |
+
enum class Mode { kTrain, kTest };
|
| 20 |
+
|
| 21 |
+
/// Loads the MNIST dataset from the `root` path.
|
| 22 |
+
///
|
| 23 |
+
/// The supplied `root` path should contain the *content* of the unzipped
|
| 24 |
+
/// MNIST dataset, available from http://yann.lecun.com/exdb/mnist.
|
| 25 |
+
explicit MNIST(const std::string& root, Mode mode = Mode::kTrain);
|
| 26 |
+
|
| 27 |
+
/// Returns the `Example` at the given `index`.
|
| 28 |
+
Example<> get(size_t index) override;
|
| 29 |
+
|
| 30 |
+
/// Returns the size of the dataset.
|
| 31 |
+
std::optional<size_t> size() const override;
|
| 32 |
+
|
| 33 |
+
/// Returns true if this is the training subset of MNIST.
|
| 34 |
+
// NOLINTNEXTLINE(bugprone-exception-escape)
|
| 35 |
+
bool is_train() const noexcept;
|
| 36 |
+
|
| 37 |
+
/// Returns all images stacked into a single tensor.
|
| 38 |
+
const Tensor& images() const;
|
| 39 |
+
|
| 40 |
+
/// Returns all targets stacked into a single tensor.
|
| 41 |
+
const Tensor& targets() const;
|
| 42 |
+
|
| 43 |
+
private:
|
| 44 |
+
Tensor images_, targets_;
|
| 45 |
+
};
|
| 46 |
+
} // namespace datasets
|
| 47 |
+
} // namespace data
|
| 48 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/shared.h
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/datasets/base.h>
|
| 4 |
+
|
| 5 |
+
#include <memory>
|
| 6 |
+
#include <utility>
|
| 7 |
+
|
| 8 |
+
namespace torch {
|
| 9 |
+
namespace data {
|
| 10 |
+
namespace datasets {
|
| 11 |
+
|
| 12 |
+
/// A dataset that wraps another dataset in a shared pointer and implements the
|
| 13 |
+
/// `BatchDataset` API, delegating all calls to the shared instance. This is
|
| 14 |
+
/// useful when you want all worker threads in the dataloader to access the same
|
| 15 |
+
/// dataset instance. The dataset must take care of synchronization and
|
| 16 |
+
/// thread-safe access itself.
|
| 17 |
+
///
|
| 18 |
+
/// Use `torch::data::datasets::make_shared_dataset()` to create a new
|
| 19 |
+
/// `SharedBatchDataset` like you would a `std::shared_ptr`.
|
| 20 |
+
template <typename UnderlyingDataset>
|
| 21 |
+
class SharedBatchDataset : public BatchDataset<
|
| 22 |
+
SharedBatchDataset<UnderlyingDataset>,
|
| 23 |
+
typename UnderlyingDataset::BatchType,
|
| 24 |
+
typename UnderlyingDataset::BatchRequestType> {
|
| 25 |
+
public:
|
| 26 |
+
using BatchType = typename UnderlyingDataset::BatchType;
|
| 27 |
+
using BatchRequestType = typename UnderlyingDataset::BatchRequestType;
|
| 28 |
+
|
| 29 |
+
/// Constructs a new `SharedBatchDataset` from a `shared_ptr` to the
|
| 30 |
+
/// `UnderlyingDataset`.
|
| 31 |
+
/* implicit */ SharedBatchDataset(
|
| 32 |
+
std::shared_ptr<UnderlyingDataset> shared_dataset)
|
| 33 |
+
: dataset_(std::move(shared_dataset)) {}
|
| 34 |
+
|
| 35 |
+
/// Calls `get_batch` on the underlying dataset.
|
| 36 |
+
BatchType get_batch(BatchRequestType request) override {
|
| 37 |
+
return dataset_->get_batch(std::move(request));
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
/// Returns the `size` from the underlying dataset.
|
| 41 |
+
std::optional<size_t> size() const override {
|
| 42 |
+
return dataset_->size();
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
/// Accesses the underlying dataset.
|
| 46 |
+
UnderlyingDataset& operator*() {
|
| 47 |
+
return *dataset_;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
/// Accesses the underlying dataset.
|
| 51 |
+
const UnderlyingDataset& operator*() const {
|
| 52 |
+
return *dataset_;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
/// Accesses the underlying dataset.
|
| 56 |
+
UnderlyingDataset* operator->() {
|
| 57 |
+
return dataset_.get();
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
/// Accesses the underlying dataset.
|
| 61 |
+
const UnderlyingDataset* operator->() const {
|
| 62 |
+
return dataset_.get();
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
/// Calls `reset()` on the underlying dataset.
|
| 66 |
+
void reset() {
|
| 67 |
+
dataset_->reset();
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
private:
|
| 71 |
+
std::shared_ptr<UnderlyingDataset> dataset_;
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
/// Constructs a new `SharedBatchDataset` by creating a
|
| 75 |
+
/// `shared_ptr<UnderlyingDatase>`. All arguments are forwarded to
|
| 76 |
+
/// `make_shared<UnderlyingDataset>`.
|
| 77 |
+
template <typename UnderlyingDataset, typename... Args>
|
| 78 |
+
SharedBatchDataset<UnderlyingDataset> make_shared_dataset(Args&&... args) {
|
| 79 |
+
return std::make_shared<UnderlyingDataset>(std::forward<Args>(args)...);
|
| 80 |
+
}
|
| 81 |
+
} // namespace datasets
|
| 82 |
+
} // namespace data
|
| 83 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/datasets/base.h>
|
| 4 |
+
#include <torch/data/example.h>
|
| 5 |
+
|
| 6 |
+
#include <cstddef>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
namespace torch {
|
| 10 |
+
namespace serialize {
|
| 11 |
+
class OutputArchive;
|
| 12 |
+
class InputArchive;
|
| 13 |
+
} // namespace serialize
|
| 14 |
+
} // namespace torch
|
| 15 |
+
|
| 16 |
+
namespace torch {
|
| 17 |
+
namespace data {
|
| 18 |
+
namespace datasets {
|
| 19 |
+
|
| 20 |
+
/// A stateful dataset is a dataset that maintains some internal state, which
|
| 21 |
+
/// will be `reset()` at the beginning of each epoch. Subclasses can override
|
| 22 |
+
/// the `reset()` method to configure this behavior. Further, the return type of
|
| 23 |
+
/// a stateful dataset's `get_batch()` method is always an `optional`. When the
|
| 24 |
+
/// stateful dataset wants to indicate to the dataloader that its epoch has
|
| 25 |
+
/// ended, it should return an empty optional. The dataloader knows to modify
|
| 26 |
+
/// its implementation based on whether the dataset is stateless or stateful.
|
| 27 |
+
///
|
| 28 |
+
/// Note that when subclassing a from `StatefulDataset<Self, T>`, the return
|
| 29 |
+
/// type of `get_batch()`, which the subclass must override, will be
|
| 30 |
+
/// `optional<T>` (i.e. the type specified in the `StatefulDataset`
|
| 31 |
+
/// specialization is automatically boxed into an `optional` for the dataset's
|
| 32 |
+
/// `BatchType`).
|
| 33 |
+
template <
|
| 34 |
+
typename Self,
|
| 35 |
+
typename Batch = std::vector<Example<>>,
|
| 36 |
+
typename BatchRequest = size_t>
|
| 37 |
+
class StatefulDataset
|
| 38 |
+
: public BatchDataset<Self, std::optional<Batch>, BatchRequest> {
|
| 39 |
+
public:
|
| 40 |
+
/// Resets internal state of the dataset.
|
| 41 |
+
virtual void reset() = 0;
|
| 42 |
+
|
| 43 |
+
/// Saves the statefulDataset's state to OutputArchive.
|
| 44 |
+
virtual void save(serialize::OutputArchive& archive) const = 0;
|
| 45 |
+
|
| 46 |
+
/// Deserializes the statefulDataset's state from the `archive`.
|
| 47 |
+
virtual void load(serialize::InputArchive& archive) = 0;
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
/// Serializes a statefulDataset to `OutputArchive`.
|
| 51 |
+
template <typename... Args>
|
| 52 |
+
serialize::OutputArchive& operator<<(
|
| 53 |
+
serialize::OutputArchive& archive,
|
| 54 |
+
const StatefulDataset<Args...>& statefulDataset) {
|
| 55 |
+
statefulDataset.save(archive);
|
| 56 |
+
return archive;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
/// Deserializes a statefulDataset from an `InputArchive`.
|
| 60 |
+
template <typename... Args>
|
| 61 |
+
serialize::InputArchive& operator>>(
|
| 62 |
+
serialize::InputArchive& archive,
|
| 63 |
+
StatefulDataset<Args...>& statefulDataset) {
|
| 64 |
+
statefulDataset.load(archive);
|
| 65 |
+
return archive;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
} // namespace datasets
|
| 69 |
+
} // namespace data
|
| 70 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/datasets/base.h>
|
| 4 |
+
#include <torch/data/example.h>
|
| 5 |
+
#include <torch/types.h>
|
| 6 |
+
|
| 7 |
+
#include <cstddef>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
namespace torch {
|
| 11 |
+
namespace data {
|
| 12 |
+
namespace datasets {
|
| 13 |
+
|
| 14 |
+
/// A dataset of tensors.
|
| 15 |
+
/// Stores a single tensor internally, which is then indexed inside `get()`.
|
| 16 |
+
struct TensorDataset : public Dataset<TensorDataset, TensorExample> {
|
| 17 |
+
/// Creates a `TensorDataset` from a vector of tensors.
|
| 18 |
+
explicit TensorDataset(const std::vector<Tensor>& tensors)
|
| 19 |
+
: TensorDataset(torch::stack(tensors)) {}
|
| 20 |
+
|
| 21 |
+
explicit TensorDataset(torch::Tensor tensor) : tensor(std::move(tensor)) {}
|
| 22 |
+
|
| 23 |
+
/// Returns a single `TensorExample`.
|
| 24 |
+
TensorExample get(size_t index) override {
|
| 25 |
+
return tensor[index];
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
/// Returns the number of tensors in the dataset.
|
| 29 |
+
std::optional<size_t> size() const override {
|
| 30 |
+
return tensor.size(0);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
Tensor tensor;
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
} // namespace datasets
|
| 37 |
+
} // namespace data
|
| 38 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/detail/queue.h>
|
| 4 |
+
#include <torch/types.h>
|
| 5 |
+
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
#include <optional>
|
| 8 |
+
|
| 9 |
+
#include <chrono>
|
| 10 |
+
#include <utility>
|
| 11 |
+
|
| 12 |
+
namespace torch {
|
| 13 |
+
namespace data {
|
| 14 |
+
namespace detail {
|
| 15 |
+
|
| 16 |
+
/// Encapsulates the full life cycle of DataLoader jobs.
|
| 17 |
+
///
|
| 18 |
+
/// When a new job is enqueued to the `DataShuttle`, a counter for in-flight
|
| 19 |
+
/// jobs is bumped. This job is said to be "in-flight" until its result is
|
| 20 |
+
/// popped. Worker threads dequeue jobs as soon as they are available. When a
|
| 21 |
+
/// worker finishes a job, it enqueues the result. Only when the main thread
|
| 22 |
+
/// dequeues a result is the count of in-flight jobs decremented. When the main
|
| 23 |
+
/// thread attempts to dequeue a job but no jobs are in-flight, that means the
|
| 24 |
+
/// epoch is complete and `pop_result` returns an empty optional.
|
| 25 |
+
template <typename Job, typename Result>
|
| 26 |
+
class DataShuttle {
|
| 27 |
+
public:
|
| 28 |
+
/// Pushes a new job. Called by the main thread.
|
| 29 |
+
void push_job(Job job) {
|
| 30 |
+
new_jobs_.push(std::move(job));
|
| 31 |
+
++in_flight_jobs_;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
/// Pushes the result of a job. Called by worker threads.
|
| 35 |
+
void push_result(Result result) {
|
| 36 |
+
results_.push(std::move(result));
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
/// Returns the next job, blocking until there is one available. Called by
|
| 40 |
+
/// worker threads.
|
| 41 |
+
Job pop_job() {
|
| 42 |
+
return new_jobs_.pop();
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
/// Returns the result of a job, or nullopt if all jobs were exhausted. Called
|
| 46 |
+
/// by the main thread.
|
| 47 |
+
std::optional<Result> pop_result(
|
| 48 |
+
std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
|
| 49 |
+
if (in_flight_jobs_ > 0) {
|
| 50 |
+
auto result = results_.pop(timeout);
|
| 51 |
+
--in_flight_jobs_;
|
| 52 |
+
return result;
|
| 53 |
+
}
|
| 54 |
+
return nullopt;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
/// Discards any jobs that are not yet in flight, and waits for all in-flight
|
| 58 |
+
/// jobs to finish, discarding their result.
|
| 59 |
+
void drain() {
|
| 60 |
+
// Clear all inputs so that no further jobs are scheduled.
|
| 61 |
+
auto number_cleared = new_jobs_.clear();
|
| 62 |
+
in_flight_jobs_ -= number_cleared;
|
| 63 |
+
// Remove any outstanding results.
|
| 64 |
+
while (in_flight_jobs_ > 0) {
|
| 65 |
+
pop_result();
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
/// Returns the number of jobs that are still in progress.
|
| 70 |
+
/// When this number is zero, an epoch is finished.
|
| 71 |
+
size_t in_flight_jobs() const noexcept {
|
| 72 |
+
return in_flight_jobs_;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
private:
|
| 76 |
+
/// The queue for jobs that are not yet in flight.
|
| 77 |
+
Queue<Job> new_jobs_;
|
| 78 |
+
/// The number of in-flight jobs.
|
| 79 |
+
/// NOTE: Not atomic because only manipulated by the main thread.
|
| 80 |
+
size_t in_flight_jobs_ = 0;
|
| 81 |
+
/// The queue for results of finished jobs.
|
| 82 |
+
Queue<Result> results_;
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
} // namespace detail
|
| 86 |
+
} // namespace data
|
| 87 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/queue.h
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/types.h>
|
| 4 |
+
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
|
| 7 |
+
#include <chrono>
|
| 8 |
+
#include <condition_variable>
|
| 9 |
+
#include <cstddef>
|
| 10 |
+
#include <mutex>
|
| 11 |
+
#include <queue>
|
| 12 |
+
|
| 13 |
+
namespace torch {
|
| 14 |
+
namespace data {
|
| 15 |
+
namespace detail {
|
| 16 |
+
|
| 17 |
+
/// A basic locked, blocking MPMC queue.
|
| 18 |
+
///
|
| 19 |
+
/// Every `push` and `pop` is guarded by a mutex. A condition variable is used
|
| 20 |
+
/// to communicate insertion of new elements, such that waiting threads will be
|
| 21 |
+
/// woken up if they are currently waiting inside a call to `pop()`.
|
| 22 |
+
///
|
| 23 |
+
/// Note that this data structure is written specifically for use with the
|
| 24 |
+
/// `DataLoader`. Its behavior is tailored to this use case and may not be
|
| 25 |
+
/// applicable to more general uses.
|
| 26 |
+
template <typename T>
|
| 27 |
+
class Queue {
|
| 28 |
+
public:
|
| 29 |
+
/// Pushes a new value to the back of the `Queue` and notifies one thread on
|
| 30 |
+
/// the waiting side about this event.
|
| 31 |
+
void push(T value) {
|
| 32 |
+
{
|
| 33 |
+
std::lock_guard<std::mutex> lock(mutex_);
|
| 34 |
+
queue_.push(std::move(value));
|
| 35 |
+
}
|
| 36 |
+
cv_.notify_one();
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
/// Blocks until at least one element is ready to be popped from the front of
|
| 40 |
+
/// the queue. An optional `timeout` in seconds can be used to limit the time
|
| 41 |
+
/// spent waiting for an element. If the wait times out, an exception is
|
| 42 |
+
/// raised.
|
| 43 |
+
T pop(std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
|
| 44 |
+
std::unique_lock<std::mutex> lock(mutex_);
|
| 45 |
+
if (timeout) {
|
| 46 |
+
if (!cv_.wait_for(
|
| 47 |
+
lock, *timeout, [this] { return !this->queue_.empty(); })) {
|
| 48 |
+
// clang-format off
|
| 49 |
+
AT_ERROR(
|
| 50 |
+
"Timeout in DataLoader queue while waiting for next batch"
|
| 51 |
+
" (timeout was ", timeout->count(), " ms)");
|
| 52 |
+
// clang-format on
|
| 53 |
+
}
|
| 54 |
+
} else {
|
| 55 |
+
cv_.wait(lock, [this] { return !this->queue_.empty(); });
|
| 56 |
+
}
|
| 57 |
+
AT_ASSERT(!queue_.empty());
|
| 58 |
+
T value = queue_.front();
|
| 59 |
+
queue_.pop();
|
| 60 |
+
lock.unlock();
|
| 61 |
+
return value;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
/// Empties the queue and returns the number of elements that were present at
|
| 65 |
+
/// the start of the function. No threads are notified about this event as it
|
| 66 |
+
/// is assumed to be used to drain the queue during shutdown of a
|
| 67 |
+
/// `DataLoader`.
|
| 68 |
+
size_t clear() {
|
| 69 |
+
std::lock_guard<std::mutex> lock(this->mutex_);
|
| 70 |
+
const auto size = queue_.size();
|
| 71 |
+
while (!queue_.empty()) {
|
| 72 |
+
queue_.pop();
|
| 73 |
+
}
|
| 74 |
+
return size;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
private:
|
| 78 |
+
std::queue<T> queue_;
|
| 79 |
+
std::mutex mutex_;
|
| 80 |
+
std::condition_variable cv_;
|
| 81 |
+
};
|
| 82 |
+
} // namespace detail
|
| 83 |
+
} // namespace data
|
| 84 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/types.h>
|
| 4 |
+
|
| 5 |
+
#include <algorithm>
|
| 6 |
+
#include <cstddef>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
namespace torch {
|
| 10 |
+
namespace data {
|
| 11 |
+
namespace detail {
|
| 12 |
+
namespace sequencers {
|
| 13 |
+
namespace detail {
|
| 14 |
+
template <typename Result>
|
| 15 |
+
bool buffer_contains_result(const std::vector<std::optional<Result>>& buffer) {
|
| 16 |
+
return std::any_of(
|
| 17 |
+
buffer.begin(), buffer.end(), [](const std::optional<Result>& result) {
|
| 18 |
+
return result.has_value();
|
| 19 |
+
});
|
| 20 |
+
}
|
| 21 |
+
} // namespace detail
|
| 22 |
+
|
| 23 |
+
/// A `Sequencer` accepts a function that yields the next result of a
|
| 24 |
+
/// `DataLoader` and then has the opportunity to influence the order in which
|
| 25 |
+
/// these results are returned. The `NoSequencer` does not enforce any
|
| 26 |
+
/// sequencing and returns any result directly. The `OrderedSequencer` instead
|
| 27 |
+
/// buffers results internally to return them in order of their sequence number.
|
| 28 |
+
template <typename Result>
|
| 29 |
+
struct Sequencer {
|
| 30 |
+
using ResultProducer = std::function<std::optional<Result>()>;
|
| 31 |
+
virtual ~Sequencer() = default;
|
| 32 |
+
virtual std::optional<Result> next(ResultProducer next_result) = 0;
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
/// A `Sequencer` that does not enforce any ordering. It is effectively the
|
| 36 |
+
/// identity function.
|
| 37 |
+
template <typename Result>
|
| 38 |
+
struct NoSequencer final : public Sequencer<Result> {
|
| 39 |
+
using typename Sequencer<Result>::ResultProducer;
|
| 40 |
+
std::optional<Result> next(ResultProducer next_result) override {
|
| 41 |
+
return next_result();
|
| 42 |
+
}
|
| 43 |
+
};
|
| 44 |
+
|
| 45 |
+
/// A `Sequencer` that buffers results and returns them in order of their
|
| 46 |
+
/// sequence number. The `OrderedSequencer` maintains an internal, monotonically
|
| 47 |
+
/// incrementing counter for the next sequence number it expects. If it receives
|
| 48 |
+
/// a result with a higher sequence number, it will buffer it for later (when
|
| 49 |
+
/// the sequence number reaches that of this result). Otherwise, if the sequence
|
| 50 |
+
/// numbers match, the result is returned.
|
| 51 |
+
///
|
| 52 |
+
/// Implementation note: The `OrderedSequencer` is implemented with a fixed-size
|
| 53 |
+
/// buffer. Let `m` be the maximum number of jobs in the data loader's queue and
|
| 54 |
+
/// `s` be the current sequence number. Assume `m` jobs are scheduled in the
|
| 55 |
+
/// `DataLoader`. Any new result is stored at index `job.sqn mod m` in the
|
| 56 |
+
/// `OrderedSequencer`. Why are we sure sequence numbers of new jobs will not
|
| 57 |
+
/// collide with sequence numbers of buffered jobs? The `OrderedSequencer` will
|
| 58 |
+
/// not return from `next()` until it receives the result with sqn `s`. This
|
| 59 |
+
/// means no new jobs can be scheduled in the `DataLoader` in the meantime,
|
| 60 |
+
/// which enforces that as long as sqn `s` has not been received, `s + m` (which
|
| 61 |
+
/// would cause a collision in the fixed-size buffer) will not yet be scheduled.
|
| 62 |
+
template <typename Result>
|
| 63 |
+
struct OrderedSequencer : public Sequencer<Result> {
|
| 64 |
+
using typename Sequencer<Result>::ResultProducer;
|
| 65 |
+
|
| 66 |
+
/// Constructs the `OrderedSequencer` with the maximum number of results it
|
| 67 |
+
/// will ever hold at one point in time.
|
| 68 |
+
explicit OrderedSequencer(size_t max_jobs) : buffer_(max_jobs) {}
|
| 69 |
+
|
| 70 |
+
/// Buffers results until the next one in the expected order is received.
|
| 71 |
+
std::optional<Result> next(ResultProducer next_result) override {
|
| 72 |
+
// If we already have the result for the next sqn, return it.
|
| 73 |
+
if (auto& maybe_result = buffer(next_sequence_number_)) {
|
| 74 |
+
auto result = std::move(*maybe_result);
|
| 75 |
+
buffer(next_sequence_number_++).reset();
|
| 76 |
+
return result;
|
| 77 |
+
}
|
| 78 |
+
// Otherwise wait for the next result.
|
| 79 |
+
while (true) {
|
| 80 |
+
auto result = next_result();
|
| 81 |
+
if (!result) {
|
| 82 |
+
AT_ASSERT(!detail::buffer_contains_result(buffer_));
|
| 83 |
+
break;
|
| 84 |
+
}
|
| 85 |
+
// If it was not nullopt and the sequence numbers match, return it
|
| 86 |
+
// directly and bump the sequence number.
|
| 87 |
+
if (result->sequence_number == next_sequence_number_) {
|
| 88 |
+
++next_sequence_number_;
|
| 89 |
+
return result;
|
| 90 |
+
}
|
| 91 |
+
// Stash the result for later.
|
| 92 |
+
AT_ASSERT(!buffer(result->sequence_number).has_value());
|
| 93 |
+
buffer(result->sequence_number) = std::move(result);
|
| 94 |
+
}
|
| 95 |
+
// The result was an empty optional, so we are done with this epoch.
|
| 96 |
+
return nullopt;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
/// Accesses the buffer at the `index` modulo the buffer size.
|
| 100 |
+
std::optional<Result>& buffer(size_t index) {
|
| 101 |
+
return buffer_.at(index % buffer_.size());
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
/// The monotonically increasing sequence number we expect.
|
| 105 |
+
size_t next_sequence_number_ = 0;
|
| 106 |
+
|
| 107 |
+
/// A fixed-size buffer (after construction).
|
| 108 |
+
std::vector<std::optional<Result>> buffer_;
|
| 109 |
+
};
|
| 110 |
+
} // namespace sequencers
|
| 111 |
+
} // namespace detail
|
| 112 |
+
} // namespace data
|
| 113 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/example.h
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/types.h>
|
| 4 |
+
|
| 5 |
+
namespace torch {
|
| 6 |
+
namespace data {
|
| 7 |
+
|
| 8 |
+
/// An `Example` from a dataset.
|
| 9 |
+
///
|
| 10 |
+
/// A dataset consists of data and an associated target (label).
|
| 11 |
+
template <typename Data = at::Tensor, typename Target = at::Tensor>
|
| 12 |
+
struct Example {
|
| 13 |
+
using DataType = Data;
|
| 14 |
+
using TargetType = Target;
|
| 15 |
+
|
| 16 |
+
Example() = default;
|
| 17 |
+
Example(Data data, Target target)
|
| 18 |
+
: data(std::move(data)), target(std::move(target)) {}
|
| 19 |
+
|
| 20 |
+
Data data;
|
| 21 |
+
Target target;
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
namespace example {
|
| 25 |
+
using NoTarget = void;
|
| 26 |
+
} // namespace example
|
| 27 |
+
|
| 28 |
+
/// A specialization for `Example` that does not have a target.
|
| 29 |
+
///
|
| 30 |
+
/// This class exists so that code can be written for a templated `Example`
|
| 31 |
+
/// type, and work both for labeled and unlabeled datasets.
|
| 32 |
+
template <typename Data>
|
| 33 |
+
struct Example<Data, example::NoTarget> {
|
| 34 |
+
using DataType = Data;
|
| 35 |
+
using TargetType = example::NoTarget;
|
| 36 |
+
|
| 37 |
+
Example() = default;
|
| 38 |
+
/* implicit */ Example(Data data) : data(std::move(data)) {}
|
| 39 |
+
|
| 40 |
+
// When a DataLoader returns an Example like this, that example should be
|
| 41 |
+
// implicitly convertible to the underlying data type.
|
| 42 |
+
|
| 43 |
+
operator Data&() {
|
| 44 |
+
return data;
|
| 45 |
+
}
|
| 46 |
+
operator const Data&() const {
|
| 47 |
+
return data;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
Data data;
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
using TensorExample = Example<at::Tensor, example::NoTarget>;
|
| 54 |
+
} // namespace data
|
| 55 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/iterator.h
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/csrc/utils/variadic.h>
|
| 4 |
+
#include <torch/types.h>
|
| 5 |
+
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
|
| 8 |
+
#include <functional>
|
| 9 |
+
#include <iterator>
|
| 10 |
+
#include <memory>
|
| 11 |
+
#include <type_traits>
|
| 12 |
+
#include <utility>
|
| 13 |
+
|
| 14 |
+
namespace torch {
|
| 15 |
+
namespace data {
|
| 16 |
+
namespace detail {
|
| 17 |
+
// For increased safety and more separated logic, this implementation of
|
| 18 |
+
// `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A
|
| 19 |
+
// `ValidIterator` yields new batches until the `DataLoader` is exhausted. While
|
| 20 |
+
// the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are
|
| 21 |
+
// the same object. When the `ValidIterator` becomes exhausted, it compares
|
| 22 |
+
// equal to the `SentinelIterator`, but not before. Half the code here is to
|
| 23 |
+
// implement double dispatch for the comparison. Got damnit, C++.
|
| 24 |
+
|
| 25 |
+
template <typename Batch>
|
| 26 |
+
struct ValidIterator;
|
| 27 |
+
|
| 28 |
+
template <typename Batch>
|
| 29 |
+
struct SentinelIterator;
|
| 30 |
+
|
| 31 |
+
/// Base class for the `ValidIterator` and `SentinelIterator`
|
| 32 |
+
template <typename Batch>
|
| 33 |
+
struct IteratorImpl {
|
| 34 |
+
virtual ~IteratorImpl() = default;
|
| 35 |
+
virtual void next() = 0;
|
| 36 |
+
virtual Batch& get() = 0;
|
| 37 |
+
virtual bool operator==(const IteratorImpl& other) const = 0;
|
| 38 |
+
virtual bool operator==(const ValidIterator<Batch>& other) const = 0;
|
| 39 |
+
virtual bool operator==(const SentinelIterator<Batch>& other) const = 0;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
template <typename Batch>
|
| 43 |
+
struct ValidIterator : public IteratorImpl<Batch> {
|
| 44 |
+
using BatchProducer = std::function<std::optional<Batch>()>;
|
| 45 |
+
|
| 46 |
+
explicit ValidIterator(BatchProducer next_batch)
|
| 47 |
+
: next_batch_(std::move(next_batch)) {}
|
| 48 |
+
|
| 49 |
+
/// Fetches the next batch.
|
| 50 |
+
void next() override {
|
| 51 |
+
// If we didn't get the very first batch yet, get it now.
|
| 52 |
+
lazy_initialize();
|
| 53 |
+
TORCH_CHECK(
|
| 54 |
+
batch_.has_value(), "Attempted to increment iterator past the end");
|
| 55 |
+
// Increment to the next batch.
|
| 56 |
+
batch_ = next_batch_();
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
/// Returns the current batch. The precondition for this operation to not
|
| 60 |
+
/// throw an exception is that it has been compared to the `SentinelIterator`
|
| 61 |
+
/// and did not compare equal.
|
| 62 |
+
Batch& get() override {
|
| 63 |
+
// If we didn't get the very first batch yet, get it now.
|
| 64 |
+
lazy_initialize();
|
| 65 |
+
TORCH_CHECK(
|
| 66 |
+
batch_.has_value(),
|
| 67 |
+
"Attempted to dereference iterator that was past the end");
|
| 68 |
+
return batch_.value();
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
/// Does double dispatch.
|
| 72 |
+
bool operator==(const IteratorImpl<Batch>& other) const override {
|
| 73 |
+
return other == *this;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
/// A `ValidIterator` is equal to the `SentinelIterator` iff. the
|
| 77 |
+
/// `ValidIterator` has reached the end of the dataloader.
|
| 78 |
+
bool operator==(const SentinelIterator<Batch>& /* unused */) const override {
|
| 79 |
+
lazy_initialize();
|
| 80 |
+
return !batch_;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/// Returns true if the memory address of `other` equals that of `this`.
|
| 84 |
+
bool operator==(const ValidIterator<Batch>& other) const override {
|
| 85 |
+
return &other == this;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
/// Gets the very first batch if it has not yet been fetched.
|
| 89 |
+
void lazy_initialize() const {
|
| 90 |
+
if (!initialized_) {
|
| 91 |
+
batch_ = next_batch_();
|
| 92 |
+
initialized_ = true;
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
BatchProducer next_batch_;
|
| 97 |
+
mutable std::optional<Batch> batch_;
|
| 98 |
+
mutable bool initialized_ = false;
|
| 99 |
+
};
|
| 100 |
+
|
| 101 |
+
template <typename Batch>
|
| 102 |
+
struct SentinelIterator : public IteratorImpl<Batch> {
|
| 103 |
+
void next() override {
|
| 104 |
+
AT_ERROR(
|
| 105 |
+
"Incrementing the DataLoader's past-the-end iterator is not allowed");
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
Batch& get() override {
|
| 109 |
+
AT_ERROR(
|
| 110 |
+
"Dereferencing the DataLoader's past-the-end iterator is not allowed");
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
/// Does double dispatch.
|
| 114 |
+
bool operator==(const IteratorImpl<Batch>& other) const override {
|
| 115 |
+
return other == *this;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
/// Calls the comparison operator between `ValidIterator` and
|
| 119 |
+
/// `SentinelIterator`.
|
| 120 |
+
bool operator==(const ValidIterator<Batch>& other) const override {
|
| 121 |
+
return other == *this;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
/// Sentinel iterators always compare equal.
|
| 125 |
+
bool operator==(const SentinelIterator<Batch>& other) const override {
|
| 126 |
+
return true;
|
| 127 |
+
}
|
| 128 |
+
};
|
| 129 |
+
} // namespace detail
|
| 130 |
+
|
| 131 |
+
template <typename Batch>
|
| 132 |
+
class Iterator {
|
| 133 |
+
public:
|
| 134 |
+
// Type aliases to make the class recognized as a proper iterator.
|
| 135 |
+
using difference_type = std::ptrdiff_t;
|
| 136 |
+
using value_type = Batch;
|
| 137 |
+
using pointer = Batch*;
|
| 138 |
+
using reference = Batch&;
|
| 139 |
+
using iterator_category = std::input_iterator_tag;
|
| 140 |
+
|
| 141 |
+
explicit Iterator(std::unique_ptr<detail::IteratorImpl<Batch>> impl)
|
| 142 |
+
: impl_(std::move(impl)) {}
|
| 143 |
+
|
| 144 |
+
/// Increments the iterator.
|
| 145 |
+
/// Only permitted for valid iterators (not past the end).
|
| 146 |
+
Iterator& operator++() {
|
| 147 |
+
impl_->next();
|
| 148 |
+
return *this;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
/// Returns the current batch.
|
| 152 |
+
/// Only permitted for valid iterators (not past the end).
|
| 153 |
+
Batch& operator*() {
|
| 154 |
+
return impl_->get();
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
/// Returns a pointer to the current batch.
|
| 158 |
+
/// Only permitted for valid iterators (not past the end).
|
| 159 |
+
Batch* operator->() {
|
| 160 |
+
return &impl_->get();
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
/// Compares two iterators for equality.
|
| 164 |
+
bool operator==(const Iterator& other) const {
|
| 165 |
+
return *impl_ == *other.impl_;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
/// Compares two iterators for inequality.
|
| 169 |
+
bool operator!=(const Iterator& other) const {
|
| 170 |
+
return !(*this == other);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
private:
|
| 174 |
+
/// Points either to a `ValidIterator` or to a `SentinelIterator`.
|
| 175 |
+
std::shared_ptr<detail::IteratorImpl<Batch>> impl_;
|
| 176 |
+
};
|
| 177 |
+
} // namespace data
|
| 178 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers.h
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/samplers/base.h>
|
| 4 |
+
#include <torch/data/samplers/custom_batch_request.h>
|
| 5 |
+
#include <torch/data/samplers/distributed.h>
|
| 6 |
+
#include <torch/data/samplers/random.h>
|
| 7 |
+
#include <torch/data/samplers/sequential.h>
|
| 8 |
+
#include <torch/data/samplers/serialize.h>
|
| 9 |
+
#include <torch/data/samplers/stream.h>
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/base.h
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/csrc/Export.h>
|
| 4 |
+
#include <torch/types.h>
|
| 5 |
+
|
| 6 |
+
#include <cstddef>
|
| 7 |
+
#include <mutex>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
namespace torch {
|
| 11 |
+
namespace serialize {
|
| 12 |
+
class OutputArchive;
|
| 13 |
+
class InputArchive;
|
| 14 |
+
} // namespace serialize
|
| 15 |
+
} // namespace torch
|
| 16 |
+
|
| 17 |
+
namespace torch {
|
| 18 |
+
namespace data {
|
| 19 |
+
namespace samplers {
|
| 20 |
+
/// A `Sampler` is an object that yields an index with which to access a
|
| 21 |
+
/// dataset.
|
| 22 |
+
template <typename BatchRequest = std::vector<size_t>>
|
| 23 |
+
class Sampler {
|
| 24 |
+
public:
|
| 25 |
+
using BatchRequestType = BatchRequest;
|
| 26 |
+
|
| 27 |
+
virtual ~Sampler() = default;
|
| 28 |
+
|
| 29 |
+
/// Resets the `Sampler`'s internal state.
|
| 30 |
+
/// Typically called before a new epoch.
|
| 31 |
+
/// Optionally, accepts a new size when reseting the sampler.
|
| 32 |
+
virtual void reset(std::optional<size_t> new_size) = 0;
|
| 33 |
+
|
| 34 |
+
/// Returns the next index if possible, or an empty optional if the
|
| 35 |
+
/// sampler is exhausted for this epoch.
|
| 36 |
+
virtual std::optional<BatchRequest> next(size_t batch_size) = 0;
|
| 37 |
+
|
| 38 |
+
/// Serializes the `Sampler` to the `archive`.
|
| 39 |
+
virtual void save(serialize::OutputArchive& archive) const = 0;
|
| 40 |
+
|
| 41 |
+
/// Deserializes the `Sampler` from the `archive`.
|
| 42 |
+
virtual void load(serialize::InputArchive& archive) = 0;
|
| 43 |
+
};
|
| 44 |
+
|
| 45 |
+
} // namespace samplers
|
| 46 |
+
} // namespace data
|
| 47 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/csrc/Export.h>
|
| 4 |
+
#include <cstddef>
|
| 5 |
+
|
| 6 |
+
namespace torch {
|
| 7 |
+
namespace data {
|
| 8 |
+
namespace samplers {
|
| 9 |
+
/// A base class for custom index types.
|
| 10 |
+
struct TORCH_API CustomBatchRequest {
|
| 11 |
+
CustomBatchRequest() = default;
|
| 12 |
+
CustomBatchRequest(const CustomBatchRequest&) = default;
|
| 13 |
+
CustomBatchRequest(CustomBatchRequest&&) noexcept = default;
|
| 14 |
+
virtual ~CustomBatchRequest() = default;
|
| 15 |
+
|
| 16 |
+
/// The number of elements accessed by this index.
|
| 17 |
+
virtual size_t size() const = 0;
|
| 18 |
+
};
|
| 19 |
+
} // namespace samplers
|
| 20 |
+
} // namespace data
|
| 21 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/csrc/Export.h>
|
| 4 |
+
#include <torch/data/samplers/base.h>
|
| 5 |
+
|
| 6 |
+
#include <cstddef>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
namespace torch {
|
| 10 |
+
namespace serialize {
|
| 11 |
+
class OutputArchive;
|
| 12 |
+
class InputArchive;
|
| 13 |
+
} // namespace serialize
|
| 14 |
+
} // namespace torch
|
| 15 |
+
|
| 16 |
+
namespace torch {
|
| 17 |
+
namespace data {
|
| 18 |
+
namespace samplers {
|
| 19 |
+
|
| 20 |
+
/// A `Sampler` that selects a subset of indices to sample from and defines a
|
| 21 |
+
/// sampling behavior. In a distributed setting, this selects a subset of the
|
| 22 |
+
/// indices depending on the provided num_replicas and rank parameters. The
|
| 23 |
+
/// `Sampler` performs a rounding operation based on the `allow_duplicates`
|
| 24 |
+
/// parameter to decide the local sample count.
|
| 25 |
+
template <typename BatchRequest = std::vector<size_t>>
|
| 26 |
+
class DistributedSampler : public Sampler<BatchRequest> {
|
| 27 |
+
public:
|
| 28 |
+
DistributedSampler(
|
| 29 |
+
size_t size,
|
| 30 |
+
size_t num_replicas = 1,
|
| 31 |
+
size_t rank = 0,
|
| 32 |
+
bool allow_duplicates = true)
|
| 33 |
+
: size_(size),
|
| 34 |
+
num_replicas_(num_replicas),
|
| 35 |
+
rank_(rank),
|
| 36 |
+
epoch_(0),
|
| 37 |
+
allow_duplicates_(allow_duplicates) {}
|
| 38 |
+
|
| 39 |
+
/// Set the epoch for the current enumeration. This can be used to alter the
|
| 40 |
+
/// sample selection and shuffling behavior.
|
| 41 |
+
void set_epoch(size_t epoch) {
|
| 42 |
+
epoch_ = epoch;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
size_t epoch() const {
|
| 46 |
+
return epoch_;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
protected:
|
| 50 |
+
size_t local_sample_count() {
|
| 51 |
+
if (allow_duplicates_) {
|
| 52 |
+
return (size_ + num_replicas_ - 1) / num_replicas_;
|
| 53 |
+
} else {
|
| 54 |
+
return size_ / num_replicas_;
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 59 |
+
size_t size_;
|
| 60 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 61 |
+
size_t num_replicas_;
|
| 62 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 63 |
+
size_t rank_;
|
| 64 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 65 |
+
size_t epoch_;
|
| 66 |
+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
| 67 |
+
bool allow_duplicates_;
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
/// Select samples randomly. The sampling order is shuffled at each `reset()`
|
| 71 |
+
/// call.
|
| 72 |
+
class TORCH_API DistributedRandomSampler : public DistributedSampler<> {
|
| 73 |
+
public:
|
| 74 |
+
DistributedRandomSampler(
|
| 75 |
+
size_t size,
|
| 76 |
+
size_t num_replicas = 1,
|
| 77 |
+
size_t rank = 0,
|
| 78 |
+
bool allow_duplicates = true);
|
| 79 |
+
|
| 80 |
+
/// Resets the `DistributedRandomSampler` to a new set of indices.
|
| 81 |
+
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
| 82 |
+
|
| 83 |
+
/// Returns the next batch of indices.
|
| 84 |
+
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
| 85 |
+
|
| 86 |
+
/// Serializes the `DistributedRandomSampler` to the `archive`.
|
| 87 |
+
void save(serialize::OutputArchive& archive) const override;
|
| 88 |
+
|
| 89 |
+
/// Deserializes the `DistributedRandomSampler` from the `archive`.
|
| 90 |
+
void load(serialize::InputArchive& archive) override;
|
| 91 |
+
|
| 92 |
+
/// Returns the current index of the `DistributedRandomSampler`.
|
| 93 |
+
size_t index() const noexcept;
|
| 94 |
+
|
| 95 |
+
private:
|
| 96 |
+
void populate_indices();
|
| 97 |
+
|
| 98 |
+
size_t begin_index_;
|
| 99 |
+
size_t end_index_;
|
| 100 |
+
size_t sample_index_;
|
| 101 |
+
std::vector<size_t> all_indices_;
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
+
/// Select samples sequentially.
|
| 105 |
+
class TORCH_API DistributedSequentialSampler : public DistributedSampler<> {
|
| 106 |
+
public:
|
| 107 |
+
DistributedSequentialSampler(
|
| 108 |
+
size_t size,
|
| 109 |
+
size_t num_replicas = 1,
|
| 110 |
+
size_t rank = 0,
|
| 111 |
+
bool allow_duplicates = true);
|
| 112 |
+
|
| 113 |
+
/// Resets the `DistributedSequentialSampler` to a new set of indices.
|
| 114 |
+
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
| 115 |
+
|
| 116 |
+
/// Returns the next batch of indices.
|
| 117 |
+
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
| 118 |
+
|
| 119 |
+
/// Serializes the `DistributedSequentialSampler` to the `archive`.
|
| 120 |
+
void save(serialize::OutputArchive& archive) const override;
|
| 121 |
+
|
| 122 |
+
/// Deserializes the `DistributedSequentialSampler` from the `archive`.
|
| 123 |
+
void load(serialize::InputArchive& archive) override;
|
| 124 |
+
|
| 125 |
+
/// Returns the current index of the `DistributedSequentialSampler`.
|
| 126 |
+
size_t index() const noexcept;
|
| 127 |
+
|
| 128 |
+
private:
|
| 129 |
+
void populate_indices();
|
| 130 |
+
|
| 131 |
+
size_t begin_index_;
|
| 132 |
+
size_t end_index_;
|
| 133 |
+
size_t sample_index_;
|
| 134 |
+
std::vector<size_t> all_indices_;
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
} // namespace samplers
|
| 138 |
+
} // namespace data
|
| 139 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/random.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/csrc/Export.h>
|
| 4 |
+
#include <torch/data/samplers/base.h>
|
| 5 |
+
#include <torch/types.h>
|
| 6 |
+
|
| 7 |
+
#include <cstddef>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
namespace torch {
|
| 11 |
+
namespace serialize {
|
| 12 |
+
class OutputArchive;
|
| 13 |
+
class InputArchive;
|
| 14 |
+
} // namespace serialize
|
| 15 |
+
} // namespace torch
|
| 16 |
+
|
| 17 |
+
namespace torch {
|
| 18 |
+
namespace data {
|
| 19 |
+
namespace samplers {
|
| 20 |
+
|
| 21 |
+
/// A `Sampler` that returns random indices.
|
| 22 |
+
class TORCH_API RandomSampler : public Sampler<> {
|
| 23 |
+
public:
|
| 24 |
+
/// Constructs a `RandomSampler` with a size and dtype for the stored indices.
|
| 25 |
+
///
|
| 26 |
+
/// The constructor will eagerly allocate all required indices, which is the
|
| 27 |
+
/// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored
|
| 28 |
+
/// indices. You can change it to influence memory usage.
|
| 29 |
+
explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64);
|
| 30 |
+
|
| 31 |
+
~RandomSampler() override;
|
| 32 |
+
|
| 33 |
+
/// Resets the `RandomSampler` to a new set of indices.
|
| 34 |
+
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
| 35 |
+
|
| 36 |
+
/// Returns the next batch of indices.
|
| 37 |
+
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
| 38 |
+
|
| 39 |
+
/// Serializes the `RandomSampler` to the `archive`.
|
| 40 |
+
void save(serialize::OutputArchive& archive) const override;
|
| 41 |
+
|
| 42 |
+
/// Deserializes the `RandomSampler` from the `archive`.
|
| 43 |
+
void load(serialize::InputArchive& archive) override;
|
| 44 |
+
|
| 45 |
+
/// Returns the current index of the `RandomSampler`.
|
| 46 |
+
size_t index() const noexcept;
|
| 47 |
+
|
| 48 |
+
private:
|
| 49 |
+
at::Tensor indices_;
|
| 50 |
+
int64_t index_ = 0;
|
| 51 |
+
};
|
| 52 |
+
} // namespace samplers
|
| 53 |
+
} // namespace data
|
| 54 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/csrc/Export.h>
|
| 4 |
+
#include <torch/data/samplers/base.h>
|
| 5 |
+
#include <torch/types.h>
|
| 6 |
+
|
| 7 |
+
#include <cstddef>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
namespace torch {
|
| 11 |
+
namespace serialize {
|
| 12 |
+
class OutputArchive;
|
| 13 |
+
class InputArchive;
|
| 14 |
+
} // namespace serialize
|
| 15 |
+
} // namespace torch
|
| 16 |
+
|
| 17 |
+
namespace torch {
|
| 18 |
+
namespace data {
|
| 19 |
+
namespace samplers {
|
| 20 |
+
|
| 21 |
+
/// A `Sampler` that returns indices sequentially.
|
| 22 |
+
class TORCH_API SequentialSampler : public Sampler<> {
|
| 23 |
+
public:
|
| 24 |
+
/// Creates a `SequentialSampler` that will return indices in the range
|
| 25 |
+
/// `0...size - 1`.
|
| 26 |
+
explicit SequentialSampler(size_t size);
|
| 27 |
+
|
| 28 |
+
/// Resets the `SequentialSampler` to zero.
|
| 29 |
+
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
| 30 |
+
|
| 31 |
+
/// Returns the next batch of indices.
|
| 32 |
+
std::optional<std::vector<size_t>> next(size_t batch_size) override;
|
| 33 |
+
|
| 34 |
+
/// Serializes the `SequentialSampler` to the `archive`.
|
| 35 |
+
void save(serialize::OutputArchive& archive) const override;
|
| 36 |
+
|
| 37 |
+
/// Deserializes the `SequentialSampler` from the `archive`.
|
| 38 |
+
void load(serialize::InputArchive& archive) override;
|
| 39 |
+
|
| 40 |
+
/// Returns the current index of the `SequentialSampler`.
|
| 41 |
+
size_t index() const noexcept;
|
| 42 |
+
|
| 43 |
+
private:
|
| 44 |
+
size_t size_;
|
| 45 |
+
size_t index_{0};
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
} // namespace samplers
|
| 49 |
+
} // namespace data
|
| 50 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/samplers/base.h>
|
| 4 |
+
#include <torch/serialize/archive.h>
|
| 5 |
+
|
| 6 |
+
namespace torch {
|
| 7 |
+
namespace data {
|
| 8 |
+
namespace samplers {
|
| 9 |
+
/// Serializes a `Sampler` into an `OutputArchive`.
|
| 10 |
+
template <typename BatchRequest>
|
| 11 |
+
serialize::OutputArchive& operator<<(
|
| 12 |
+
serialize::OutputArchive& archive,
|
| 13 |
+
const Sampler<BatchRequest>& sampler) {
|
| 14 |
+
sampler.save(archive);
|
| 15 |
+
return archive;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
/// Deserializes a `Sampler` from an `InputArchive`.
|
| 19 |
+
template <typename BatchRequest>
|
| 20 |
+
serialize::InputArchive& operator>>(
|
| 21 |
+
serialize::InputArchive& archive,
|
| 22 |
+
Sampler<BatchRequest>& sampler) {
|
| 23 |
+
sampler.load(archive);
|
| 24 |
+
return archive;
|
| 25 |
+
}
|
| 26 |
+
} // namespace samplers
|
| 27 |
+
} // namespace data
|
| 28 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/stream.h
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/csrc/Export.h>
|
| 4 |
+
#include <torch/data/samplers/base.h>
|
| 5 |
+
#include <torch/data/samplers/custom_batch_request.h>
|
| 6 |
+
#include <torch/types.h>
|
| 7 |
+
|
| 8 |
+
#include <cstddef>
|
| 9 |
+
|
| 10 |
+
namespace torch {
|
| 11 |
+
namespace serialize {
|
| 12 |
+
class InputArchive;
|
| 13 |
+
class OutputArchive;
|
| 14 |
+
} // namespace serialize
|
| 15 |
+
} // namespace torch
|
| 16 |
+
|
| 17 |
+
namespace torch {
|
| 18 |
+
namespace data {
|
| 19 |
+
namespace samplers {
|
| 20 |
+
|
| 21 |
+
/// A wrapper around a batch size value, which implements the
|
| 22 |
+
/// `CustomBatchRequest` interface.
|
| 23 |
+
struct TORCH_API BatchSize : public CustomBatchRequest {
|
| 24 |
+
explicit BatchSize(size_t size);
|
| 25 |
+
size_t size() const noexcept override;
|
| 26 |
+
operator size_t() const noexcept;
|
| 27 |
+
size_t size_;
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
/// A sampler for (potentially infinite) streams of data.
|
| 31 |
+
///
|
| 32 |
+
/// The major feature of the `StreamSampler` is that it does not return
|
| 33 |
+
/// particular indices, but instead only the number of elements to fetch from
|
| 34 |
+
/// the dataset. The dataset has to decide how to produce those elements.
|
| 35 |
+
class TORCH_API StreamSampler : public Sampler<BatchSize> {
|
| 36 |
+
public:
|
| 37 |
+
/// Constructs the `StreamSampler` with the number of individual examples that
|
| 38 |
+
/// should be fetched until the sampler is exhausted.
|
| 39 |
+
explicit StreamSampler(size_t epoch_size);
|
| 40 |
+
|
| 41 |
+
/// Resets the internal state of the sampler.
|
| 42 |
+
void reset(std::optional<size_t> new_size = std::nullopt) override;
|
| 43 |
+
|
| 44 |
+
/// Returns a `BatchSize` object with the number of elements to fetch in the
|
| 45 |
+
/// next batch. This number is the minimum of the supplied `batch_size` and
|
| 46 |
+
/// the difference between the `epoch_size` and the current index. If the
|
| 47 |
+
/// `epoch_size` has been reached, returns an empty optional.
|
| 48 |
+
std::optional<BatchSize> next(size_t batch_size) override;
|
| 49 |
+
|
| 50 |
+
/// Serializes the `StreamSampler` to the `archive`.
|
| 51 |
+
void save(serialize::OutputArchive& archive) const override;
|
| 52 |
+
|
| 53 |
+
/// Deserializes the `StreamSampler` from the `archive`.
|
| 54 |
+
void load(serialize::InputArchive& archive) override;
|
| 55 |
+
|
| 56 |
+
private:
|
| 57 |
+
size_t examples_retrieved_so_far_ = 0;
|
| 58 |
+
size_t epoch_size_;
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
} // namespace samplers
|
| 62 |
+
} // namespace data
|
| 63 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms.h
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/transforms/base.h>
|
| 4 |
+
#include <torch/data/transforms/collate.h>
|
| 5 |
+
#include <torch/data/transforms/lambda.h>
|
| 6 |
+
#include <torch/data/transforms/stack.h>
|
| 7 |
+
#include <torch/data/transforms/tensor.h>
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/base.h
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/types.h>
|
| 4 |
+
|
| 5 |
+
#include <utility>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
namespace torch {
|
| 9 |
+
namespace data {
|
| 10 |
+
namespace transforms {
|
| 11 |
+
|
| 12 |
+
/// A transformation of a batch to a new batch.
|
| 13 |
+
template <typename InputBatch, typename OutputBatch>
|
| 14 |
+
class BatchTransform {
|
| 15 |
+
public:
|
| 16 |
+
using InputBatchType = InputBatch;
|
| 17 |
+
using OutputBatchType = OutputBatch;
|
| 18 |
+
|
| 19 |
+
virtual ~BatchTransform() = default;
|
| 20 |
+
|
| 21 |
+
/// Applies the transformation to the given `input_batch`.
|
| 22 |
+
virtual OutputBatch apply_batch(InputBatch input_batch) = 0;
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
/// A transformation of individual input examples to individual output examples.
|
| 26 |
+
///
|
| 27 |
+
/// Just like a `Dataset` is a `BatchDataset`, a `Transform` is a
|
| 28 |
+
/// `BatchTransform` that can operate on the level of individual examples rather
|
| 29 |
+
/// than entire batches. The batch-level transform is implemented (by default)
|
| 30 |
+
/// in terms of the example-level transform, though this can be customized.
|
| 31 |
+
template <typename Input, typename Output>
|
| 32 |
+
class Transform
|
| 33 |
+
: public BatchTransform<std::vector<Input>, std::vector<Output>> {
|
| 34 |
+
public:
|
| 35 |
+
using InputType = Input;
|
| 36 |
+
using OutputType = Output;
|
| 37 |
+
|
| 38 |
+
/// Applies the transformation to the given `input`.
|
| 39 |
+
virtual OutputType apply(InputType input) = 0;
|
| 40 |
+
|
| 41 |
+
/// Applies the `transformation` over the entire `input_batch`.
|
| 42 |
+
std::vector<Output> apply_batch(std::vector<Input> input_batch) override {
|
| 43 |
+
std::vector<Output> output_batch;
|
| 44 |
+
output_batch.reserve(input_batch.size());
|
| 45 |
+
for (auto&& input : input_batch) {
|
| 46 |
+
output_batch.push_back(apply(std::move(input)));
|
| 47 |
+
}
|
| 48 |
+
return output_batch;
|
| 49 |
+
}
|
| 50 |
+
};
|
| 51 |
+
} // namespace transforms
|
| 52 |
+
} // namespace data
|
| 53 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/collate.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/example.h>
|
| 4 |
+
#include <torch/data/transforms/lambda.h>
|
| 5 |
+
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
namespace torch {
|
| 9 |
+
namespace data {
|
| 10 |
+
namespace transforms {
|
| 11 |
+
|
| 12 |
+
/// A `Collation` is a transform that reduces a batch into a single value.
|
| 13 |
+
/// The result is a `BatchDataset` that has the type of the single value as its
|
| 14 |
+
/// `BatchType`.
|
| 15 |
+
template <typename T, typename BatchType = std::vector<T>>
|
| 16 |
+
using Collation = BatchTransform<BatchType, T>;
|
| 17 |
+
|
| 18 |
+
/// A `Collate` allows passing a custom function to reduce/collate a batch
|
| 19 |
+
/// into a single value. It's effectively the lambda version of `Collation`,
|
| 20 |
+
/// which you could subclass and override `operator()` to achieve the same.
|
| 21 |
+
///
|
| 22 |
+
/// \rst
|
| 23 |
+
/// .. code-block:: cpp
|
| 24 |
+
/// using namespace torch::data;
|
| 25 |
+
///
|
| 26 |
+
/// auto dataset = datasets::MNIST("path/to/mnist")
|
| 27 |
+
/// .map(transforms::Collate<Example<>>([](std::vector<Example<>> e) {
|
| 28 |
+
/// return std::move(e.front());
|
| 29 |
+
/// }));
|
| 30 |
+
/// \endrst
|
| 31 |
+
template <typename T, typename BatchType = std::vector<T>>
|
| 32 |
+
using Collate = BatchLambda<BatchType, T>;
|
| 33 |
+
} // namespace transforms
|
| 34 |
+
} // namespace data
|
| 35 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/transforms/base.h>
|
| 4 |
+
|
| 5 |
+
#include <functional>
|
| 6 |
+
#include <utility>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
namespace torch {
|
| 10 |
+
namespace data {
|
| 11 |
+
namespace transforms {
|
| 12 |
+
|
| 13 |
+
/// A `BatchTransform` that applies a user-provided functor to a batch.
|
| 14 |
+
template <typename Input, typename Output = Input>
|
| 15 |
+
class BatchLambda : public BatchTransform<Input, Output> {
|
| 16 |
+
public:
|
| 17 |
+
using typename BatchTransform<Input, Output>::InputBatchType;
|
| 18 |
+
using typename BatchTransform<Input, Output>::OutputBatchType;
|
| 19 |
+
using FunctionType = std::function<OutputBatchType(InputBatchType)>;
|
| 20 |
+
|
| 21 |
+
/// Constructs the `BatchLambda` from the given `function` object.
|
| 22 |
+
explicit BatchLambda(FunctionType function)
|
| 23 |
+
: function_(std::move(function)) {}
|
| 24 |
+
|
| 25 |
+
/// Applies the user-provided function object to the `input_batch`.
|
| 26 |
+
OutputBatchType apply_batch(InputBatchType input_batch) override {
|
| 27 |
+
return function_(std::move(input_batch));
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
private:
|
| 31 |
+
FunctionType function_;
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
// A `Transform` that applies a user-provided functor to individual examples.
|
| 35 |
+
template <typename Input, typename Output = Input>
|
| 36 |
+
class Lambda : public Transform<Input, Output> {
|
| 37 |
+
public:
|
| 38 |
+
using typename Transform<Input, Output>::InputType;
|
| 39 |
+
using typename Transform<Input, Output>::OutputType;
|
| 40 |
+
using FunctionType = std::function<Output(Input)>;
|
| 41 |
+
|
| 42 |
+
/// Constructs the `Lambda` from the given `function` object.
|
| 43 |
+
explicit Lambda(FunctionType function) : function_(std::move(function)) {}
|
| 44 |
+
|
| 45 |
+
/// Applies the user-provided function object to the `input`.
|
| 46 |
+
OutputType apply(InputType input) override {
|
| 47 |
+
return function_(std::move(input));
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
private:
|
| 51 |
+
FunctionType function_;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
} // namespace transforms
|
| 55 |
+
} // namespace data
|
| 56 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/stack.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/example.h>
|
| 4 |
+
#include <torch/data/transforms/collate.h>
|
| 5 |
+
#include <torch/types.h>
|
| 6 |
+
|
| 7 |
+
#include <utility>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
namespace torch {
|
| 11 |
+
namespace data {
|
| 12 |
+
namespace transforms {
|
| 13 |
+
|
| 14 |
+
template <typename T = Example<>>
|
| 15 |
+
struct Stack;
|
| 16 |
+
|
| 17 |
+
/// A `Collation` for `Example<Tensor, Tensor>` types that stacks all data
|
| 18 |
+
/// tensors into one tensor, and all target (label) tensors into one tensor.
|
| 19 |
+
template <>
|
| 20 |
+
struct Stack<Example<>> : public Collation<Example<>> {
|
| 21 |
+
Example<> apply_batch(std::vector<Example<>> examples) override {
|
| 22 |
+
std::vector<torch::Tensor> data, targets;
|
| 23 |
+
data.reserve(examples.size());
|
| 24 |
+
targets.reserve(examples.size());
|
| 25 |
+
for (auto& example : examples) {
|
| 26 |
+
data.push_back(std::move(example.data));
|
| 27 |
+
targets.push_back(std::move(example.target));
|
| 28 |
+
}
|
| 29 |
+
return {torch::stack(data), torch::stack(targets)};
|
| 30 |
+
}
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
/// A `Collation` for `Example<Tensor, NoTarget>` types that stacks all data
|
| 34 |
+
/// tensors into one tensor.
|
| 35 |
+
template <>
|
| 36 |
+
struct Stack<TensorExample>
|
| 37 |
+
: public Collation<Example<Tensor, example::NoTarget>> {
|
| 38 |
+
TensorExample apply_batch(std::vector<TensorExample> examples) override {
|
| 39 |
+
std::vector<torch::Tensor> data;
|
| 40 |
+
data.reserve(examples.size());
|
| 41 |
+
for (auto& example : examples) {
|
| 42 |
+
data.push_back(std::move(example.data));
|
| 43 |
+
}
|
| 44 |
+
return torch::stack(data);
|
| 45 |
+
}
|
| 46 |
+
};
|
| 47 |
+
} // namespace transforms
|
| 48 |
+
} // namespace data
|
| 49 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/data/example.h>
|
| 4 |
+
#include <torch/data/transforms/base.h>
|
| 5 |
+
#include <torch/types.h>
|
| 6 |
+
|
| 7 |
+
#include <functional>
|
| 8 |
+
#include <utility>
|
| 9 |
+
|
| 10 |
+
namespace torch {
|
| 11 |
+
namespace data {
|
| 12 |
+
namespace transforms {
|
| 13 |
+
|
| 14 |
+
/// A `Transform` that is specialized for the typical `Example<Tensor, Tensor>`
|
| 15 |
+
/// combination. It exposes a single `operator()` interface hook (for
|
| 16 |
+
/// subclasses), and calls this function on input `Example` objects.
|
| 17 |
+
template <typename Target = Tensor>
|
| 18 |
+
class TensorTransform
|
| 19 |
+
: public Transform<Example<Tensor, Target>, Example<Tensor, Target>> {
|
| 20 |
+
public:
|
| 21 |
+
using E = Example<Tensor, Target>;
|
| 22 |
+
using typename Transform<E, E>::InputType;
|
| 23 |
+
using typename Transform<E, E>::OutputType;
|
| 24 |
+
|
| 25 |
+
/// Transforms a single input tensor to an output tensor.
|
| 26 |
+
virtual Tensor operator()(Tensor input) = 0;
|
| 27 |
+
|
| 28 |
+
/// Implementation of `Transform::apply` that calls `operator()`.
|
| 29 |
+
OutputType apply(InputType input) override {
|
| 30 |
+
input.data = (*this)(std::move(input.data));
|
| 31 |
+
return input;
|
| 32 |
+
}
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
/// A `Lambda` specialized for the typical `Example<Tensor, Tensor>` input type.
|
| 36 |
+
template <typename Target = Tensor>
|
| 37 |
+
class TensorLambda : public TensorTransform<Target> {
|
| 38 |
+
public:
|
| 39 |
+
using FunctionType = std::function<Tensor(Tensor)>;
|
| 40 |
+
|
| 41 |
+
/// Creates a `TensorLambda` from the given `function`.
|
| 42 |
+
explicit TensorLambda(FunctionType function)
|
| 43 |
+
: function_(std::move(function)) {}
|
| 44 |
+
|
| 45 |
+
/// Applies the user-provided functor to the input tensor.
|
| 46 |
+
Tensor operator()(Tensor input) override {
|
| 47 |
+
return function_(std::move(input));
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
private:
|
| 51 |
+
FunctionType function_;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
/// Normalizes input tensors by subtracting the supplied mean and dividing by
|
| 55 |
+
/// the given standard deviation.
|
| 56 |
+
template <typename Target = Tensor>
|
| 57 |
+
struct Normalize : public TensorTransform<Target> {
|
| 58 |
+
/// Constructs a `Normalize` transform. The mean and standard deviation can be
|
| 59 |
+
/// anything that is broadcastable over the input tensors (like single
|
| 60 |
+
/// scalars).
|
| 61 |
+
Normalize(ArrayRef<double> mean, ArrayRef<double> stddev)
|
| 62 |
+
: mean(torch::tensor(mean, torch::kFloat32)
|
| 63 |
+
.unsqueeze(/*dim=*/1)
|
| 64 |
+
.unsqueeze(/*dim=*/2)),
|
| 65 |
+
stddev(torch::tensor(stddev, torch::kFloat32)
|
| 66 |
+
.unsqueeze(/*dim=*/1)
|
| 67 |
+
.unsqueeze(/*dim=*/2)) {}
|
| 68 |
+
|
| 69 |
+
torch::Tensor operator()(Tensor input) override {
|
| 70 |
+
return input.sub(mean).div(stddev);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
torch::Tensor mean, stddev;
|
| 74 |
+
};
|
| 75 |
+
} // namespace transforms
|
| 76 |
+
} // namespace data
|
| 77 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/worker_exception.h
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <exception>
|
| 4 |
+
#include <string>
|
| 5 |
+
#include <utility>
|
| 6 |
+
|
| 7 |
+
namespace torch {
|
| 8 |
+
namespace data {
|
| 9 |
+
|
| 10 |
+
/// An exception thrown when a DataLoader's worker thread throws an exception,
|
| 11 |
+
/// which is caught. A `WorkerException` stores an `exception_ptr` to the
|
| 12 |
+
/// original exception thrown in the worker thread.
|
| 13 |
+
struct WorkerException : public std::exception {
|
| 14 |
+
/// Constructs a `WorkerException` from an `exception_ptr`.
|
| 15 |
+
explicit WorkerException(std::exception_ptr original)
|
| 16 |
+
: original_exception(std::move(original)),
|
| 17 |
+
message("Caught exception in DataLoader worker thread.") {
|
| 18 |
+
try {
|
| 19 |
+
std::rethrow_exception(original_exception);
|
| 20 |
+
} catch (std::exception& e) {
|
| 21 |
+
message += " Original message: ";
|
| 22 |
+
message += e.what();
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
const char* what() const noexcept override {
|
| 27 |
+
return message.c_str();
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
/// The original exception thrown in the worker thread.
|
| 31 |
+
std::exception_ptr original_exception;
|
| 32 |
+
|
| 33 |
+
/// This exception's message (not the original exception's message).
|
| 34 |
+
std::string message;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
} // namespace data
|
| 38 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Dispatch.h>
|
| 4 |
+
#include <ATen/ScalarOps.h>
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
#include <ATen/core/grad_mode.h>
|
| 7 |
+
|
| 8 |
+
#include <c10/util/irange.h>
|
| 9 |
+
|
| 10 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 11 |
+
#include <ATen/Functions.h>
|
| 12 |
+
#else
|
| 13 |
+
#include <ATen/ops/empty.h>
|
| 14 |
+
#include <ATen/ops/tensor.h>
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
#include <initializer_list>
|
| 18 |
+
|
| 19 |
+
namespace torch {
|
| 20 |
+
|
| 21 |
+
namespace detail {
|
| 22 |
+
|
| 23 |
+
enum class TensorDataContainerType { Scalar, InitList, Tensor };
|
| 24 |
+
|
| 25 |
+
struct TensorDataContainer;
|
| 26 |
+
|
| 27 |
+
inline std::ostream& operator<<(
|
| 28 |
+
std::ostream& stream,
|
| 29 |
+
const TensorDataContainer& tensor_data_container);
|
| 30 |
+
|
| 31 |
+
inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) {
|
| 32 |
+
if (scalar_type == at::kInt || scalar_type == at::kLong) {
|
| 33 |
+
// C++ `torch::tensor` with an integer type or an `at::ArrayRef` /
|
| 34 |
+
// `std::vector` / (nested) braced-init-list of integer types always
|
| 35 |
+
// produces a tensor of dtype `at::kLong` (aka. int64_t), matching Python
|
| 36 |
+
// `torch.tensor` behavior.
|
| 37 |
+
return at::kLong;
|
| 38 |
+
} else if (scalar_type == at::kFloat || scalar_type == at::kDouble) {
|
| 39 |
+
// C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` /
|
| 40 |
+
// `std::vector` / (nested) braced-init-list of floating-point types always
|
| 41 |
+
// produces a tensor of dtype `torch::get_default_dtype()`, matching Python
|
| 42 |
+
// `torch.tensor` behavior.
|
| 43 |
+
return at::typeMetaToScalarType(at::get_default_dtype());
|
| 44 |
+
} else {
|
| 45 |
+
return scalar_type;
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// We use `TensorDataContainer` to support converting the following data
|
| 50 |
+
// container types into the equivalent Tensor:
|
| 51 |
+
//
|
| 52 |
+
// 1. Arbitrarily nested braced-init-list (e.g. `{{1, 2}, {3, 4}}`).
|
| 53 |
+
// 2. `at::ArrayRef` of supported tensor data types.
|
| 54 |
+
// 3. `std::vector` of supported tensor data types.
|
| 55 |
+
//
|
| 56 |
+
// At any time, a `TensorDataContainer` object represents one of the following:
|
| 57 |
+
//
|
| 58 |
+
// 1. A scalar with value `scalar()` and type `scalar_type()`.
|
| 59 |
+
// 2. A Tensor represented in `std::initializer_list<TensorDataContainer>` form,
|
| 60 |
+
// with value `init_list()`, Tensor scalar type `scalar_type()`, and Tensor
|
| 61 |
+
// sizes `sizes()`.
|
| 62 |
+
// 3. A Tensor represented in `at::Tensor` form, with value `tensor()`, scalar
|
| 63 |
+
// type `scalar_type()`,
|
| 64 |
+
// and Tensor sizes `sizes()`.
|
| 65 |
+
//
|
| 66 |
+
// All the infrastructure here is mostly to support converting an arbitrarily
|
| 67 |
+
// nested braced-init-list to the equivalent Tensor successfully. Consider the
|
| 68 |
+
// following example:
|
| 69 |
+
//
|
| 70 |
+
// `torch::tensor({{1}, {2}})`
|
| 71 |
+
//
|
| 72 |
+
// this will call into the `torch::tensor` function:
|
| 73 |
+
//
|
| 74 |
+
// `at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const
|
| 75 |
+
// at::TensorOptions& options = {})`
|
| 76 |
+
//
|
| 77 |
+
// the compiler will first try to convert `{{1}, {2}}` to `TensorDataContainer`
|
| 78 |
+
// type:
|
| 79 |
+
//
|
| 80 |
+
// `TensorDataContainer({{1}, {2}})`
|
| 81 |
+
//
|
| 82 |
+
// which matches to the
|
| 83 |
+
// `TensorDataContainer(std::initializer_list<TensorDataContainer>)`
|
| 84 |
+
// constructor, and in an attempt to convert `{1}` and `{2}` to
|
| 85 |
+
// `TensorDataContainer`, it calls the following:
|
| 86 |
+
//
|
| 87 |
+
// `TensorDataContainer({1})` (same call path happens for `{2}`, and we'll just
|
| 88 |
+
// focus on `{1}` here)
|
| 89 |
+
//
|
| 90 |
+
// At this point, theoretically there are two plausible ways for `{1}` to be
|
| 91 |
+
// matched to one of the constructors of `TensorDataContainer`:
|
| 92 |
+
//
|
| 93 |
+
// 1. It can be a list-initialization of a scalar value, thus matching
|
| 94 |
+
// `TensorDataContainer(int value)`.
|
| 95 |
+
// 2. It can be converted to `std::initializer_list<TensorDataContainer>`, thus
|
| 96 |
+
// matching
|
| 97 |
+
// `TensorDataContainer(std::initializer_list<TensorDataContainer>)`.
|
| 98 |
+
//
|
| 99 |
+
// How does the compiler decide which one to choose? According to
|
| 100 |
+
// `https://en.cppreference.com/w/cpp/language/list_initialization`,
|
| 101 |
+
// braced-init-list always prefers the constructor that takes
|
| 102 |
+
// `std::initializer_list`. Hence we happily move forward with constructor #2,
|
| 103 |
+
// and it calls the following:
|
| 104 |
+
//
|
| 105 |
+
// `TensorDataContainer(1)`
|
| 106 |
+
//
|
| 107 |
+
// Now it matches `TensorDataContainer(int value)`, which stores `1` as a scalar
|
| 108 |
+
// value. All is good.
|
| 109 |
+
struct TensorDataContainer {
|
| 110 |
+
// NOTE: For tensors with zero-size dimensions (e.g. `torch::tensor({{},
|
| 111 |
+
// {}})`), the innermost empty braced-init-list `{}` matches the default
|
| 112 |
+
// constructor of the innermost `TensorDataContainer`.
|
| 113 |
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
| 114 |
+
TensorDataContainer()
|
| 115 |
+
: sizes_({0}),
|
| 116 |
+
// NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g.
|
| 117 |
+
// `torch.tensor([[], []])`) depends on the value of
|
| 118 |
+
// `torch.get_default_dtype()`, and we should do the same for the C++
|
| 119 |
+
// equivalent.
|
| 120 |
+
scalar_type_(at::typeMetaToScalarType(at::get_default_dtype())),
|
| 121 |
+
type_(TensorDataContainerType::InitList) {}
|
| 122 |
+
#define TENSOR(T, S) \
|
| 123 |
+
TensorDataContainer(T value) \
|
| 124 |
+
: sizes_(), \
|
| 125 |
+
scalar_type_(at::k##S), \
|
| 126 |
+
type_(TensorDataContainerType::Scalar), \
|
| 127 |
+
scalar_(value) {}
|
| 128 |
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
| 129 |
+
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
|
| 130 |
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
| 131 |
+
AT_FORALL_COMPLEX_TYPES(TENSOR)
|
| 132 |
+
#undef TENSOR
|
| 133 |
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
| 134 |
+
TensorDataContainer(std::initializer_list<TensorDataContainer> init_list)
|
| 135 |
+
: sizes_(),
|
| 136 |
+
scalar_type_(init_list.begin()->scalar_type()),
|
| 137 |
+
type_(TensorDataContainerType::InitList),
|
| 138 |
+
init_list_(init_list) {
|
| 139 |
+
const TensorDataContainer& first_elem = *(init_list.begin());
|
| 140 |
+
for (const auto& elem : init_list) {
|
| 141 |
+
TORCH_CHECK(
|
| 142 |
+
elem.sizes() == first_elem.sizes(),
|
| 143 |
+
"Expected all sub-lists to have sizes: ",
|
| 144 |
+
first_elem.sizes(),
|
| 145 |
+
" (e.g. ",
|
| 146 |
+
first_elem,
|
| 147 |
+
"), ",
|
| 148 |
+
"but got sub-list ",
|
| 149 |
+
elem,
|
| 150 |
+
" with sizes: ",
|
| 151 |
+
elem.sizes());
|
| 152 |
+
TORCH_CHECK(
|
| 153 |
+
elem.scalar_type() == first_elem.scalar_type(),
|
| 154 |
+
"Expected all elements of the tensor to have the same scalar type: ",
|
| 155 |
+
first_elem.scalar_type(),
|
| 156 |
+
", but got element of scalar type: ",
|
| 157 |
+
elem.scalar_type());
|
| 158 |
+
}
|
| 159 |
+
sizes_.reserve(first_elem.sizes().size() + 1);
|
| 160 |
+
sizes_.push_back(init_list.size());
|
| 161 |
+
sizes_.insert(
|
| 162 |
+
sizes_.end(), first_elem.sizes().begin(), first_elem.sizes().end());
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
#define TENSOR(T, S) \
|
| 166 |
+
TensorDataContainer(at::ArrayRef<T> values) \
|
| 167 |
+
: sizes_({(int64_t)values.size()}), \
|
| 168 |
+
scalar_type_(at::k##S), \
|
| 169 |
+
type_(TensorDataContainerType::Tensor) { \
|
| 170 |
+
at::AutoDispatchBelowAutograd mode; \
|
| 171 |
+
if (scalar_type_ == at::kBool) { \
|
| 172 |
+
tensor_ = at::tensor(values, at::TensorOptions().device(at::kCPU)); \
|
| 173 |
+
} else { \
|
| 174 |
+
tensor_ = at::tensor(values, at::dtype(scalar_type_).device(at::kCPU)); \
|
| 175 |
+
} \
|
| 176 |
+
}
|
| 177 |
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
| 178 |
+
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
|
| 179 |
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
| 180 |
+
AT_FORALL_COMPLEX_TYPES(TENSOR)
|
| 181 |
+
#undef TENSOR
|
| 182 |
+
|
| 183 |
+
// NOTE: We need to handle `std::vector` explicitly instead of relying on an
|
| 184 |
+
// implicit conversion to `at::ArrayRef`, otherwise the following error can be
|
| 185 |
+
// thrown when calling `torch::tensor(std::vector<int>({1, 2}))`:
|
| 186 |
+
// ```
|
| 187 |
+
// error: no matching function for call to 'tensor(const std::vector<int>&)'
|
| 188 |
+
// no known conversion for argument 1 from 'const std::vector<int>' to
|
| 189 |
+
// 'torch::detail::TensorDataContainer'
|
| 190 |
+
// ```
|
| 191 |
+
//
|
| 192 |
+
// NOTE: `torch::tensor(std::vector<bool>)` is not supported for now, because
|
| 193 |
+
// ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.
|
| 194 |
+
#define TENSOR(T, S) \
|
| 195 |
+
TensorDataContainer(const std::vector<T>& values) \
|
| 196 |
+
: TensorDataContainer(at::ArrayRef<T>(values)) {}
|
| 197 |
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
| 198 |
+
AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR)
|
| 199 |
+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
| 200 |
+
AT_FORALL_COMPLEX_TYPES(TENSOR)
|
| 201 |
+
#undef TENSOR
|
| 202 |
+
|
| 203 |
+
bool is_scalar() const {
|
| 204 |
+
return type_ == TensorDataContainerType::Scalar;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
const c10::Scalar& scalar() const {
|
| 208 |
+
TORCH_CHECK(
|
| 209 |
+
is_scalar(),
|
| 210 |
+
"Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`");
|
| 211 |
+
return scalar_;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
bool is_init_list() const {
|
| 215 |
+
return type_ == TensorDataContainerType::InitList;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
const std::initializer_list<TensorDataContainer>& init_list() const {
|
| 219 |
+
TORCH_CHECK(
|
| 220 |
+
is_init_list(),
|
| 221 |
+
"Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`");
|
| 222 |
+
return init_list_;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
bool is_tensor() const {
|
| 226 |
+
return type_ == TensorDataContainerType::Tensor;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
const at::Tensor& tensor() const {
|
| 230 |
+
TORCH_CHECK(
|
| 231 |
+
is_tensor(),
|
| 232 |
+
"Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`");
|
| 233 |
+
return tensor_;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
const std::vector<int64_t>& sizes() const {
|
| 237 |
+
return sizes_;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
const c10::ScalarType& scalar_type() const {
|
| 241 |
+
return scalar_type_;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
at::Tensor convert_to_tensor(at::TensorOptions options) const {
|
| 245 |
+
if (!options.has_dtype()) {
|
| 246 |
+
options = options.dtype(compute_desired_dtype(scalar_type_));
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
if (is_scalar()) {
|
| 250 |
+
at::AutoDispatchBelowAutograd mode;
|
| 251 |
+
return at::scalar_tensor(scalar_, options);
|
| 252 |
+
} else if (is_init_list()) {
|
| 253 |
+
// NOTE: Here we explicitly choose to initialize the tensor on CPU first,
|
| 254 |
+
// fill each element of the tensor, and then move the tensor to the
|
| 255 |
+
// desired device. For CUDA device, this approach only involves 1 CUDA
|
| 256 |
+
// kernel launch, and is much faster than initializing the tensor on CUDA
|
| 257 |
+
// first and then filling each element of it (which involves `N` CUDA
|
| 258 |
+
// kernel launches where `N` is the number of the elements in the tensor).
|
| 259 |
+
at::Tensor tensor = ([&]() {
|
| 260 |
+
at::AutoDispatchBelowAutograd mode;
|
| 261 |
+
return at::empty(sizes_, options.device(at::kCPU));
|
| 262 |
+
})();
|
| 263 |
+
fill_tensor(tensor);
|
| 264 |
+
return tensor.to(options.device());
|
| 265 |
+
} else if (is_tensor()) {
|
| 266 |
+
auto output = tensor_.to(options);
|
| 267 |
+
TORCH_CHECK(
|
| 268 |
+
!tensor_.is_complex() || output.is_complex(),
|
| 269 |
+
"can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information");
|
| 270 |
+
return output;
|
| 271 |
+
} else {
|
| 272 |
+
TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
void pretty_print_recursive(std::ostream& stream) const {
|
| 277 |
+
if (is_scalar()) {
|
| 278 |
+
AT_DISPATCH_ALL_TYPES_AND3(
|
| 279 |
+
at::kBool,
|
| 280 |
+
at::kHalf,
|
| 281 |
+
at::kBFloat16,
|
| 282 |
+
scalar_type_,
|
| 283 |
+
"TensorDataContainer_pretty_print_scalar",
|
| 284 |
+
[&] { stream << scalar_.to<scalar_t>(); });
|
| 285 |
+
} else if (is_init_list()) {
|
| 286 |
+
stream << "{";
|
| 287 |
+
for (const TensorDataContainer* it = init_list_.begin();
|
| 288 |
+
it != init_list_.end();
|
| 289 |
+
it++) {
|
| 290 |
+
stream << *it;
|
| 291 |
+
if (std::next(it) != init_list_.end())
|
| 292 |
+
stream << ", ";
|
| 293 |
+
}
|
| 294 |
+
stream << "}";
|
| 295 |
+
} else if (is_tensor()) {
|
| 296 |
+
stream << "{";
|
| 297 |
+
for (const auto i : c10::irange(tensor_.sizes()[0])) {
|
| 298 |
+
AT_DISPATCH_ALL_TYPES_AND3(
|
| 299 |
+
at::kBool,
|
| 300 |
+
at::kHalf,
|
| 301 |
+
at::kBFloat16,
|
| 302 |
+
scalar_type_,
|
| 303 |
+
"TensorDataContainer_pretty_print_tensor_item",
|
| 304 |
+
[&] { stream << tensor_[i].item<scalar_t>(); });
|
| 305 |
+
if (i != tensor_.sizes()[0] - 1)
|
| 306 |
+
stream << ", ";
|
| 307 |
+
}
|
| 308 |
+
stream << "}";
|
| 309 |
+
} else {
|
| 310 |
+
TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
private:
|
| 315 |
+
void fill_tensor(at::Tensor& tensor) const {
|
| 316 |
+
if (is_scalar()) {
|
| 317 |
+
TORCH_INTERNAL_ASSERT(
|
| 318 |
+
tensor.dim() == 0,
|
| 319 |
+
"Expected a 0-dim Tensor, but got Tensor with dimensions: ",
|
| 320 |
+
tensor.dim());
|
| 321 |
+
at::NoGradGuard guard;
|
| 322 |
+
tensor.fill_(scalar_);
|
| 323 |
+
} else if (is_init_list()) {
|
| 324 |
+
TORCH_INTERNAL_ASSERT(
|
| 325 |
+
tensor.sizes()[0] == (int64_t)init_list_.size(),
|
| 326 |
+
"Expected a Tensor with size ",
|
| 327 |
+
init_list_.size(),
|
| 328 |
+
" in its first dimension, but got Tensor with size ",
|
| 329 |
+
tensor.sizes()[0],
|
| 330 |
+
" in its first dimension");
|
| 331 |
+
size_t index = 0;
|
| 332 |
+
for (const auto& elem : init_list_) {
|
| 333 |
+
at::Tensor slice = tensor[index];
|
| 334 |
+
elem.fill_tensor(slice);
|
| 335 |
+
index++;
|
| 336 |
+
}
|
| 337 |
+
} else if (is_tensor()) {
|
| 338 |
+
TORCH_INTERNAL_ASSERT(
|
| 339 |
+
false,
|
| 340 |
+
"TensorDataContainer is already a Tensor type, `fill_tensor` should not be called");
|
| 341 |
+
} else {
|
| 342 |
+
TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
std::vector<int64_t> sizes_;
|
| 347 |
+
c10::ScalarType scalar_type_;
|
| 348 |
+
TensorDataContainerType type_;
|
| 349 |
+
c10::Scalar scalar_;
|
| 350 |
+
std::initializer_list<TensorDataContainer> init_list_;
|
| 351 |
+
at::Tensor tensor_;
|
| 352 |
+
};
|
| 353 |
+
|
| 354 |
+
inline std::ostream& operator<<(
|
| 355 |
+
std::ostream& stream,
|
| 356 |
+
const TensorDataContainer& tensor_data_container) {
|
| 357 |
+
tensor_data_container.pretty_print_recursive(stream);
|
| 358 |
+
return stream;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
} // namespace detail
|
| 362 |
+
|
| 363 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/static.h
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/csrc/utils/variadic.h>
|
| 4 |
+
#include <torch/types.h>
|
| 5 |
+
|
| 6 |
+
#include <cstdint>
|
| 7 |
+
#include <type_traits>
|
| 8 |
+
|
| 9 |
+
namespace torch {
|
| 10 |
+
namespace nn {
|
| 11 |
+
class Module;
|
| 12 |
+
} // namespace nn
|
| 13 |
+
} // namespace torch
|
| 14 |
+
|
| 15 |
+
namespace torch {
|
| 16 |
+
namespace detail {
|
| 17 |
+
/// Detects if a type T has a forward() method.
|
| 18 |
+
template <typename T>
|
| 19 |
+
struct has_forward {
|
| 20 |
+
// Declare two types with differing size.
|
| 21 |
+
using yes = int8_t;
|
| 22 |
+
using no = int16_t;
|
| 23 |
+
|
| 24 |
+
// Here we declare two functions. The first is only enabled if `&U::forward`
|
| 25 |
+
// is well-formed and returns the `yes` type. In C++, the ellipsis parameter
|
| 26 |
+
// type (`...`) always puts the function at the bottom of overload resolution.
|
| 27 |
+
// This is specified in the standard as: 1) A standard conversion sequence is
|
| 28 |
+
// always better than a user-defined conversion sequence or an ellipsis
|
| 29 |
+
// conversion sequence. 2) A user-defined conversion sequence is always better
|
| 30 |
+
// than an ellipsis conversion sequence This means that if the first overload
|
| 31 |
+
// is viable, it will be preferred over the second as long as we pass any
|
| 32 |
+
// convertible type. The type of `&U::forward` is a pointer type, so we can
|
| 33 |
+
// pass e.g. 0.
|
| 34 |
+
template <typename U>
|
| 35 |
+
static yes test(decltype(&U::forward));
|
| 36 |
+
template <typename U>
|
| 37 |
+
static no test(...);
|
| 38 |
+
|
| 39 |
+
// Finally we test statically whether the size of the type returned by the
|
| 40 |
+
// selected overload is the size of the `yes` type.
|
| 41 |
+
static constexpr bool value = (sizeof(test<T>(nullptr)) == sizeof(yes));
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
template <typename Head = void, typename... Tail>
|
| 45 |
+
constexpr bool check_not_lvalue_references() {
|
| 46 |
+
return (!std::is_lvalue_reference<Head>::value ||
|
| 47 |
+
std::is_const<typename std::remove_reference<Head>::type>::value) &&
|
| 48 |
+
check_not_lvalue_references<Tail...>();
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
template <>
|
| 52 |
+
inline constexpr bool check_not_lvalue_references<void>() {
|
| 53 |
+
return true;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
/// A type trait whose `value` member is true if `M` derives from `Module`.
|
| 57 |
+
template <typename M>
|
| 58 |
+
using is_module =
|
| 59 |
+
std::is_base_of<torch::nn::Module, typename std::decay<M>::type>;
|
| 60 |
+
|
| 61 |
+
template <typename M, typename T = void>
|
| 62 |
+
using enable_if_module_t =
|
| 63 |
+
typename std::enable_if<is_module<M>::value, T>::type;
|
| 64 |
+
} // namespace detail
|
| 65 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/detail/static.h>
|
| 4 |
+
#include <torch/nn/module.h>
|
| 5 |
+
#include <torch/nn/modules/container/any_module_holder.h>
|
| 6 |
+
#include <torch/nn/modules/container/any_value.h>
|
| 7 |
+
#include <torch/nn/pimpl.h>
|
| 8 |
+
#include <torch/types.h>
|
| 9 |
+
|
| 10 |
+
#include <torch/csrc/autograd/variable.h>
|
| 11 |
+
#include <torch/csrc/utils/variadic.h>
|
| 12 |
+
|
| 13 |
+
#include <ATen/Device.h>
|
| 14 |
+
|
| 15 |
+
#include <memory>
|
| 16 |
+
#include <type_traits>
|
| 17 |
+
#include <typeinfo>
|
| 18 |
+
#include <utility>
|
| 19 |
+
#include <vector>
|
| 20 |
+
|
| 21 |
+
namespace torch {
|
| 22 |
+
namespace nn {
|
| 23 |
+
|
| 24 |
+
/// Stores a type erased `Module`.
|
| 25 |
+
///
|
| 26 |
+
/// The PyTorch C++ API does not impose an interface on the signature of
|
| 27 |
+
/// `forward()` in `Module` subclasses. This gives you complete freedom to
|
| 28 |
+
/// design your `forward()` methods to your liking. However, this also means
|
| 29 |
+
/// there is no unified base type you could store in order to call `forward()`
|
| 30 |
+
/// polymorphically for any module. This is where the `AnyModule` comes in.
|
| 31 |
+
/// Instead of inheritance, it relies on type erasure for polymorphism.
|
| 32 |
+
///
|
| 33 |
+
/// An `AnyModule` can store any `nn::Module` subclass that provides a
|
| 34 |
+
/// `forward()` method. This `forward()` may accept any types and return any
|
| 35 |
+
/// type. Once stored in an `AnyModule`, you can invoke the underlying module's
|
| 36 |
+
/// `forward()` by calling `AnyModule::forward()` with the arguments you would
|
| 37 |
+
/// supply to the stored module (though see one important limitation below).
|
| 38 |
+
/// Example:
|
| 39 |
+
///
|
| 40 |
+
/// \rst
|
| 41 |
+
/// .. code-block:: cpp
|
| 42 |
+
///
|
| 43 |
+
/// struct GenericTrainer {
|
| 44 |
+
/// torch::nn::AnyModule module;
|
| 45 |
+
///
|
| 46 |
+
/// void train(torch::Tensor input) {
|
| 47 |
+
/// module.forward(input);
|
| 48 |
+
/// }
|
| 49 |
+
/// };
|
| 50 |
+
///
|
| 51 |
+
/// GenericTrainer trainer1{torch::nn::Linear(3, 4)};
|
| 52 |
+
/// GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
|
| 53 |
+
/// \endrst
|
| 54 |
+
///
|
| 55 |
+
/// As `AnyModule` erases the static type of the stored module (and its
|
| 56 |
+
/// `forward()` method) to achieve polymorphism, type checking of arguments is
|
| 57 |
+
/// moved to runtime. That is, passing an argument with an incorrect type to an
|
| 58 |
+
/// `AnyModule` will compile, but throw an exception at runtime:
|
| 59 |
+
///
|
| 60 |
+
/// \rst
|
| 61 |
+
/// .. code-block:: cpp
|
| 62 |
+
///
|
| 63 |
+
/// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
|
| 64 |
+
/// // Linear takes a tensor as input, but we are passing an integer.
|
| 65 |
+
/// // This will compile, but throw a `torch::Error` exception at runtime.
|
| 66 |
+
/// module.forward(123);
|
| 67 |
+
/// \endrst
|
| 68 |
+
///
|
| 69 |
+
/// \rst
|
| 70 |
+
/// .. attention::
|
| 71 |
+
/// One noteworthy limitation of `AnyModule` is that its `forward()` method
|
| 72 |
+
/// does not support implicit conversion of argument types. For example, if
|
| 73 |
+
/// the stored module's `forward()` method accepts a `float` and you call
|
| 74 |
+
/// `any_module.forward(3.4)` (where `3.4` is a `double`), this will throw
|
| 75 |
+
/// an exception.
|
| 76 |
+
/// \endrst
|
| 77 |
+
///
|
| 78 |
+
/// The return type of the `AnyModule`'s `forward()` method is controlled via
|
| 79 |
+
/// the first template argument to `AnyModule::forward()`. It defaults to
|
| 80 |
+
/// `torch::Tensor`. To change it, you can write `any_module.forward<int>()`,
|
| 81 |
+
/// for example.
|
| 82 |
+
///
|
| 83 |
+
/// \rst
|
| 84 |
+
/// .. code-block:: cpp
|
| 85 |
+
///
|
| 86 |
+
/// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
|
| 87 |
+
/// auto output = module.forward(torch::ones({2, 3}));
|
| 88 |
+
///
|
| 89 |
+
/// struct IntModule {
|
| 90 |
+
/// int forward(int x) { return x; }
|
| 91 |
+
/// };
|
| 92 |
+
/// torch::nn::AnyModule module(IntModule{});
|
| 93 |
+
/// int output = module.forward<int>(5);
|
| 94 |
+
/// \endrst
|
| 95 |
+
///
|
| 96 |
+
/// The only other method an `AnyModule` provides access to on the stored
|
| 97 |
+
/// module is `clone()`. However, you may acquire a handle on the module via
|
| 98 |
+
/// `.ptr()`, which returns a `shared_ptr<nn::Module>`. Further, if you know
|
| 99 |
+
/// the concrete type of the stored module, you can get a concrete handle to it
|
| 100 |
+
/// using `.get<T>()` where `T` is the concrete module type.
|
| 101 |
+
///
|
| 102 |
+
/// \rst
|
| 103 |
+
/// .. code-block:: cpp
|
| 104 |
+
///
|
| 105 |
+
/// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
|
| 106 |
+
/// std::shared_ptr<nn::Module> ptr = module.ptr();
|
| 107 |
+
/// torch::nn::Linear linear(module.get<torch::nn::Linear>());
|
| 108 |
+
/// \endrst
|
| 109 |
+
class AnyModule {
|
| 110 |
+
public:
|
| 111 |
+
/// A default-constructed `AnyModule` is in an empty state.
|
| 112 |
+
AnyModule() = default;
|
| 113 |
+
|
| 114 |
+
/// Constructs an `AnyModule` from a `shared_ptr` to concrete module object.
|
| 115 |
+
template <typename ModuleType>
|
| 116 |
+
explicit AnyModule(std::shared_ptr<ModuleType> module);
|
| 117 |
+
|
| 118 |
+
/// Constructs an `AnyModule` from a concrete module object.
|
| 119 |
+
template <
|
| 120 |
+
typename ModuleType,
|
| 121 |
+
typename = torch::detail::enable_if_module_t<ModuleType>>
|
| 122 |
+
explicit AnyModule(ModuleType&& module);
|
| 123 |
+
|
| 124 |
+
/// Constructs an `AnyModule` from a module holder.
|
| 125 |
+
template <typename ModuleType>
|
| 126 |
+
explicit AnyModule(const ModuleHolder<ModuleType>& module_holder);
|
| 127 |
+
|
| 128 |
+
/// Move construction and assignment is allowed, and follows the default
|
| 129 |
+
/// behavior of move for `std::unique_ptr`.
|
| 130 |
+
AnyModule(AnyModule&&) = default;
|
| 131 |
+
AnyModule& operator=(AnyModule&&) = default;
|
| 132 |
+
|
| 133 |
+
/// Creates a shallow copy of an `AnyModule`.
|
| 134 |
+
AnyModule(const AnyModule& other);
|
| 135 |
+
AnyModule& operator=(const AnyModule& other);
|
| 136 |
+
|
| 137 |
+
/// Creates a deep copy of an `AnyModule` if it contains a module, else an
|
| 138 |
+
/// empty `AnyModule` if it is empty.
|
| 139 |
+
AnyModule clone(std::optional<Device> device = std::nullopt) const;
|
| 140 |
+
|
| 141 |
+
/// Assigns a module to the `AnyModule` (to circumvent the explicit
|
| 142 |
+
/// constructor).
|
| 143 |
+
template <typename ModuleType>
|
| 144 |
+
AnyModule& operator=(std::shared_ptr<ModuleType> module);
|
| 145 |
+
|
| 146 |
+
/// Invokes `forward()` on the contained module with the given arguments, and
|
| 147 |
+
/// returns the return value as an `AnyValue`. Use this method when chaining
|
| 148 |
+
/// `AnyModule`s in a loop.
|
| 149 |
+
template <typename... ArgumentTypes>
|
| 150 |
+
AnyValue any_forward(ArgumentTypes&&... arguments);
|
| 151 |
+
|
| 152 |
+
/// Invokes `forward()` on the contained module with the given arguments, and
|
| 153 |
+
/// casts the returned `AnyValue` to the supplied `ReturnType` (which defaults
|
| 154 |
+
/// to `torch::Tensor`).
|
| 155 |
+
template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
|
| 156 |
+
ReturnType forward(ArgumentTypes&&... arguments);
|
| 157 |
+
|
| 158 |
+
/// Attempts to cast the underlying module to the given module type. Throws an
|
| 159 |
+
/// exception if the types do not match.
|
| 160 |
+
template <typename T, typename = torch::detail::enable_if_module_t<T>>
|
| 161 |
+
T& get();
|
| 162 |
+
|
| 163 |
+
/// Attempts to cast the underlying module to the given module type. Throws an
|
| 164 |
+
/// exception if the types do not match.
|
| 165 |
+
template <typename T, typename = torch::detail::enable_if_module_t<T>>
|
| 166 |
+
const T& get() const;
|
| 167 |
+
|
| 168 |
+
/// Returns the contained module in a `nn::ModuleHolder` subclass if possible
|
| 169 |
+
/// (i.e. if `T` has a constructor for the underlying module type).
|
| 170 |
+
template <typename T, typename ContainedType = typename T::ContainedType>
|
| 171 |
+
T get() const;
|
| 172 |
+
|
| 173 |
+
/// Returns a `std::shared_ptr` whose dynamic type is that of the underlying
|
| 174 |
+
/// module.
|
| 175 |
+
std::shared_ptr<Module> ptr() const;
|
| 176 |
+
|
| 177 |
+
/// Like `ptr()`, but casts the pointer to the given type.
|
| 178 |
+
template <typename T, typename = torch::detail::enable_if_module_t<T>>
|
| 179 |
+
std::shared_ptr<T> ptr() const;
|
| 180 |
+
|
| 181 |
+
/// Returns the `type_info` object of the contained value.
|
| 182 |
+
const std::type_info& type_info() const;
|
| 183 |
+
|
| 184 |
+
/// Returns true if the `AnyModule` does not contain a module.
|
| 185 |
+
bool is_empty() const noexcept;
|
| 186 |
+
|
| 187 |
+
private:
|
| 188 |
+
/// Creates a `unique_ptr<AnyModulePlaceholder>` pointing to a
|
| 189 |
+
/// `AnyModuleHolder` of the correct type. This method is used to deduce the
|
| 190 |
+
/// arguments of the module's `forward()` method.
|
| 191 |
+
template <
|
| 192 |
+
typename ModuleType,
|
| 193 |
+
typename Class,
|
| 194 |
+
typename ReturnType,
|
| 195 |
+
typename... ArgumentTypes>
|
| 196 |
+
std::unique_ptr<AnyModulePlaceholder> make_holder(
|
| 197 |
+
std::shared_ptr<ModuleType>&& module,
|
| 198 |
+
ReturnType (Class::*)(ArgumentTypes...));
|
| 199 |
+
|
| 200 |
+
/// Helper method invoked by const and non-const `get()`.
|
| 201 |
+
template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
|
| 202 |
+
ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const;
|
| 203 |
+
|
| 204 |
+
/// Helper method invoked by const and non-const `get()`.
|
| 205 |
+
template <typename ModuleType>
|
| 206 |
+
ModuleType& get_() const;
|
| 207 |
+
|
| 208 |
+
/// The type erased module.
|
| 209 |
+
std::unique_ptr<AnyModulePlaceholder> content_;
|
| 210 |
+
};
|
| 211 |
+
|
| 212 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 213 |
+
|
| 214 |
+
template <typename ModuleType>
|
| 215 |
+
AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
|
| 216 |
+
: content_(make_holder(
|
| 217 |
+
std::move(module),
|
| 218 |
+
&std::remove_reference<ModuleType>::type::forward)) {
|
| 219 |
+
// `AnyModule` can only store an `nn::Module` subclass object that provides
|
| 220 |
+
// a `forward()` method that has a non-templatized return type.
|
| 221 |
+
// (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s
|
| 222 |
+
// `forward()` method has a templatized return type.)
|
| 223 |
+
static_assert(
|
| 224 |
+
torch::detail::is_module<ModuleType>::value,
|
| 225 |
+
"Can only store object derived from nn::Module into AnyModule");
|
| 226 |
+
static_assert(
|
| 227 |
+
torch::detail::has_forward<ModuleType>::value,
|
| 228 |
+
"Can only store module with a forward() method that has a non-templatized"
|
| 229 |
+
" argument type and return type into AnyModule (e.g. we cannot store nn::Sequential"
|
| 230 |
+
"into AnyModule, because its forward() method's argument type and return type are templatized."
|
| 231 |
+
" If you need to use nn::Sequentials inside each other you can subclass "
|
| 232 |
+
"nn::Sequential and write a non-templatized forward function for it. You can checkout "
|
| 233 |
+
"https://github.com/pytorch/vision/blob/2f46070f3cb1ea894d82578f3dc5677f82f34958/torchvision/csrc/models/mnasnet.cpp#L59 "
|
| 234 |
+
"for an example on how to do this.).");
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
template <typename ModuleType, typename>
|
| 238 |
+
AnyModule::AnyModule(ModuleType&& module)
|
| 239 |
+
: AnyModule(
|
| 240 |
+
std::make_shared<ModuleType>(std::forward<ModuleType>(module))) {}
|
| 241 |
+
|
| 242 |
+
template <typename ModuleType>
|
| 243 |
+
AnyModule::AnyModule(const ModuleHolder<ModuleType>& module_holder)
|
| 244 |
+
: AnyModule(module_holder.ptr()) {}
|
| 245 |
+
|
| 246 |
+
inline AnyModule::AnyModule(const AnyModule& other)
|
| 247 |
+
: content_(other.content_ ? other.content_->copy() : nullptr) {}
|
| 248 |
+
|
| 249 |
+
inline AnyModule& AnyModule::operator=(const AnyModule& other) {
|
| 250 |
+
if (this != &other) {
|
| 251 |
+
content_ = other.content_ ? other.content_->copy() : nullptr;
|
| 252 |
+
}
|
| 253 |
+
return *this;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
inline AnyModule AnyModule::clone(std::optional<Device> device) const {
|
| 257 |
+
AnyModule clone;
|
| 258 |
+
clone.content_ = content_ ? content_->clone_module(device) : nullptr;
|
| 259 |
+
return clone;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
template <typename ModuleType>
|
| 263 |
+
AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
|
| 264 |
+
// NOLINTNEXTLINE(cppcoreguidelines-c-copy-assignment-signature)
|
| 265 |
+
return (*this = AnyModule(std::move(module)));
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
template <typename... ArgumentTypes>
|
| 269 |
+
AnyValue AnyModule::any_forward(ArgumentTypes&&... arguments) {
|
| 270 |
+
TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
|
| 271 |
+
std::vector<AnyValue> values;
|
| 272 |
+
values.reserve(sizeof...(ArgumentTypes));
|
| 273 |
+
torch::apply(
|
| 274 |
+
[&values](AnyValue&& value) { values.push_back(std::move(value)); },
|
| 275 |
+
AnyValue(std::forward<ArgumentTypes>(arguments))...);
|
| 276 |
+
return content_->forward(std::move(values));
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
template <typename ReturnType, typename... ArgumentTypes>
|
| 280 |
+
ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
|
| 281 |
+
return any_forward(std::forward<ArgumentTypes>(arguments)...)
|
| 282 |
+
.template get<ReturnType>();
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
template <typename T, typename>
|
| 286 |
+
T& AnyModule::get() {
|
| 287 |
+
TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
|
| 288 |
+
return get_<T>();
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
template <typename T, typename>
|
| 292 |
+
const T& AnyModule::get() const {
|
| 293 |
+
TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
|
| 294 |
+
return get_<T>();
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
template <typename T, typename ContainedType>
|
| 298 |
+
T AnyModule::get() const {
|
| 299 |
+
return T(ptr<ContainedType>());
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
inline std::shared_ptr<Module> AnyModule::ptr() const {
|
| 303 |
+
TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
|
| 304 |
+
return content_->ptr();
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
template <typename T, typename>
|
| 308 |
+
std::shared_ptr<T> AnyModule::ptr() const {
|
| 309 |
+
TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
|
| 310 |
+
// Call get() but discard the value, just to do the type checking.
|
| 311 |
+
get_<T>();
|
| 312 |
+
return std::dynamic_pointer_cast<T>(ptr());
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
inline const std::type_info& AnyModule::type_info() const {
|
| 316 |
+
TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
|
| 317 |
+
return content_->type_info;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
inline bool AnyModule::is_empty() const noexcept {
|
| 321 |
+
return content_ == nullptr;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
// Private Methods
|
| 325 |
+
|
| 326 |
+
template <
|
| 327 |
+
typename ModuleType,
|
| 328 |
+
typename Class,
|
| 329 |
+
typename ReturnType,
|
| 330 |
+
typename... ArgumentTypes>
|
| 331 |
+
std::unique_ptr<AnyModulePlaceholder> AnyModule::make_holder(
|
| 332 |
+
std::shared_ptr<ModuleType>&& module,
|
| 333 |
+
ReturnType (Class::*)(ArgumentTypes...)) {
|
| 334 |
+
static_assert(
|
| 335 |
+
torch::detail::check_not_lvalue_references<ArgumentTypes...>(),
|
| 336 |
+
"Modules stored inside AnyModule must not take references. "
|
| 337 |
+
"Use pointers instead.");
|
| 338 |
+
static_assert(
|
| 339 |
+
!std::is_void<ReturnType>::value,
|
| 340 |
+
"AnyModule cannot store modules that return void "
|
| 341 |
+
"(you can return a dummy value).");
|
| 342 |
+
return std::make_unique<
|
| 343 |
+
AnyModuleHolder<std::decay_t<ModuleType>, ArgumentTypes...>>(
|
| 344 |
+
std::move(module));
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
template <typename ModuleType>
|
| 348 |
+
ModuleType& AnyModule::get_() const {
|
| 349 |
+
using M = typename std::remove_reference<ModuleType>::type;
|
| 350 |
+
static_assert(
|
| 351 |
+
torch::detail::has_forward<M>::value,
|
| 352 |
+
"Can only call AnyModule::get<T> with a type T that has a forward method");
|
| 353 |
+
return get_(&M::forward);
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
|
| 357 |
+
ModuleType& AnyModule::get_(
|
| 358 |
+
ReturnType (ModuleType::*)(ArgumentTypes...)) const {
|
| 359 |
+
if (typeid(ModuleType).hash_code() == type_info().hash_code()) {
|
| 360 |
+
return *static_cast<AnyModuleHolder<ModuleType, ArgumentTypes...>&>(
|
| 361 |
+
*content_)
|
| 362 |
+
.module;
|
| 363 |
+
}
|
| 364 |
+
AT_ERROR(
|
| 365 |
+
"Attempted to cast module of type ",
|
| 366 |
+
c10::demangle(type_info().name()),
|
| 367 |
+
" to type ",
|
| 368 |
+
c10::demangle(typeid(ModuleType).name()));
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
} // namespace nn
|
| 372 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/nn/modules/container/any_value.h>
|
| 4 |
+
|
| 5 |
+
namespace torch {
|
| 6 |
+
namespace nn {
|
| 7 |
+
|
| 8 |
+
class Module;
|
| 9 |
+
|
| 10 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModulePlaceholder ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 11 |
+
|
| 12 |
+
/// The static type of the object we store in the `AnyModule`, which erases
|
| 13 |
+
/// the actual type, but allows us to call `forward()` on the underlying
|
| 14 |
+
/// module.
|
| 15 |
+
struct AnyModulePlaceholder : public AnyValue::Placeholder {
|
| 16 |
+
using AnyValue::Placeholder::Placeholder;
|
| 17 |
+
|
| 18 |
+
/// The "erased" `forward()` method.
|
| 19 |
+
virtual AnyValue forward(std::vector<AnyValue>&& arguments) = 0;
|
| 20 |
+
|
| 21 |
+
/// Returns std::shared_ptr<Module> pointing to the erased module.
|
| 22 |
+
virtual std::shared_ptr<Module> ptr() = 0;
|
| 23 |
+
|
| 24 |
+
/// Returns a `AnyModulePlaceholder` with a shallow copy of this `AnyModule`.
|
| 25 |
+
virtual std::unique_ptr<AnyModulePlaceholder> copy() const = 0;
|
| 26 |
+
|
| 27 |
+
/// Returns a `AnyModulePlaceholder` with a deep copy of this `AnyModule`.
|
| 28 |
+
virtual std::unique_ptr<AnyModulePlaceholder> clone_module(
|
| 29 |
+
std::optional<Device> device) const = 0;
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModuleHolder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 33 |
+
|
| 34 |
+
/// The dynamic type of the object stored in the `AnyModule`. It contains the
|
| 35 |
+
/// concrete instance to which all calls are forwarded. It is parameterized
|
| 36 |
+
/// over the concrete type of the module, and the types of the arguments the
|
| 37 |
+
/// module takes in its `forward()` method.
|
| 38 |
+
template <typename ModuleType, typename... ArgumentTypes>
|
| 39 |
+
struct AnyModuleHolder : public AnyModulePlaceholder {
|
| 40 |
+
/// \internal
|
| 41 |
+
struct CheckedGetter {
|
| 42 |
+
template <typename T>
|
| 43 |
+
std::decay_t<T>&& operator()(size_t index) {
|
| 44 |
+
AT_ASSERT(index < arguments_.size());
|
| 45 |
+
auto& value = arguments_[index];
|
| 46 |
+
if (auto* maybe_value = value.template try_get<std::decay_t<T>>()) {
|
| 47 |
+
return std::move(*maybe_value);
|
| 48 |
+
}
|
| 49 |
+
AT_ERROR(
|
| 50 |
+
"Expected argument #",
|
| 51 |
+
index,
|
| 52 |
+
" to be of type ",
|
| 53 |
+
c10::demangle(typeid(T).name()),
|
| 54 |
+
", but received value of type ",
|
| 55 |
+
c10::demangle(value.type_info().name()));
|
| 56 |
+
}
|
| 57 |
+
std::vector<AnyValue>& arguments_;
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
/// \internal
|
| 61 |
+
struct InvokeForward {
|
| 62 |
+
template <typename... Ts>
|
| 63 |
+
AnyValue operator()(Ts&&... ts) {
|
| 64 |
+
return AnyValue(module_->forward(std::forward<Ts>(ts)...));
|
| 65 |
+
}
|
| 66 |
+
std::shared_ptr<ModuleType>& module_;
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
/// Constructs the `AnyModuleHolder` from a concrete module.
|
| 70 |
+
explicit AnyModuleHolder(std::shared_ptr<ModuleType>&& module_)
|
| 71 |
+
: AnyModulePlaceholder(typeid(ModuleType)), module(std::move(module_)) {}
|
| 72 |
+
|
| 73 |
+
/// Calls `forward()` on the underlying module, casting each `AnyValue` in the
|
| 74 |
+
/// argument vector to a concrete value.
|
| 75 |
+
AnyValue forward(std::vector<AnyValue>&& arguments) override {
|
| 76 |
+
if (module->_forward_has_default_args()) {
|
| 77 |
+
TORCH_CHECK(
|
| 78 |
+
arguments.size() >= module->_forward_num_required_args() &&
|
| 79 |
+
arguments.size() <= sizeof...(ArgumentTypes),
|
| 80 |
+
c10::demangle(type_info.name()),
|
| 81 |
+
"'s forward() method expects at least ",
|
| 82 |
+
module->_forward_num_required_args(),
|
| 83 |
+
" argument(s) and at most ",
|
| 84 |
+
sizeof...(ArgumentTypes),
|
| 85 |
+
" argument(s), but received ",
|
| 86 |
+
arguments.size(),
|
| 87 |
+
".");
|
| 88 |
+
arguments = std::move(
|
| 89 |
+
module->_forward_populate_default_args(std::move(arguments)));
|
| 90 |
+
} else {
|
| 91 |
+
std::string use_default_args_macro_prompt = " If " +
|
| 92 |
+
c10::demangle(type_info.name()) +
|
| 93 |
+
"'s forward() method has default arguments, " +
|
| 94 |
+
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.";
|
| 95 |
+
TORCH_CHECK(
|
| 96 |
+
arguments.size() == sizeof...(ArgumentTypes),
|
| 97 |
+
c10::demangle(type_info.name()),
|
| 98 |
+
"'s forward() method expects ",
|
| 99 |
+
sizeof...(ArgumentTypes),
|
| 100 |
+
" argument(s), but received ",
|
| 101 |
+
arguments.size(),
|
| 102 |
+
".",
|
| 103 |
+
(arguments.size() < sizeof...(ArgumentTypes))
|
| 104 |
+
? use_default_args_macro_prompt
|
| 105 |
+
: "");
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// FYI: During invocation of a module's `forward()` method, the values live
|
| 109 |
+
// in the `arguments` vector inside this function.
|
| 110 |
+
return torch::unpack<AnyValue, ArgumentTypes...>(
|
| 111 |
+
InvokeForward{module}, CheckedGetter{arguments});
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
std::shared_ptr<Module> ptr() override {
|
| 115 |
+
return module;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
std::unique_ptr<AnyModulePlaceholder> copy() const override {
|
| 119 |
+
return std::make_unique<AnyModuleHolder>(*this);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
std::unique_ptr<AnyModulePlaceholder> clone_module(
|
| 123 |
+
std::optional<Device> device) const override {
|
| 124 |
+
return std::make_unique<AnyModuleHolder>(
|
| 125 |
+
std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
/// The actual concrete module instance.
|
| 129 |
+
std::shared_ptr<ModuleType> module;
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
} // namespace nn
|
| 133 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/detail/static.h>
|
| 4 |
+
#include <torch/nn/module.h>
|
| 5 |
+
#include <torch/nn/pimpl.h>
|
| 6 |
+
#include <torch/types.h>
|
| 7 |
+
|
| 8 |
+
#include <torch/csrc/autograd/variable.h>
|
| 9 |
+
#include <torch/csrc/utils/variadic.h>
|
| 10 |
+
|
| 11 |
+
#include <memory>
|
| 12 |
+
#include <type_traits>
|
| 13 |
+
#include <typeinfo>
|
| 14 |
+
#include <utility>
|
| 15 |
+
|
| 16 |
+
namespace torch {
|
| 17 |
+
namespace nn {
|
| 18 |
+
|
| 19 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyValue ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 20 |
+
|
| 21 |
+
/// An implementation of `std::any` which stores
|
| 22 |
+
/// a type erased object, whose concrete value can be retrieved at runtime by
|
| 23 |
+
/// checking if the `typeid()` of a requested type matches the `typeid()` of
|
| 24 |
+
/// the object stored.
|
| 25 |
+
class AnyValue {
|
| 26 |
+
public:
|
| 27 |
+
/// Move construction and assignment is allowed, and follows the default
|
| 28 |
+
/// behavior of move for `std::unique_ptr`.
|
| 29 |
+
AnyValue(AnyValue&&) = default;
|
| 30 |
+
AnyValue& operator=(AnyValue&&) = default;
|
| 31 |
+
|
| 32 |
+
/// Copy construction and assignment is allowed.
|
| 33 |
+
AnyValue(const AnyValue& other) : content_(other.content_->clone()) {}
|
| 34 |
+
AnyValue& operator=(const AnyValue& other) {
|
| 35 |
+
content_ = other.content_->clone();
|
| 36 |
+
return *this;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
/// Constructs the `AnyValue` from value type.
|
| 40 |
+
template <typename T>
|
| 41 |
+
// NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
|
| 42 |
+
explicit AnyValue(T&& value)
|
| 43 |
+
: content_(
|
| 44 |
+
std::make_unique<Holder<std::decay_t<T>>>(std::forward<T>(value))) {
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
/// Returns a pointer to the value contained in the `AnyValue` if the type
|
| 48 |
+
/// passed as template parameter matches the type of the value stored, and
|
| 49 |
+
/// returns a null pointer otherwise.
|
| 50 |
+
template <typename T>
|
| 51 |
+
T* try_get() {
|
| 52 |
+
static_assert(
|
| 53 |
+
!std::is_reference<T>::value,
|
| 54 |
+
"AnyValue stores decayed types, you cannot cast it to a reference type");
|
| 55 |
+
static_assert(
|
| 56 |
+
!std::is_array<T>::value,
|
| 57 |
+
"AnyValue stores decayed types, you must cast it to T* instead of T[]");
|
| 58 |
+
if (typeid(T).hash_code() == type_info().hash_code()) {
|
| 59 |
+
return &static_cast<Holder<T>&>(*content_).value;
|
| 60 |
+
}
|
| 61 |
+
return nullptr;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
/// Returns the value contained in the `AnyValue` if the type passed as
|
| 65 |
+
/// template parameter matches the type of the value stored, and throws an
|
| 66 |
+
/// exception otherwise.
|
| 67 |
+
template <typename T>
|
| 68 |
+
T get() {
|
| 69 |
+
if (auto* maybe_value = try_get<T>()) {
|
| 70 |
+
return *maybe_value;
|
| 71 |
+
}
|
| 72 |
+
AT_ERROR(
|
| 73 |
+
"Attempted to cast AnyValue to ",
|
| 74 |
+
c10::demangle(typeid(T).name()),
|
| 75 |
+
", but its actual type is ",
|
| 76 |
+
c10::demangle(type_info().name()));
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
/// Returns the `type_info` object of the contained value.
|
| 80 |
+
const std::type_info& type_info() const noexcept {
|
| 81 |
+
return content_->type_info;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
private:
|
| 85 |
+
friend struct AnyModulePlaceholder;
|
| 86 |
+
friend struct TestAnyValue;
|
| 87 |
+
|
| 88 |
+
/// \internal
|
| 89 |
+
/// The static type of the object we store in the `AnyValue`, which erases the
|
| 90 |
+
/// actual object's type, allowing us only to check the `type_info` of the
|
| 91 |
+
/// type stored in the dynamic type.
|
| 92 |
+
struct Placeholder {
|
| 93 |
+
explicit Placeholder(const std::type_info& type_info_) noexcept
|
| 94 |
+
: type_info(type_info_) {}
|
| 95 |
+
Placeholder(const Placeholder&) = default;
|
| 96 |
+
Placeholder(Placeholder&&) = default;
|
| 97 |
+
virtual ~Placeholder() = default;
|
| 98 |
+
virtual std::unique_ptr<Placeholder> clone() const {
|
| 99 |
+
TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`");
|
| 100 |
+
}
|
| 101 |
+
const std::type_info& type_info;
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
+
/// \internal
|
| 105 |
+
/// The dynamic type of the object we store in the `AnyValue`, which hides the
|
| 106 |
+
/// actual object we have erased in this `AnyValue`.
|
| 107 |
+
template <typename T>
|
| 108 |
+
struct Holder : public Placeholder {
|
| 109 |
+
/// A template because T&& would not be universal reference here.
|
| 110 |
+
template <typename U>
|
| 111 |
+
// NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
|
| 112 |
+
explicit Holder(U&& value_) noexcept
|
| 113 |
+
: Placeholder(typeid(T)), value(std::forward<U>(value_)) {}
|
| 114 |
+
std::unique_ptr<Placeholder> clone() const override {
|
| 115 |
+
return std::make_unique<Holder<T>>(value);
|
| 116 |
+
}
|
| 117 |
+
T value;
|
| 118 |
+
};
|
| 119 |
+
|
| 120 |
+
/// The type erased object.
|
| 121 |
+
std::unique_ptr<Placeholder> content_;
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
} // namespace nn
|
| 125 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/csrc/Export.h>
|
| 4 |
+
#include <torch/csrc/utils/variadic.h>
|
| 5 |
+
#include <torch/nn/cloneable.h>
|
| 6 |
+
#include <torch/nn/pimpl.h>
|
| 7 |
+
#include <torch/types.h>
|
| 8 |
+
|
| 9 |
+
#include <functional>
|
| 10 |
+
#include <utility>
|
| 11 |
+
|
| 12 |
+
namespace torch {
|
| 13 |
+
namespace nn {
|
| 14 |
+
|
| 15 |
+
/// Wraps a function in a `Module`.
|
| 16 |
+
///
|
| 17 |
+
/// The `Functional` module allows wrapping an arbitrary function or function
|
| 18 |
+
/// object in an `nn::Module`. This is primarily handy for usage in
|
| 19 |
+
/// `Sequential`.
|
| 20 |
+
///
|
| 21 |
+
/// \rst
|
| 22 |
+
/// .. code-block:: cpp
|
| 23 |
+
///
|
| 24 |
+
/// Sequential sequential(
|
| 25 |
+
/// Linear(3, 4),
|
| 26 |
+
/// Functional(torch::relu),
|
| 27 |
+
/// BatchNorm1d(3),
|
| 28 |
+
/// Functional(torch::elu, /*alpha=*/1));
|
| 29 |
+
/// \endrst
|
| 30 |
+
///
|
| 31 |
+
/// While a `Functional` module only accepts a single `Tensor` as input, it is
|
| 32 |
+
/// possible for the wrapped function to accept further arguments. However,
|
| 33 |
+
/// these have to be bound *at construction time*. For example, if
|
| 34 |
+
/// you want to wrap `torch::leaky_relu`, which accepts a `slope` scalar as its
|
| 35 |
+
/// second argument, with a particular value for its `slope` in a `Functional`
|
| 36 |
+
/// module, you could write
|
| 37 |
+
///
|
| 38 |
+
/// \rst
|
| 39 |
+
/// .. code-block:: cpp
|
| 40 |
+
///
|
| 41 |
+
/// Functional(torch::leaky_relu, /*slope=*/0.5)
|
| 42 |
+
/// \endrst
|
| 43 |
+
///
|
| 44 |
+
/// The value of `0.5` is then stored within the `Functional` object and
|
| 45 |
+
/// supplied to the function call at invocation time. Note that such bound
|
| 46 |
+
/// values are evaluated eagerly and stored a single time. See the documentation
|
| 47 |
+
/// of [std::bind](https://en.cppreference.com/w/cpp/utility/functional/bind)
|
| 48 |
+
/// for more information on the semantics of argument binding.
|
| 49 |
+
///
|
| 50 |
+
/// \rst
|
| 51 |
+
/// .. attention::
|
| 52 |
+
/// After passing any bound arguments, the function must accept a single
|
| 53 |
+
/// tensor and return a single tensor.
|
| 54 |
+
/// \endrst
|
| 55 |
+
///
|
| 56 |
+
/// Note that `Functional` overloads the call operator (`operator()`) such that
|
| 57 |
+
/// you can invoke it with `my_func(...)`.
|
| 58 |
+
class TORCH_API FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl> {
|
| 59 |
+
public:
|
| 60 |
+
using Function = std::function<Tensor(Tensor)>;
|
| 61 |
+
|
| 62 |
+
/// Constructs a `Functional` from a function object.
|
| 63 |
+
explicit FunctionalImpl(Function function);
|
| 64 |
+
|
| 65 |
+
template <
|
| 66 |
+
typename SomeFunction,
|
| 67 |
+
typename... Args,
|
| 68 |
+
typename = std::enable_if_t<(sizeof...(Args) > 0)>>
|
| 69 |
+
explicit FunctionalImpl(SomeFunction original_function, Args&&... args)
|
| 70 |
+
// NOLINTNEXTLINE(modernize-avoid-bind)
|
| 71 |
+
: function_(std::bind(
|
| 72 |
+
original_function,
|
| 73 |
+
/*input=*/std::placeholders::_1,
|
| 74 |
+
std::forward<Args>(args)...)) {
|
| 75 |
+
// std::bind is normally evil, but (1) gcc is broken w.r.t. handling
|
| 76 |
+
// parameter pack expansion in lambdas and (2) moving parameter packs into
|
| 77 |
+
// a lambda only works with C++14, so std::bind is the more move-aware
|
| 78 |
+
// solution here.
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
void reset() override;
|
| 82 |
+
|
| 83 |
+
/// Pretty prints the `Functional` module into the given `stream`.
|
| 84 |
+
void pretty_print(std::ostream& stream) const override;
|
| 85 |
+
|
| 86 |
+
/// Forwards the `input` tensor to the underlying (bound) function object.
|
| 87 |
+
Tensor forward(Tensor input);
|
| 88 |
+
|
| 89 |
+
/// Calls forward(input).
|
| 90 |
+
Tensor operator()(Tensor input);
|
| 91 |
+
|
| 92 |
+
bool is_serializable() const override;
|
| 93 |
+
|
| 94 |
+
private:
|
| 95 |
+
Function function_;
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
/// A `ModuleHolder` subclass for `FunctionalImpl`.
|
| 99 |
+
/// See the documentation for `FunctionalImpl` class to learn what methods it
|
| 100 |
+
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
|
| 101 |
+
/// module storage semantics.
|
| 102 |
+
TORCH_MODULE(Functional);
|
| 103 |
+
|
| 104 |
+
} // namespace nn
|
| 105 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/nn/cloneable.h>
|
| 4 |
+
#include <torch/nn/module.h>
|
| 5 |
+
#include <torch/ordered_dict.h>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
namespace torch {
|
| 9 |
+
namespace nn {
|
| 10 |
+
|
| 11 |
+
/// An OrderedDict of `Module`s that registers its elements by their `key`s.
|
| 12 |
+
///
|
| 13 |
+
/// \rst
|
| 14 |
+
/// .. code-block:: cpp
|
| 15 |
+
///
|
| 16 |
+
/// torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
| 17 |
+
/// {"linear", Linear(10, 3).ptr()},
|
| 18 |
+
/// {"conv", Conv2d(1, 2, 3).ptr()},
|
| 19 |
+
/// {"dropout", Dropout(0.5).ptr()},
|
| 20 |
+
/// };
|
| 21 |
+
/// torch::nn::ModuleDict dict1(ordereddict);
|
| 22 |
+
///
|
| 23 |
+
/// for (const auto &module : *dict1) {
|
| 24 |
+
/// module->pretty_print(std::cout);
|
| 25 |
+
/// }
|
| 26 |
+
///
|
| 27 |
+
/// std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
|
| 28 |
+
/// {"linear", Linear(10, 3).ptr()},
|
| 29 |
+
/// {"conv", Conv2d(1, 2, 3).ptr()},
|
| 30 |
+
/// {"dropout", Dropout(0.5).ptr()},
|
| 31 |
+
/// };
|
| 32 |
+
/// torch::nn::ModuleDict dict2(list);
|
| 33 |
+
///
|
| 34 |
+
/// for (const auto &module : *dict2) {
|
| 35 |
+
/// module->pretty_print(std::cout);
|
| 36 |
+
/// }
|
| 37 |
+
///
|
| 38 |
+
/// \endrst
|
| 39 |
+
///
|
| 40 |
+
/// Why should you use `ModuleDict` instead of a simple `map` or `OrderedDict`?
|
| 41 |
+
/// The value a `ModuleDict` provides over manually calling an ordered map of
|
| 42 |
+
/// modules is that it allows treating the whole container *as a single module*,
|
| 43 |
+
/// such that performing a transformation on the `ModuleDict` applies to each of
|
| 44 |
+
/// the modules it stores (which are each a registered submodule of the
|
| 45 |
+
/// `ModuleDict`). For example, calling `.to(torch::kCUDA)` on a `ModuleDict`
|
| 46 |
+
/// will move each module in the map to CUDA memory. For example:
|
| 47 |
+
///
|
| 48 |
+
/// \rst
|
| 49 |
+
/// .. code-block:: cpp
|
| 50 |
+
///
|
| 51 |
+
/// torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
|
| 52 |
+
/// {"linear", Linear(10, 3).ptr()},
|
| 53 |
+
/// {"conv", Conv2d(1, 2, 3).ptr()},
|
| 54 |
+
/// {"dropout", Dropout(0.5).ptr()},
|
| 55 |
+
/// };
|
| 56 |
+
/// torch::nn::ModuleDict dict(ordereddict);
|
| 57 |
+
///
|
| 58 |
+
/// // Convert all modules to CUDA.
|
| 59 |
+
/// dict->to(torch::kCUDA);
|
| 60 |
+
///
|
| 61 |
+
/// \endrst
|
| 62 |
+
///
|
| 63 |
+
/// Finally, `ModuleDict` provides a lightweight container API, such as allowing
|
| 64 |
+
/// iteration over submodules, positional access, adding new modules from a
|
| 65 |
+
/// vector of key-module pairs or an `OrderedDict` or another `ModuleDict` after
|
| 66 |
+
/// construction via `update`.
|
| 67 |
+
class ModuleDictImpl : public Cloneable<ModuleDictImpl> {
|
| 68 |
+
public:
|
| 69 |
+
using Iterator =
|
| 70 |
+
torch::OrderedDict<std::string, std::shared_ptr<Module>>::Iterator;
|
| 71 |
+
using ConstIterator =
|
| 72 |
+
torch::OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator;
|
| 73 |
+
|
| 74 |
+
ModuleDictImpl() = default;
|
| 75 |
+
|
| 76 |
+
/// Constructs the `ModuleDict` from a list of string-Module pairs.
|
| 77 |
+
explicit ModuleDictImpl(
|
| 78 |
+
const std::vector<std::pair<std::string, std::shared_ptr<Module>>>&
|
| 79 |
+
modules) {
|
| 80 |
+
update(modules);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/// Constructs the `ModuleDict` from an `OrderedDict`.
|
| 84 |
+
explicit ModuleDictImpl(
|
| 85 |
+
const torch::OrderedDict<std::string, std::shared_ptr<Module>>& modules) {
|
| 86 |
+
update(modules);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
/// Return the items in the `ModuleDict`.
|
| 90 |
+
std::vector<std::pair<std::string, std::shared_ptr<Module>>> items() const {
|
| 91 |
+
return modules_.pairs();
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
/// Return the keys in the `ModuleDict`.
|
| 95 |
+
std::vector<std::string> keys() const {
|
| 96 |
+
return modules_.keys();
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
/// Return the values in the `ModuleDict`.
|
| 100 |
+
std::vector<std::shared_ptr<Module>> values() const {
|
| 101 |
+
return modules_.values();
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
/// Return an iterator to the start of `ModuleDict`.
|
| 105 |
+
Iterator begin() {
|
| 106 |
+
return modules_.begin();
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
/// Return a const iterator to the start of `ModuleDict`.
|
| 110 |
+
ConstIterator begin() const {
|
| 111 |
+
return modules_.begin();
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
/// Return an iterator to the end of `ModuleDict`.
|
| 115 |
+
Iterator end() {
|
| 116 |
+
return modules_.end();
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
/// Return a const iterator to the end of `ModuleDict`.
|
| 120 |
+
ConstIterator end() const {
|
| 121 |
+
return modules_.end();
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
/// Return the number of items currently stored in the `ModuleDict`.
|
| 125 |
+
size_t size() const noexcept {
|
| 126 |
+
return modules_.size();
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
/// Return true if the `ModuleDict` is empty, otherwise return false.
|
| 130 |
+
bool empty() const noexcept {
|
| 131 |
+
return modules_.is_empty();
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
/// Check if the centain parameter with the key in the `ModuleDict`.
|
| 135 |
+
bool contains(const std::string& key) const noexcept {
|
| 136 |
+
return modules_.contains(key);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
/// Remove all items from the `ModuleDict`.
|
| 140 |
+
void clear() {
|
| 141 |
+
// Not remove the registration of modules to make it consistent with python
|
| 142 |
+
// version.
|
| 143 |
+
modules_.clear();
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
/// Special cloning function for `ModuleDict` because it does not use
|
| 147 |
+
/// `reset()`.
|
| 148 |
+
std::shared_ptr<Module> clone(
|
| 149 |
+
const std::optional<Device>& device = std::nullopt) const override {
|
| 150 |
+
auto clone = std::make_shared<ModuleDictImpl>();
|
| 151 |
+
for (const auto& module : modules_) {
|
| 152 |
+
clone->insert(module.key(), module.value()->clone(device));
|
| 153 |
+
}
|
| 154 |
+
return clone;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
/// `reset()` is empty for `ModuleDict`, since it does not have parameters of
|
| 158 |
+
/// its own.
|
| 159 |
+
void reset() override {}
|
| 160 |
+
|
| 161 |
+
/// Pretty prints the `ModuleDict` into the given `stream`.
|
| 162 |
+
void pretty_print(std::ostream& stream) const override {
|
| 163 |
+
stream << "torch::nn::ModuleDict";
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/// Attempts to returns the `Module` associated with the given `key`. Throws
|
| 167 |
+
/// an exception if no such `key` is stored in the `ModuleDict`. Check
|
| 168 |
+
/// contains(key) before for a non-throwing way of access.
|
| 169 |
+
std::shared_ptr<Module> operator[](const std::string& key) const {
|
| 170 |
+
return modules_[key];
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
/// Attempts to return the module at the given key as the requested type.
|
| 174 |
+
/// Throws an exception if no such `key` is stored in the `ModuleDict`.
|
| 175 |
+
/// Check contains(key) before for a non-throwing way of access.
|
| 176 |
+
template <typename T>
|
| 177 |
+
T& at(const std::string& key) {
|
| 178 |
+
static_assert(
|
| 179 |
+
torch::detail::is_module<T>::value,
|
| 180 |
+
"Can only call ModuleList::at with an nn::Module type");
|
| 181 |
+
auto module = modules_[key]->as<T>();
|
| 182 |
+
TORCH_CHECK(
|
| 183 |
+
module,
|
| 184 |
+
"Unable to cast module[",
|
| 185 |
+
key,
|
| 186 |
+
"] to ",
|
| 187 |
+
c10::demangle(typeid(T).name()));
|
| 188 |
+
return *module;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
/// Attempts to return the module at the given key as the requested type.
|
| 192 |
+
/// Throws an exception if no such `key` is stored in the `ModuleDict`.
|
| 193 |
+
/// Check contains(key) before for a non-throwing way of access.
|
| 194 |
+
template <typename T>
|
| 195 |
+
const T& at(const std::string& key) const {
|
| 196 |
+
static_assert(
|
| 197 |
+
torch::detail::is_module<T>::value,
|
| 198 |
+
"Can only call ModuleList::at with an nn::Module type");
|
| 199 |
+
const auto module = modules_[key]->as<T>();
|
| 200 |
+
TORCH_CHECK(
|
| 201 |
+
module,
|
| 202 |
+
"Unable to cast module[",
|
| 203 |
+
key,
|
| 204 |
+
"] to ",
|
| 205 |
+
c10::demangle(typeid(T).name()));
|
| 206 |
+
return *module;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
/// Removes and returns the `Module` associated with the given `key`.
|
| 210 |
+
/// Throws an exception if no such `key` is stored in the `ModuleDict`.
|
| 211 |
+
/// Check contains(key) before for a non-throwing way of access.
|
| 212 |
+
std::shared_ptr<Module> pop(const std::string& key) {
|
| 213 |
+
auto module = modules_[key];
|
| 214 |
+
modules_.erase(key);
|
| 215 |
+
// Not remove the registration of the module to make it consistent with
|
| 216 |
+
// python version.
|
| 217 |
+
return module;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
/// Updated the `ModuleDict` with a vector of key-module pairs.
|
| 221 |
+
void update(
|
| 222 |
+
const std::vector<std::pair<std::string, std::shared_ptr<Module>>>&
|
| 223 |
+
modules) {
|
| 224 |
+
for (auto& item : modules) {
|
| 225 |
+
insert(item.first, item.second);
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
/// Updated the `ModuleDict` with key-value pairs from `OrderedDict` or
|
| 230 |
+
/// `ModuleDict`.
|
| 231 |
+
template <typename Container>
|
| 232 |
+
void update(const Container& container) {
|
| 233 |
+
for (auto& item : container) {
|
| 234 |
+
insert(item.key(), item.value());
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
private:
|
| 239 |
+
/// Private `OrderedDict` holding the key-Module pairs.
|
| 240 |
+
torch::OrderedDict<std::string, std::shared_ptr<Module>> modules_;
|
| 241 |
+
|
| 242 |
+
/// Insert a key-module pair by overwriting existing keys,
|
| 243 |
+
/// and register or replace the `Module`.
|
| 244 |
+
void insert(const std::string& key, std::shared_ptr<Module> module) {
|
| 245 |
+
if (contains(key)) {
|
| 246 |
+
modules_[key] = std::move(module);
|
| 247 |
+
replace_module(key, modules_[key]);
|
| 248 |
+
} else {
|
| 249 |
+
modules_.insert(key, std::move(module));
|
| 250 |
+
register_module(key, modules_.back().value());
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
};
|
| 254 |
+
|
| 255 |
+
/// A `ModuleHolder` subclass for `ModuleDictImpl`.
|
| 256 |
+
/// See the documentation for `ModuleDictImpl` class to learn what methods it
|
| 257 |
+
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
|
| 258 |
+
/// module storage semantics.
|
| 259 |
+
TORCH_MODULE(ModuleDict);
|
| 260 |
+
|
| 261 |
+
} // namespace nn
|
| 262 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
#include <torch/nn/cloneable.h>
|
| 5 |
+
#include <torch/nn/module.h>
|
| 6 |
+
|
| 7 |
+
#include <utility>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
namespace torch {
|
| 11 |
+
namespace nn {
|
| 12 |
+
|
| 13 |
+
/// A list of `Module`s that registers its elements.
|
| 14 |
+
///
|
| 15 |
+
/// \rst
|
| 16 |
+
/// .. code-block:: cpp
|
| 17 |
+
///
|
| 18 |
+
/// torch::nn::ModuleList mlist(
|
| 19 |
+
/// torch::nn::Linear(3, 4),
|
| 20 |
+
/// torch::nn::BatchNorm1d(4),
|
| 21 |
+
/// torch::nn::Dropout(0.5)
|
| 22 |
+
/// );
|
| 23 |
+
///
|
| 24 |
+
/// for (const auto &module : *mlist) {
|
| 25 |
+
/// module->pretty_print(std::cout);
|
| 26 |
+
/// }
|
| 27 |
+
///
|
| 28 |
+
/// \endrst
|
| 29 |
+
///
|
| 30 |
+
/// Why should you use `ModuleList` instead of a simple `std::vector`? The value
|
| 31 |
+
/// a `ModuleList` provides over manually calling a sequence of modules is that
|
| 32 |
+
/// it allows treating the whole container *as a single module*, such that
|
| 33 |
+
/// performing a transformation on the `ModuleList` applies to each of the
|
| 34 |
+
/// modules it stores (which are each a registered submodule of the
|
| 35 |
+
/// `ModuleList`). For example, calling
|
| 36 |
+
/// `.to(torch::kCUDA)` on a `ModuleList` will move each module in the list to
|
| 37 |
+
/// CUDA memory. For example:
|
| 38 |
+
///
|
| 39 |
+
/// \rst
|
| 40 |
+
/// .. code-block:: cpp
|
| 41 |
+
///
|
| 42 |
+
/// torch::nn::ModuleList mlist(
|
| 43 |
+
/// torch::nn::Linear(3, 4),
|
| 44 |
+
/// torch::nn::BatchNorm1d(4),
|
| 45 |
+
/// torch::nn::Dropout(0.5)
|
| 46 |
+
/// );
|
| 47 |
+
///
|
| 48 |
+
/// // Convert all modules to CUDA.
|
| 49 |
+
/// mlist->to(torch::kCUDA);
|
| 50 |
+
///
|
| 51 |
+
/// \endrst
|
| 52 |
+
///
|
| 53 |
+
/// Finally, `ModuleList` provides a lightweight container API, such as allowing
|
| 54 |
+
/// iteration over submodules, positional access, adding a new module after
|
| 55 |
+
/// construction via `push_back`, as well as joining two `ModuleList`s via
|
| 56 |
+
/// `extend`.
|
| 57 |
+
class ModuleListImpl : public Cloneable<ModuleListImpl> {
|
| 58 |
+
public:
|
| 59 |
+
using Iterator = std::vector<std::shared_ptr<Module>>::iterator;
|
| 60 |
+
using ConstIterator = std::vector<std::shared_ptr<Module>>::const_iterator;
|
| 61 |
+
|
| 62 |
+
ModuleListImpl() = default;
|
| 63 |
+
|
| 64 |
+
/// Constructs the `ModuleList` from a variadic list of modules.
|
| 65 |
+
template <typename... Modules>
|
| 66 |
+
explicit ModuleListImpl(Modules&&... modules) {
|
| 67 |
+
modules_.reserve(sizeof...(Modules));
|
| 68 |
+
push_back_var(std::forward<Modules>(modules)...);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
/// Special cloning function for `ModuleList` because it does not use
|
| 72 |
+
/// `reset()`.
|
| 73 |
+
std::shared_ptr<Module> clone(
|
| 74 |
+
const std::optional<Device>& device = std::nullopt) const override {
|
| 75 |
+
auto clone = std::make_shared<ModuleListImpl>();
|
| 76 |
+
for (const auto& module : modules_) {
|
| 77 |
+
clone->push_back(module->clone(device));
|
| 78 |
+
}
|
| 79 |
+
return clone;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
/// `reset()` is empty for `ModuleList`, since it does not have parameters of
|
| 83 |
+
/// its own.
|
| 84 |
+
void reset() override {}
|
| 85 |
+
|
| 86 |
+
/// Pretty prints the `ModuleList` module into the given `stream`.
|
| 87 |
+
void pretty_print(std::ostream& stream) const override {
|
| 88 |
+
stream << "torch::nn::ModuleList";
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
void push_back(std::shared_ptr<Module> module) {
|
| 92 |
+
modules_.push_back(std::move(module));
|
| 93 |
+
const auto index = modules_.size() - 1;
|
| 94 |
+
register_module(std::to_string(index), modules_[index]);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/// Adds a new `Module` to the `ModuleList` container, moving or copying
|
| 98 |
+
/// it into a `shared_ptr` internally. This method allows passing value types,
|
| 99 |
+
/// and letting the container deal with the boxing.
|
| 100 |
+
template <typename M, typename = torch::detail::enable_if_module_t<M>>
|
| 101 |
+
void push_back(M&& module) {
|
| 102 |
+
using Type = typename std::remove_reference<M>::type;
|
| 103 |
+
push_back(std::make_shared<Type>(std::forward<M>(module)));
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
/// Unwraps the contained module of a `ModuleHolder` and adds it to the
|
| 107 |
+
/// `ModuleList`.
|
| 108 |
+
template <typename M>
|
| 109 |
+
void push_back(const ModuleHolder<M>& module_holder) {
|
| 110 |
+
push_back(module_holder.ptr());
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
/// Iterates over the container and calls `push_back()` on each value.
|
| 114 |
+
template <typename Container>
|
| 115 |
+
void extend(const Container& container) {
|
| 116 |
+
for (const auto& module : container) {
|
| 117 |
+
push_back(module);
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
/// Returns an iterator to the start of the `ModuleList`.
|
| 122 |
+
Iterator begin() {
|
| 123 |
+
return modules_.begin();
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
/// Returns a const iterator to the start of the `ModuleList`.
|
| 127 |
+
ConstIterator begin() const {
|
| 128 |
+
return modules_.begin();
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
/// Returns an iterator to the end of the `ModuleList`.
|
| 132 |
+
Iterator end() {
|
| 133 |
+
return modules_.end();
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
/// Returns a const iterator to the end of the `ModuleList`.
|
| 137 |
+
ConstIterator end() const {
|
| 138 |
+
return modules_.end();
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
/// Attempts to return the module at the given index as the requested type.
|
| 142 |
+
/// Throws an exception if the index is out of bounds or the types do not
|
| 143 |
+
/// match.
|
| 144 |
+
template <typename T>
|
| 145 |
+
T& at(size_t index) {
|
| 146 |
+
static_assert(
|
| 147 |
+
torch::detail::is_module<T>::value,
|
| 148 |
+
"Can only call ModuleList::at with an nn::Module type");
|
| 149 |
+
TORCH_CHECK(index < size(), "Index out of range");
|
| 150 |
+
auto module = modules_[index]->as<T>();
|
| 151 |
+
TORCH_CHECK(
|
| 152 |
+
module,
|
| 153 |
+
"Unable to cast module[",
|
| 154 |
+
index,
|
| 155 |
+
"] to ",
|
| 156 |
+
c10::demangle(typeid(T).name()));
|
| 157 |
+
return *module;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/// Attempts to return the module at the given index as the requested type.
|
| 161 |
+
/// Throws an exception if the index is out of bounds or the types do not
|
| 162 |
+
/// match.
|
| 163 |
+
template <typename T>
|
| 164 |
+
const T& at(size_t index) const {
|
| 165 |
+
static_assert(
|
| 166 |
+
torch::detail::is_module<T>::value,
|
| 167 |
+
"Can only call ModuleList::at with an nn::Module type");
|
| 168 |
+
TORCH_CHECK(index < size(), "Index out of range");
|
| 169 |
+
const auto module = modules_[index]->as<T>();
|
| 170 |
+
TORCH_CHECK(
|
| 171 |
+
module,
|
| 172 |
+
"Unable to cast module[",
|
| 173 |
+
index,
|
| 174 |
+
"] to ",
|
| 175 |
+
c10::demangle(typeid(T).name()));
|
| 176 |
+
return *module;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
/// Attempts to return a `std::shared_ptr` whose dynamic type is that of the
|
| 180 |
+
/// underlying module at the given index. Throws an exception if the index is
|
| 181 |
+
/// out of bounds.
|
| 182 |
+
std::shared_ptr<Module> ptr(size_t index) const {
|
| 183 |
+
TORCH_CHECK(index < size(), "Index out of range");
|
| 184 |
+
return modules_[index];
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
/// Attempts to return a `std::shared_ptr` whose type is the one provided.
|
| 188 |
+
/// Throws an exception if the index is out of bounds or the types do not
|
| 189 |
+
/// match.
|
| 190 |
+
template <typename T>
|
| 191 |
+
std::shared_ptr<T> ptr(size_t index) const {
|
| 192 |
+
static_assert(
|
| 193 |
+
torch::detail::is_module<T>::value,
|
| 194 |
+
"Can only call ModuleList::ptr with an nn::Module type");
|
| 195 |
+
TORCH_CHECK(index < size(), "Index out of range");
|
| 196 |
+
return std::dynamic_pointer_cast<T>(modules_[index]);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
/// Like `ptr(index)`.
|
| 200 |
+
std::shared_ptr<Module> operator[](size_t index) const {
|
| 201 |
+
// This is the only method we can call without a type.
|
| 202 |
+
return ptr(index);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
/// The current size of the `ModuleList` container.
|
| 206 |
+
size_t size() const noexcept {
|
| 207 |
+
return modules_.size();
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
/// True if there are no modules in the `ModuleList`.
|
| 211 |
+
bool is_empty() const noexcept {
|
| 212 |
+
return size() == 0;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
void insert(size_t index, std::shared_ptr<Module> module) {
|
| 216 |
+
TORCH_CHECK(index <= size(), "Index out of range");
|
| 217 |
+
|
| 218 |
+
if (index == size())
|
| 219 |
+
push_back(std::move(module));
|
| 220 |
+
else {
|
| 221 |
+
modules_.insert(
|
| 222 |
+
modules_.begin() + Iterator::difference_type(index),
|
| 223 |
+
std::move(module));
|
| 224 |
+
|
| 225 |
+
for (const auto i : c10::irange(index, size() - 1)) {
|
| 226 |
+
(void)i; // Suppress unused variable warning
|
| 227 |
+
replace_module(std::to_string(index), modules_[index]);
|
| 228 |
+
}
|
| 229 |
+
register_module(std::to_string(size() - 1), modules_.back());
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
/// Unwraps the contained module of a `ModuleHolder` and inserts it in the
|
| 234 |
+
/// `ModuleList`.
|
| 235 |
+
template <typename M>
|
| 236 |
+
void insert(size_t index, const ModuleHolder<M>& module_holder) {
|
| 237 |
+
insert(index, module_holder.ptr());
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
/// inserts a new `Module` to the `ModuleList` container, moving or copying
|
| 241 |
+
/// it into a `shared_ptr` internally. This method allows passing value types,
|
| 242 |
+
/// and letting the container deal with the boxing.
|
| 243 |
+
template <typename M, typename = torch::detail::enable_if_module_t<M>>
|
| 244 |
+
void insert(size_t index, M&& module) {
|
| 245 |
+
using Type = typename std::remove_reference<M>::type;
|
| 246 |
+
insert(index, std::make_shared<Type>(std::forward<M>(module)));
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
private:
|
| 250 |
+
template <typename Head, typename... Tail>
|
| 251 |
+
void push_back_var(Head&& head, Tail&&... tail) {
|
| 252 |
+
push_back(std::forward<Head>(head));
|
| 253 |
+
// Recursively calls this method, until the parameter pack only thas this
|
| 254 |
+
// entry left. Then calls `push_back()` a final time (above).
|
| 255 |
+
push_back_var(std::forward<Tail>(tail)...);
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
/// The base case, when the list of modules is empty.
|
| 259 |
+
void push_back_var() {}
|
| 260 |
+
|
| 261 |
+
// Box the AnyModules to give ModuleList reference semantics, like the rest of
|
| 262 |
+
// the API. Note that this is not required otherwise, this could just be a
|
| 263 |
+
// `vector<AnyModule>`.
|
| 264 |
+
std::vector<std::shared_ptr<Module>> modules_;
|
| 265 |
+
};
|
| 266 |
+
|
| 267 |
+
/// A `ModuleHolder` subclass for `ModuleListImpl`.
|
| 268 |
+
/// See the documentation for `ModuleListImpl` class to learn what methods it
|
| 269 |
+
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
|
| 270 |
+
/// module storage semantics.
|
| 271 |
+
TORCH_MODULE(ModuleList);
|
| 272 |
+
|
| 273 |
+
} // namespace nn
|
| 274 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/detail/static.h>
|
| 4 |
+
#include <torch/nn/module.h>
|
| 5 |
+
#include <torch/nn/modules/container/any.h>
|
| 6 |
+
#include <torch/nn/pimpl.h>
|
| 7 |
+
#include <torch/types.h>
|
| 8 |
+
|
| 9 |
+
#include <torch/csrc/autograd/variable.h>
|
| 10 |
+
#include <torch/csrc/utils/variadic.h>
|
| 11 |
+
|
| 12 |
+
#include <ATen/Device.h>
|
| 13 |
+
|
| 14 |
+
#include <initializer_list>
|
| 15 |
+
#include <memory>
|
| 16 |
+
#include <type_traits>
|
| 17 |
+
#include <typeinfo>
|
| 18 |
+
#include <utility>
|
| 19 |
+
#include <vector>
|
| 20 |
+
|
| 21 |
+
namespace torch {
|
| 22 |
+
namespace nn {
|
| 23 |
+
|
| 24 |
+
/// Stores a type erased `Module` with name.
|
| 25 |
+
///
|
| 26 |
+
/// The `NamedAnyModule` class enables the following API for constructing
|
| 27 |
+
/// `nn::Sequential` with named submodules:
|
| 28 |
+
/// \rst
|
| 29 |
+
/// .. code-block:: cpp
|
| 30 |
+
///
|
| 31 |
+
/// struct M : torch::nn::Module {
|
| 32 |
+
/// explicit M(int value_) : value(value_) {}
|
| 33 |
+
/// int value;
|
| 34 |
+
/// int forward() {
|
| 35 |
+
/// return value;
|
| 36 |
+
/// }
|
| 37 |
+
/// };
|
| 38 |
+
///
|
| 39 |
+
/// Sequential sequential({
|
| 40 |
+
/// {"m1", std::make_shared<M>(1)}, // shared pointer to `Module` is
|
| 41 |
+
/// supported {std::string("m2"), M(2)}, // `Module` is supported
|
| 42 |
+
/// {"linear1", Linear(10, 3)} // `ModuleHolder` is supported
|
| 43 |
+
/// });
|
| 44 |
+
/// \endrst
|
| 45 |
+
class NamedAnyModule {
|
| 46 |
+
public:
|
| 47 |
+
/// Creates a `NamedAnyModule` from a (boxed) `Module`.
|
| 48 |
+
template <typename ModuleType>
|
| 49 |
+
NamedAnyModule(std::string name, std::shared_ptr<ModuleType> module_ptr)
|
| 50 |
+
: NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {}
|
| 51 |
+
|
| 52 |
+
/// Creates a `NamedAnyModule` from a `Module`, moving or copying it
|
| 53 |
+
/// into a `shared_ptr` internally.
|
| 54 |
+
// NOTE: We need to use `std::remove_reference<M>::type` to get rid of
|
| 55 |
+
// any reference components for make_unique.
|
| 56 |
+
template <typename M, typename = torch::detail::enable_if_module_t<M>>
|
| 57 |
+
NamedAnyModule(std::string name, M&& module)
|
| 58 |
+
: NamedAnyModule(
|
| 59 |
+
std::move(name),
|
| 60 |
+
std::make_shared<typename std::remove_reference<M>::type>(
|
| 61 |
+
std::forward<M>(module))) {}
|
| 62 |
+
|
| 63 |
+
/// Creates a `NamedAnyModule` from a `Module` that is unwrapped from
|
| 64 |
+
/// a `ModuleHolder`.
|
| 65 |
+
template <typename M>
|
| 66 |
+
NamedAnyModule(std::string name, const ModuleHolder<M>& module_holder)
|
| 67 |
+
: NamedAnyModule(std::move(name), module_holder.ptr()) {}
|
| 68 |
+
|
| 69 |
+
/// Creates a `NamedAnyModule` from a type-erased `AnyModule`.
|
| 70 |
+
NamedAnyModule(std::string name, AnyModule any_module)
|
| 71 |
+
: name_(std::move(name)), module_(std::move(any_module)) {}
|
| 72 |
+
|
| 73 |
+
/// Returns a reference to the name.
|
| 74 |
+
const std::string& name() const noexcept {
|
| 75 |
+
return name_;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
/// Returns a reference to the module.
|
| 79 |
+
AnyModule& module() noexcept {
|
| 80 |
+
return module_;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/// Returns a const reference to the module.
|
| 84 |
+
const AnyModule& module() const noexcept {
|
| 85 |
+
return module_;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
private:
|
| 89 |
+
std::string name_;
|
| 90 |
+
AnyModule module_;
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
} // namespace nn
|
| 94 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/nn/cloneable.h>
|
| 4 |
+
#include <torch/nn/pimpl.h>
|
| 5 |
+
#include <torch/ordered_dict.h>
|
| 6 |
+
#include <utility>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
namespace torch {
|
| 10 |
+
namespace nn {
|
| 11 |
+
|
| 12 |
+
class ParameterDictImpl : public Cloneable<ParameterDictImpl> {
|
| 13 |
+
public:
|
| 14 |
+
using Iterator = OrderedDict<std::string, Tensor>::Iterator;
|
| 15 |
+
using ConstIterator = OrderedDict<std::string, Tensor>::ConstIterator;
|
| 16 |
+
|
| 17 |
+
ParameterDictImpl() = default;
|
| 18 |
+
|
| 19 |
+
explicit ParameterDictImpl(
|
| 20 |
+
const torch::OrderedDict<std::string, torch::Tensor>& params) {
|
| 21 |
+
parameters_ = params;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
/// `reset()` is empty for `ParameterDict`, since it does not have
|
| 25 |
+
/// parameters of its own.
|
| 26 |
+
void reset() override {}
|
| 27 |
+
|
| 28 |
+
/// Pretty prints the `ParameterDict` module into the given `stream`.
|
| 29 |
+
void pretty_print(std::ostream& stream) const override {
|
| 30 |
+
stream << "torch::nn::ParameterDict(" << std::endl;
|
| 31 |
+
for (const auto& pair : parameters_) {
|
| 32 |
+
stream << "(" << pair.key() << ")"
|
| 33 |
+
<< ": Parameter containing: [" << pair.value().scalar_type()
|
| 34 |
+
<< " of size " << pair.value().sizes() << "]";
|
| 35 |
+
;
|
| 36 |
+
stream << std::endl;
|
| 37 |
+
}
|
| 38 |
+
stream << ")";
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
/// Insert the parameter along with the key into ParameterDict
|
| 42 |
+
/// The parameter is set to be require grad by default
|
| 43 |
+
Tensor& insert(std::string key, Tensor param) {
|
| 44 |
+
bool requires_grad = param.requires_grad();
|
| 45 |
+
return register_parameter(std::move(key), std::move(param), requires_grad);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
/// Remove key from the ParameterDict and return its value, throw exception
|
| 49 |
+
/// if the key is not contained. Please check contains(key) before for a
|
| 50 |
+
/// non-throwing access.
|
| 51 |
+
Tensor pop(const std::string& key) {
|
| 52 |
+
torch::Tensor v = parameters_[key];
|
| 53 |
+
parameters_.erase(key);
|
| 54 |
+
return v;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
/// Return the keys in the dict
|
| 58 |
+
::std::vector<std::string> keys() const {
|
| 59 |
+
return parameters_.keys();
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
/// Return the Values in the dict
|
| 63 |
+
::std::vector<torch::Tensor> values() const {
|
| 64 |
+
return parameters_.values();
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
/// Return an iterator to the start of ParameterDict
|
| 68 |
+
Iterator begin() {
|
| 69 |
+
return parameters_.begin();
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
/// Return a const iterator to the start of ParameterDict
|
| 73 |
+
ConstIterator begin() const {
|
| 74 |
+
return parameters_.begin();
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
/// Return an iterator to the end of ParameterDict
|
| 78 |
+
Iterator end() {
|
| 79 |
+
return parameters_.end();
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
/// Return a const iterator to the end of ParameterDict
|
| 83 |
+
ConstIterator end() const {
|
| 84 |
+
return parameters_.end();
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
/// Return the number of items currently stored in the ParameterDict
|
| 88 |
+
size_t size() const noexcept {
|
| 89 |
+
return parameters_.size();
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
/// Return true if the ParameterDict is empty, otherwise return false
|
| 93 |
+
bool empty() const noexcept {
|
| 94 |
+
return parameters_.is_empty();
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/// Update the ParameterDict with the key-value pairs from
|
| 98 |
+
/// another ParameterDict, overwriting existing key
|
| 99 |
+
template <typename Container>
|
| 100 |
+
void update(const Container& container) {
|
| 101 |
+
for (auto& item : container) {
|
| 102 |
+
parameters_[item.key()] = item.value();
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
/// Remove all parameters in the ParameterDict
|
| 107 |
+
void clear() {
|
| 108 |
+
parameters_.clear();
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
/// Check if the centain parameter with the key in the ParameterDict
|
| 112 |
+
bool contains(const std::string& key) const noexcept {
|
| 113 |
+
return parameters_.contains(key);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
/// Returns the value associated with the given `key`. Throws an exception if
|
| 117 |
+
/// no such key is stored in the `ParameterDict`. Check contains(key) before
|
| 118 |
+
/// for a non-throwing way of access
|
| 119 |
+
const Tensor& get(const std::string& key) const {
|
| 120 |
+
return parameters_[key];
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
/// Returns the value associated with the given `key`. Throws an exception if
|
| 124 |
+
/// no such key is stored in the `ParameterDict`. Check contains(key) before
|
| 125 |
+
/// for a non-throwing way of access
|
| 126 |
+
Tensor& get(const std::string& key) {
|
| 127 |
+
return parameters_[key];
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/// Returns the value associated with the given `key`. Throws an exception if
|
| 131 |
+
/// no such key is stored in the `ParameterDict`. Check contains(key) before
|
| 132 |
+
/// for a non-throwing way of access
|
| 133 |
+
Tensor& operator[](const std::string& key) {
|
| 134 |
+
return parameters_[key];
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
/// Returns the value associated with the given `key`. Throws an exception if
|
| 138 |
+
/// no such key is stored in the `ParameterDict`. Check contains(key) before
|
| 139 |
+
/// for a non-throwing way of access
|
| 140 |
+
const Tensor& operator[](const std::string& key) const {
|
| 141 |
+
return parameters_[key];
|
| 142 |
+
}
|
| 143 |
+
};
|
| 144 |
+
|
| 145 |
+
TORCH_MODULE(ParameterDict);
|
| 146 |
+
|
| 147 |
+
} // namespace nn
|
| 148 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/nn/cloneable.h>
|
| 4 |
+
#include <torch/nn/module.h>
|
| 5 |
+
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
namespace torch {
|
| 9 |
+
namespace nn {
|
| 10 |
+
class ParameterListImpl : public Cloneable<ParameterListImpl> {
|
| 11 |
+
public:
|
| 12 |
+
using Iterator = typename std::vector<
|
| 13 |
+
OrderedDict<std::string, torch::Tensor>::Item>::iterator;
|
| 14 |
+
using ConstIterator = typename std::vector<
|
| 15 |
+
OrderedDict<std::string, torch::Tensor>::Item>::const_iterator;
|
| 16 |
+
|
| 17 |
+
ParameterListImpl() = default;
|
| 18 |
+
|
| 19 |
+
/// Constructs the `ParameterList` from a variadic list of ParameterList.
|
| 20 |
+
template <typename... Tensors>
|
| 21 |
+
explicit ParameterListImpl(Tensors&&... params) {
|
| 22 |
+
parameters_.reserve(sizeof...(Tensors));
|
| 23 |
+
push_back_var(std::forward<Tensors>(params)...);
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
template <typename... Tensors>
|
| 27 |
+
explicit ParameterListImpl(const Tensors&... params) {
|
| 28 |
+
parameters_.reserve(sizeof...(Tensors));
|
| 29 |
+
push_back_var(std::forward<Tensors>(params)...);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
/// `reset()` is empty for `ParameterList`, since it does not have parameters
|
| 33 |
+
/// of its own.
|
| 34 |
+
void reset() override {}
|
| 35 |
+
|
| 36 |
+
/// Pretty prints the `ParameterList` module into the given `stream`.
|
| 37 |
+
void pretty_print(std::ostream& stream) const override {
|
| 38 |
+
stream << "torch::nn::ParameterList(" << std::endl;
|
| 39 |
+
for (const auto& pair : parameters_) {
|
| 40 |
+
stream << "(" << pair.key() << ")"
|
| 41 |
+
<< ": Parameter containing: [" << pair.value().scalar_type()
|
| 42 |
+
<< " of size " << pair.value().sizes() << "]";
|
| 43 |
+
;
|
| 44 |
+
stream << std::endl;
|
| 45 |
+
}
|
| 46 |
+
stream << ")";
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
/// push the a given parameter at the end of the list
|
| 50 |
+
void append(torch::Tensor&& param) {
|
| 51 |
+
bool requires_grad = param.requires_grad();
|
| 52 |
+
register_parameter(
|
| 53 |
+
std::to_string(parameters_.size()), std::move(param), requires_grad);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
/// push the a given parameter at the end of the list
|
| 57 |
+
void append(const torch::Tensor& param) {
|
| 58 |
+
bool requires_grad = param.requires_grad();
|
| 59 |
+
register_parameter(
|
| 60 |
+
std::to_string(parameters_.size()), param, requires_grad);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
/// push the a given parameter at the end of the list
|
| 64 |
+
/// And the key of the pair will be discarded, only the value
|
| 65 |
+
/// will be added into the `ParameterList`
|
| 66 |
+
void append(const OrderedDict<std::string, torch::Tensor>::Item& pair) {
|
| 67 |
+
register_parameter(
|
| 68 |
+
std::to_string(parameters_.size()),
|
| 69 |
+
pair.value(),
|
| 70 |
+
pair.value().requires_grad());
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
/// extend parameters from a container to the end of the list
|
| 74 |
+
template <typename Container>
|
| 75 |
+
void extend(const Container& container) {
|
| 76 |
+
for (const auto& param : container) {
|
| 77 |
+
append(param);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
/// Returns an iterator to the start of the ParameterList
|
| 82 |
+
/// the iterator returned will be type of `OrderedDict<std::string,
|
| 83 |
+
/// torch::Tensor>::Item`
|
| 84 |
+
Iterator begin() {
|
| 85 |
+
return parameters_.begin();
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
/// Returns a const iterator to the start of the ParameterList
|
| 89 |
+
/// the iterator returned will be type of `OrderedDict<std::string,
|
| 90 |
+
/// torch::Tensor>::Item`
|
| 91 |
+
ConstIterator begin() const {
|
| 92 |
+
return parameters_.begin();
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
/// Returns an iterator to the end of the ParameterList
|
| 96 |
+
/// the iterator returned will be type of `OrderedDict<std::string,
|
| 97 |
+
/// torch::Tensor>::Item`
|
| 98 |
+
Iterator end() {
|
| 99 |
+
return parameters_.end();
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
/// Returns a const iterator to the end of the ParameterList
|
| 103 |
+
/// the iterator returned will be type of `OrderedDict<std::string,
|
| 104 |
+
/// torch::Tensor>::Item`
|
| 105 |
+
ConstIterator end() const {
|
| 106 |
+
return parameters_.end();
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
/// Returns the value associated with the given `key`. Throws an exception if
|
| 110 |
+
/// no such key is stored in the `ParameterList`. Check contains(key) before
|
| 111 |
+
/// for a non-throwing way of access
|
| 112 |
+
at::Tensor& at(size_t idx) {
|
| 113 |
+
TORCH_CHECK(idx < size(), "Index out of range");
|
| 114 |
+
return parameters_[std::to_string(idx)];
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
/// Returns the value associated with the given `key`. Throws an exception if
|
| 118 |
+
/// no such key is stored in the `ParameterList`. Check contains(key) before
|
| 119 |
+
/// for a non-throwing way of access
|
| 120 |
+
const at::Tensor& at(size_t idx) const {
|
| 121 |
+
TORCH_CHECK(idx < size(), "Index out of range");
|
| 122 |
+
return parameters_[std::to_string(idx)];
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
/// Returns the value associated with the given `key`. Throws an exception if
|
| 126 |
+
/// no such key is stored in the `ParameterList`. Check contains(key) before
|
| 127 |
+
/// for a non-throwing way of access
|
| 128 |
+
at::Tensor& operator[](size_t idx) {
|
| 129 |
+
return at(idx);
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
/// Returns the value associated with the given `key`. Throws an exception if
|
| 133 |
+
/// no such key is stored in the `ParameterList`. Check contains(key) before
|
| 134 |
+
/// for a non-throwing way of access
|
| 135 |
+
const at::Tensor& operator[](size_t idx) const {
|
| 136 |
+
return at(idx);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
/// Return the size of the ParameterList
|
| 140 |
+
size_t size() const noexcept {
|
| 141 |
+
return parameters_.size();
|
| 142 |
+
}
|
| 143 |
+
/// True if the ParameterList is empty
|
| 144 |
+
bool is_empty() const noexcept {
|
| 145 |
+
return parameters_.is_empty();
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
/// Overload the +=, so that two ParameterList could be incrementally added
|
| 149 |
+
template <typename Container>
|
| 150 |
+
Container& operator+=(const Container& other) {
|
| 151 |
+
extend(other);
|
| 152 |
+
return *this;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
private:
|
| 156 |
+
template <typename Head, typename... Tail>
|
| 157 |
+
void push_back_var(Head&& head, Tail&&... tail) {
|
| 158 |
+
append(std::forward<Head>(head));
|
| 159 |
+
// Recursively calls this method, until the parameter pack only thas this
|
| 160 |
+
// entry left. Then calls `push_back()` a final time (above).
|
| 161 |
+
push_back_var(std::forward<Tail>(tail)...);
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
/// The base case, when the list of modules is empty.
|
| 165 |
+
void push_back_var() {}
|
| 166 |
+
};
|
| 167 |
+
TORCH_MODULE(ParameterList);
|
| 168 |
+
} // namespace nn
|
| 169 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/detail/static.h>
|
| 4 |
+
#include <torch/nn/cloneable.h>
|
| 5 |
+
#include <torch/nn/module.h>
|
| 6 |
+
#include <torch/nn/modules/container/any.h>
|
| 7 |
+
#include <torch/nn/modules/container/named_any.h>
|
| 8 |
+
#include <torch/nn/pimpl.h>
|
| 9 |
+
#include <torch/types.h>
|
| 10 |
+
|
| 11 |
+
#include <c10/util/Exception.h>
|
| 12 |
+
|
| 13 |
+
#include <cstdint>
|
| 14 |
+
#include <memory>
|
| 15 |
+
#include <ostream>
|
| 16 |
+
#include <string>
|
| 17 |
+
#include <type_traits>
|
| 18 |
+
#include <utility>
|
| 19 |
+
#include <vector>
|
| 20 |
+
|
| 21 |
+
namespace torch {
|
| 22 |
+
namespace nn {
|
| 23 |
+
|
| 24 |
+
/// A list of `Module`s that acts as a `Module` itself.
|
| 25 |
+
///
|
| 26 |
+
/// A `Sequential` is fundamentally a list of `Module`s, each with a `forward()`
|
| 27 |
+
/// method. `Sequential` provides a `forward()` method of its own, which accepts
|
| 28 |
+
/// any input and forwards it to the first module it stores. It then "chains"
|
| 29 |
+
/// outputs to inputs sequentially for each subsequent module, finally returning
|
| 30 |
+
/// the output of the last module. For example:
|
| 31 |
+
///
|
| 32 |
+
/// \rst
|
| 33 |
+
/// .. code-block:: cpp
|
| 34 |
+
///
|
| 35 |
+
/// torch::nn::Sequential seq(
|
| 36 |
+
/// torch::nn::Linear(3, 4),
|
| 37 |
+
/// torch::nn::BatchNorm1d(4),
|
| 38 |
+
/// torch::nn::Dropout(0.5)
|
| 39 |
+
/// );
|
| 40 |
+
///
|
| 41 |
+
/// auto output = seq->forward(torch::ones(3));
|
| 42 |
+
///
|
| 43 |
+
/// \endrst
|
| 44 |
+
///
|
| 45 |
+
/// This can conceptually be thought of as the following loop (using Python as
|
| 46 |
+
/// pseudocode):
|
| 47 |
+
///
|
| 48 |
+
/// \rst
|
| 49 |
+
/// .. code-block:: python
|
| 50 |
+
///
|
| 51 |
+
/// def forward(sequential, input):
|
| 52 |
+
/// for module in sequential:
|
| 53 |
+
/// input = module(input)
|
| 54 |
+
/// return input
|
| 55 |
+
///
|
| 56 |
+
/// \endrst
|
| 57 |
+
///
|
| 58 |
+
/// Why should you use `Sequential` instead of a simple `std::vector`? The value
|
| 59 |
+
/// a `Sequential` provides over manually calling a sequence of modules is that
|
| 60 |
+
/// it allows treating the whole container *as a single module*, such that
|
| 61 |
+
/// performing a transformation on the `Sequential` applies to each of the
|
| 62 |
+
/// modules it stores (which are each a registered submodule of the
|
| 63 |
+
/// `Sequential`). For example, calling
|
| 64 |
+
/// `.to(torch::kCUDA)` on a `Sequential` will move each module in the list to
|
| 65 |
+
/// CUDA memory. For example:
|
| 66 |
+
///
|
| 67 |
+
/// \rst
|
| 68 |
+
/// .. code-block:: cpp
|
| 69 |
+
///
|
| 70 |
+
/// torch::nn::Sequential seq(
|
| 71 |
+
/// torch::nn::Linear(3, 4),
|
| 72 |
+
/// torch::nn::BatchNorm1d(4),
|
| 73 |
+
/// torch::nn::Dropout(0.5)
|
| 74 |
+
/// );
|
| 75 |
+
///
|
| 76 |
+
/// // Convert all modules to CUDA.
|
| 77 |
+
/// seq->to(torch::kCUDA);
|
| 78 |
+
///
|
| 79 |
+
/// \endrst
|
| 80 |
+
///
|
| 81 |
+
/// Finally, `Sequential` provides a lightweight container API, such as allowing
|
| 82 |
+
/// iteration over submodules, positional access, adding a new module after
|
| 83 |
+
/// construction via `push_back`, as well as joining two `Sequential`s via
|
| 84 |
+
/// `extend`.
|
| 85 |
+
///
|
| 86 |
+
/// \rst
|
| 87 |
+
/// .. attention::
|
| 88 |
+
/// One current limitation of `Sequential` is that all except the first module
|
| 89 |
+
/// must accept a single argument. If your modules need to take multiple
|
| 90 |
+
/// arguments, you should define them to take and return tuples.
|
| 91 |
+
/// \endrst
|
| 92 |
+
class SequentialImpl : public Cloneable<SequentialImpl> {
|
| 93 |
+
public:
|
| 94 |
+
using Iterator = std::vector<AnyModule>::iterator;
|
| 95 |
+
using ConstIterator = std::vector<AnyModule>::const_iterator;
|
| 96 |
+
|
| 97 |
+
SequentialImpl() = default;
|
| 98 |
+
|
| 99 |
+
/// Constructs the `Sequential` from a variadic list of modules.
|
| 100 |
+
template <typename... Modules>
|
| 101 |
+
explicit SequentialImpl(Modules&&... modules) {
|
| 102 |
+
modules_.reserve(sizeof...(Modules));
|
| 103 |
+
push_back(std::forward<Modules>(modules)...);
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
/// Constructs the `Sequential` from an `OrderedDict` of named `AnyModule`s.
|
| 107 |
+
explicit SequentialImpl(
|
| 108 |
+
torch::OrderedDict<std::string, AnyModule>&& ordered_dict) {
|
| 109 |
+
modules_.reserve(ordered_dict.size());
|
| 110 |
+
for (auto& item : ordered_dict) {
|
| 111 |
+
push_back(item.key(), std::move(item.value()));
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
/// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s.
|
| 116 |
+
/// It enables the following use case:
|
| 117 |
+
/// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})`
|
| 118 |
+
explicit SequentialImpl(std::initializer_list<NamedAnyModule> named_modules) {
|
| 119 |
+
modules_.reserve(named_modules.size());
|
| 120 |
+
for (const auto& named_module : named_modules) {
|
| 121 |
+
push_back(named_module.name(), named_module.module());
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
/// Special cloning function for `Sequential` because it does not use
|
| 126 |
+
/// `reset()`.
|
| 127 |
+
std::shared_ptr<Module> clone(
|
| 128 |
+
const std::optional<Device>& device = std::nullopt) const override {
|
| 129 |
+
auto clone = std::make_shared<SequentialImpl>();
|
| 130 |
+
for (const auto& module : modules_) {
|
| 131 |
+
clone->push_back(module.clone(device));
|
| 132 |
+
}
|
| 133 |
+
return clone;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
/// `reset()` is empty for `Sequential`, since it does not have parameters of
|
| 137 |
+
/// its own.
|
| 138 |
+
void reset() override {}
|
| 139 |
+
|
| 140 |
+
/// Pretty prints the `Sequential` module into the given `stream`.
|
| 141 |
+
void pretty_print(std::ostream& stream) const override {
|
| 142 |
+
stream << "torch::nn::Sequential";
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
/// Feeds `inputs` to the first module and then chains outputs to inputs,
|
| 146 |
+
/// returning the last output.
|
| 147 |
+
///
|
| 148 |
+
/// Conceptually the following loop in Python:
|
| 149 |
+
///
|
| 150 |
+
/// \rst
|
| 151 |
+
/// .. code-block:: python
|
| 152 |
+
///
|
| 153 |
+
/// def forward(sequential, input):
|
| 154 |
+
/// for module in sequential:
|
| 155 |
+
/// input = module(input)
|
| 156 |
+
/// return input
|
| 157 |
+
///
|
| 158 |
+
/// \endrst
|
| 159 |
+
///
|
| 160 |
+
/// The return type is taken as the first template parameter. It defaults to
|
| 161 |
+
/// `Tensor`. If the last module in the `Sequential` returns another type `T`,
|
| 162 |
+
/// you should call `forward<T>(inputs)` instead of just `forward(inputs)`:
|
| 163 |
+
///
|
| 164 |
+
/// \rst
|
| 165 |
+
/// .. code-block:: cpp
|
| 166 |
+
///
|
| 167 |
+
/// torch::Tensor tensor = sequential1->forward(inputs);
|
| 168 |
+
/// int integer = sequential2->forward<int>(inputs);
|
| 169 |
+
/// float value = sequential3->forward<float>(inputs);
|
| 170 |
+
///
|
| 171 |
+
/// \endrst
|
| 172 |
+
template <typename ReturnType = Tensor, typename... InputTypes>
|
| 173 |
+
ReturnType forward(InputTypes&&... inputs) {
|
| 174 |
+
TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");
|
| 175 |
+
|
| 176 |
+
auto iterator = modules_.begin();
|
| 177 |
+
auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...);
|
| 178 |
+
|
| 179 |
+
for (++iterator; iterator != modules_.end(); ++iterator) {
|
| 180 |
+
input = iterator->any_forward(std::move(input));
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
// Check the return value and give a nice error message if the requested
|
| 184 |
+
// return type was incorrect.
|
| 185 |
+
if (auto* return_value = input.template try_get<ReturnType>()) {
|
| 186 |
+
return std::move(*return_value);
|
| 187 |
+
}
|
| 188 |
+
AT_ERROR(
|
| 189 |
+
"The type of the return value is ",
|
| 190 |
+
c10::demangle(input.type_info().name()),
|
| 191 |
+
", but you asked for type ",
|
| 192 |
+
c10::demangle(typeid(ReturnType).name()));
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
/// Adds a new (boxed) `Module` to the `Sequential` container.
|
| 196 |
+
template <typename ModuleType>
|
| 197 |
+
void push_back(std::shared_ptr<ModuleType> module_ptr) {
|
| 198 |
+
push_back(std::to_string(modules_.size()), std::move(module_ptr));
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
/// Adds a new named (boxed) `Module` to the `Sequential` container.
|
| 202 |
+
template <typename ModuleType>
|
| 203 |
+
void push_back(std::string name, std::shared_ptr<ModuleType> module_ptr) {
|
| 204 |
+
push_back(std::move(name), AnyModule(std::move(module_ptr)));
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
/// Adds a new `Module` to the `Sequential` container, moving or copying it
|
| 208 |
+
/// into a `shared_ptr` internally. This method allows passing value types,
|
| 209 |
+
/// and letting the container deal with the boxing. This means you can write
|
| 210 |
+
/// `Sequential(Module(3, 4))` instead of
|
| 211 |
+
/// `Sequential(std::make_shared<Module>(3, 4))`.
|
| 212 |
+
template <typename M, typename = torch::detail::enable_if_module_t<M>>
|
| 213 |
+
void push_back(M&& module) {
|
| 214 |
+
push_back(std::to_string(modules_.size()), std::forward<M>(module));
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
/// Adds a new named `Module` to the `Sequential` container, moving or copying
|
| 218 |
+
/// it into a `shared_ptr` internally. This method allows passing value types,
|
| 219 |
+
/// and letting the container deal with the boxing.
|
| 220 |
+
template <typename M, typename = torch::detail::enable_if_module_t<M>>
|
| 221 |
+
void push_back(std::string name, M&& module) {
|
| 222 |
+
using Type = typename std::remove_reference_t<M>;
|
| 223 |
+
push_back(std::move(name), std::make_shared<Type>(std::forward<M>(module)));
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
/// Unwraps the contained module of a `ModuleHolder` and adds it to the
|
| 227 |
+
/// `Sequential`.
|
| 228 |
+
template <typename M>
|
| 229 |
+
void push_back(const ModuleHolder<M>& module_holder) {
|
| 230 |
+
push_back(std::to_string(modules_.size()), module_holder);
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
/// Unwraps the contained named module of a `ModuleHolder` and adds it to the
|
| 234 |
+
/// `Sequential`.
|
| 235 |
+
template <typename M>
|
| 236 |
+
void push_back(std::string name, const ModuleHolder<M>& module_holder) {
|
| 237 |
+
push_back(std::move(name), module_holder.ptr());
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
/// Iterates over the container and calls `push_back()` on each value.
|
| 241 |
+
template <typename Container>
|
| 242 |
+
void extend(const Container& container) {
|
| 243 |
+
for (const auto& module : container) {
|
| 244 |
+
push_back(module);
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
/// Adds a type-erased `AnyModule` to the `Sequential`.
|
| 249 |
+
void push_back(AnyModule any_module) {
|
| 250 |
+
push_back(std::to_string(modules_.size()), std::move(any_module));
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
void push_back(std::string name, AnyModule any_module) {
|
| 254 |
+
modules_.push_back(std::move(any_module));
|
| 255 |
+
const auto index = modules_.size() - 1;
|
| 256 |
+
register_module(std::move(name), modules_[index].ptr());
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
/// Returns an iterator to the start of the `Sequential`.
|
| 260 |
+
Iterator begin() {
|
| 261 |
+
return modules_.begin();
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
/// Returns a const iterator to the start of the `Sequential`.
|
| 265 |
+
ConstIterator begin() const {
|
| 266 |
+
return modules_.begin();
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
/// Returns an iterator to the end of the `Sequential`.
|
| 270 |
+
Iterator end() {
|
| 271 |
+
return modules_.end();
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
/// Returns a const iterator to the end of the `Sequential`.
|
| 275 |
+
ConstIterator end() const {
|
| 276 |
+
return modules_.end();
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
/// Attempts to return the module at the given index as the requested type.
|
| 280 |
+
/// Throws an exception if the index is out of bounds or the types do not
|
| 281 |
+
/// match.
|
| 282 |
+
template <typename T>
|
| 283 |
+
T& at(size_t index) {
|
| 284 |
+
static_assert(
|
| 285 |
+
torch::detail::is_module<T>::value,
|
| 286 |
+
"Can only call Sequential::at with an nn::Module type");
|
| 287 |
+
TORCH_CHECK(index < size(), "Index out of range");
|
| 288 |
+
return modules_[index].get<T>();
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
/// Attempts to return the module at the given index as the requested type.
|
| 292 |
+
/// Throws an exception if the index is out of bounds or the types do not
|
| 293 |
+
/// match.
|
| 294 |
+
template <typename T>
|
| 295 |
+
const T& at(size_t index) const {
|
| 296 |
+
static_assert(
|
| 297 |
+
torch::detail::is_module<T>::value,
|
| 298 |
+
"Can only call Sequential::at with an nn::Module type");
|
| 299 |
+
TORCH_CHECK(index < size(), "Index out of range");
|
| 300 |
+
return modules_[index].get<T>();
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
/// Attempts to return a `std::shared_ptr` whose dynamic type is that of the
|
| 304 |
+
/// underlying module at the given index. Throws an exception if the index is
|
| 305 |
+
/// out of bounds.
|
| 306 |
+
std::shared_ptr<Module> ptr(size_t index) const {
|
| 307 |
+
TORCH_CHECK(index < size(), "Index out of range");
|
| 308 |
+
return modules_[index].ptr();
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
/// Attempts to return a `std::shared_ptr` whose type is the one provided.
|
| 312 |
+
/// Throws an exception if the index is out of bounds or the types do not
|
| 313 |
+
/// match.
|
| 314 |
+
template <typename T>
|
| 315 |
+
std::shared_ptr<T> ptr(size_t index) const {
|
| 316 |
+
static_assert(
|
| 317 |
+
torch::detail::is_module<T>::value,
|
| 318 |
+
"Can only call Sequential::ptr with an nn::Module type");
|
| 319 |
+
TORCH_CHECK(index < size(), "Index out of range");
|
| 320 |
+
return modules_[index].ptr<T>();
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
/// Like `ptr(index)`.
|
| 324 |
+
std::shared_ptr<Module> operator[](size_t index) const {
|
| 325 |
+
// This is the only method we can call without a type.
|
| 326 |
+
return ptr(index);
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
/// The current size of the `Sequential` container.
|
| 330 |
+
size_t size() const noexcept {
|
| 331 |
+
return modules_.size();
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
/// True if there are no modules in the `Sequential`.
|
| 335 |
+
bool is_empty() const noexcept {
|
| 336 |
+
return size() == 0;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
private:
|
| 340 |
+
/// Takes a First *and* Second parameter, to avoid ambiguity when a parameter
|
| 341 |
+
/// pack has only one type, in which case the template would be preferred,
|
| 342 |
+
/// even if the other `push_back` functions are better fits (e.g. `unique_ptr`
|
| 343 |
+
/// -> `shared_ptr` overload).
|
| 344 |
+
/// NOTE: We explicitly avoid matching this template with
|
| 345 |
+
/// `push_back(std::string("name"), module)` or `push_back("name", module)`,
|
| 346 |
+
/// since they should be handled by their respective `push_back` functions.
|
| 347 |
+
template <
|
| 348 |
+
typename First,
|
| 349 |
+
typename Second,
|
| 350 |
+
typename... Rest,
|
| 351 |
+
typename = std::enable_if_t<
|
| 352 |
+
!std::is_same_v<First, std::string> &&
|
| 353 |
+
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
| 354 |
+
!std::is_same_v<std::decay_t<First>, std::decay_t<const char (&)[]>>>>
|
| 355 |
+
void push_back(First&& first, Second&& second, Rest&&... rest) {
|
| 356 |
+
push_back(std::forward<First>(first));
|
| 357 |
+
// Recursively calls this method, until the parameter pack only thas this
|
| 358 |
+
// entry left. Then calls `push_back()` a final time (above).
|
| 359 |
+
push_back(std::forward<Second>(second), std::forward<Rest>(rest)...);
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
/// The base case, when the list of modules is empty.
|
| 363 |
+
void push_back() {}
|
| 364 |
+
|
| 365 |
+
// Box the AnyModules to give Sequential reference semantics, like the rest of
|
| 366 |
+
// the API. Note that this is not required otherwise, this could just be a
|
| 367 |
+
// `vector<AnyModule>`.
|
| 368 |
+
std::vector<AnyModule> modules_;
|
| 369 |
+
};
|
| 370 |
+
|
| 371 |
+
/// A `ModuleHolder` subclass for `SequentialImpl`.
|
| 372 |
+
/// See the documentation for `SequentialImpl` class to learn what methods it
|
| 373 |
+
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
|
| 374 |
+
/// module storage semantics.
|
| 375 |
+
class Sequential : public torch::nn::ModuleHolder<SequentialImpl> {
|
| 376 |
+
public:
|
| 377 |
+
using torch::nn::ModuleHolder<SequentialImpl>::ModuleHolder;
|
| 378 |
+
|
| 379 |
+
Sequential() : ModuleHolder() {}
|
| 380 |
+
|
| 381 |
+
/// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s.
|
| 382 |
+
/// It enables the following use case:
|
| 383 |
+
/// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})`
|
| 384 |
+
Sequential(std::initializer_list<NamedAnyModule> named_modules)
|
| 385 |
+
: ModuleHolder(std::make_shared<SequentialImpl>(named_modules)) {}
|
| 386 |
+
};
|
| 387 |
+
} // namespace nn
|
| 388 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/nn/cloneable.h>
|
| 4 |
+
#include <torch/nn/options/dropout.h>
|
| 5 |
+
#include <torch/nn/pimpl.h>
|
| 6 |
+
#include <torch/types.h>
|
| 7 |
+
|
| 8 |
+
#include <torch/csrc/Export.h>
|
| 9 |
+
|
| 10 |
+
#include <cstddef>
|
| 11 |
+
#include <vector>
|
| 12 |
+
|
| 13 |
+
namespace torch {
|
| 14 |
+
namespace nn {
|
| 15 |
+
|
| 16 |
+
namespace detail {
|
| 17 |
+
|
| 18 |
+
template <typename Derived>
|
| 19 |
+
class _DropoutNd : public torch::nn::Cloneable<Derived> {
|
| 20 |
+
public:
|
| 21 |
+
_DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)){};
|
| 22 |
+
|
| 23 |
+
explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) {
|
| 24 |
+
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
| 25 |
+
reset();
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
void reset() override {
|
| 29 |
+
TORCH_CHECK(
|
| 30 |
+
options.p() >= 0. && options.p() <= 1.,
|
| 31 |
+
"dropout probability has to be between 0 and 1, but got ",
|
| 32 |
+
options.p());
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
/// The options with which this `Module` was constructed.
|
| 36 |
+
DropoutOptions options;
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
} // namespace detail
|
| 40 |
+
|
| 41 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 42 |
+
|
| 43 |
+
/// Applies dropout over a 1-D input.
|
| 44 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout to learn
|
| 45 |
+
/// about the exact behavior of this module.
|
| 46 |
+
///
|
| 47 |
+
/// See the documentation for `torch::nn::DropoutOptions` class to learn what
|
| 48 |
+
/// constructor arguments are supported for this module.
|
| 49 |
+
///
|
| 50 |
+
/// Example:
|
| 51 |
+
/// ```
|
| 52 |
+
/// Dropout model(DropoutOptions().p(0.42).inplace(true));
|
| 53 |
+
/// ```
|
| 54 |
+
class TORCH_API DropoutImpl : public detail::_DropoutNd<DropoutImpl> {
|
| 55 |
+
public:
|
| 56 |
+
using detail::_DropoutNd<DropoutImpl>::_DropoutNd;
|
| 57 |
+
|
| 58 |
+
Tensor forward(Tensor input);
|
| 59 |
+
|
| 60 |
+
/// Pretty prints the `Dropout` module into the given `stream`.
|
| 61 |
+
void pretty_print(std::ostream& stream) const override;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
/// A `ModuleHolder` subclass for `DropoutImpl`.
|
| 65 |
+
/// See the documentation for `DropoutImpl` class to learn what methods it
|
| 66 |
+
/// provides, and examples of how to use `Dropout` with
|
| 67 |
+
/// `torch::nn::DropoutOptions`. See the documentation for `ModuleHolder` to
|
| 68 |
+
/// learn about PyTorch's module storage semantics.
|
| 69 |
+
TORCH_MODULE(Dropout);
|
| 70 |
+
|
| 71 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 72 |
+
|
| 73 |
+
/// Applies dropout over a 2-D input.
|
| 74 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout2d to learn
|
| 75 |
+
/// about the exact behavior of this module.
|
| 76 |
+
///
|
| 77 |
+
/// See the documentation for `torch::nn::Dropout2dOptions` class to learn what
|
| 78 |
+
/// constructor arguments are supported for this module.
|
| 79 |
+
///
|
| 80 |
+
/// Example:
|
| 81 |
+
/// ```
|
| 82 |
+
/// Dropout2d model(Dropout2dOptions().p(0.42).inplace(true));
|
| 83 |
+
/// ```
|
| 84 |
+
class TORCH_API Dropout2dImpl : public detail::_DropoutNd<Dropout2dImpl> {
|
| 85 |
+
public:
|
| 86 |
+
using detail::_DropoutNd<Dropout2dImpl>::_DropoutNd;
|
| 87 |
+
|
| 88 |
+
Tensor forward(Tensor input);
|
| 89 |
+
|
| 90 |
+
/// Pretty prints the `Dropout2d` module into the given `stream`.
|
| 91 |
+
void pretty_print(std::ostream& stream) const override;
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
/// A `ModuleHolder` subclass for `Dropout2dImpl`.
|
| 95 |
+
/// See the documentation for `Dropout2dImpl` class to learn what methods it
|
| 96 |
+
/// provides, and examples of how to use `Dropout2d` with
|
| 97 |
+
/// `torch::nn::Dropout2dOptions`. See the documentation for `ModuleHolder` to
|
| 98 |
+
/// learn about PyTorch's module storage semantics.
|
| 99 |
+
TORCH_MODULE(Dropout2d);
|
| 100 |
+
|
| 101 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 102 |
+
|
| 103 |
+
/// Applies dropout over a 3-D input.
|
| 104 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout3d to learn
|
| 105 |
+
/// about the exact behavior of this module.
|
| 106 |
+
///
|
| 107 |
+
/// See the documentation for `torch::nn::Dropout3dOptions` class to learn what
|
| 108 |
+
/// constructor arguments are supported for this module.
|
| 109 |
+
///
|
| 110 |
+
/// Example:
|
| 111 |
+
/// ```
|
| 112 |
+
/// Dropout3d model(Dropout3dOptions().p(0.42).inplace(true));
|
| 113 |
+
/// ```
|
| 114 |
+
class TORCH_API Dropout3dImpl : public detail::_DropoutNd<Dropout3dImpl> {
|
| 115 |
+
public:
|
| 116 |
+
using detail::_DropoutNd<Dropout3dImpl>::_DropoutNd;
|
| 117 |
+
|
| 118 |
+
Tensor forward(Tensor input);
|
| 119 |
+
|
| 120 |
+
/// Pretty prints the `Dropout3d` module into the given `stream`.
|
| 121 |
+
void pretty_print(std::ostream& stream) const override;
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
/// A `ModuleHolder` subclass for `Dropout3dImpl`.
|
| 125 |
+
/// See the documentation for `Dropout3dImpl` class to learn what methods it
|
| 126 |
+
/// provides, and examples of how to use `Dropout3d` with
|
| 127 |
+
/// `torch::nn::Dropout3dOptions`. See the documentation for `ModuleHolder` to
|
| 128 |
+
/// learn about PyTorch's module storage semantics.
|
| 129 |
+
TORCH_MODULE(Dropout3d);
|
| 130 |
+
|
| 131 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 132 |
+
|
| 133 |
+
/// Applies Alpha Dropout over the input.
|
| 134 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.AlphaDropout to learn
|
| 135 |
+
/// about the exact behavior of this module.
|
| 136 |
+
///
|
| 137 |
+
/// See the documentation for `torch::nn::AlphaDropoutOptions` class to learn
|
| 138 |
+
/// what constructor arguments are supported for this module.
|
| 139 |
+
///
|
| 140 |
+
/// Example:
|
| 141 |
+
/// ```
|
| 142 |
+
/// AlphaDropout model(AlphaDropoutOptions(0.2).inplace(true));
|
| 143 |
+
/// ```
|
| 144 |
+
class TORCH_API AlphaDropoutImpl : public detail::_DropoutNd<AlphaDropoutImpl> {
|
| 145 |
+
public:
|
| 146 |
+
using detail::_DropoutNd<AlphaDropoutImpl>::_DropoutNd;
|
| 147 |
+
|
| 148 |
+
Tensor forward(const Tensor& input);
|
| 149 |
+
|
| 150 |
+
/// Pretty prints the `AlphaDropout` module into the given `stream`.
|
| 151 |
+
void pretty_print(std::ostream& stream) const override;
|
| 152 |
+
};
|
| 153 |
+
|
| 154 |
+
/// A `ModuleHolder` subclass for `AlphaDropoutImpl`.
|
| 155 |
+
/// See the documentation for `AlphaDropoutImpl` class to learn what methods it
|
| 156 |
+
/// provides, and examples of how to use `AlphaDropout` with
|
| 157 |
+
/// `torch::nn::AlphaDropoutOptions`. See the documentation for `ModuleHolder`
|
| 158 |
+
/// to learn about PyTorch's module storage semantics.
|
| 159 |
+
TORCH_MODULE(AlphaDropout);
|
| 160 |
+
|
| 161 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout
|
| 162 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 163 |
+
|
| 164 |
+
/// See the documentation for `torch::nn::FeatureAlphaDropoutOptions` class to
|
| 165 |
+
/// learn what constructor arguments are supported for this module.
|
| 166 |
+
///
|
| 167 |
+
/// Example:
|
| 168 |
+
/// ```
|
| 169 |
+
/// FeatureAlphaDropout model(FeatureAlphaDropoutOptions(0.2).inplace(true));
|
| 170 |
+
/// ```
|
| 171 |
+
class TORCH_API FeatureAlphaDropoutImpl
|
| 172 |
+
: public detail::_DropoutNd<FeatureAlphaDropoutImpl> {
|
| 173 |
+
public:
|
| 174 |
+
using detail::_DropoutNd<FeatureAlphaDropoutImpl>::_DropoutNd;
|
| 175 |
+
|
| 176 |
+
Tensor forward(const Tensor& input);
|
| 177 |
+
|
| 178 |
+
/// Pretty prints the `FeatureAlphaDropout` module into the given `stream`.
|
| 179 |
+
void pretty_print(std::ostream& stream) const override;
|
| 180 |
+
};
|
| 181 |
+
|
| 182 |
+
/// A `ModuleHolder` subclass for `FeatureAlphaDropoutImpl`.
|
| 183 |
+
/// See the documentation for `FeatureAlphaDropoutImpl` class to learn what
|
| 184 |
+
/// methods it provides, and examples of how to use `FeatureAlphaDropout` with
|
| 185 |
+
/// `torch::nn::FeatureAlphaDropoutOptions`. See the documentation for
|
| 186 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 187 |
+
TORCH_MODULE(FeatureAlphaDropout);
|
| 188 |
+
|
| 189 |
+
} // namespace nn
|
| 190 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/fold.h
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/expanding_array.h>
|
| 4 |
+
#include <torch/nn/cloneable.h>
|
| 5 |
+
#include <torch/nn/functional/fold.h>
|
| 6 |
+
#include <torch/nn/options/fold.h>
|
| 7 |
+
#include <torch/nn/pimpl.h>
|
| 8 |
+
#include <torch/types.h>
|
| 9 |
+
|
| 10 |
+
namespace torch {
|
| 11 |
+
namespace nn {
|
| 12 |
+
|
| 13 |
+
/// Applies fold over a 3-D input.
|
| 14 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.Fold to learn about
|
| 15 |
+
/// the exact behavior of this module.
|
| 16 |
+
///
|
| 17 |
+
/// See the documentation for `torch::nn::FoldOptions` class to learn what
|
| 18 |
+
/// constructor arguments are supported for this module.
|
| 19 |
+
///
|
| 20 |
+
/// Example:
|
| 21 |
+
/// ```
|
| 22 |
+
/// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2,
|
| 23 |
+
/// 1}).stride(2));
|
| 24 |
+
/// ```
|
| 25 |
+
class TORCH_API FoldImpl : public torch::nn::Cloneable<FoldImpl> {
|
| 26 |
+
public:
|
| 27 |
+
FoldImpl(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size)
|
| 28 |
+
: FoldImpl(FoldOptions(output_size, kernel_size)) {}
|
| 29 |
+
explicit FoldImpl(const FoldOptions& options_);
|
| 30 |
+
|
| 31 |
+
void reset() override;
|
| 32 |
+
|
| 33 |
+
/// Pretty prints the `Fold` module into the given `stream`.
|
| 34 |
+
void pretty_print(std::ostream& stream) const override;
|
| 35 |
+
|
| 36 |
+
Tensor forward(const Tensor& input);
|
| 37 |
+
|
| 38 |
+
/// The options with which this `Module` was constructed.
|
| 39 |
+
FoldOptions options;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
/// A `ModuleHolder` subclass for `FoldImpl`.
|
| 43 |
+
/// See the documentation for `FoldImpl` class to learn what methods it
|
| 44 |
+
/// provides, and examples of how to use `Fold` with `torch::nn::FoldOptions`.
|
| 45 |
+
/// See the documentation for `ModuleHolder` to learn about PyTorch's
|
| 46 |
+
/// module storage semantics.
|
| 47 |
+
TORCH_MODULE(Fold);
|
| 48 |
+
|
| 49 |
+
// ============================================================================
|
| 50 |
+
|
| 51 |
+
/// Applies unfold over a 4-D input.
|
| 52 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.Unfold to learn about
|
| 53 |
+
/// the exact behavior of this module.
|
| 54 |
+
///
|
| 55 |
+
/// See the documentation for `torch::nn::UnfoldOptions` class to learn what
|
| 56 |
+
/// constructor arguments are supported for this module.
|
| 57 |
+
///
|
| 58 |
+
/// Example:
|
| 59 |
+
/// ```
|
| 60 |
+
/// Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2));
|
| 61 |
+
/// ```
|
| 62 |
+
class TORCH_API UnfoldImpl : public Cloneable<UnfoldImpl> {
|
| 63 |
+
public:
|
| 64 |
+
UnfoldImpl(ExpandingArray<2> kernel_size)
|
| 65 |
+
: UnfoldImpl(UnfoldOptions(kernel_size)) {}
|
| 66 |
+
explicit UnfoldImpl(const UnfoldOptions& options_);
|
| 67 |
+
|
| 68 |
+
void reset() override;
|
| 69 |
+
|
| 70 |
+
/// Pretty prints the `Unfold` module into the given `stream`.
|
| 71 |
+
void pretty_print(std::ostream& stream) const override;
|
| 72 |
+
|
| 73 |
+
Tensor forward(const Tensor& input);
|
| 74 |
+
|
| 75 |
+
/// The options with which this `Module` was constructed.
|
| 76 |
+
UnfoldOptions options;
|
| 77 |
+
};
|
| 78 |
+
|
| 79 |
+
/// A `ModuleHolder` subclass for `UnfoldImpl`.
|
| 80 |
+
/// See the documentation for `UnfoldImpl` class to learn what methods it
|
| 81 |
+
/// provides, and examples of how to use `Unfold` with
|
| 82 |
+
/// `torch::nn::UnfoldOptions`. See the documentation for `ModuleHolder` to
|
| 83 |
+
/// learn about PyTorch's module storage semantics.
|
| 84 |
+
TORCH_MODULE(Unfold);
|
| 85 |
+
|
| 86 |
+
} // namespace nn
|
| 87 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/nn/modules/batchnorm.h>
|
| 4 |
+
#include <torch/nn/options/instancenorm.h>
|
| 5 |
+
|
| 6 |
+
namespace torch {
|
| 7 |
+
namespace nn {
|
| 8 |
+
|
| 9 |
+
/// Base class for all (dimension-specialized) instance norm modules
|
| 10 |
+
template <size_t D, typename Derived>
|
| 11 |
+
class InstanceNormImpl
|
| 12 |
+
: public torch::nn::NormImplBase<D, Derived, InstanceNormOptions> {
|
| 13 |
+
private:
|
| 14 |
+
inline Tensor apply_instance_norm(const Tensor& input) {
|
| 15 |
+
return torch::nn::functional::detail::instance_norm(
|
| 16 |
+
input,
|
| 17 |
+
this->running_mean,
|
| 18 |
+
this->running_var,
|
| 19 |
+
this->weight,
|
| 20 |
+
this->bias,
|
| 21 |
+
this->is_training() || !this->options.track_running_stats(),
|
| 22 |
+
this->options.momentum(),
|
| 23 |
+
this->options.eps());
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
inline Tensor handle_no_batch_input(const Tensor& input) {
|
| 27 |
+
return this->apply_instance_norm(input.unsqueeze(0)).squeeze(0);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
public:
|
| 31 |
+
using torch::nn::NormImplBase<D, Derived, InstanceNormOptions>::NormImplBase;
|
| 32 |
+
|
| 33 |
+
Tensor forward(const Tensor& input) {
|
| 34 |
+
this->_check_input_dim(input);
|
| 35 |
+
|
| 36 |
+
// For InstanceNorm1D, 2D is unbatched and 3D is batched
|
| 37 |
+
// For InstanceNorm2D, 3D is unbatched and 4D is batched
|
| 38 |
+
// For InstanceNorm3D, 4D is unbatched and 5D is batched
|
| 39 |
+
// check if input does not have a batch-dim
|
| 40 |
+
if (input.dim() == D + 1) {
|
| 41 |
+
return this->handle_no_batch_input(input);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
return this->apply_instance_norm(input);
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
/// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`.
|
| 48 |
+
void pretty_print(std::ostream& stream) const override {
|
| 49 |
+
stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d("
|
| 50 |
+
<< this->options.num_features() << ", "
|
| 51 |
+
<< "eps=" << this->options.eps() << ", "
|
| 52 |
+
<< "momentum=" << this->options.momentum() << ", "
|
| 53 |
+
<< "affine=" << this->options.affine() << ", "
|
| 54 |
+
<< "track_running_stats=" << this->options.track_running_stats()
|
| 55 |
+
<< ")";
|
| 56 |
+
}
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d
|
| 60 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 61 |
+
|
| 62 |
+
/// Applies the InstanceNorm1d function.
|
| 63 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm1d to learn
|
| 64 |
+
/// about the exact behavior of this module.
|
| 65 |
+
///
|
| 66 |
+
/// See the documentation for `torch::nn::InstanceNorm1dOptions` class to learn
|
| 67 |
+
/// what constructor arguments are supported for this module.
|
| 68 |
+
///
|
| 69 |
+
/// Example:
|
| 70 |
+
/// ```
|
| 71 |
+
/// InstanceNorm1d
|
| 72 |
+
/// model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
|
| 73 |
+
/// ```
|
| 74 |
+
class TORCH_API InstanceNorm1dImpl
|
| 75 |
+
: public InstanceNormImpl<1, InstanceNorm1dImpl> {
|
| 76 |
+
protected:
|
| 77 |
+
void _check_input_dim(const Tensor& input) override;
|
| 78 |
+
|
| 79 |
+
public:
|
| 80 |
+
using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl;
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
/// A `ModuleHolder` subclass for `InstanceNorm1dImpl`.
|
| 84 |
+
/// See the documentation for `InstanceNorm1dImpl` class to learn what methods
|
| 85 |
+
/// it provides, and examples of how to use `InstanceNorm1d` with
|
| 86 |
+
/// `torch::nn::InstanceNorm1dOptions`. See the documentation for `ModuleHolder`
|
| 87 |
+
/// to learn about PyTorch's module storage semantics.
|
| 88 |
+
TORCH_MODULE(InstanceNorm1d);
|
| 89 |
+
|
| 90 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d
|
| 91 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 92 |
+
|
| 93 |
+
/// Applies the InstanceNorm2d function.
|
| 94 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm2d to learn
|
| 95 |
+
/// about the exact behavior of this module.
|
| 96 |
+
///
|
| 97 |
+
/// See the documentation for `torch::nn::InstanceNorm2dOptions` class to learn
|
| 98 |
+
/// what constructor arguments are supported for this module.
|
| 99 |
+
///
|
| 100 |
+
/// Example:
|
| 101 |
+
/// ```
|
| 102 |
+
/// InstanceNorm2d
|
| 103 |
+
/// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
|
| 104 |
+
/// ```
|
| 105 |
+
class TORCH_API InstanceNorm2dImpl
|
| 106 |
+
: public InstanceNormImpl<2, InstanceNorm2dImpl> {
|
| 107 |
+
protected:
|
| 108 |
+
void _check_input_dim(const Tensor& input) override;
|
| 109 |
+
|
| 110 |
+
public:
|
| 111 |
+
using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl;
|
| 112 |
+
};
|
| 113 |
+
|
| 114 |
+
/// A `ModuleHolder` subclass for `InstanceNorm2dImpl`.
|
| 115 |
+
/// See the documentation for `InstanceNorm2dImpl` class to learn what methods
|
| 116 |
+
/// it provides, and examples of how to use `InstanceNorm2d` with
|
| 117 |
+
/// `torch::nn::InstanceNorm2dOptions`. See the documentation for `ModuleHolder`
|
| 118 |
+
/// to learn about PyTorch's module storage semantics.
|
| 119 |
+
TORCH_MODULE(InstanceNorm2d);
|
| 120 |
+
|
| 121 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d
|
| 122 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 123 |
+
|
| 124 |
+
/// Applies the InstanceNorm3d function.
|
| 125 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm3d to learn
|
| 126 |
+
/// about the exact behavior of this module.
|
| 127 |
+
///
|
| 128 |
+
/// See the documentation for `torch::nn::InstanceNorm3dOptions` class to learn
|
| 129 |
+
/// what constructor arguments are supported for this module.
|
| 130 |
+
///
|
| 131 |
+
/// Example:
|
| 132 |
+
/// ```
|
| 133 |
+
/// InstanceNorm3d
|
| 134 |
+
/// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
|
| 135 |
+
/// ```
|
| 136 |
+
class TORCH_API InstanceNorm3dImpl
|
| 137 |
+
: public InstanceNormImpl<3, InstanceNorm3dImpl> {
|
| 138 |
+
protected:
|
| 139 |
+
void _check_input_dim(const Tensor& input) override;
|
| 140 |
+
|
| 141 |
+
public:
|
| 142 |
+
using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl;
|
| 143 |
+
};
|
| 144 |
+
|
| 145 |
+
/// A `ModuleHolder` subclass for `InstanceNorm3dImpl`.
|
| 146 |
+
/// See the documentation for `InstanceNorm3dImpl` class to learn what methods
|
| 147 |
+
/// it provides, and examples of how to use `InstanceNorm3d` with
|
| 148 |
+
/// `torch::nn::InstanceNorm3dOptions`. See the documentation for `ModuleHolder`
|
| 149 |
+
/// to learn about PyTorch's module storage semantics.
|
| 150 |
+
TORCH_MODULE(InstanceNorm3d);
|
| 151 |
+
|
| 152 |
+
} // namespace nn
|
| 153 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/loss.h
ADDED
|
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/expanding_array.h>
|
| 4 |
+
#include <torch/nn/cloneable.h>
|
| 5 |
+
#include <torch/nn/functional/loss.h>
|
| 6 |
+
#include <torch/nn/options/loss.h>
|
| 7 |
+
#include <torch/nn/pimpl.h>
|
| 8 |
+
#include <torch/types.h>
|
| 9 |
+
|
| 10 |
+
#include <torch/csrc/Export.h>
|
| 11 |
+
|
| 12 |
+
#include <cstddef>
|
| 13 |
+
#include <vector>
|
| 14 |
+
|
| 15 |
+
namespace torch {
|
| 16 |
+
namespace nn {
|
| 17 |
+
|
| 18 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ L1Loss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 19 |
+
|
| 20 |
+
/// Creates a criterion that measures the mean absolute error (MAE) between each
|
| 21 |
+
/// element in the input : math :`x` and target : `y`.
|
| 22 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.L1Loss to learn
|
| 23 |
+
/// about the exact behavior of this module.
|
| 24 |
+
///
|
| 25 |
+
/// See the documentation for `torch::nn::L1LossOptions` class to learn what
|
| 26 |
+
/// constructor arguments are supported for this module.
|
| 27 |
+
///
|
| 28 |
+
/// Example:
|
| 29 |
+
/// ```
|
| 30 |
+
/// L1Loss model(L1LossOptions(torch::kNone));
|
| 31 |
+
/// ```
|
| 32 |
+
struct TORCH_API L1LossImpl : Cloneable<L1LossImpl> {
|
| 33 |
+
explicit L1LossImpl(L1LossOptions options_ = {});
|
| 34 |
+
|
| 35 |
+
void reset() override;
|
| 36 |
+
|
| 37 |
+
/// Pretty prints the `L1Loss` module into the given `stream`.
|
| 38 |
+
void pretty_print(std::ostream& stream) const override;
|
| 39 |
+
|
| 40 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 41 |
+
|
| 42 |
+
/// The options with which this `Module` was constructed.
|
| 43 |
+
L1LossOptions options;
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
/// A `ModuleHolder` subclass for `L1LossImpl`.
|
| 47 |
+
/// See the documentation for `L1LossImpl` class to learn what methods it
|
| 48 |
+
/// provides, and examples of how to use `L1Loss` with
|
| 49 |
+
/// `torch::nn::L1LossOptions`. See the documentation for `ModuleHolder` to
|
| 50 |
+
/// learn about PyTorch's module storage semantics.
|
| 51 |
+
TORCH_MODULE(L1Loss);
|
| 52 |
+
|
| 53 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ KLDivLoss
|
| 54 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 55 |
+
|
| 56 |
+
/// The Kullback-Leibler divergence loss measure
|
| 57 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.KLDivLoss to learn
|
| 58 |
+
/// about the exact behavior of this module.
|
| 59 |
+
///
|
| 60 |
+
/// See the documentation for `torch::nn::KLDivLossOptions` class to learn what
|
| 61 |
+
/// constructor arguments are supported for this module.
|
| 62 |
+
///
|
| 63 |
+
/// Example:
|
| 64 |
+
/// ```
|
| 65 |
+
/// KLDivLoss model(KLDivLossOptions().reduction(torch::kNone));
|
| 66 |
+
/// ```
|
| 67 |
+
struct TORCH_API KLDivLossImpl : Cloneable<KLDivLossImpl> {
|
| 68 |
+
explicit KLDivLossImpl(KLDivLossOptions options_ = {});
|
| 69 |
+
|
| 70 |
+
void reset() override;
|
| 71 |
+
|
| 72 |
+
/// Pretty prints the `KLDivLoss` module into the given `stream`.
|
| 73 |
+
void pretty_print(std::ostream& stream) const override;
|
| 74 |
+
|
| 75 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 76 |
+
|
| 77 |
+
/// The options with which this `Module` was constructed.
|
| 78 |
+
KLDivLossOptions options;
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
/// A `ModuleHolder` subclass for `KLDivLossImpl`.
|
| 82 |
+
/// See the documentation for `KLDivLossImpl` class to learn what methods it
|
| 83 |
+
/// provides, and examples of how to use `KLDivLoss` with
|
| 84 |
+
/// `torch::nn::KLDivLossOptions`. See the documentation for `ModuleHolder` to
|
| 85 |
+
/// learn about PyTorch's module storage semantics.
|
| 86 |
+
TORCH_MODULE(KLDivLoss);
|
| 87 |
+
|
| 88 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MSELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 89 |
+
|
| 90 |
+
/// Creates a criterion that measures the mean squared error (squared L2 norm)
|
| 91 |
+
/// between each element in the input :math:`x` and target :math:`y`.
|
| 92 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.MSELoss to learn
|
| 93 |
+
/// about the exact behavior of this module.
|
| 94 |
+
///
|
| 95 |
+
/// See the documentation for `torch::nn::MSELossOptions` class to learn what
|
| 96 |
+
/// constructor arguments are supported for this module.
|
| 97 |
+
///
|
| 98 |
+
/// Example:
|
| 99 |
+
/// ```
|
| 100 |
+
/// MSELoss model(MSELossOptions(torch::kNone));
|
| 101 |
+
/// ```
|
| 102 |
+
struct TORCH_API MSELossImpl : Cloneable<MSELossImpl> {
|
| 103 |
+
explicit MSELossImpl(MSELossOptions options_ = {});
|
| 104 |
+
|
| 105 |
+
void reset() override;
|
| 106 |
+
|
| 107 |
+
/// Pretty prints the `MSELoss` module into the given `stream`.
|
| 108 |
+
void pretty_print(std::ostream& stream) const override;
|
| 109 |
+
|
| 110 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 111 |
+
|
| 112 |
+
/// The options with which this `Module` was constructed.
|
| 113 |
+
MSELossOptions options;
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
/// A `ModuleHolder` subclass for `MSELossImpl`.
|
| 117 |
+
/// See the documentation for `MSELossImpl` class to learn what methods it
|
| 118 |
+
/// provides, and examples of how to use `MSELoss` with
|
| 119 |
+
/// `torch::nn::MSELossOptions`. See the documentation for `ModuleHolder` to
|
| 120 |
+
/// learn about PyTorch's module storage semantics.
|
| 121 |
+
TORCH_MODULE(MSELoss);
|
| 122 |
+
|
| 123 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 124 |
+
|
| 125 |
+
/// Creates a criterion that measures the Binary Cross Entropy
|
| 126 |
+
/// between the target and the output.
|
| 127 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.BCELoss to learn
|
| 128 |
+
/// about the exact behavior of this module.
|
| 129 |
+
///
|
| 130 |
+
/// See the documentation for `torch::nn::BCELossOptions` class to learn what
|
| 131 |
+
/// constructor arguments are supported for this module.
|
| 132 |
+
///
|
| 133 |
+
/// Example:
|
| 134 |
+
/// ```
|
| 135 |
+
/// BCELoss model(BCELossOptions().reduction(torch::kNone).weight(weight));
|
| 136 |
+
/// ```
|
| 137 |
+
struct TORCH_API BCELossImpl : Cloneable<BCELossImpl> {
|
| 138 |
+
explicit BCELossImpl(BCELossOptions options_ = {});
|
| 139 |
+
|
| 140 |
+
void reset() override;
|
| 141 |
+
|
| 142 |
+
/// Pretty prints the `BCELoss` module into the given `stream`.
|
| 143 |
+
void pretty_print(std::ostream& stream) const override;
|
| 144 |
+
|
| 145 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 146 |
+
|
| 147 |
+
/// The options with which this `Module` was constructed.
|
| 148 |
+
BCELossOptions options;
|
| 149 |
+
};
|
| 150 |
+
|
| 151 |
+
/// A `ModuleHolder` subclass for `BCELossImpl`.
|
| 152 |
+
/// See the documentation for `BCELossImpl` class to learn what methods it
|
| 153 |
+
/// provides, and examples of how to use `BCELoss` with
|
| 154 |
+
/// `torch::nn::BCELossOptions`. See the documentation for `ModuleHolder` to
|
| 155 |
+
/// learn about PyTorch's module storage semantics.
|
| 156 |
+
TORCH_MODULE(BCELoss);
|
| 157 |
+
|
| 158 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HingeEmbeddingLoss
|
| 159 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 160 |
+
|
| 161 |
+
/// Creates a criterion that measures the loss given an input tensor :math:`x`
|
| 162 |
+
/// and a labels tensor :math:`y` (containing 1 or -1).
|
| 163 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.HingeEmbeddingLoss to
|
| 164 |
+
/// learn about the exact behavior of this module.
|
| 165 |
+
///
|
| 166 |
+
/// See the documentation for `torch::nn::HingeEmbeddingLossOptions` class to
|
| 167 |
+
/// learn what constructor arguments are supported for this module.
|
| 168 |
+
///
|
| 169 |
+
/// Example:
|
| 170 |
+
/// ```
|
| 171 |
+
/// HingeEmbeddingLoss
|
| 172 |
+
/// model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone));
|
| 173 |
+
/// ```
|
| 174 |
+
struct TORCH_API HingeEmbeddingLossImpl : Cloneable<HingeEmbeddingLossImpl> {
|
| 175 |
+
explicit HingeEmbeddingLossImpl(HingeEmbeddingLossOptions options_ = {});
|
| 176 |
+
|
| 177 |
+
void reset() override;
|
| 178 |
+
|
| 179 |
+
/// Pretty prints the `HingeEmbeddingLoss` module into the given `stream`.
|
| 180 |
+
void pretty_print(std::ostream& stream) const override;
|
| 181 |
+
|
| 182 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 183 |
+
|
| 184 |
+
/// The options with which this `Module` was constructed.
|
| 185 |
+
HingeEmbeddingLossOptions options;
|
| 186 |
+
};
|
| 187 |
+
|
| 188 |
+
/// A `ModuleHolder` subclass for `HingeEmbeddingLossImpl`.
|
| 189 |
+
/// See the documentation for `HingeEmbeddingLossImpl` class to learn what
|
| 190 |
+
/// methods it provides, and examples of how to use `HingeEmbeddingLoss` with
|
| 191 |
+
/// `torch::nn::HingeEmbeddingLossOptions`. See the documentation for
|
| 192 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 193 |
+
TORCH_MODULE(HingeEmbeddingLoss);
|
| 194 |
+
|
| 195 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiMarginLoss
|
| 196 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 197 |
+
|
| 198 |
+
/// Creates a criterion that optimizes a multi-class classification hinge
|
| 199 |
+
/// loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
|
| 200 |
+
/// and output :math:`y` (which is a 1D tensor of target class indices, :math:`0
|
| 201 |
+
/// \leq y \leq \text{x.size}(1)-1`). See
|
| 202 |
+
/// https://pytorch.org/docs/main/nn.html#torch.nn.MultiMarginLoss to learn
|
| 203 |
+
/// about the exact behavior of this module.
|
| 204 |
+
///
|
| 205 |
+
/// See the documentation for `torch::nn::MultiMarginLossOptions` class to learn
|
| 206 |
+
/// what constructor arguments are supported for this module.
|
| 207 |
+
///
|
| 208 |
+
/// Example:
|
| 209 |
+
/// ```
|
| 210 |
+
/// MultiMarginLoss model(MultiMarginLossOptions().margin(2).weight(weight));
|
| 211 |
+
/// ```
|
| 212 |
+
struct TORCH_API MultiMarginLossImpl : public Cloneable<MultiMarginLossImpl> {
|
| 213 |
+
explicit MultiMarginLossImpl(MultiMarginLossOptions options_ = {});
|
| 214 |
+
|
| 215 |
+
void reset() override;
|
| 216 |
+
|
| 217 |
+
/// Pretty prints the `MultiMarginLoss` module into the given `stream`.
|
| 218 |
+
void pretty_print(std::ostream& stream) const override;
|
| 219 |
+
|
| 220 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 221 |
+
|
| 222 |
+
/// The options with which this `Module` was constructed.
|
| 223 |
+
MultiMarginLossOptions options;
|
| 224 |
+
};
|
| 225 |
+
|
| 226 |
+
/// A `ModuleHolder` subclass for `MultiMarginLossImpl`.
|
| 227 |
+
/// See the documentation for `MultiMarginLossImpl` class to learn what methods
|
| 228 |
+
/// it provides, and examples of how to use `MultiMarginLoss` with
|
| 229 |
+
/// `torch::nn::MultiMarginLossOptions`. See the documentation for
|
| 230 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 231 |
+
TORCH_MODULE(MultiMarginLoss);
|
| 232 |
+
|
| 233 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CosineEmbeddingLoss
|
| 234 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 235 |
+
|
| 236 |
+
/// Creates a criterion that measures the loss given input tensors
|
| 237 |
+
/// `input1`, `input2`, and a `Tensor` label `target` with values 1 or
|
| 238 |
+
/// -1. This is used for measuring whether two inputs are similar or
|
| 239 |
+
/// dissimilar, using the cosine distance, and is typically used for learning
|
| 240 |
+
/// nonlinear embeddings or semi-supervised learning.
|
| 241 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.CosineEmbeddingLoss to
|
| 242 |
+
/// learn about the exact behavior of this module.
|
| 243 |
+
///
|
| 244 |
+
/// See the documentation for `torch::nn::CosineEmbeddingLossOptions` class to
|
| 245 |
+
/// learn what constructor arguments are supported for this module.
|
| 246 |
+
///
|
| 247 |
+
/// Example:
|
| 248 |
+
/// ```
|
| 249 |
+
/// CosineEmbeddingLoss model(CosineEmbeddingLossOptions().margin(0.5));
|
| 250 |
+
/// ```
|
| 251 |
+
struct TORCH_API CosineEmbeddingLossImpl
|
| 252 |
+
: public Cloneable<CosineEmbeddingLossImpl> {
|
| 253 |
+
explicit CosineEmbeddingLossImpl(CosineEmbeddingLossOptions options_ = {});
|
| 254 |
+
|
| 255 |
+
void reset() override;
|
| 256 |
+
|
| 257 |
+
/// Pretty prints the `CosineEmbeddingLoss` module into the given `stream`.
|
| 258 |
+
void pretty_print(std::ostream& stream) const override;
|
| 259 |
+
|
| 260 |
+
Tensor forward(
|
| 261 |
+
const Tensor& input1,
|
| 262 |
+
const Tensor& input2,
|
| 263 |
+
const Tensor& target);
|
| 264 |
+
|
| 265 |
+
/// The options with which this `Module` was constructed.
|
| 266 |
+
CosineEmbeddingLossOptions options;
|
| 267 |
+
};
|
| 268 |
+
|
| 269 |
+
/// A `ModuleHolder` subclass for `CosineEmbeddingLossImpl`.
|
| 270 |
+
/// See the documentation for `CosineEmbeddingLossImpl` class to learn what
|
| 271 |
+
/// methods it provides, and examples of how to use `CosineEmbeddingLoss` with
|
| 272 |
+
/// `torch::nn::CosineEmbeddingLossOptions`. See the documentation for
|
| 273 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 274 |
+
TORCH_MODULE(CosineEmbeddingLoss);
|
| 275 |
+
|
| 276 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SmoothL1Loss
|
| 277 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 278 |
+
|
| 279 |
+
/// Creates a criterion that uses a squared term if the absolute
|
| 280 |
+
/// element-wise error falls below beta and an L1 term otherwise.
|
| 281 |
+
/// It is less sensitive to outliers than the `MSELoss` and in some cases
|
| 282 |
+
/// prevents exploding gradients (e.g. see the paper `Fast R-CNN` by Ross
|
| 283 |
+
/// Girshick). See https://pytorch.org/docs/main/nn.html#torch.nn.SmoothL1Loss
|
| 284 |
+
/// to learn about the exact behavior of this module.
|
| 285 |
+
///
|
| 286 |
+
/// See the documentation for `torch::nn::SmoothL1LossOptions` class to learn
|
| 287 |
+
/// what constructor arguments are supported for this module.
|
| 288 |
+
///
|
| 289 |
+
/// Example:
|
| 290 |
+
/// ```
|
| 291 |
+
/// SmoothL1Loss model(SmoothL1LossOptions().reduction(torch::kNone).beta(0.5));
|
| 292 |
+
/// ```
|
| 293 |
+
struct TORCH_API SmoothL1LossImpl : public Cloneable<SmoothL1LossImpl> {
|
| 294 |
+
explicit SmoothL1LossImpl(SmoothL1LossOptions options = {});
|
| 295 |
+
|
| 296 |
+
void reset() override;
|
| 297 |
+
|
| 298 |
+
/// Pretty prints the `L1Loss` module into the given `stream`.
|
| 299 |
+
void pretty_print(std::ostream& stream) const override;
|
| 300 |
+
|
| 301 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 302 |
+
|
| 303 |
+
/// The options with which this `Module` was constructed.
|
| 304 |
+
SmoothL1LossOptions options;
|
| 305 |
+
};
|
| 306 |
+
|
| 307 |
+
/// A `ModuleHolder` subclass for `SmoothL1LossImpl`.
|
| 308 |
+
/// See the documentation for `SmoothL1LossImpl` class to learn what methods it
|
| 309 |
+
/// provides, and examples of how to use `SmoothL1Loss` with
|
| 310 |
+
/// `torch::nn::SmoothL1LossOptions`. See the documentation for `ModuleHolder`
|
| 311 |
+
/// to learn about PyTorch's module storage semantics.
|
| 312 |
+
TORCH_MODULE(SmoothL1Loss);
|
| 313 |
+
|
| 314 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HuberLoss
|
| 315 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 316 |
+
|
| 317 |
+
/// Creates a criterion that uses a squared term if the absolute
|
| 318 |
+
/// element-wise error falls below delta and a delta-scaled L1 term otherwise.
|
| 319 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.HuberLoss to learn
|
| 320 |
+
/// about the exact behavior of this module.
|
| 321 |
+
///
|
| 322 |
+
/// See the documentation for `torch::nn::HuberLossOptions` class to learn what
|
| 323 |
+
/// constructor arguments are supported for this module.
|
| 324 |
+
///
|
| 325 |
+
/// Example:
|
| 326 |
+
/// ```
|
| 327 |
+
/// HuberLoss model(HuberLossOptions().reduction(torch::kNone).delta(0.5));
|
| 328 |
+
/// ```
|
| 329 |
+
struct TORCH_API HuberLossImpl : public Cloneable<HuberLossImpl> {
|
| 330 |
+
explicit HuberLossImpl(HuberLossOptions options_ = {});
|
| 331 |
+
|
| 332 |
+
void reset() override;
|
| 333 |
+
|
| 334 |
+
/// Pretty prints the `HuberLoss` module into the given `stream`.
|
| 335 |
+
void pretty_print(std::ostream& stream) const override;
|
| 336 |
+
|
| 337 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 338 |
+
|
| 339 |
+
/// The options with which this `Module` was constructed.
|
| 340 |
+
HuberLossOptions options;
|
| 341 |
+
};
|
| 342 |
+
|
| 343 |
+
/// A `ModuleHolder` subclass for `HuberLossImpl`.
|
| 344 |
+
/// See the documentation for `HuberLossImpl` class to learn what methods it
|
| 345 |
+
/// provides, and examples of how to use `HuberLoss` with
|
| 346 |
+
/// `torch::nn::HuberLossOptions`. See the documentation for `ModuleHolder` to
|
| 347 |
+
/// learn about PyTorch's module storage semantics.
|
| 348 |
+
TORCH_MODULE(HuberLoss);
|
| 349 |
+
|
| 350 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss
|
| 351 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 352 |
+
|
| 353 |
+
/// Creates a criterion that optimizes a multi-class multi-classification
|
| 354 |
+
/// hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch
|
| 355 |
+
/// `Tensor`) and output :math:`y` (which is a 2D `Tensor` of target class
|
| 356 |
+
/// indices). See
|
| 357 |
+
/// https://pytorch.org/docs/main/nn.html#torch.nn.MultiLabelMarginLoss to
|
| 358 |
+
/// learn about the exact behavior of this module.
|
| 359 |
+
///
|
| 360 |
+
/// See the documentation for `torch::nn::MultiLabelMarginLossOptions` class to
|
| 361 |
+
/// learn what constructor arguments are supported for this module.
|
| 362 |
+
///
|
| 363 |
+
/// Example:
|
| 364 |
+
/// ```
|
| 365 |
+
/// MultiLabelMarginLoss model(MultiLabelMarginLossOptions(torch::kNone));
|
| 366 |
+
/// ```
|
| 367 |
+
struct TORCH_API MultiLabelMarginLossImpl
|
| 368 |
+
: public Cloneable<MultiLabelMarginLossImpl> {
|
| 369 |
+
explicit MultiLabelMarginLossImpl(MultiLabelMarginLossOptions options_ = {});
|
| 370 |
+
|
| 371 |
+
void reset() override;
|
| 372 |
+
|
| 373 |
+
/// Pretty prints the `L1Loss` module into the given `stream`.
|
| 374 |
+
void pretty_print(std::ostream& stream) const override;
|
| 375 |
+
|
| 376 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 377 |
+
|
| 378 |
+
/// The options with which this `Module` was constructed.
|
| 379 |
+
MultiLabelMarginLossOptions options;
|
| 380 |
+
};
|
| 381 |
+
|
| 382 |
+
/// A `ModuleHolder` subclass for `MultiLabelMarginLossImpl`.
|
| 383 |
+
/// See the documentation for `MultiLabelMarginLossImpl` class to learn what
|
| 384 |
+
/// methods it provides, and examples of how to use `MultiLabelMarginLoss` with
|
| 385 |
+
/// `torch::nn::MultiLabelMarginLossOptions`. See the documentation for
|
| 386 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 387 |
+
TORCH_MODULE(MultiLabelMarginLoss);
|
| 388 |
+
|
| 389 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SoftMarginLoss
|
| 390 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 391 |
+
|
| 392 |
+
/// Creates a criterion that optimizes a two-class classification
|
| 393 |
+
/// logistic loss between input tensor :math:`x` and target tensor :math:`y`
|
| 394 |
+
/// (containing 1 or -1).
|
| 395 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.SoftMarginLoss to learn
|
| 396 |
+
/// about the exact behavior of this module.
|
| 397 |
+
///
|
| 398 |
+
/// See the documentation for `torch::nn::SoftMarginLossOptions` class to learn
|
| 399 |
+
/// what constructor arguments are supported for this module.
|
| 400 |
+
///
|
| 401 |
+
/// Example:
|
| 402 |
+
/// ```
|
| 403 |
+
/// SoftMarginLoss model(SoftMarginLossOptions(torch::kNone));
|
| 404 |
+
/// ```
|
| 405 |
+
struct TORCH_API SoftMarginLossImpl : public Cloneable<SoftMarginLossImpl> {
|
| 406 |
+
explicit SoftMarginLossImpl(SoftMarginLossOptions options_ = {});
|
| 407 |
+
|
| 408 |
+
/// Pretty prints the `SoftMarginLoss` module into the given `stream`.
|
| 409 |
+
void pretty_print(std::ostream& stream) const override;
|
| 410 |
+
|
| 411 |
+
void reset() override;
|
| 412 |
+
|
| 413 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 414 |
+
|
| 415 |
+
/// The options with which this `Module` was constructed.
|
| 416 |
+
SoftMarginLossOptions options;
|
| 417 |
+
};
|
| 418 |
+
|
| 419 |
+
/// A `ModuleHolder` subclass for `SoftMarginLossImpl`.
|
| 420 |
+
/// See the documentation for `SoftMarginLossImpl` class to learn what methods
|
| 421 |
+
/// it provides, and examples of how to use `SoftMarginLoss` with
|
| 422 |
+
/// `torch::nn::SoftMarginLossOptions`. See the documentation for `ModuleHolder`
|
| 423 |
+
/// to learn about PyTorch's module storage semantics.
|
| 424 |
+
TORCH_MODULE(SoftMarginLoss);
|
| 425 |
+
|
| 426 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelSoftMarginLoss
|
| 427 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 428 |
+
|
| 429 |
+
/// Creates a criterion that optimizes a multi-label one-versus-all
|
| 430 |
+
/// loss based on max-entropy, between input :math:`x` and target :math:`y` of
|
| 431 |
+
/// size :math:`(N, C)`. See
|
| 432 |
+
/// https://pytorch.org/docs/main/nn.html#torch.nn.MultiLabelSoftMarginLoss to
|
| 433 |
+
/// learn about the exact behavior of this module.
|
| 434 |
+
///
|
| 435 |
+
/// See the documentation for `torch::nn::MultiLabelSoftMarginLossOptions` class
|
| 436 |
+
/// to learn what constructor arguments are supported for this module.
|
| 437 |
+
///
|
| 438 |
+
/// Example:
|
| 439 |
+
/// ```
|
| 440 |
+
/// MultiLabelSoftMarginLoss
|
| 441 |
+
/// model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight));
|
| 442 |
+
/// ```
|
| 443 |
+
struct TORCH_API MultiLabelSoftMarginLossImpl
|
| 444 |
+
: public Cloneable<MultiLabelSoftMarginLossImpl> {
|
| 445 |
+
explicit MultiLabelSoftMarginLossImpl(
|
| 446 |
+
MultiLabelSoftMarginLossOptions options_ = {});
|
| 447 |
+
|
| 448 |
+
/// Pretty prints the `MultiLabelSoftMarginLoss` module into the given
|
| 449 |
+
/// `stream`.
|
| 450 |
+
void pretty_print(std::ostream& stream) const override;
|
| 451 |
+
|
| 452 |
+
void reset() override;
|
| 453 |
+
|
| 454 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 455 |
+
|
| 456 |
+
/// The options with which this `Module` was constructed.
|
| 457 |
+
MultiLabelSoftMarginLossOptions options;
|
| 458 |
+
};
|
| 459 |
+
|
| 460 |
+
/// A `ModuleHolder` subclass for `MultiLabelSoftMarginLossImpl`.
|
| 461 |
+
/// See the documentation for `MultiLabelSoftMarginLossImpl` class to learn what
|
| 462 |
+
/// methods it provides, and examples of how to use `MultiLabelSoftMarginLoss`
|
| 463 |
+
/// with `torch::nn::MultiLabelSoftMarginLossOptions`. See the documentation for
|
| 464 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 465 |
+
TORCH_MODULE(MultiLabelSoftMarginLoss);
|
| 466 |
+
|
| 467 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginLoss
|
| 468 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 469 |
+
|
| 470 |
+
/// Creates a criterion that measures the triplet loss given an input
|
| 471 |
+
/// tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater
|
| 472 |
+
/// than :math:`0`. This is used for measuring a relative similarity between
|
| 473 |
+
/// samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`,
|
| 474 |
+
/// `positive examples` and `negative examples` respectively). The
|
| 475 |
+
/// shapes of all input tensors should be :math:`(N, D)`.
|
| 476 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.TripletMarginLoss to
|
| 477 |
+
/// learn about the exact behavior of this module.
|
| 478 |
+
///
|
| 479 |
+
/// See the documentation for `torch::nn::TripletMarginLossOptions` class to
|
| 480 |
+
/// learn what constructor arguments are supported for this module.
|
| 481 |
+
///
|
| 482 |
+
/// Example:
|
| 483 |
+
/// ```
|
| 484 |
+
/// TripletMarginLoss
|
| 485 |
+
/// model(TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false));
|
| 486 |
+
/// ```
|
| 487 |
+
struct TORCH_API TripletMarginLossImpl
|
| 488 |
+
: public Cloneable<TripletMarginLossImpl> {
|
| 489 |
+
explicit TripletMarginLossImpl(TripletMarginLossOptions options_ = {});
|
| 490 |
+
|
| 491 |
+
void reset() override;
|
| 492 |
+
|
| 493 |
+
/// Pretty prints the `TripletMarginLoss` module into the given `stream`.
|
| 494 |
+
void pretty_print(std::ostream& stream) const override;
|
| 495 |
+
|
| 496 |
+
Tensor forward(
|
| 497 |
+
const Tensor& anchor,
|
| 498 |
+
const Tensor& positive,
|
| 499 |
+
const Tensor& negative);
|
| 500 |
+
|
| 501 |
+
/// The options with which this `Module` was constructed.
|
| 502 |
+
TripletMarginLossOptions options;
|
| 503 |
+
};
|
| 504 |
+
|
| 505 |
+
/// A `ModuleHolder` subclass for `TripletMarginLossImpl`.
|
| 506 |
+
/// See the documentation for `TripletMarginLossImpl` class to learn what
|
| 507 |
+
/// methods it provides, and examples of how to use `TripletMarginLoss` with
|
| 508 |
+
/// `torch::nn::TripletMarginLossOptions`. See the documentation for
|
| 509 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 510 |
+
TORCH_MODULE(TripletMarginLoss);
|
| 511 |
+
|
| 512 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss
|
| 513 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 514 |
+
|
| 515 |
+
/// Creates a criterion that measures the triplet loss given input
|
| 516 |
+
/// tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
|
| 517 |
+
/// positive, and negative examples, respectively); and a nonnegative,
|
| 518 |
+
/// real-valued function
|
| 519 |
+
/// ("distance function") used to compute the relationships between the anchor
|
| 520 |
+
/// and positive example ("positive distance") and the anchor and negative
|
| 521 |
+
/// example ("negative distance").
|
| 522 |
+
/// See
|
| 523 |
+
/// https://pytorch.org/docs/main/nn.html#torch.nn.TripletMarginWithDistanceLoss
|
| 524 |
+
/// to learn about the exact behavior of this module.
|
| 525 |
+
///
|
| 526 |
+
/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions`
|
| 527 |
+
/// class to learn what constructor arguments are supported for this module.
|
| 528 |
+
///
|
| 529 |
+
/// Example:
|
| 530 |
+
/// ```
|
| 531 |
+
/// TripletMarginWithDistanceLoss
|
| 532 |
+
/// model(TripletMarginWithDistanceLossOptions().margin(3).swap(false));
|
| 533 |
+
/// ```
|
| 534 |
+
struct TORCH_API TripletMarginWithDistanceLossImpl
|
| 535 |
+
: public Cloneable<TripletMarginWithDistanceLossImpl> {
|
| 536 |
+
explicit TripletMarginWithDistanceLossImpl(
|
| 537 |
+
TripletMarginWithDistanceLossOptions options_ = {});
|
| 538 |
+
|
| 539 |
+
void reset() override;
|
| 540 |
+
|
| 541 |
+
/// Pretty prints the `TripletMarginWithDistanceLoss` module into the given
|
| 542 |
+
/// `stream`.
|
| 543 |
+
void pretty_print(std::ostream& stream) const override;
|
| 544 |
+
|
| 545 |
+
Tensor forward(
|
| 546 |
+
const Tensor& anchor,
|
| 547 |
+
const Tensor& positive,
|
| 548 |
+
const Tensor& negative);
|
| 549 |
+
|
| 550 |
+
/// The options with which this `Module` was constructed.
|
| 551 |
+
TripletMarginWithDistanceLossOptions options;
|
| 552 |
+
};
|
| 553 |
+
|
| 554 |
+
/// A `ModuleHolder` subclass for `TripletMarginWithDistanceLossImpl`.
|
| 555 |
+
/// See the documentation for `TripletMarginWithDistanceLossImpl` class to learn
|
| 556 |
+
/// what methods it provides, and examples of how to use
|
| 557 |
+
/// `TripletMarginWithDistanceLoss` with
|
| 558 |
+
/// `torch::nn::TripletMarginWithDistanceLossOptions`.
|
| 559 |
+
/// See the documentation for `ModuleHolder` to learn about PyTorch's
|
| 560 |
+
/// module storage semantics.
|
| 561 |
+
TORCH_MODULE(TripletMarginWithDistanceLoss);
|
| 562 |
+
|
| 563 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CTCLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 564 |
+
|
| 565 |
+
/// The Connectionist Temporal Classification loss.
|
| 566 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.CTCLoss to learn
|
| 567 |
+
/// about the exact behavior of this module.
|
| 568 |
+
///
|
| 569 |
+
/// See the documentation for `torch::nn::CTCLossOptions` class to learn what
|
| 570 |
+
/// constructor arguments are supported for this module.
|
| 571 |
+
///
|
| 572 |
+
/// Example:
|
| 573 |
+
/// ```
|
| 574 |
+
/// CTCLoss
|
| 575 |
+
/// model(CTCLossOptions().blank(42).zero_infinity(false).reduction(torch::kSum));
|
| 576 |
+
/// ```
|
| 577 |
+
struct TORCH_API CTCLossImpl : public Cloneable<CTCLossImpl> {
|
| 578 |
+
explicit CTCLossImpl(CTCLossOptions options_ = {});
|
| 579 |
+
|
| 580 |
+
void reset() override;
|
| 581 |
+
|
| 582 |
+
/// Pretty prints the `CTCLoss` module into the given `stream`.
|
| 583 |
+
void pretty_print(std::ostream& stream) const override;
|
| 584 |
+
|
| 585 |
+
Tensor forward(
|
| 586 |
+
const Tensor& log_probs,
|
| 587 |
+
const Tensor& targets,
|
| 588 |
+
const Tensor& input_lengths,
|
| 589 |
+
const Tensor& target_lengths);
|
| 590 |
+
|
| 591 |
+
/// The options with which this `Module` was constructed.
|
| 592 |
+
CTCLossOptions options;
|
| 593 |
+
};
|
| 594 |
+
|
| 595 |
+
/// A `ModuleHolder` subclass for `CTCLossImpl`.
|
| 596 |
+
/// See the documentation for `CTCLossImpl` class to learn what methods it
|
| 597 |
+
/// provides, and examples of how to use `CTCLoss` with
|
| 598 |
+
/// `torch::nn::CTCLossOptions`. See the documentation for `ModuleHolder` to
|
| 599 |
+
/// learn about PyTorch's module storage semantics.
|
| 600 |
+
TORCH_MODULE(CTCLoss);
|
| 601 |
+
|
| 602 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PoissonNLLLoss
|
| 603 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 604 |
+
|
| 605 |
+
/// Negative log likelihood loss with Poisson distribution of target.
|
| 606 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.PoissonNLLLoss to learn
|
| 607 |
+
/// about the exact behavior of this module.
|
| 608 |
+
///
|
| 609 |
+
/// See the documentation for `torch::nn::PoissonNLLLossOptions` class to learn
|
| 610 |
+
/// what constructor arguments are supported for this module.
|
| 611 |
+
///
|
| 612 |
+
/// Example:
|
| 613 |
+
/// ```
|
| 614 |
+
/// PoissonNLLLoss
|
| 615 |
+
/// model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum));
|
| 616 |
+
/// ```
|
| 617 |
+
struct TORCH_API PoissonNLLLossImpl : public Cloneable<PoissonNLLLossImpl> {
|
| 618 |
+
explicit PoissonNLLLossImpl(PoissonNLLLossOptions options_ = {});
|
| 619 |
+
|
| 620 |
+
void reset() override;
|
| 621 |
+
|
| 622 |
+
/// Pretty prints the `PoissonNLLLoss` module into the given `stream`.
|
| 623 |
+
void pretty_print(std::ostream& stream) const override;
|
| 624 |
+
|
| 625 |
+
Tensor forward(const Tensor& log_input, const Tensor& targets);
|
| 626 |
+
|
| 627 |
+
/// The options with which this `Module` was constructed.
|
| 628 |
+
PoissonNLLLossOptions options;
|
| 629 |
+
};
|
| 630 |
+
|
| 631 |
+
/// A `ModuleHolder` subclass for `PoissonNLLLossImpl`.
|
| 632 |
+
/// See the documentation for `PoissonNLLLossImpl` class to learn what methods
|
| 633 |
+
/// it provides, and examples of how to use `PoissonNLLLoss` with
|
| 634 |
+
/// `torch::nn::PoissonNLLLossOptions`. See the documentation for `ModuleHolder`
|
| 635 |
+
/// to learn about PyTorch's module storage semantics.
|
| 636 |
+
TORCH_MODULE(PoissonNLLLoss);
|
| 637 |
+
|
| 638 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MarginRankingLoss
|
| 639 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 640 |
+
|
| 641 |
+
/// Creates a criterion that measures the loss given
|
| 642 |
+
/// inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`,
|
| 643 |
+
/// and a label 1D mini-batch tensor :math:`y` (containing 1 or -1).
|
| 644 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.MarginRankingLoss to
|
| 645 |
+
/// learn about the exact behavior of this module.
|
| 646 |
+
///
|
| 647 |
+
/// See the documentation for `torch::nn::MarginRankingLossOptions` class to
|
| 648 |
+
/// learn what constructor arguments are supported for this module.
|
| 649 |
+
///
|
| 650 |
+
/// Example:
|
| 651 |
+
/// ```
|
| 652 |
+
/// MarginRankingLoss
|
| 653 |
+
/// model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum));
|
| 654 |
+
/// ```
|
| 655 |
+
struct TORCH_API MarginRankingLossImpl
|
| 656 |
+
: public Cloneable<MarginRankingLossImpl> {
|
| 657 |
+
explicit MarginRankingLossImpl(MarginRankingLossOptions options_ = {});
|
| 658 |
+
|
| 659 |
+
void reset() override;
|
| 660 |
+
|
| 661 |
+
/// Pretty prints the `MarginRankingLoss` module into the given `stream`.
|
| 662 |
+
void pretty_print(std::ostream& stream) const override;
|
| 663 |
+
|
| 664 |
+
Tensor forward(
|
| 665 |
+
const Tensor& input1,
|
| 666 |
+
const Tensor& input2,
|
| 667 |
+
const Tensor& targets);
|
| 668 |
+
|
| 669 |
+
/// The options with which this `Module` was constructed.
|
| 670 |
+
MarginRankingLossOptions options;
|
| 671 |
+
};
|
| 672 |
+
|
| 673 |
+
/// A `ModuleHolder` subclass for `MarginRankingLossImpl`.
|
| 674 |
+
/// See the documentation for `MarginRankingLossImpl` class to learn what
|
| 675 |
+
/// methods it provides, and examples of how to use `MarginRankingLoss` with
|
| 676 |
+
/// `torch::nn::MarginRankingLossOptions`. See the documentation for
|
| 677 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 678 |
+
TORCH_MODULE(MarginRankingLoss);
|
| 679 |
+
|
| 680 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ NLLLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 681 |
+
|
| 682 |
+
/// The negative log likelihood loss. It is useful to train a classification
|
| 683 |
+
/// problem with `C` classes.
|
| 684 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss to learn
|
| 685 |
+
/// about the exact behavior of this module.
|
| 686 |
+
///
|
| 687 |
+
/// See the documentation for `torch::nn::NLLLossOptions` class to learn what
|
| 688 |
+
/// constructor arguments are supported for this module.
|
| 689 |
+
///
|
| 690 |
+
/// Example:
|
| 691 |
+
/// ```
|
| 692 |
+
/// NLLLoss model(NLLLossOptions().ignore_index(-100).reduction(torch::kMean));
|
| 693 |
+
/// ```
|
| 694 |
+
struct TORCH_API NLLLossImpl : public Cloneable<NLLLossImpl> {
|
| 695 |
+
explicit NLLLossImpl(NLLLossOptions options_ = {});
|
| 696 |
+
|
| 697 |
+
/// Pretty prints the `NLLLoss` module into the given `stream`.
|
| 698 |
+
void pretty_print(std::ostream& stream) const override;
|
| 699 |
+
|
| 700 |
+
void reset() override;
|
| 701 |
+
|
| 702 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 703 |
+
|
| 704 |
+
/// The options with which this `Module` was constructed.
|
| 705 |
+
NLLLossOptions options;
|
| 706 |
+
|
| 707 |
+
/// A manual rescaling weight given to to each class.
|
| 708 |
+
Tensor weight;
|
| 709 |
+
};
|
| 710 |
+
|
| 711 |
+
/// A `ModuleHolder` subclass for `NLLLossImpl`.
|
| 712 |
+
/// See the documentation for `NLLLossImpl` class to learn what methods it
|
| 713 |
+
/// provides, and examples of how to use `NLLLoss` with
|
| 714 |
+
/// `torch::nn::NLLLossOptions`. See the documentation for `ModuleHolder` to
|
| 715 |
+
/// learn about PyTorch's module storage semantics.
|
| 716 |
+
TORCH_MODULE(NLLLoss);
|
| 717 |
+
|
| 718 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossEntropyLoss
|
| 719 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 720 |
+
|
| 721 |
+
/// Creates a criterion that computes cross entropy loss between input and
|
| 722 |
+
/// target. See
|
| 723 |
+
/// https://pytorch.org/docs/main/nn.html#torch.nn.CrossEntropyLoss to learn
|
| 724 |
+
/// about the exact behavior of this module.
|
| 725 |
+
///
|
| 726 |
+
/// See the documentation for `torch::nn::CrossEntropyLossOptions` class to
|
| 727 |
+
/// learn what constructor arguments are supported for this module.
|
| 728 |
+
///
|
| 729 |
+
/// Example:
|
| 730 |
+
/// ```
|
| 731 |
+
/// CrossEntropyLoss
|
| 732 |
+
/// model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean));
|
| 733 |
+
/// ```
|
| 734 |
+
struct TORCH_API CrossEntropyLossImpl : public Cloneable<CrossEntropyLossImpl> {
|
| 735 |
+
explicit CrossEntropyLossImpl(CrossEntropyLossOptions options_ = {});
|
| 736 |
+
|
| 737 |
+
void reset() override;
|
| 738 |
+
|
| 739 |
+
/// Pretty prints the `CrossEntropyLoss` module into the given `stream`.
|
| 740 |
+
void pretty_print(std::ostream& stream) const override;
|
| 741 |
+
|
| 742 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 743 |
+
|
| 744 |
+
/// The options with which this `Module` was constructed.
|
| 745 |
+
CrossEntropyLossOptions options;
|
| 746 |
+
|
| 747 |
+
/// A manual rescaling weight given to to each class.
|
| 748 |
+
Tensor weight;
|
| 749 |
+
};
|
| 750 |
+
|
| 751 |
+
/// A `ModuleHolder` subclass for `CrossEntropyLossImpl`.
|
| 752 |
+
/// See the documentation for `CrossEntropyLossImpl` class to learn what methods
|
| 753 |
+
/// it provides, and examples of how to use `CrossEntropyLoss` with
|
| 754 |
+
/// `torch::nn::CrossEntropyLossOptions`. See the documentation for
|
| 755 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 756 |
+
TORCH_MODULE(CrossEntropyLoss);
|
| 757 |
+
|
| 758 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCEWithLogitsLoss
|
| 759 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 760 |
+
|
| 761 |
+
/// This loss combines a `Sigmoid` layer and the `BCELoss` in one single
|
| 762 |
+
/// class. This version is more numerically stable than using a plain `Sigmoid`
|
| 763 |
+
/// followed by a `BCELoss` as, by combining the operations into one layer,
|
| 764 |
+
/// we take advantage of the log-sum-exp trick for numerical stability.
|
| 765 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.BCEWithLogitsLoss to
|
| 766 |
+
/// learn about the exact behavior of this module.
|
| 767 |
+
///
|
| 768 |
+
/// See the documentation for `torch::nn::BCEWithLogitsLossOptions` class to
|
| 769 |
+
/// learn what constructor arguments are supported for this module.
|
| 770 |
+
///
|
| 771 |
+
/// Example:
|
| 772 |
+
/// ```
|
| 773 |
+
/// BCEWithLogitsLoss
|
| 774 |
+
/// model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight));
|
| 775 |
+
/// ```
|
| 776 |
+
struct TORCH_API BCEWithLogitsLossImpl
|
| 777 |
+
: public Cloneable<BCEWithLogitsLossImpl> {
|
| 778 |
+
explicit BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_ = {});
|
| 779 |
+
|
| 780 |
+
void reset() override;
|
| 781 |
+
|
| 782 |
+
/// Pretty prints the `BCEWithLogitsLoss` module into the given `stream`.
|
| 783 |
+
void pretty_print(std::ostream& stream) const override;
|
| 784 |
+
|
| 785 |
+
Tensor forward(const Tensor& input, const Tensor& target);
|
| 786 |
+
|
| 787 |
+
/// The options with which this `Module` was constructed.
|
| 788 |
+
BCEWithLogitsLossOptions options;
|
| 789 |
+
|
| 790 |
+
/// A manual rescaling weight given to the loss of each batch element.
|
| 791 |
+
Tensor weight;
|
| 792 |
+
|
| 793 |
+
/// A weight of positive examples.
|
| 794 |
+
Tensor pos_weight;
|
| 795 |
+
};
|
| 796 |
+
|
| 797 |
+
/// A `ModuleHolder` subclass for `BCEWithLogitsLossImpl`.
|
| 798 |
+
/// See the documentation for `BCEWithLogitsLossImpl` class to learn what
|
| 799 |
+
/// methods it provides, and examples of how to use `BCEWithLogitsLoss` with
|
| 800 |
+
/// `torch::nn::BCEWithLogitsLossOptions`. See the documentation for
|
| 801 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 802 |
+
TORCH_MODULE(BCEWithLogitsLoss);
|
| 803 |
+
|
| 804 |
+
} // namespace nn
|
| 805 |
+
} // namespace torch
|
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/nn/cloneable.h>
|
| 4 |
+
#include <torch/nn/functional/normalization.h>
|
| 5 |
+
#include <torch/nn/modules/_functions.h>
|
| 6 |
+
#include <torch/nn/options/normalization.h>
|
| 7 |
+
#include <torch/nn/pimpl.h>
|
| 8 |
+
#include <torch/types.h>
|
| 9 |
+
|
| 10 |
+
#include <cstddef>
|
| 11 |
+
#include <vector>
|
| 12 |
+
|
| 13 |
+
namespace torch {
|
| 14 |
+
namespace nn {
|
| 15 |
+
|
| 16 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LayerNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 17 |
+
|
| 18 |
+
/// Applies Layer Normalization over a mini-batch of inputs as described in
|
| 19 |
+
/// the paper `Layer Normalization`_ .
|
| 20 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.LayerNorm to learn
|
| 21 |
+
/// about the exact behavior of this module.
|
| 22 |
+
///
|
| 23 |
+
/// See the documentation for `torch::nn::LayerNormOptions` class to learn what
|
| 24 |
+
/// constructor arguments are supported for this module.
|
| 25 |
+
///
|
| 26 |
+
/// Example:
|
| 27 |
+
/// ```
|
| 28 |
+
/// LayerNorm model(LayerNormOptions({2,
|
| 29 |
+
/// 2}).elementwise_affine(false).eps(2e-5));
|
| 30 |
+
/// ```
|
| 31 |
+
class TORCH_API LayerNormImpl : public torch::nn::Cloneable<LayerNormImpl> {
|
| 32 |
+
public:
|
| 33 |
+
LayerNormImpl(std::vector<int64_t> normalized_shape)
|
| 34 |
+
: LayerNormImpl(LayerNormOptions(normalized_shape)) {}
|
| 35 |
+
explicit LayerNormImpl(LayerNormOptions options_);
|
| 36 |
+
|
| 37 |
+
void reset() override;
|
| 38 |
+
|
| 39 |
+
void reset_parameters();
|
| 40 |
+
|
| 41 |
+
/// Pretty prints the `LayerNorm` module into the given `stream`.
|
| 42 |
+
void pretty_print(std::ostream& stream) const override;
|
| 43 |
+
|
| 44 |
+
/// Applies layer normalization over a mini-batch of inputs as described in
|
| 45 |
+
/// the paper `Layer Normalization`_ .
|
| 46 |
+
///
|
| 47 |
+
/// The mean and standard-deviation are calculated separately over the last
|
| 48 |
+
/// certain number dimensions which have to be of the shape specified by
|
| 49 |
+
/// input `normalized_shape`.
|
| 50 |
+
///
|
| 51 |
+
/// `Layer Normalization`: https://arxiv.org/abs/1607.06450
|
| 52 |
+
Tensor forward(const Tensor& input);
|
| 53 |
+
|
| 54 |
+
/// The options with which this module was constructed.
|
| 55 |
+
LayerNormOptions options;
|
| 56 |
+
|
| 57 |
+
/// The learned weight.
|
| 58 |
+
/// Initialized to ones if the `elementwise_affine` option is set to `true`
|
| 59 |
+
/// upon construction.
|
| 60 |
+
Tensor weight;
|
| 61 |
+
|
| 62 |
+
/// The learned bias.
|
| 63 |
+
/// Initialized to zeros `elementwise_affine` option is set to `true` upon
|
| 64 |
+
/// construction.
|
| 65 |
+
Tensor bias;
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
/// A `ModuleHolder` subclass for `LayerNormImpl`.
|
| 69 |
+
/// See the documentation for `LayerNormImpl` class to learn what methods it
|
| 70 |
+
/// provides, and examples of how to use `LayerNorm` with
|
| 71 |
+
/// `torch::nn::LayerNormOptions`. See the documentation for `ModuleHolder` to
|
| 72 |
+
/// learn about PyTorch's module storage semantics.
|
| 73 |
+
TORCH_MODULE(LayerNorm);
|
| 74 |
+
|
| 75 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LocalResponseNorm
|
| 76 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 77 |
+
|
| 78 |
+
/// Applies local response normalization over an input signal composed
|
| 79 |
+
/// of several input planes, where channels occupy the second dimension.
|
| 80 |
+
/// Applies normalization across channels.
|
| 81 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.LocalResponseNorm to
|
| 82 |
+
/// learn about the exact behavior of this module.
|
| 83 |
+
///
|
| 84 |
+
/// See the documentation for `torch::nn::LocalResponseNormOptions` class to
|
| 85 |
+
/// learn what constructor arguments are supported for this module.
|
| 86 |
+
///
|
| 87 |
+
/// Example:
|
| 88 |
+
/// ```
|
| 89 |
+
/// LocalResponseNorm
|
| 90 |
+
/// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.));
|
| 91 |
+
/// ```
|
| 92 |
+
class TORCH_API LocalResponseNormImpl
|
| 93 |
+
: public Cloneable<LocalResponseNormImpl> {
|
| 94 |
+
public:
|
| 95 |
+
LocalResponseNormImpl(int64_t size)
|
| 96 |
+
: LocalResponseNormImpl(LocalResponseNormOptions(size)) {}
|
| 97 |
+
explicit LocalResponseNormImpl(const LocalResponseNormOptions& options_);
|
| 98 |
+
|
| 99 |
+
Tensor forward(const Tensor& input);
|
| 100 |
+
|
| 101 |
+
void reset() override;
|
| 102 |
+
|
| 103 |
+
/// Pretty prints the `LocalResponseNormImpl` module into the given `stream`.
|
| 104 |
+
void pretty_print(std::ostream& stream) const override;
|
| 105 |
+
|
| 106 |
+
/// The options with which this `Module` was constructed.
|
| 107 |
+
LocalResponseNormOptions options;
|
| 108 |
+
};
|
| 109 |
+
|
| 110 |
+
/// A `ModuleHolder` subclass for `LocalResponseNormImpl`.
|
| 111 |
+
/// See the documentation for `LocalResponseNormImpl` class to learn what
|
| 112 |
+
/// methods it provides, and examples of how to use `LocalResponseNorm` with
|
| 113 |
+
/// `torch::nn::LocalResponseNormOptions`. See the documentation for
|
| 114 |
+
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
|
| 115 |
+
TORCH_MODULE(LocalResponseNorm);
|
| 116 |
+
|
| 117 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossMapLRN2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 118 |
+
|
| 119 |
+
/// See the documentation for `torch::nn::CrossMapLRN2dOptions` class to learn
|
| 120 |
+
/// what constructor arguments are supported for this module.
|
| 121 |
+
///
|
| 122 |
+
/// Example:
|
| 123 |
+
/// ```
|
| 124 |
+
/// CrossMapLRN2d model(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10));
|
| 125 |
+
/// ```
|
| 126 |
+
class TORCH_API CrossMapLRN2dImpl
|
| 127 |
+
: public torch::nn::Cloneable<CrossMapLRN2dImpl> {
|
| 128 |
+
public:
|
| 129 |
+
CrossMapLRN2dImpl(int64_t size)
|
| 130 |
+
: CrossMapLRN2dImpl(CrossMapLRN2dOptions(size)) {}
|
| 131 |
+
explicit CrossMapLRN2dImpl(const CrossMapLRN2dOptions& options_)
|
| 132 |
+
: options(options_) {}
|
| 133 |
+
|
| 134 |
+
void reset() override;
|
| 135 |
+
|
| 136 |
+
/// Pretty prints the `CrossMapLRN2d` module into the given `stream`.
|
| 137 |
+
void pretty_print(std::ostream& stream) const override;
|
| 138 |
+
|
| 139 |
+
torch::Tensor forward(const torch::Tensor& input);
|
| 140 |
+
|
| 141 |
+
CrossMapLRN2dOptions options;
|
| 142 |
+
};
|
| 143 |
+
|
| 144 |
+
/// A `ModuleHolder` subclass for `CrossMapLRN2dImpl`.
|
| 145 |
+
/// See the documentation for `CrossMapLRN2dImpl` class to learn what methods it
|
| 146 |
+
/// provides, and examples of how to use `CrossMapLRN2d` with
|
| 147 |
+
/// `torch::nn::CrossMapLRN2dOptions`. See the documentation for `ModuleHolder`
|
| 148 |
+
/// to learn about PyTorch's module storage semantics.
|
| 149 |
+
TORCH_MODULE(CrossMapLRN2d);
|
| 150 |
+
|
| 151 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GroupNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 152 |
+
|
| 153 |
+
/// Applies Group Normalization over a mini-batch of inputs as described in
|
| 154 |
+
/// the paper `Group Normalization`_ .
|
| 155 |
+
/// See https://pytorch.org/docs/main/nn.html#torch.nn.GroupNorm to learn
|
| 156 |
+
/// about the exact behavior of this module.
|
| 157 |
+
///
|
| 158 |
+
/// See the documentation for `torch::nn::GroupNormOptions` class to learn what
|
| 159 |
+
/// constructor arguments are supported for this module.
|
| 160 |
+
///
|
| 161 |
+
/// Example:
|
| 162 |
+
/// ```
|
| 163 |
+
/// GroupNorm model(GroupNormOptions(2, 2).eps(2e-5).affine(false));
|
| 164 |
+
/// ```
|
| 165 |
+
class TORCH_API GroupNormImpl : public torch::nn::Cloneable<GroupNormImpl> {
|
| 166 |
+
public:
|
| 167 |
+
GroupNormImpl(int64_t num_groups, int64_t num_channels)
|
| 168 |
+
: GroupNormImpl(GroupNormOptions(num_groups, num_channels)) {}
|
| 169 |
+
explicit GroupNormImpl(const GroupNormOptions& options_);
|
| 170 |
+
|
| 171 |
+
void reset() override;
|
| 172 |
+
|
| 173 |
+
void reset_parameters();
|
| 174 |
+
|
| 175 |
+
/// Pretty prints the `GroupNorm` module into the given `stream`.
|
| 176 |
+
void pretty_print(std::ostream& stream) const override;
|
| 177 |
+
|
| 178 |
+
Tensor forward(const Tensor& input);
|
| 179 |
+
|
| 180 |
+
/// The options with which this module was constructed.
|
| 181 |
+
GroupNormOptions options;
|
| 182 |
+
|
| 183 |
+
/// The learned weight.
|
| 184 |
+
Tensor weight;
|
| 185 |
+
|
| 186 |
+
/// The learned bias.
|
| 187 |
+
Tensor bias;
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
/// A `ModuleHolder` subclass for `GroupNormImpl`.
|
| 191 |
+
/// See the documentation for `GroupNormImpl` class to learn what methods it
|
| 192 |
+
/// provides, and examples of how to use `GroupNorm` with
|
| 193 |
+
/// `torch::nn::GroupNormOptions`. See the documentation for `ModuleHolder` to
|
| 194 |
+
/// learn about PyTorch's module storage semantics.
|
| 195 |
+
TORCH_MODULE(GroupNorm);
|
| 196 |
+
|
| 197 |
+
} // namespace nn
|
| 198 |
+
} // namespace torch
|