koichi12 commited on
Commit
f610d77
·
verified ·
1 Parent(s): da0ba90

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h +57 -0
  2. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h +255 -0
  3. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h +63 -0
  4. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h +82 -0
  5. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h +65 -0
  6. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets.h +9 -0
  7. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/base.h +104 -0
  8. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h +529 -0
  9. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/map.h +118 -0
  10. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h +48 -0
  11. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/shared.h +83 -0
  12. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h +70 -0
  13. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h +38 -0
  14. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h +87 -0
  15. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/queue.h +84 -0
  16. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h +113 -0
  17. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/example.h +55 -0
  18. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/iterator.h +178 -0
  19. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers.h +9 -0
  20. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/base.h +47 -0
  21. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h +21 -0
  22. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h +139 -0
  23. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/random.h +54 -0
  24. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h +50 -0
  25. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h +28 -0
  26. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/stream.h +63 -0
  27. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms.h +7 -0
  28. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/base.h +53 -0
  29. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/collate.h +35 -0
  30. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h +56 -0
  31. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/stack.h +49 -0
  32. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h +77 -0
  33. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/worker_exception.h +38 -0
  34. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h +363 -0
  35. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/static.h +65 -0
  36. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h +372 -0
  37. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h +133 -0
  38. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h +125 -0
  39. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h +105 -0
  40. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h +262 -0
  41. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h +274 -0
  42. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h +94 -0
  43. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h +148 -0
  44. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h +169 -0
  45. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h +388 -0
  46. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h +190 -0
  47. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/fold.h +87 -0
  48. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h +153 -0
  49. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/loss.h +805 -0
  50. .venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h +198 -0
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/dataloader/stateful.h>
4
+ #include <torch/data/dataloader/stateless.h>
5
+
6
+ #include <torch/csrc/utils/variadic.h>
7
+
8
+ #include <c10/util/Exception.h>
9
+
10
+ #include <cstddef>
11
+ #include <memory>
12
+ #include <type_traits>
13
+ #include <utility>
14
+
15
+ namespace torch {
16
+ namespace data {
17
+
18
+ /// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and
19
+ /// some `options`.
20
+ template <typename Dataset, typename Sampler>
21
+ std::enable_if_t<
22
+ !Dataset::is_stateful,
23
+ std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
24
+ make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) {
25
+ return std::make_unique<StatelessDataLoader<Dataset, Sampler>>(
26
+ std::move(dataset), std::move(sampler), std::move(options));
27
+ }
28
+
29
+ /// Creates a `DataLoader` instance for a stateless `dataset` and some
30
+ /// `options`. A sampler (by default a `RandomSampler`) will be constructed from
31
+ /// the size of the dataset.
32
+ template <typename Sampler = samplers::RandomSampler, typename Dataset>
33
+ std::enable_if_t<
34
+ !Dataset::is_stateful && std::is_constructible_v<Sampler, size_t>,
35
+ std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
36
+ make_data_loader(
37
+ Dataset dataset,
38
+ DataLoaderOptions options = DataLoaderOptions()) {
39
+ const std::optional<size_t> size = dataset.size();
40
+ TORCH_CHECK(
41
+ size.has_value(),
42
+ "Expected the dataset to be sized in "
43
+ "order to construct the Sampler");
44
+ return make_data_loader(
45
+ std::move(dataset), Sampler(*size), std::move(options));
46
+ }
47
+
48
+ /// Creates a `DataLoader` for a stateful `dataset` and some `options`.
49
+ template <typename Dataset, typename = std::enable_if_t<Dataset::is_stateful>>
50
+ std::unique_ptr<StatefulDataLoader<Dataset>> make_data_loader(
51
+ Dataset dataset,
52
+ DataLoaderOptions options = DataLoaderOptions()) {
53
+ return std::make_unique<StatefulDataLoader<Dataset>>(
54
+ std::move(dataset), std::move(options));
55
+ }
56
+ } // namespace data
57
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/dataloader_options.h>
4
+ #include <torch/data/detail/data_shuttle.h>
5
+ #include <torch/data/detail/sequencers.h>
6
+ #include <torch/data/iterator.h>
7
+ #include <torch/data/samplers/random.h>
8
+ #include <torch/data/worker_exception.h>
9
+ #include <torch/types.h>
10
+
11
+ #include <torch/csrc/utils/variadic.h>
12
+
13
+ #include <c10/util/Exception.h>
14
+ #include <c10/util/irange.h>
15
+
16
+ #include <cstddef>
17
+ #include <exception>
18
+ #include <memory>
19
+ #include <thread>
20
+ #include <type_traits>
21
+ #include <utility>
22
+ #include <vector>
23
+
24
+ namespace torch {
25
+ namespace data {
26
+ template <typename Dataset, typename Batch, typename BatchRequest>
27
+ class DataLoaderBase {
28
+ public:
29
+ using BatchType = Batch;
30
+ using BatchRequestType = BatchRequest;
31
+
32
+ /// Constructs a new DataLoader from a `dataset` to sample from, `options`
33
+ /// to configure the DataLoader with, and a `sampler` that specifies the
34
+ /// sampling strategy.
35
+ DataLoaderBase(
36
+ DataLoaderOptions options,
37
+ std::unique_ptr<Dataset> main_thread_dataset = nullptr)
38
+ : options_(std::move(options)),
39
+ main_thread_dataset_(std::move(main_thread_dataset)),
40
+ sequencer_(new_sequencer()) {}
41
+
42
+ // NOLINTNEXTLINE(bugprone-exception-escape)
43
+ virtual ~DataLoaderBase() {
44
+ join();
45
+ }
46
+
47
+ /// Returns an iterator into the DataLoader. The lifetime of the iterator is
48
+ /// bound to the DataLoader. In C++ standards language, the category of the
49
+ /// iterator is `OutputIterator`. See
50
+ /// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this
51
+ /// means. In short: you may increment the iterator and dereference it, but
52
+ /// cannot go back, or step forward more than one position at a time. When the
53
+ /// DataLoader is exhausted, it will compare equal with the special
54
+ /// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you
55
+ /// should only use range-for loops to loop over the DataLoader, but
56
+ /// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(),
57
+ /// output_iterator)` are supported too.
58
+ Iterator<Batch> begin() {
59
+ TORCH_CHECK(
60
+ shuttle_.in_flight_jobs() == 0,
61
+ "Attempted to get a new DataLoader iterator "
62
+ "while another iterator is not yet exhausted");
63
+ reset();
64
+ return Iterator<Batch>(std::make_unique<detail::ValidIterator<Batch>>(
65
+ [this] { return this->next(); }));
66
+ }
67
+
68
+ /// Returns a special "sentinel" iterator that compares equal with a
69
+ /// non-sentinel iterator once the DataLoader is exhausted.
70
+ Iterator<Batch> end() {
71
+ return Iterator<Batch>(std::make_unique<detail::SentinelIterator<Batch>>());
72
+ }
73
+
74
+ /// Joins the DataLoader's worker threads and drains internal queues.
75
+ /// This function may only be invoked from the main thread (in which the
76
+ /// DataLoader lives).
77
+ void join() {
78
+ if (joined_) {
79
+ return;
80
+ }
81
+ shuttle_.drain();
82
+ // Send one 'quit' message per worker. Since a worker dies (exits its
83
+ // thread) after receiving this message, each `QuitWorker()` message will be
84
+ // read by exactly one worker.
85
+ for (const auto w : c10::irange(options_.workers)) {
86
+ (void)w; // Suppress unused variable warning
87
+ push_job(QuitWorker());
88
+ }
89
+ for (auto& worker : workers_) {
90
+ worker.join();
91
+ }
92
+ joined_ = true;
93
+ }
94
+
95
+ /// Returns the options with which the DataLoader was configured.
96
+ const FullDataLoaderOptions& options() const noexcept {
97
+ return options_;
98
+ }
99
+
100
+ protected:
101
+ /// Simple mix-in to give something a sequence number.
102
+ struct Sequenced {
103
+ Sequenced() = default;
104
+ Sequenced(size_t sqn) : sequence_number(sqn) {}
105
+ size_t sequence_number;
106
+ };
107
+
108
+ struct QuitWorker {};
109
+
110
+ /// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a
111
+ /// `QuitWorker` object, to indicate the worker should shut down.
112
+ struct Job : Sequenced {
113
+ Job() = default;
114
+ Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
115
+ Job(BatchRequest&& i, size_t sqn)
116
+ : Sequenced(sqn), batch_request(std::move(i)) {}
117
+ std::optional<QuitWorker> quit;
118
+ std::optional<BatchRequest> batch_request;
119
+ };
120
+
121
+ /// The finished result of a job.
122
+ struct Result : Sequenced {
123
+ Result() = default;
124
+ Result(std::optional<Batch>&& b, size_t sqn)
125
+ : Sequenced(sqn), batch(std::move(b)) {}
126
+ Result(std::exception_ptr exception, size_t sqn)
127
+ : Sequenced(sqn), exception(std::move(exception)) {}
128
+ std::optional<Batch> batch;
129
+ std::exception_ptr exception;
130
+ };
131
+
132
+ /// Subclass hook for getting the next batch request. The stateless case will
133
+ /// ask the sampler for a new batch request (e.g. a vector of indices), while
134
+ /// the stateful one will simply return the batch size.
135
+ virtual std::optional<BatchRequestType> get_batch_request() = 0;
136
+
137
+ /// Resets the internal state of the DataLoader, optionally pre-fetching
138
+ /// new jobs.
139
+ virtual void reset() {
140
+ shuttle_.drain();
141
+ sequence_number_ = 0;
142
+ sequencer_ = new_sequencer();
143
+ prefetch();
144
+ }
145
+
146
+ /// Schedules `requested_jobs` many new batches to be fetched. The actual
147
+ /// number of jobs scheduled may be less if the DataLoader exhausts.
148
+ void prefetch(size_t requested_jobs) {
149
+ for (const auto r : c10::irange(requested_jobs)) {
150
+ (void)r; // Suppress unused variable
151
+ if (auto batch_request = get_batch_request()) {
152
+ this->push_job(std::move(*batch_request));
153
+ } else {
154
+ break;
155
+ }
156
+ }
157
+ }
158
+
159
+ /// Schedules the maximum number of jobs (based on the `max_jobs` option).
160
+ void prefetch() {
161
+ prefetch(options_.max_jobs);
162
+ }
163
+
164
+ /// Returns the next batch of data, or an empty `optional` if the DataLoader
165
+ /// is exhausted. This operation will block until a batch is available if one
166
+ /// is still expected.
167
+ std::optional<BatchType> next() {
168
+ if (options_.workers > 0) {
169
+ while (std::optional<Result> result = this->pop_result()) {
170
+ if (result->exception) {
171
+ throw WorkerException(result->exception);
172
+ } else if (result->batch) {
173
+ prefetch(1);
174
+ return std::move(result->batch);
175
+ }
176
+ }
177
+ } else if (auto batch_request = get_batch_request()) {
178
+ return this->main_thread_dataset_->get_batch(std::move(*batch_request));
179
+ }
180
+ return nullopt;
181
+ }
182
+
183
+ /// The function that worker threads run.
184
+ void worker_thread(Dataset& dataset) {
185
+ while (true) {
186
+ auto job = shuttle_.pop_job();
187
+ if (job.quit) {
188
+ break;
189
+ }
190
+ try {
191
+ auto batch = dataset.get_batch(std::move(*job.batch_request));
192
+ shuttle_.push_result({std::move(batch), job.sequence_number});
193
+ } catch (...) {
194
+ shuttle_.push_result({std::current_exception(), job.sequence_number});
195
+ }
196
+ }
197
+ }
198
+
199
+ /// Convenience method that calls `shuttle_.push_job()` with the next sequence
200
+ /// number.
201
+ template <typename T>
202
+ void push_job(T value) {
203
+ shuttle_.push_job({std::move(value), sequence_number_++});
204
+ }
205
+
206
+ /// Convenience method that gets the next result from the sequencer.
207
+ std::optional<Result> pop_result() {
208
+ return sequencer_->next(
209
+ [this] { return this->shuttle_.pop_result(this->options_.timeout); });
210
+ }
211
+
212
+ /// Convenience method that creates a new sequencer based on the
213
+ /// `enforce_ordering` option.
214
+ std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
215
+ if (options_.enforce_ordering) {
216
+ return std::make_unique<detail::sequencers::OrderedSequencer<Result>>(
217
+ options_.max_jobs);
218
+ }
219
+ return std::make_unique<detail::sequencers::NoSequencer<Result>>();
220
+ }
221
+
222
+ /// The options the DataLoader was configured with.
223
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
224
+ const FullDataLoaderOptions options_;
225
+
226
+ /// The dataset for the main thread, only has a value if the number of
227
+ /// worker threads was configured as zero, meaning the main thread has to do
228
+ /// all the work (synchronously). NOTE: Really want this to be on the heap
229
+ /// when empty, therefore `unique_ptr` and not `optional`.
230
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
231
+ std::unique_ptr<Dataset> main_thread_dataset_;
232
+
233
+ /// The sequence number for the *next* batch to be retrieved from the
234
+ /// dataset.
235
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
236
+ size_t sequence_number_ = 0;
237
+
238
+ /// The worker threads, running the `worker_thread()` method.
239
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
240
+ std::vector<std::thread> workers_;
241
+
242
+ /// The `DataShuttle` which takes care of the life cycle of a job.
243
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
244
+ detail::DataShuttle<Job, Result> shuttle_;
245
+
246
+ /// The `Sequencer`, which handles optional ordering of batches.
247
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
248
+ std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;
249
+
250
+ /// True if the DataLoader has joined its worker threads.
251
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
252
+ bool joined_ = false;
253
+ };
254
+ } // namespace data
255
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/irange.h>
4
+ #include <torch/data/dataloader/base.h>
5
+
6
+ #include <cstddef>
7
+ #include <thread>
8
+ #include <utility>
9
+
10
+ namespace torch {
11
+ namespace data {
12
+
13
+ /// A dataloader for stateful datasets.
14
+ ///
15
+ /// A dataloader for stateful datatasets differs from one for stateless
16
+ /// datasets one in that the dataset is shared among worker threads, and that
17
+ /// this dataset is itself responsible for producing batches rather than
18
+ /// depending on a sampler. The statefulness here actually refers to the
19
+ /// dataset. The StatefulDataLoader simply alters the data loading algorithm to
20
+ /// accommodate the stateful, shared nature of the dataset. Note that the
21
+ /// dataset must be thread safe if more than one worker thread is used.
22
+ ///
23
+ /// A stateful dataloader is created by calling `make_data_loader` with a
24
+ /// stateful dataset.
25
+ template <typename Dataset>
26
+ class StatefulDataLoader : public DataLoaderBase<
27
+ Dataset,
28
+ typename Dataset::BatchType::value_type,
29
+ typename Dataset::BatchRequestType> {
30
+ public:
31
+ using super = DataLoaderBase<
32
+ Dataset,
33
+ typename Dataset::BatchType::value_type,
34
+ typename Dataset::BatchRequestType>;
35
+ using typename super::BatchRequestType;
36
+
37
+ /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`.
38
+ StatefulDataLoader(Dataset dataset, DataLoaderOptions options)
39
+ : super(options, std::make_unique<Dataset>(std::move(dataset))) {
40
+ for ([[maybe_unused]] const auto _ : c10::irange(this->options_.workers)) {
41
+ // As opposed to the stateless case, here all worker threads access the
42
+ // same underlying dataset.
43
+ this->workers_.emplace_back(
44
+ [this] { this->worker_thread(*this->main_thread_dataset_); });
45
+ }
46
+ }
47
+
48
+ private:
49
+ /// Resets the internal state of the dataloader and the dataset.
50
+ void reset() override {
51
+ this->main_thread_dataset_->reset();
52
+ // Call the base class method last because it calls `prefetch()`
53
+ super::reset();
54
+ }
55
+
56
+ /// For stateful datasets, the batch request is always the batch size. The
57
+ /// dataset is responsible for determining what goes into the batch next.
58
+ std::optional<BatchRequestType> get_batch_request() override {
59
+ return this->options_.batch_size;
60
+ }
61
+ };
62
+ } // namespace data
63
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/dataloader/base.h>
4
+ #include <torch/data/worker_exception.h>
5
+
6
+ #include <c10/util/Exception.h>
7
+ #include <c10/util/irange.h>
8
+
9
+ #include <cstddef>
10
+ #include <thread>
11
+ #include <utility>
12
+
13
+ namespace torch {
14
+ namespace data {
15
+
16
+ /// A dataloader for stateless datasets.
17
+ ///
18
+ /// This dataloader follows the traditional PyTorch dataloader design, whereby a
19
+ /// (posssibly) stateful sampler produces *batch requests* for a stateless
20
+ /// dataset, which acts as a simple batch request to batch mapping. The batch
21
+ /// request will often be an array of indices, and if the dataset is a simple
22
+ /// image dataset, the dataset would produce the images at those indices.
23
+ template <typename Dataset, typename Sampler>
24
+ class StatelessDataLoader : public DataLoaderBase<
25
+ Dataset,
26
+ typename Dataset::BatchType,
27
+ typename Sampler::BatchRequestType> {
28
+ public:
29
+ using super = DataLoaderBase<
30
+ Dataset,
31
+ typename Dataset::BatchType,
32
+ typename Sampler::BatchRequestType>;
33
+ using typename super::BatchRequestType;
34
+
35
+ /// Constructs the `StatelessDataLoader` from a `dataset`, a `sampler` and
36
+ /// some `options`.
37
+ StatelessDataLoader(
38
+ Dataset dataset,
39
+ Sampler sampler,
40
+ DataLoaderOptions options)
41
+ : super(std::move(options)), sampler_(std::move(sampler)) {
42
+ for (const auto w : c10::irange(this->options_.workers)) {
43
+ // Here we copy the dataset into the worker thread closure. Each worker
44
+ // has its own copy of the dataset. This means the dataset must be
45
+ // trivially copiable, or else we don't expect more than one worker to
46
+ // be in use.
47
+ (void)w; // Suppress unused variable warning
48
+ this->workers_.emplace_back(
49
+ [this, dataset]() mutable { this->worker_thread(dataset); });
50
+ }
51
+ if (this->options_.workers == 0) {
52
+ this->main_thread_dataset_ =
53
+ std::make_unique<Dataset>(std::move(dataset));
54
+ }
55
+ }
56
+
57
+ private:
58
+ /// Resets the internal state of the dataloader and the sampler.
59
+ void reset() override {
60
+ sampler_.reset();
61
+ // Call the base class method last because it calls `prefetch()`
62
+ super::reset();
63
+ }
64
+
65
+ /// Queries the sampler for the next batch request (possibly progressing its
66
+ /// internal state).
67
+ std::optional<BatchRequestType> get_batch_request() override {
68
+ auto indices = sampler_.next(this->options_.batch_size);
69
+ if (!indices ||
70
+ (indices->size() < this->options_.batch_size &&
71
+ this->options_.drop_last)) {
72
+ return nullopt;
73
+ }
74
+ AT_ASSERT(indices->size() > 0);
75
+ return indices;
76
+ }
77
+
78
+ /// The `Sampler` used to produce batch requests.
79
+ Sampler sampler_;
80
+ };
81
+ } // namespace data
82
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/arg.h>
4
+ #include <torch/types.h>
5
+
6
+ #include <chrono>
7
+ #include <cstddef>
8
+
9
+ namespace torch {
10
+ namespace data {
11
+
12
+ /// Options to configure a `DataLoader`.
13
+ struct DataLoaderOptions {
14
+ DataLoaderOptions() = default;
15
+ /* implicit */ DataLoaderOptions(size_t batch_size)
16
+ : batch_size_(batch_size) {}
17
+
18
+ /// The size of each batch to fetch.
19
+ TORCH_ARG(size_t, batch_size) = 1;
20
+
21
+ /// The number of worker threads to launch. If zero, the main thread will
22
+ /// synchronously perform the data loading.
23
+ TORCH_ARG(size_t, workers) = 0;
24
+
25
+ /// The maximum number of jobs to enqueue for fetching by worker threads.
26
+ /// Defaults to two times the number of worker threads.
27
+ TORCH_ARG(std::optional<size_t>, max_jobs);
28
+
29
+ /// An optional limit on the time to wait for the next batch.
30
+ TORCH_ARG(std::optional<std::chrono::milliseconds>, timeout);
31
+
32
+ /// Whether to enforce ordering of batches when multiple are loaded
33
+ /// asynchronously by worker threads. Set to `false` for better performance if
34
+ /// you do not care about determinism.
35
+ TORCH_ARG(bool, enforce_ordering) = true;
36
+
37
+ /// Whether to omit the last batch if it contains less than `batch_size`
38
+ /// examples.
39
+ TORCH_ARG(bool, drop_last) = false;
40
+ };
41
+
42
+ /// Like `DataLoaderOptions`, but without any unconfigured state.
43
+ /// `DataLoaderOptions` has some options that depend on other options
44
+ /// (`max_jobs` => `2 * workers`). In the spirit of properly using the C++ type
45
+ /// system, `DataLoaderOptions` allows only setting values. To access values,
46
+ /// you must create a `FullDataLoaderOptions` from a `DataLoaderOptions`
47
+ /// instance, which will do any necessary coalescing.
48
+ struct FullDataLoaderOptions {
49
+ explicit FullDataLoaderOptions(DataLoaderOptions options)
50
+ : batch_size(options.batch_size()),
51
+ workers(options.workers()),
52
+ max_jobs(options.max_jobs().value_or(2 * workers)),
53
+ timeout(options.timeout()),
54
+ enforce_ordering(options.enforce_ordering()),
55
+ drop_last(options.drop_last()) {}
56
+
57
+ size_t batch_size;
58
+ size_t workers;
59
+ size_t max_jobs;
60
+ std::optional<std::chrono::milliseconds> timeout;
61
+ bool enforce_ordering;
62
+ bool drop_last;
63
+ };
64
+ } // namespace data
65
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets.h ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/datasets/base.h>
4
+ #include <torch/data/datasets/chunk.h>
5
+ #include <torch/data/datasets/map.h>
6
+ #include <torch/data/datasets/mnist.h>
7
+ #include <torch/data/datasets/shared.h>
8
+ #include <torch/data/datasets/stateful.h>
9
+ #include <torch/data/datasets/tensor.h>
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/base.h ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/example.h>
4
+ #include <torch/types.h>
5
+
6
+ #include <c10/util/ArrayRef.h>
7
+
8
+ #include <cstddef>
9
+ #include <cstdint>
10
+ #include <type_traits>
11
+ #include <utility>
12
+ #include <vector>
13
+
14
+ namespace torch {
15
+ namespace data {
16
+ namespace datasets {
17
+ template <typename S, typename T>
18
+ class MapDataset;
19
+ template <typename D, typename T>
20
+ MapDataset<D, T> map(D, T); // NOLINT
21
+ } // namespace datasets
22
+ } // namespace data
23
+ } // namespace torch
24
+
25
+ namespace torch {
26
+ namespace data {
27
+ namespace datasets {
28
+ namespace detail {
29
+ template <typename T>
30
+ struct is_optional : std::false_type {};
31
+ template <typename T>
32
+ struct is_optional<std::optional<T>> : std::true_type {};
33
+ } // namespace detail
34
+
35
+ /// A dataset that can yield data only in batches.
36
+ template <
37
+ typename Self,
38
+ typename Batch = std::vector<Example<>>,
39
+ typename BatchRequest = ArrayRef<size_t>>
40
+ class BatchDataset {
41
+ public:
42
+ using SelfType = Self;
43
+ using BatchType = Batch;
44
+ using BatchRequestType = BatchRequest;
45
+ constexpr static bool is_stateful = detail::is_optional<BatchType>::value;
46
+
47
+ virtual ~BatchDataset() = default;
48
+
49
+ /// Returns a batch of data given an index.
50
+ virtual Batch get_batch(BatchRequest request) = 0;
51
+
52
+ /// Returns the size of the dataset, or an empty std::optional if it is
53
+ /// unsized.
54
+ virtual std::optional<size_t> size() const = 0;
55
+
56
+ /// Creates a `MapDataset` that applies the given `transform` to this dataset.
57
+ template <typename TransformType>
58
+ MapDataset<Self, TransformType> map(TransformType transform) & {
59
+ return datasets::map(static_cast<Self&>(*this), std::move(transform));
60
+ }
61
+
62
+ /// Creates a `MapDataset` that applies the given `transform` to this dataset.
63
+ template <typename TransformType>
64
+ MapDataset<Self, TransformType> map(TransformType transform) && {
65
+ return datasets::map(
66
+ std::move(static_cast<Self&>(*this)), std::move(transform));
67
+ }
68
+ };
69
+
70
+ /// A dataset that can yield data in batches, or as individual examples.
71
+ ///
72
+ /// A `Dataset` is a `BatchDataset`, because it supports random access and
73
+ /// therefore batched access is implemented (by default) by calling the random
74
+ /// access indexing function for each index in the requested batch of indices.
75
+ /// This can be customized.
76
+ template <typename Self, typename SingleExample = Example<>>
77
+ class Dataset : public BatchDataset<Self, std::vector<SingleExample>> {
78
+ public:
79
+ using ExampleType = SingleExample;
80
+
81
+ /// Returns the example at the given index.
82
+ virtual ExampleType get(size_t index) = 0;
83
+
84
+ /// Returns a batch of data.
85
+ /// The default implementation calls `get()` for every requested index
86
+ /// in the batch.
87
+ std::vector<ExampleType> get_batch(ArrayRef<size_t> indices) override {
88
+ std::vector<ExampleType> batch;
89
+ batch.reserve(indices.size());
90
+ for (const auto i : indices) {
91
+ batch.push_back(get(i));
92
+ }
93
+ return batch;
94
+ }
95
+ };
96
+
97
+ /// A `StreamDataset` represents a dataset that is a potentially infinite
98
+ /// stream. It takes as batch index only a number, which is the batch size, and
99
+ /// yields that many elements from the stream.
100
+ template <typename Self, typename Batch = std::vector<Example<>>>
101
+ using StreamDataset = BatchDataset<Self, Batch, /*BatchRequest=*/size_t>;
102
+ } // namespace datasets
103
+ } // namespace data
104
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/irange.h>
4
+ #include <torch/arg.h>
5
+ #include <torch/data/datasets/stateful.h>
6
+ #include <torch/data/samplers.h>
7
+ #include <queue>
8
+ #include <thread>
9
+
10
+ #include <torch/serialize.h>
11
+
12
+ namespace torch {
13
+ namespace data {
14
+ namespace datasets {
15
+
16
+ /// Interface for chunk reader, which performs data chunking and reading of
17
+ /// entire chunks.
18
+ ///
19
+ /// A chunk could be an entire file, such as an audio data file or an image,
20
+ /// or part of a file in the case of a large text-file split based on seek
21
+ /// positions.
22
+ template <
23
+ typename ExampleType_,
24
+ typename ChunkType_ = std::vector<ExampleType_>>
25
+ class ChunkDataReader {
26
+ public:
27
+ virtual ~ChunkDataReader() = default;
28
+
29
+ using ChunkType = ChunkType_;
30
+ using ExampleType = ExampleType_;
31
+
32
+ /// Read an entire chunk.
33
+ virtual ChunkType read_chunk(size_t chunk_index) = 0;
34
+
35
+ /// Returns the number of chunks available in this reader.
36
+ virtual size_t chunk_count() = 0;
37
+
38
+ /// This will clear any internal state associate with this reader.
39
+ virtual void reset() = 0;
40
+ };
41
+
42
+ namespace detail {
43
+ /// BatchDataBuffer manages a queue of UnwrappedBatchData. After a new chunk is
44
+ /// loaded, BatchDataBuffer splits it into small batches and push them into the
45
+ /// queue. When get_batch is called from data loader, it pops cached batches and
46
+ /// return. If the cache is empty, it either waits to load more chunks or return
47
+ /// null if all chunks are loaded.
48
+ template <
49
+ typename UnwrappedBatch,
50
+ typename ExampleSampler = samplers::RandomSampler>
51
+ class BatchDataBuffer {
52
+ public:
53
+ using UnwrappedBatchType = UnwrappedBatch;
54
+ using BatchType = torch::optional<UnwrappedBatchType>;
55
+ using BatchRequestType = typename ExampleSampler::BatchRequestType;
56
+
57
+ BatchDataBuffer(
58
+ size_t batch_size,
59
+ ExampleSampler& example_sampler,
60
+ size_t queue_capacity)
61
+ : batch_size_(batch_size),
62
+ example_sampler_(example_sampler),
63
+ queue_capacity_(queue_capacity) {}
64
+
65
+ /// Return batch data from the queue. Called from the ChunkDataset main
66
+ /// thread.
67
+ BatchType get_batch() {
68
+ std::unique_lock<std::mutex> lock(queue_mutex_);
69
+ cv_read_.wait(lock, [this] {
70
+ // wait till there is available data in the queue or if all chunks are
71
+ // loaded (i.e. the dataset is exhausted for this epoch)
72
+ return (
73
+ this->total_example_count_in_queue_ >= batch_size_ || this->stop_);
74
+ });
75
+ if (batch_queue_.empty()) {
76
+ AT_ASSERT(stop_);
77
+ // All batches have been retrieved. Return an empty batch.
78
+ return nullopt;
79
+ }
80
+
81
+ UnwrappedBatchData batch = std::move(batch_queue_.front());
82
+ batch_queue_.pop();
83
+ if (batch.exception) {
84
+ throw WorkerException(batch.exception);
85
+ }
86
+
87
+ total_example_count_in_queue_ -= batch.batch_data.size();
88
+ lock.unlock();
89
+ cv_write_.notify_all();
90
+
91
+ return batch.batch_data;
92
+ }
93
+
94
+ /// Push preloaded chunks to batch queue. Called from the ChunkDataset worker
95
+ /// threads.
96
+ void add_chunk_data(UnwrappedBatchType data) {
97
+ std::unique_lock<std::mutex> lock(queue_mutex_);
98
+ cv_write_.wait(lock, [this] {
99
+ // stop loading if we have preloaded enough data.
100
+ return this->total_example_count_in_queue_ < this->queue_capacity_ ||
101
+ this->stop_;
102
+ });
103
+ if (stop_) {
104
+ // When stop_ is true, it means no further chunk loading is necessary.
105
+ // Return without any further processing.
106
+ return;
107
+ }
108
+
109
+ auto data_size = data.size();
110
+ auto remaining_size = data_size;
111
+ example_sampler_.reset(data_size);
112
+
113
+ auto fill_batch = [&](size_t example_count, UnwrappedBatchType& batch) {
114
+ auto batch_example_indices = this->example_sampler_.next(example_count);
115
+ AT_ASSERT(
116
+ batch_example_indices &&
117
+ batch_example_indices.value().size() == example_count);
118
+ BatchRequestType& indices = batch_example_indices.value();
119
+ for (size_t i : indices) {
120
+ TORCH_CHECK(i < data_size, "Index out of range");
121
+ batch.emplace_back(std::move(data[i]));
122
+ }
123
+ remaining_size -= example_count;
124
+ };
125
+
126
+ if (!batch_queue_.empty()) {
127
+ // if the queue has existing data, and the last batch doesn't have enough
128
+ // examples to fill a batch_size batch, add more example to this batch
129
+ // first.
130
+ auto& batch = batch_queue_.back();
131
+ size_t current_count = batch.batch_data.size();
132
+ if (current_count < batch_size_) {
133
+ auto example_count =
134
+ std::min(remaining_size, batch_size_ - current_count);
135
+ fill_batch(example_count, batch.batch_data);
136
+ }
137
+ }
138
+
139
+ // If we still have data remaining after filling the last pushed batch, add
140
+ // them to the queue too.
141
+ // NOLINTNEXTLINE(bugprone-infinite-loop)
142
+ while (remaining_size > 0) {
143
+ UnwrappedBatchType current_batch;
144
+
145
+ // Allocate the batch memory ahead of time.
146
+ current_batch.reserve(batch_size_);
147
+
148
+ auto example_count = std::min(remaining_size, batch_size_);
149
+ fill_batch(example_count, current_batch);
150
+ batch_queue_.emplace(std::move(current_batch));
151
+ }
152
+ total_example_count_in_queue_ += data_size;
153
+ lock.unlock();
154
+ cv_read_.notify_all();
155
+ }
156
+
157
+ /// Push exceptions thrown during preloading into batch queue. Called from
158
+ /// the ChunkDataset worker threads.
159
+ void add_chunk_data(std::exception_ptr e_ptr) {
160
+ std::unique_lock<std::mutex> lock(queue_mutex_);
161
+ cv_write_.wait(lock, [this] {
162
+ // stop loading if we have preloaded enough data.
163
+ return (
164
+ this->total_example_count_in_queue_ < this->queue_capacity_ ||
165
+ this->stop_);
166
+ });
167
+ if (stop_) {
168
+ // When stop_ is true, it means this current thread needs to be tore down,
169
+ // the batch buffer will be discarded, so no need to enqueue any new
170
+ // exceptions.
171
+ return;
172
+ }
173
+
174
+ batch_queue_.emplace(e_ptr);
175
+ lock.unlock();
176
+ cv_read_.notify_all();
177
+ }
178
+
179
+ void stop() {
180
+ {
181
+ // Hold the lock before changing stop_ to prevent a race condition which
182
+ // can cause a deadlock. To be more specific, conditional variable
183
+ // cv_write_ waits on predicate stop_ in add_chunk_data(). The wait
184
+ // happens in two steps: 1) while still holding the lock, check if
185
+ // predicate is true; 2) if it is true, proceeds, otherwise, release the
186
+ // lock and wait until notified. Without holding a lock, cv_write_'s
187
+ // notification can happen in between step 1) and 2). In that case, as
188
+ // cv_write_ is not in waiting status yet, so the notification is lost and
189
+ // cv_write_ will sleep forever. By taking a lock before changing
190
+ // predicate stop_, it is ensured updating and evaluating stop_ always
191
+ // happen in a synchronized way
192
+ std::lock_guard<std::mutex> lock(queue_mutex_);
193
+ stop_ = true;
194
+ }
195
+
196
+ // notify all writers, wake them from wait to exit current method.
197
+ cv_write_.notify_all();
198
+ // notify all readers too.
199
+ cv_read_.notify_all();
200
+ }
201
+ /// The batch size is needed to create batches from the chunk data. Similar to
202
+ /// regular dataloader where the batches are created with prefetches,
203
+ /// BatchDataBuffer perform the batch creation using the provided batch size.
204
+ size_t batch_size_ = 0;
205
+
206
+ /// count of total example stored in the queue
207
+ size_t total_example_count_in_queue_ = 0;
208
+
209
+ /// struct that contains a raw unwrapped batch unit. An unwrapped batch unit
210
+ /// is the raw data without 'optional' wrapper. It can be a collection of
211
+ /// images, utterances, e.t.c.
212
+ struct UnwrappedBatchData {
213
+ explicit UnwrappedBatchData(UnwrappedBatchType data)
214
+ : batch_data(std::move(data)) {}
215
+
216
+ // NOLINTNEXTLINE(modernize-pass-by-value)
217
+ explicit UnwrappedBatchData(std::exception_ptr e) : exception(e) {}
218
+
219
+ /// batch data to return
220
+ UnwrappedBatchType batch_data;
221
+
222
+ /// exception pointer which captures any abnormal exceptions while creating
223
+ /// the batch.
224
+ std::exception_ptr exception;
225
+ };
226
+
227
+ /// local cache to store example batches from loaded chunk
228
+ std::queue<UnwrappedBatchData> batch_queue_;
229
+
230
+ // sync batch_queue_ update.
231
+ std::mutex queue_mutex_;
232
+
233
+ std::condition_variable cv_read_;
234
+ std::condition_variable cv_write_;
235
+
236
+ ExampleSampler& example_sampler_;
237
+
238
+ // configurable maximun number of elements the queue can hold at one time.
239
+ size_t queue_capacity_;
240
+
241
+ // When set to true, it wakes the writer threads from the wait and exit
242
+ // current function call. This is needed when ChunkDataSet.Reset is called
243
+ // while the previous epoch is not exhausted yet. When ChunkDataset is waiting
244
+ // its preloader to finish previous work before tearing down the thread, the
245
+ // preloader could be still waiting for the conditional variable, thus cause
246
+ // the program to hang. This boolean is used to break this waiting condition.
247
+ bool stop_ = false;
248
+ };
249
+ } // namespace detail
250
+
251
+ /// Options to configure a `ChunkDataset`.
252
+ struct ChunkDatasetOptions {
253
+ ChunkDatasetOptions() = delete;
254
+ ChunkDatasetOptions(
255
+ size_t preloader_count,
256
+ size_t batch_size,
257
+ size_t cache_size = 2048,
258
+ size_t cross_chunk_shuffle_count = 1)
259
+ : preloader_count_(preloader_count),
260
+ batch_size_(batch_size),
261
+ cache_size_(cache_size),
262
+ cross_chunk_shuffle_count_(cross_chunk_shuffle_count) {
263
+ TORCH_CHECK(
264
+ preloader_count_ > 0,
265
+ "Preloader count is 0. At least one preloader needs to be specified.");
266
+ TORCH_CHECK(
267
+ batch_size_ > 0,
268
+ "Batch size is 0. A positive batch size needs to be specified.");
269
+ TORCH_CHECK(
270
+ cache_size_ > 0,
271
+ "Cache size is 0. A positive cache size needs to be specified.");
272
+ TORCH_CHECK(
273
+ cache_size_ >= batch_size_,
274
+ "Cache size is less than batch size. Cache needs to be large enough to "
275
+ "hold at least one batch.");
276
+ TORCH_CHECK(
277
+ cross_chunk_shuffle_count_ > 0,
278
+ "cross_chunk_shuffle_count needs to be greater than 0.");
279
+ }
280
+
281
+ /// The number of worker thread to preload chunk data.
282
+ TORCH_ARG(size_t, preloader_count);
283
+
284
+ /// The size of each batch.
285
+ TORCH_ARG(size_t, batch_size);
286
+
287
+ /// The capacity of the queue for batch caching.
288
+ TORCH_ARG(size_t, cache_size) = 2048;
289
+
290
+ // The number of chunks to perfrom cross-chunk shuffling. Default to 1 meaning
291
+ // no cross-chunk shuffling. When it is equal to n (n > 1), n random
292
+ // chunks will be loaded at once and example shuffling will be performed
293
+ // across all those n chunks.
294
+ // Note: Usually the default config (1 chunk shuffle + example shuffle) is
295
+ // good enough to generate random distributed data. Use this parameter only if
296
+ // you know cross-shuffle is needed in your case. Also there is a performance
297
+ // penalty when this value is greater than 1, as we need to do extra merge
298
+ // between multiple chunks before performing example sampling.
299
+ TORCH_ARG(size_t, cross_chunk_shuffle_count) = 1;
300
+ };
301
+
302
+ /// A stateful dataset that support hierarchical sampling and prefetching of
303
+ /// entre chunks.
304
+ ///
305
+ /// Unlike regular dataset, chunk dataset require two samplers to operate and
306
+ /// keeps an internal state. `ChunkSampler` selects, which chunk to load next,
307
+ /// while the `ExampleSampler` determins the order of Examples that are returned
308
+ /// in each `get_batch` call. The hierarchical sampling approach used here is
309
+ /// inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf
310
+ template <
311
+ typename ChunkReader,
312
+ typename ChunkSampler = samplers::RandomSampler,
313
+ typename ExampleSampler = samplers::RandomSampler>
314
+ class ChunkDataset final
315
+ : public StatefulDataset<
316
+ ChunkDataset<ChunkReader, ChunkSampler, ExampleSampler>,
317
+ typename ChunkReader::BatchType,
318
+ size_t> {
319
+ public:
320
+ using BatchType = torch::optional<typename ChunkReader::BatchType>;
321
+ using UnwrappedBatchType = typename ChunkReader::BatchType;
322
+ using BatchRequestType = size_t;
323
+ using ChunkSamplerType = ChunkSampler;
324
+ using ExampleSamplerType = ExampleSampler;
325
+
326
+ ChunkDataset(
327
+ ChunkReader chunk_reader,
328
+ ChunkSampler chunk_sampler,
329
+ ExampleSampler example_sampler,
330
+ ChunkDatasetOptions options,
331
+ std::function<void(UnwrappedBatchType&)> preprocessing_policy =
332
+ std::function<void(UnwrappedBatchType&)>())
333
+ : chunk_reader_(std::move(chunk_reader)),
334
+ chunk_sampler_(std::move(chunk_sampler)),
335
+ example_sampler_(std::move(example_sampler)),
336
+ options_(std::move(options)),
337
+ preprocessing_policy_(std::move(preprocessing_policy)),
338
+ quit_worker_(false),
339
+ running_preloaders_(0),
340
+ load_checkpoint_(false) {}
341
+
342
+ ~ChunkDataset() override {
343
+ // stop batch buffer first.
344
+ if (batch_buffer_) {
345
+ batch_buffer_->stop();
346
+ }
347
+ free_workers();
348
+ }
349
+
350
+ /// Default get_batch method of BatchDataset. This method returns
351
+ /// Example batches created from the preloaded chunks. The implemenation
352
+ /// is dataset agnostic and does not need overriding in different chunk
353
+ /// datasets.
354
+ BatchType get_batch(size_t batch_size) override {
355
+ TORCH_CHECK(
356
+ batch_buffer_ != nullptr,
357
+ "Dataset needs to call reset() before calling get_batch().");
358
+
359
+ TORCH_CHECK(
360
+ batch_size == options_.batch_size(),
361
+ "The requested batch size does not match with the initialized batch size.\n"
362
+ " The requested batch size is ",
363
+ batch_size,
364
+ ", while the dataset is created with batch size equal to ",
365
+ options_.batch_size());
366
+ return batch_buffer_->get_batch();
367
+ }
368
+
369
+ /// Helper method around get_batch as `batch_size` is not strictly necessary
370
+ BatchType get_batch() {
371
+ return get_batch(options_.batch_size());
372
+ }
373
+
374
+ /// This will clear any internal state and starts the internal prefetching
375
+ /// mechanism for the chunk dataset.
376
+ void reset() override {
377
+ // We need this to support partial data reads via dataloader iterator.
378
+ if (batch_buffer_) {
379
+ batch_buffer_->stop();
380
+ }
381
+ // free workers from previous reset if there is any.
382
+ free_workers();
383
+ preload_threads_.clear();
384
+
385
+ if (!load_checkpoint_) {
386
+ chunk_reader_.reset();
387
+ chunk_sampler_.reset(chunk_reader_.chunk_count());
388
+ load_checkpoint_ = false;
389
+ }
390
+
391
+ // Throw out any existing cached batch in the buffer and re-creates a new
392
+ // chunk buffer.
393
+ batch_buffer_ = std::make_unique<
394
+ detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>(
395
+ options_.batch_size(), example_sampler_, options_.cache_size());
396
+
397
+ // create new workers for this new epoch.
398
+ quit_worker_ = false;
399
+
400
+ AT_ASSERT(running_preloaders_ == 0);
401
+ running_preloaders_ = options_.preloader_count();
402
+ for (const auto i : c10::irange(options_.preloader_count())) {
403
+ preload_threads_.emplace_back([this, i]() { this->preloader(i); });
404
+ }
405
+ }
406
+
407
+ /// size is not used for chunk dataset.
408
+ std::optional<size_t> size() const override {
409
+ return torch::nullopt;
410
+ }
411
+
412
+ // provide a references to chunk sampler. Used mainly in distributed data
413
+ // loading to set the epoch number for the sampler.
414
+ ChunkSamplerType& chunk_sampler() {
415
+ return chunk_sampler_;
416
+ }
417
+
418
+ void save(serialize::OutputArchive& archive) const override {
419
+ std::lock_guard<std::mutex> lock(chunk_index_guard_);
420
+ chunk_sampler_.save(archive);
421
+ }
422
+
423
+ void load(serialize::InputArchive& archive) override {
424
+ std::lock_guard<std::mutex> lock(chunk_index_guard_);
425
+ chunk_sampler_.load(archive);
426
+ load_checkpoint_ = true;
427
+ }
428
+
429
+ private:
430
+ /// running on worker thread to preload chunk data.
431
+ void preloader(size_t id) {
432
+ while (!quit_worker_.load()) {
433
+ try {
434
+ std::vector<size_t> chunk_idx;
435
+ {
436
+ std::lock_guard<std::mutex> lock(chunk_index_guard_);
437
+ if (auto chunk_sampler_result = chunk_sampler_.next(
438
+ this->options_.cross_chunk_shuffle_count())) {
439
+ chunk_idx = chunk_sampler_result.value();
440
+ } else {
441
+ break;
442
+ }
443
+ }
444
+ UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_idx[0]);
445
+ for (const auto i : c10::irange(1, chunk_idx.size())) {
446
+ auto chunk_data = chunk_reader_.read_chunk(chunk_idx[i]);
447
+ std::move(
448
+ chunk_data.begin(), chunk_data.end(), std::back_inserter(data));
449
+ }
450
+ if (preprocessing_policy_) {
451
+ preprocessing_policy_(data);
452
+ }
453
+ if (!data.empty()) { // skip empty chunks.
454
+ batch_buffer_->add_chunk_data(std::move(data));
455
+ }
456
+ } catch (...) {
457
+ batch_buffer_->add_chunk_data(std::current_exception());
458
+ }
459
+ }
460
+ AT_ASSERT(running_preloaders_.load() > 0);
461
+ --running_preloaders_;
462
+ if (running_preloaders_.load() == 0) {
463
+ // all preloaders are completed, so we can notify the batch_buffer.
464
+ batch_buffer_->stop();
465
+ }
466
+ }
467
+
468
+ /// Block the current thread until the workers finish execution and exit.
469
+ void free_workers() {
470
+ if (!quit_worker_.load()) {
471
+ quit_worker_ = true;
472
+ for (auto& worker_thread : preload_threads_) {
473
+ worker_thread.join();
474
+ }
475
+ }
476
+ }
477
+
478
+ private:
479
+ // Templated class that defines what is a chunk and how to read chunk data.
480
+ // When a chunk is returned by chunk_reader_, ChunkDataset split it into
481
+ // batches and caches them in batch_buffer_.
482
+ ChunkReader chunk_reader_;
483
+
484
+ // chunk sampler to shuffle different chunks
485
+ ChunkSamplerType chunk_sampler_;
486
+
487
+ // example sampler to shuffle examples in a specific chunk
488
+ ExampleSamplerType example_sampler_;
489
+
490
+ // batch data buffer which holds chunk data from preloading thread.
491
+ std::shared_ptr<
492
+ detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>
493
+ batch_buffer_;
494
+
495
+ // worker thread pool
496
+ std::vector<std::thread> preload_threads_;
497
+
498
+ /// The options the Dataset was configured with.
499
+ const ChunkDatasetOptions options_;
500
+
501
+ // function pointer wrapper to apply custom processing over chunk data. This
502
+ // is considered an advanced parameter for developers who want to apply a
503
+ // pre-process to the chunk data before sampling into minibatch.
504
+ // Different than the collate function, this policy is applied on the chunk
505
+ // level, instead of minibatch level. When a chunk of data is loaded (multiple
506
+ // chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is
507
+ // applied to the full loaded data. It is useful if developers want to
508
+ // perform pre-processing (like bucketing) to the chunk data before
509
+ // example sampler samples the data. By default it's an empty pointer and no
510
+ // action will be taken.
511
+ std::function<void(UnwrappedBatchType&)> preprocessing_policy_;
512
+
513
+ // indicate whether the worker thread can be teared down
514
+ std::atomic<bool> quit_worker_;
515
+
516
+ // keep track of running preloaders to notify batch buffer. A value 0
517
+ // indicates that the chunk loading is completed.
518
+ std::atomic<size_t> running_preloaders_;
519
+
520
+ // mutex to synchronize chunk sampler next() call.
521
+ mutable std::mutex chunk_index_guard_;
522
+
523
+ // boolean value to indicate whether we need to load the checkpoint for
524
+ // chunk_sampler_.
525
+ bool load_checkpoint_;
526
+ };
527
+ } // namespace datasets
528
+ } // namespace data
529
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/map.h ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/datasets/base.h>
4
+ #include <torch/types.h>
5
+
6
+ #include <c10/util/ArrayRef.h>
7
+
8
+ #include <cstddef>
9
+ #include <type_traits>
10
+ #include <utility>
11
+
12
+ namespace torch {
13
+ namespace data {
14
+ namespace datasets {
15
+ namespace detail {
16
+ template <bool C, typename T>
17
+ using optional_if_t = typename std::conditional<C, torch::optional<T>, T>::type;
18
+ } // namespace detail
19
+
20
+ /// A `MapDataset` is a dataset that applies a transform to a source dataset.
21
+ template <typename SourceDataset, typename AppliedTransform>
22
+ class MapDataset : public BatchDataset<
23
+ MapDataset<SourceDataset, AppliedTransform>,
24
+ detail::optional_if_t<
25
+ SourceDataset::is_stateful,
26
+ typename AppliedTransform::OutputBatchType>,
27
+ typename SourceDataset::BatchRequestType> {
28
+ public:
29
+ using DatasetType = SourceDataset;
30
+ using TransformType = AppliedTransform;
31
+ using BatchRequestType = typename SourceDataset::BatchRequestType;
32
+ using OutputBatchType = detail::optional_if_t<
33
+ SourceDataset::is_stateful,
34
+ typename AppliedTransform::OutputBatchType>;
35
+
36
+ MapDataset(DatasetType dataset, TransformType transform)
37
+ : dataset_(std::move(dataset)), transform_(std::move(transform)) {}
38
+
39
+ /// Gets a batch from the source dataset and applies the transform to it,
40
+ /// returning the result.
41
+ OutputBatchType get_batch(BatchRequestType indices) override {
42
+ return get_batch_impl(std::move(indices));
43
+ }
44
+
45
+ /// Returns the size of the source dataset.
46
+ // NOLINTNEXTLINE(bugprone-exception-escape)
47
+ std::optional<size_t> size() const noexcept override {
48
+ return dataset_.size();
49
+ }
50
+
51
+ /// Calls `reset()` on the underlying dataset.
52
+ /// NOTE: Stateless datasets do not have a reset() method, so a call to this
53
+ /// method will only compile for stateful datasets (which have a reset()
54
+ /// method).
55
+ void reset() {
56
+ dataset_.reset();
57
+ }
58
+
59
+ /// Returns the underlying dataset.
60
+ const SourceDataset& dataset() noexcept {
61
+ return dataset_;
62
+ }
63
+
64
+ /// Returns the transform being applied.
65
+ const AppliedTransform& transform() noexcept {
66
+ return transform_;
67
+ }
68
+
69
+ private:
70
+ /// The implementation of `get_batch()` for the stateless case, which simply
71
+ /// applies the transform to the output of `get_batch()` from the dataset.
72
+ template <
73
+ typename D = SourceDataset,
74
+ typename = std::enable_if_t<!D::is_stateful>>
75
+ OutputBatchType get_batch_impl(BatchRequestType indices) {
76
+ return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
77
+ }
78
+
79
+ /// The implementation of `get_batch()` for the stateful case. Here, we follow
80
+ /// the semantics of `Optional.map()` in many functional languages, which
81
+ /// applies a transformation to the optional's content when the optional
82
+ /// contains a value, and returns a new optional (of a different type) if the
83
+ /// original optional returned by `get_batch()` was empty.
84
+ template <typename D = SourceDataset>
85
+ std::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
86
+ BatchRequestType indices) {
87
+ if (auto batch = dataset_.get_batch(std::move(indices))) {
88
+ return transform_.apply_batch(std::move(*batch));
89
+ }
90
+ return nullopt;
91
+ }
92
+
93
+ /// The underlying dataset being transformed.
94
+ SourceDataset dataset_;
95
+
96
+ // The transformation that is applied to batches received from the dataset.
97
+ AppliedTransform transform_;
98
+ };
99
+
100
+ /// Creates a `MapDataset` with the given dataset and transform.
101
+ template <typename DatasetType, typename TransformType>
102
+ MapDataset<DatasetType, TransformType> map(
103
+ DatasetType dataset,
104
+ TransformType transform) {
105
+ static_assert(
106
+ std::is_same<
107
+ typename std::conditional<
108
+ DatasetType::is_stateful,
109
+ typename DatasetType::BatchType::value_type,
110
+ typename DatasetType::BatchType>::type,
111
+ typename TransformType::InputBatchType>::value,
112
+ "BatchType type of dataset does not match input type of transform");
113
+ return {std::move(dataset), std::move(transform)};
114
+ }
115
+
116
+ } // namespace datasets
117
+ } // namespace data
118
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/datasets/base.h>
4
+ #include <torch/data/example.h>
5
+ #include <torch/types.h>
6
+
7
+ #include <torch/csrc/Export.h>
8
+
9
+ #include <cstddef>
10
+ #include <string>
11
+
12
+ namespace torch {
13
+ namespace data {
14
+ namespace datasets {
15
+ /// The MNIST dataset.
16
+ class TORCH_API MNIST : public Dataset<MNIST> {
17
+ public:
18
+ /// The mode in which the dataset is loaded.
19
+ enum class Mode { kTrain, kTest };
20
+
21
+ /// Loads the MNIST dataset from the `root` path.
22
+ ///
23
+ /// The supplied `root` path should contain the *content* of the unzipped
24
+ /// MNIST dataset, available from http://yann.lecun.com/exdb/mnist.
25
+ explicit MNIST(const std::string& root, Mode mode = Mode::kTrain);
26
+
27
+ /// Returns the `Example` at the given `index`.
28
+ Example<> get(size_t index) override;
29
+
30
+ /// Returns the size of the dataset.
31
+ std::optional<size_t> size() const override;
32
+
33
+ /// Returns true if this is the training subset of MNIST.
34
+ // NOLINTNEXTLINE(bugprone-exception-escape)
35
+ bool is_train() const noexcept;
36
+
37
+ /// Returns all images stacked into a single tensor.
38
+ const Tensor& images() const;
39
+
40
+ /// Returns all targets stacked into a single tensor.
41
+ const Tensor& targets() const;
42
+
43
+ private:
44
+ Tensor images_, targets_;
45
+ };
46
+ } // namespace datasets
47
+ } // namespace data
48
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/shared.h ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/datasets/base.h>
4
+
5
+ #include <memory>
6
+ #include <utility>
7
+
8
+ namespace torch {
9
+ namespace data {
10
+ namespace datasets {
11
+
12
+ /// A dataset that wraps another dataset in a shared pointer and implements the
13
+ /// `BatchDataset` API, delegating all calls to the shared instance. This is
14
+ /// useful when you want all worker threads in the dataloader to access the same
15
+ /// dataset instance. The dataset must take care of synchronization and
16
+ /// thread-safe access itself.
17
+ ///
18
+ /// Use `torch::data::datasets::make_shared_dataset()` to create a new
19
+ /// `SharedBatchDataset` like you would a `std::shared_ptr`.
20
+ template <typename UnderlyingDataset>
21
+ class SharedBatchDataset : public BatchDataset<
22
+ SharedBatchDataset<UnderlyingDataset>,
23
+ typename UnderlyingDataset::BatchType,
24
+ typename UnderlyingDataset::BatchRequestType> {
25
+ public:
26
+ using BatchType = typename UnderlyingDataset::BatchType;
27
+ using BatchRequestType = typename UnderlyingDataset::BatchRequestType;
28
+
29
+ /// Constructs a new `SharedBatchDataset` from a `shared_ptr` to the
30
+ /// `UnderlyingDataset`.
31
+ /* implicit */ SharedBatchDataset(
32
+ std::shared_ptr<UnderlyingDataset> shared_dataset)
33
+ : dataset_(std::move(shared_dataset)) {}
34
+
35
+ /// Calls `get_batch` on the underlying dataset.
36
+ BatchType get_batch(BatchRequestType request) override {
37
+ return dataset_->get_batch(std::move(request));
38
+ }
39
+
40
+ /// Returns the `size` from the underlying dataset.
41
+ std::optional<size_t> size() const override {
42
+ return dataset_->size();
43
+ }
44
+
45
+ /// Accesses the underlying dataset.
46
+ UnderlyingDataset& operator*() {
47
+ return *dataset_;
48
+ }
49
+
50
+ /// Accesses the underlying dataset.
51
+ const UnderlyingDataset& operator*() const {
52
+ return *dataset_;
53
+ }
54
+
55
+ /// Accesses the underlying dataset.
56
+ UnderlyingDataset* operator->() {
57
+ return dataset_.get();
58
+ }
59
+
60
+ /// Accesses the underlying dataset.
61
+ const UnderlyingDataset* operator->() const {
62
+ return dataset_.get();
63
+ }
64
+
65
+ /// Calls `reset()` on the underlying dataset.
66
+ void reset() {
67
+ dataset_->reset();
68
+ }
69
+
70
+ private:
71
+ std::shared_ptr<UnderlyingDataset> dataset_;
72
+ };
73
+
74
+ /// Constructs a new `SharedBatchDataset` by creating a
75
+ /// `shared_ptr<UnderlyingDatase>`. All arguments are forwarded to
76
+ /// `make_shared<UnderlyingDataset>`.
77
+ template <typename UnderlyingDataset, typename... Args>
78
+ SharedBatchDataset<UnderlyingDataset> make_shared_dataset(Args&&... args) {
79
+ return std::make_shared<UnderlyingDataset>(std::forward<Args>(args)...);
80
+ }
81
+ } // namespace datasets
82
+ } // namespace data
83
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/datasets/base.h>
4
+ #include <torch/data/example.h>
5
+
6
+ #include <cstddef>
7
+ #include <vector>
8
+
9
+ namespace torch {
10
+ namespace serialize {
11
+ class OutputArchive;
12
+ class InputArchive;
13
+ } // namespace serialize
14
+ } // namespace torch
15
+
16
+ namespace torch {
17
+ namespace data {
18
+ namespace datasets {
19
+
20
+ /// A stateful dataset is a dataset that maintains some internal state, which
21
+ /// will be `reset()` at the beginning of each epoch. Subclasses can override
22
+ /// the `reset()` method to configure this behavior. Further, the return type of
23
+ /// a stateful dataset's `get_batch()` method is always an `optional`. When the
24
+ /// stateful dataset wants to indicate to the dataloader that its epoch has
25
+ /// ended, it should return an empty optional. The dataloader knows to modify
26
+ /// its implementation based on whether the dataset is stateless or stateful.
27
+ ///
28
+ /// Note that when subclassing a from `StatefulDataset<Self, T>`, the return
29
+ /// type of `get_batch()`, which the subclass must override, will be
30
+ /// `optional<T>` (i.e. the type specified in the `StatefulDataset`
31
+ /// specialization is automatically boxed into an `optional` for the dataset's
32
+ /// `BatchType`).
33
+ template <
34
+ typename Self,
35
+ typename Batch = std::vector<Example<>>,
36
+ typename BatchRequest = size_t>
37
+ class StatefulDataset
38
+ : public BatchDataset<Self, std::optional<Batch>, BatchRequest> {
39
+ public:
40
+ /// Resets internal state of the dataset.
41
+ virtual void reset() = 0;
42
+
43
+ /// Saves the statefulDataset's state to OutputArchive.
44
+ virtual void save(serialize::OutputArchive& archive) const = 0;
45
+
46
+ /// Deserializes the statefulDataset's state from the `archive`.
47
+ virtual void load(serialize::InputArchive& archive) = 0;
48
+ };
49
+
50
+ /// Serializes a statefulDataset to `OutputArchive`.
51
+ template <typename... Args>
52
+ serialize::OutputArchive& operator<<(
53
+ serialize::OutputArchive& archive,
54
+ const StatefulDataset<Args...>& statefulDataset) {
55
+ statefulDataset.save(archive);
56
+ return archive;
57
+ }
58
+
59
+ /// Deserializes a statefulDataset from an `InputArchive`.
60
+ template <typename... Args>
61
+ serialize::InputArchive& operator>>(
62
+ serialize::InputArchive& archive,
63
+ StatefulDataset<Args...>& statefulDataset) {
64
+ statefulDataset.load(archive);
65
+ return archive;
66
+ }
67
+
68
+ } // namespace datasets
69
+ } // namespace data
70
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/datasets/base.h>
4
+ #include <torch/data/example.h>
5
+ #include <torch/types.h>
6
+
7
+ #include <cstddef>
8
+ #include <vector>
9
+
10
+ namespace torch {
11
+ namespace data {
12
+ namespace datasets {
13
+
14
+ /// A dataset of tensors.
15
+ /// Stores a single tensor internally, which is then indexed inside `get()`.
16
+ struct TensorDataset : public Dataset<TensorDataset, TensorExample> {
17
+ /// Creates a `TensorDataset` from a vector of tensors.
18
+ explicit TensorDataset(const std::vector<Tensor>& tensors)
19
+ : TensorDataset(torch::stack(tensors)) {}
20
+
21
+ explicit TensorDataset(torch::Tensor tensor) : tensor(std::move(tensor)) {}
22
+
23
+ /// Returns a single `TensorExample`.
24
+ TensorExample get(size_t index) override {
25
+ return tensor[index];
26
+ }
27
+
28
+ /// Returns the number of tensors in the dataset.
29
+ std::optional<size_t> size() const override {
30
+ return tensor.size(0);
31
+ }
32
+
33
+ Tensor tensor;
34
+ };
35
+
36
+ } // namespace datasets
37
+ } // namespace data
38
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/detail/queue.h>
4
+ #include <torch/types.h>
5
+
6
+ #include <c10/util/Exception.h>
7
+ #include <optional>
8
+
9
+ #include <chrono>
10
+ #include <utility>
11
+
12
+ namespace torch {
13
+ namespace data {
14
+ namespace detail {
15
+
16
+ /// Encapsulates the full life cycle of DataLoader jobs.
17
+ ///
18
+ /// When a new job is enqueued to the `DataShuttle`, a counter for in-flight
19
+ /// jobs is bumped. This job is said to be "in-flight" until its result is
20
+ /// popped. Worker threads dequeue jobs as soon as they are available. When a
21
+ /// worker finishes a job, it enqueues the result. Only when the main thread
22
+ /// dequeues a result is the count of in-flight jobs decremented. When the main
23
+ /// thread attempts to dequeue a job but no jobs are in-flight, that means the
24
+ /// epoch is complete and `pop_result` returns an empty optional.
25
+ template <typename Job, typename Result>
26
+ class DataShuttle {
27
+ public:
28
+ /// Pushes a new job. Called by the main thread.
29
+ void push_job(Job job) {
30
+ new_jobs_.push(std::move(job));
31
+ ++in_flight_jobs_;
32
+ }
33
+
34
+ /// Pushes the result of a job. Called by worker threads.
35
+ void push_result(Result result) {
36
+ results_.push(std::move(result));
37
+ }
38
+
39
+ /// Returns the next job, blocking until there is one available. Called by
40
+ /// worker threads.
41
+ Job pop_job() {
42
+ return new_jobs_.pop();
43
+ }
44
+
45
+ /// Returns the result of a job, or nullopt if all jobs were exhausted. Called
46
+ /// by the main thread.
47
+ std::optional<Result> pop_result(
48
+ std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
49
+ if (in_flight_jobs_ > 0) {
50
+ auto result = results_.pop(timeout);
51
+ --in_flight_jobs_;
52
+ return result;
53
+ }
54
+ return nullopt;
55
+ }
56
+
57
+ /// Discards any jobs that are not yet in flight, and waits for all in-flight
58
+ /// jobs to finish, discarding their result.
59
+ void drain() {
60
+ // Clear all inputs so that no further jobs are scheduled.
61
+ auto number_cleared = new_jobs_.clear();
62
+ in_flight_jobs_ -= number_cleared;
63
+ // Remove any outstanding results.
64
+ while (in_flight_jobs_ > 0) {
65
+ pop_result();
66
+ }
67
+ }
68
+
69
+ /// Returns the number of jobs that are still in progress.
70
+ /// When this number is zero, an epoch is finished.
71
+ size_t in_flight_jobs() const noexcept {
72
+ return in_flight_jobs_;
73
+ }
74
+
75
+ private:
76
+ /// The queue for jobs that are not yet in flight.
77
+ Queue<Job> new_jobs_;
78
+ /// The number of in-flight jobs.
79
+ /// NOTE: Not atomic because only manipulated by the main thread.
80
+ size_t in_flight_jobs_ = 0;
81
+ /// The queue for results of finished jobs.
82
+ Queue<Result> results_;
83
+ };
84
+
85
+ } // namespace detail
86
+ } // namespace data
87
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/queue.h ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/types.h>
4
+
5
+ #include <c10/util/Exception.h>
6
+
7
+ #include <chrono>
8
+ #include <condition_variable>
9
+ #include <cstddef>
10
+ #include <mutex>
11
+ #include <queue>
12
+
13
+ namespace torch {
14
+ namespace data {
15
+ namespace detail {
16
+
17
+ /// A basic locked, blocking MPMC queue.
18
+ ///
19
+ /// Every `push` and `pop` is guarded by a mutex. A condition variable is used
20
+ /// to communicate insertion of new elements, such that waiting threads will be
21
+ /// woken up if they are currently waiting inside a call to `pop()`.
22
+ ///
23
+ /// Note that this data structure is written specifically for use with the
24
+ /// `DataLoader`. Its behavior is tailored to this use case and may not be
25
+ /// applicable to more general uses.
26
+ template <typename T>
27
+ class Queue {
28
+ public:
29
+ /// Pushes a new value to the back of the `Queue` and notifies one thread on
30
+ /// the waiting side about this event.
31
+ void push(T value) {
32
+ {
33
+ std::lock_guard<std::mutex> lock(mutex_);
34
+ queue_.push(std::move(value));
35
+ }
36
+ cv_.notify_one();
37
+ }
38
+
39
+ /// Blocks until at least one element is ready to be popped from the front of
40
+ /// the queue. An optional `timeout` in seconds can be used to limit the time
41
+ /// spent waiting for an element. If the wait times out, an exception is
42
+ /// raised.
43
+ T pop(std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
44
+ std::unique_lock<std::mutex> lock(mutex_);
45
+ if (timeout) {
46
+ if (!cv_.wait_for(
47
+ lock, *timeout, [this] { return !this->queue_.empty(); })) {
48
+ // clang-format off
49
+ AT_ERROR(
50
+ "Timeout in DataLoader queue while waiting for next batch"
51
+ " (timeout was ", timeout->count(), " ms)");
52
+ // clang-format on
53
+ }
54
+ } else {
55
+ cv_.wait(lock, [this] { return !this->queue_.empty(); });
56
+ }
57
+ AT_ASSERT(!queue_.empty());
58
+ T value = queue_.front();
59
+ queue_.pop();
60
+ lock.unlock();
61
+ return value;
62
+ }
63
+
64
+ /// Empties the queue and returns the number of elements that were present at
65
+ /// the start of the function. No threads are notified about this event as it
66
+ /// is assumed to be used to drain the queue during shutdown of a
67
+ /// `DataLoader`.
68
+ size_t clear() {
69
+ std::lock_guard<std::mutex> lock(this->mutex_);
70
+ const auto size = queue_.size();
71
+ while (!queue_.empty()) {
72
+ queue_.pop();
73
+ }
74
+ return size;
75
+ }
76
+
77
+ private:
78
+ std::queue<T> queue_;
79
+ std::mutex mutex_;
80
+ std::condition_variable cv_;
81
+ };
82
+ } // namespace detail
83
+ } // namespace data
84
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/types.h>
4
+
5
+ #include <algorithm>
6
+ #include <cstddef>
7
+ #include <vector>
8
+
9
+ namespace torch {
10
+ namespace data {
11
+ namespace detail {
12
+ namespace sequencers {
13
+ namespace detail {
14
+ template <typename Result>
15
+ bool buffer_contains_result(const std::vector<std::optional<Result>>& buffer) {
16
+ return std::any_of(
17
+ buffer.begin(), buffer.end(), [](const std::optional<Result>& result) {
18
+ return result.has_value();
19
+ });
20
+ }
21
+ } // namespace detail
22
+
23
+ /// A `Sequencer` accepts a function that yields the next result of a
24
+ /// `DataLoader` and then has the opportunity to influence the order in which
25
+ /// these results are returned. The `NoSequencer` does not enforce any
26
+ /// sequencing and returns any result directly. The `OrderedSequencer` instead
27
+ /// buffers results internally to return them in order of their sequence number.
28
+ template <typename Result>
29
+ struct Sequencer {
30
+ using ResultProducer = std::function<std::optional<Result>()>;
31
+ virtual ~Sequencer() = default;
32
+ virtual std::optional<Result> next(ResultProducer next_result) = 0;
33
+ };
34
+
35
+ /// A `Sequencer` that does not enforce any ordering. It is effectively the
36
+ /// identity function.
37
+ template <typename Result>
38
+ struct NoSequencer final : public Sequencer<Result> {
39
+ using typename Sequencer<Result>::ResultProducer;
40
+ std::optional<Result> next(ResultProducer next_result) override {
41
+ return next_result();
42
+ }
43
+ };
44
+
45
+ /// A `Sequencer` that buffers results and returns them in order of their
46
+ /// sequence number. The `OrderedSequencer` maintains an internal, monotonically
47
+ /// incrementing counter for the next sequence number it expects. If it receives
48
+ /// a result with a higher sequence number, it will buffer it for later (when
49
+ /// the sequence number reaches that of this result). Otherwise, if the sequence
50
+ /// numbers match, the result is returned.
51
+ ///
52
+ /// Implementation note: The `OrderedSequencer` is implemented with a fixed-size
53
+ /// buffer. Let `m` be the maximum number of jobs in the data loader's queue and
54
+ /// `s` be the current sequence number. Assume `m` jobs are scheduled in the
55
+ /// `DataLoader`. Any new result is stored at index `job.sqn mod m` in the
56
+ /// `OrderedSequencer`. Why are we sure sequence numbers of new jobs will not
57
+ /// collide with sequence numbers of buffered jobs? The `OrderedSequencer` will
58
+ /// not return from `next()` until it receives the result with sqn `s`. This
59
+ /// means no new jobs can be scheduled in the `DataLoader` in the meantime,
60
+ /// which enforces that as long as sqn `s` has not been received, `s + m` (which
61
+ /// would cause a collision in the fixed-size buffer) will not yet be scheduled.
62
+ template <typename Result>
63
+ struct OrderedSequencer : public Sequencer<Result> {
64
+ using typename Sequencer<Result>::ResultProducer;
65
+
66
+ /// Constructs the `OrderedSequencer` with the maximum number of results it
67
+ /// will ever hold at one point in time.
68
+ explicit OrderedSequencer(size_t max_jobs) : buffer_(max_jobs) {}
69
+
70
+ /// Buffers results until the next one in the expected order is received.
71
+ std::optional<Result> next(ResultProducer next_result) override {
72
+ // If we already have the result for the next sqn, return it.
73
+ if (auto& maybe_result = buffer(next_sequence_number_)) {
74
+ auto result = std::move(*maybe_result);
75
+ buffer(next_sequence_number_++).reset();
76
+ return result;
77
+ }
78
+ // Otherwise wait for the next result.
79
+ while (true) {
80
+ auto result = next_result();
81
+ if (!result) {
82
+ AT_ASSERT(!detail::buffer_contains_result(buffer_));
83
+ break;
84
+ }
85
+ // If it was not nullopt and the sequence numbers match, return it
86
+ // directly and bump the sequence number.
87
+ if (result->sequence_number == next_sequence_number_) {
88
+ ++next_sequence_number_;
89
+ return result;
90
+ }
91
+ // Stash the result for later.
92
+ AT_ASSERT(!buffer(result->sequence_number).has_value());
93
+ buffer(result->sequence_number) = std::move(result);
94
+ }
95
+ // The result was an empty optional, so we are done with this epoch.
96
+ return nullopt;
97
+ }
98
+
99
+ /// Accesses the buffer at the `index` modulo the buffer size.
100
+ std::optional<Result>& buffer(size_t index) {
101
+ return buffer_.at(index % buffer_.size());
102
+ }
103
+
104
+ /// The monotonically increasing sequence number we expect.
105
+ size_t next_sequence_number_ = 0;
106
+
107
+ /// A fixed-size buffer (after construction).
108
+ std::vector<std::optional<Result>> buffer_;
109
+ };
110
+ } // namespace sequencers
111
+ } // namespace detail
112
+ } // namespace data
113
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/example.h ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/types.h>
4
+
5
+ namespace torch {
6
+ namespace data {
7
+
8
+ /// An `Example` from a dataset.
9
+ ///
10
+ /// A dataset consists of data and an associated target (label).
11
+ template <typename Data = at::Tensor, typename Target = at::Tensor>
12
+ struct Example {
13
+ using DataType = Data;
14
+ using TargetType = Target;
15
+
16
+ Example() = default;
17
+ Example(Data data, Target target)
18
+ : data(std::move(data)), target(std::move(target)) {}
19
+
20
+ Data data;
21
+ Target target;
22
+ };
23
+
24
+ namespace example {
25
+ using NoTarget = void;
26
+ } // namespace example
27
+
28
+ /// A specialization for `Example` that does not have a target.
29
+ ///
30
+ /// This class exists so that code can be written for a templated `Example`
31
+ /// type, and work both for labeled and unlabeled datasets.
32
+ template <typename Data>
33
+ struct Example<Data, example::NoTarget> {
34
+ using DataType = Data;
35
+ using TargetType = example::NoTarget;
36
+
37
+ Example() = default;
38
+ /* implicit */ Example(Data data) : data(std::move(data)) {}
39
+
40
+ // When a DataLoader returns an Example like this, that example should be
41
+ // implicitly convertible to the underlying data type.
42
+
43
+ operator Data&() {
44
+ return data;
45
+ }
46
+ operator const Data&() const {
47
+ return data;
48
+ }
49
+
50
+ Data data;
51
+ };
52
+
53
+ using TensorExample = Example<at::Tensor, example::NoTarget>;
54
+ } // namespace data
55
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/iterator.h ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/csrc/utils/variadic.h>
4
+ #include <torch/types.h>
5
+
6
+ #include <c10/util/Exception.h>
7
+
8
+ #include <functional>
9
+ #include <iterator>
10
+ #include <memory>
11
+ #include <type_traits>
12
+ #include <utility>
13
+
14
+ namespace torch {
15
+ namespace data {
16
+ namespace detail {
17
+ // For increased safety and more separated logic, this implementation of
18
+ // `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A
19
+ // `ValidIterator` yields new batches until the `DataLoader` is exhausted. While
20
+ // the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are
21
+ // the same object. When the `ValidIterator` becomes exhausted, it compares
22
+ // equal to the `SentinelIterator`, but not before. Half the code here is to
23
+ // implement double dispatch for the comparison. Got damnit, C++.
24
+
25
+ template <typename Batch>
26
+ struct ValidIterator;
27
+
28
+ template <typename Batch>
29
+ struct SentinelIterator;
30
+
31
+ /// Base class for the `ValidIterator` and `SentinelIterator`
32
+ template <typename Batch>
33
+ struct IteratorImpl {
34
+ virtual ~IteratorImpl() = default;
35
+ virtual void next() = 0;
36
+ virtual Batch& get() = 0;
37
+ virtual bool operator==(const IteratorImpl& other) const = 0;
38
+ virtual bool operator==(const ValidIterator<Batch>& other) const = 0;
39
+ virtual bool operator==(const SentinelIterator<Batch>& other) const = 0;
40
+ };
41
+
42
+ template <typename Batch>
43
+ struct ValidIterator : public IteratorImpl<Batch> {
44
+ using BatchProducer = std::function<std::optional<Batch>()>;
45
+
46
+ explicit ValidIterator(BatchProducer next_batch)
47
+ : next_batch_(std::move(next_batch)) {}
48
+
49
+ /// Fetches the next batch.
50
+ void next() override {
51
+ // If we didn't get the very first batch yet, get it now.
52
+ lazy_initialize();
53
+ TORCH_CHECK(
54
+ batch_.has_value(), "Attempted to increment iterator past the end");
55
+ // Increment to the next batch.
56
+ batch_ = next_batch_();
57
+ }
58
+
59
+ /// Returns the current batch. The precondition for this operation to not
60
+ /// throw an exception is that it has been compared to the `SentinelIterator`
61
+ /// and did not compare equal.
62
+ Batch& get() override {
63
+ // If we didn't get the very first batch yet, get it now.
64
+ lazy_initialize();
65
+ TORCH_CHECK(
66
+ batch_.has_value(),
67
+ "Attempted to dereference iterator that was past the end");
68
+ return batch_.value();
69
+ }
70
+
71
+ /// Does double dispatch.
72
+ bool operator==(const IteratorImpl<Batch>& other) const override {
73
+ return other == *this;
74
+ }
75
+
76
+ /// A `ValidIterator` is equal to the `SentinelIterator` iff. the
77
+ /// `ValidIterator` has reached the end of the dataloader.
78
+ bool operator==(const SentinelIterator<Batch>& /* unused */) const override {
79
+ lazy_initialize();
80
+ return !batch_;
81
+ }
82
+
83
+ /// Returns true if the memory address of `other` equals that of `this`.
84
+ bool operator==(const ValidIterator<Batch>& other) const override {
85
+ return &other == this;
86
+ }
87
+
88
+ /// Gets the very first batch if it has not yet been fetched.
89
+ void lazy_initialize() const {
90
+ if (!initialized_) {
91
+ batch_ = next_batch_();
92
+ initialized_ = true;
93
+ }
94
+ }
95
+
96
+ BatchProducer next_batch_;
97
+ mutable std::optional<Batch> batch_;
98
+ mutable bool initialized_ = false;
99
+ };
100
+
101
+ template <typename Batch>
102
+ struct SentinelIterator : public IteratorImpl<Batch> {
103
+ void next() override {
104
+ AT_ERROR(
105
+ "Incrementing the DataLoader's past-the-end iterator is not allowed");
106
+ }
107
+
108
+ Batch& get() override {
109
+ AT_ERROR(
110
+ "Dereferencing the DataLoader's past-the-end iterator is not allowed");
111
+ }
112
+
113
+ /// Does double dispatch.
114
+ bool operator==(const IteratorImpl<Batch>& other) const override {
115
+ return other == *this;
116
+ }
117
+
118
+ /// Calls the comparison operator between `ValidIterator` and
119
+ /// `SentinelIterator`.
120
+ bool operator==(const ValidIterator<Batch>& other) const override {
121
+ return other == *this;
122
+ }
123
+
124
+ /// Sentinel iterators always compare equal.
125
+ bool operator==(const SentinelIterator<Batch>& other) const override {
126
+ return true;
127
+ }
128
+ };
129
+ } // namespace detail
130
+
131
+ template <typename Batch>
132
+ class Iterator {
133
+ public:
134
+ // Type aliases to make the class recognized as a proper iterator.
135
+ using difference_type = std::ptrdiff_t;
136
+ using value_type = Batch;
137
+ using pointer = Batch*;
138
+ using reference = Batch&;
139
+ using iterator_category = std::input_iterator_tag;
140
+
141
+ explicit Iterator(std::unique_ptr<detail::IteratorImpl<Batch>> impl)
142
+ : impl_(std::move(impl)) {}
143
+
144
+ /// Increments the iterator.
145
+ /// Only permitted for valid iterators (not past the end).
146
+ Iterator& operator++() {
147
+ impl_->next();
148
+ return *this;
149
+ }
150
+
151
+ /// Returns the current batch.
152
+ /// Only permitted for valid iterators (not past the end).
153
+ Batch& operator*() {
154
+ return impl_->get();
155
+ }
156
+
157
+ /// Returns a pointer to the current batch.
158
+ /// Only permitted for valid iterators (not past the end).
159
+ Batch* operator->() {
160
+ return &impl_->get();
161
+ }
162
+
163
+ /// Compares two iterators for equality.
164
+ bool operator==(const Iterator& other) const {
165
+ return *impl_ == *other.impl_;
166
+ }
167
+
168
+ /// Compares two iterators for inequality.
169
+ bool operator!=(const Iterator& other) const {
170
+ return !(*this == other);
171
+ }
172
+
173
+ private:
174
+ /// Points either to a `ValidIterator` or to a `SentinelIterator`.
175
+ std::shared_ptr<detail::IteratorImpl<Batch>> impl_;
176
+ };
177
+ } // namespace data
178
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers.h ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/samplers/base.h>
4
+ #include <torch/data/samplers/custom_batch_request.h>
5
+ #include <torch/data/samplers/distributed.h>
6
+ #include <torch/data/samplers/random.h>
7
+ #include <torch/data/samplers/sequential.h>
8
+ #include <torch/data/samplers/serialize.h>
9
+ #include <torch/data/samplers/stream.h>
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/base.h ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/csrc/Export.h>
4
+ #include <torch/types.h>
5
+
6
+ #include <cstddef>
7
+ #include <mutex>
8
+ #include <vector>
9
+
10
+ namespace torch {
11
+ namespace serialize {
12
+ class OutputArchive;
13
+ class InputArchive;
14
+ } // namespace serialize
15
+ } // namespace torch
16
+
17
+ namespace torch {
18
+ namespace data {
19
+ namespace samplers {
20
+ /// A `Sampler` is an object that yields an index with which to access a
21
+ /// dataset.
22
+ template <typename BatchRequest = std::vector<size_t>>
23
+ class Sampler {
24
+ public:
25
+ using BatchRequestType = BatchRequest;
26
+
27
+ virtual ~Sampler() = default;
28
+
29
+ /// Resets the `Sampler`'s internal state.
30
+ /// Typically called before a new epoch.
31
+ /// Optionally, accepts a new size when reseting the sampler.
32
+ virtual void reset(std::optional<size_t> new_size) = 0;
33
+
34
+ /// Returns the next index if possible, or an empty optional if the
35
+ /// sampler is exhausted for this epoch.
36
+ virtual std::optional<BatchRequest> next(size_t batch_size) = 0;
37
+
38
+ /// Serializes the `Sampler` to the `archive`.
39
+ virtual void save(serialize::OutputArchive& archive) const = 0;
40
+
41
+ /// Deserializes the `Sampler` from the `archive`.
42
+ virtual void load(serialize::InputArchive& archive) = 0;
43
+ };
44
+
45
+ } // namespace samplers
46
+ } // namespace data
47
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/csrc/Export.h>
4
+ #include <cstddef>
5
+
6
+ namespace torch {
7
+ namespace data {
8
+ namespace samplers {
9
+ /// A base class for custom index types.
10
+ struct TORCH_API CustomBatchRequest {
11
+ CustomBatchRequest() = default;
12
+ CustomBatchRequest(const CustomBatchRequest&) = default;
13
+ CustomBatchRequest(CustomBatchRequest&&) noexcept = default;
14
+ virtual ~CustomBatchRequest() = default;
15
+
16
+ /// The number of elements accessed by this index.
17
+ virtual size_t size() const = 0;
18
+ };
19
+ } // namespace samplers
20
+ } // namespace data
21
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/csrc/Export.h>
4
+ #include <torch/data/samplers/base.h>
5
+
6
+ #include <cstddef>
7
+ #include <vector>
8
+
9
+ namespace torch {
10
+ namespace serialize {
11
+ class OutputArchive;
12
+ class InputArchive;
13
+ } // namespace serialize
14
+ } // namespace torch
15
+
16
+ namespace torch {
17
+ namespace data {
18
+ namespace samplers {
19
+
20
+ /// A `Sampler` that selects a subset of indices to sample from and defines a
21
+ /// sampling behavior. In a distributed setting, this selects a subset of the
22
+ /// indices depending on the provided num_replicas and rank parameters. The
23
+ /// `Sampler` performs a rounding operation based on the `allow_duplicates`
24
+ /// parameter to decide the local sample count.
25
+ template <typename BatchRequest = std::vector<size_t>>
26
+ class DistributedSampler : public Sampler<BatchRequest> {
27
+ public:
28
+ DistributedSampler(
29
+ size_t size,
30
+ size_t num_replicas = 1,
31
+ size_t rank = 0,
32
+ bool allow_duplicates = true)
33
+ : size_(size),
34
+ num_replicas_(num_replicas),
35
+ rank_(rank),
36
+ epoch_(0),
37
+ allow_duplicates_(allow_duplicates) {}
38
+
39
+ /// Set the epoch for the current enumeration. This can be used to alter the
40
+ /// sample selection and shuffling behavior.
41
+ void set_epoch(size_t epoch) {
42
+ epoch_ = epoch;
43
+ }
44
+
45
+ size_t epoch() const {
46
+ return epoch_;
47
+ }
48
+
49
+ protected:
50
+ size_t local_sample_count() {
51
+ if (allow_duplicates_) {
52
+ return (size_ + num_replicas_ - 1) / num_replicas_;
53
+ } else {
54
+ return size_ / num_replicas_;
55
+ }
56
+ }
57
+
58
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
59
+ size_t size_;
60
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
61
+ size_t num_replicas_;
62
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
63
+ size_t rank_;
64
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
65
+ size_t epoch_;
66
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
67
+ bool allow_duplicates_;
68
+ };
69
+
70
+ /// Select samples randomly. The sampling order is shuffled at each `reset()`
71
+ /// call.
72
+ class TORCH_API DistributedRandomSampler : public DistributedSampler<> {
73
+ public:
74
+ DistributedRandomSampler(
75
+ size_t size,
76
+ size_t num_replicas = 1,
77
+ size_t rank = 0,
78
+ bool allow_duplicates = true);
79
+
80
+ /// Resets the `DistributedRandomSampler` to a new set of indices.
81
+ void reset(std::optional<size_t> new_size = std::nullopt) override;
82
+
83
+ /// Returns the next batch of indices.
84
+ std::optional<std::vector<size_t>> next(size_t batch_size) override;
85
+
86
+ /// Serializes the `DistributedRandomSampler` to the `archive`.
87
+ void save(serialize::OutputArchive& archive) const override;
88
+
89
+ /// Deserializes the `DistributedRandomSampler` from the `archive`.
90
+ void load(serialize::InputArchive& archive) override;
91
+
92
+ /// Returns the current index of the `DistributedRandomSampler`.
93
+ size_t index() const noexcept;
94
+
95
+ private:
96
+ void populate_indices();
97
+
98
+ size_t begin_index_;
99
+ size_t end_index_;
100
+ size_t sample_index_;
101
+ std::vector<size_t> all_indices_;
102
+ };
103
+
104
+ /// Select samples sequentially.
105
+ class TORCH_API DistributedSequentialSampler : public DistributedSampler<> {
106
+ public:
107
+ DistributedSequentialSampler(
108
+ size_t size,
109
+ size_t num_replicas = 1,
110
+ size_t rank = 0,
111
+ bool allow_duplicates = true);
112
+
113
+ /// Resets the `DistributedSequentialSampler` to a new set of indices.
114
+ void reset(std::optional<size_t> new_size = std::nullopt) override;
115
+
116
+ /// Returns the next batch of indices.
117
+ std::optional<std::vector<size_t>> next(size_t batch_size) override;
118
+
119
+ /// Serializes the `DistributedSequentialSampler` to the `archive`.
120
+ void save(serialize::OutputArchive& archive) const override;
121
+
122
+ /// Deserializes the `DistributedSequentialSampler` from the `archive`.
123
+ void load(serialize::InputArchive& archive) override;
124
+
125
+ /// Returns the current index of the `DistributedSequentialSampler`.
126
+ size_t index() const noexcept;
127
+
128
+ private:
129
+ void populate_indices();
130
+
131
+ size_t begin_index_;
132
+ size_t end_index_;
133
+ size_t sample_index_;
134
+ std::vector<size_t> all_indices_;
135
+ };
136
+
137
+ } // namespace samplers
138
+ } // namespace data
139
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/random.h ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/csrc/Export.h>
4
+ #include <torch/data/samplers/base.h>
5
+ #include <torch/types.h>
6
+
7
+ #include <cstddef>
8
+ #include <vector>
9
+
10
+ namespace torch {
11
+ namespace serialize {
12
+ class OutputArchive;
13
+ class InputArchive;
14
+ } // namespace serialize
15
+ } // namespace torch
16
+
17
+ namespace torch {
18
+ namespace data {
19
+ namespace samplers {
20
+
21
+ /// A `Sampler` that returns random indices.
22
+ class TORCH_API RandomSampler : public Sampler<> {
23
+ public:
24
+ /// Constructs a `RandomSampler` with a size and dtype for the stored indices.
25
+ ///
26
+ /// The constructor will eagerly allocate all required indices, which is the
27
+ /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored
28
+ /// indices. You can change it to influence memory usage.
29
+ explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64);
30
+
31
+ ~RandomSampler() override;
32
+
33
+ /// Resets the `RandomSampler` to a new set of indices.
34
+ void reset(std::optional<size_t> new_size = std::nullopt) override;
35
+
36
+ /// Returns the next batch of indices.
37
+ std::optional<std::vector<size_t>> next(size_t batch_size) override;
38
+
39
+ /// Serializes the `RandomSampler` to the `archive`.
40
+ void save(serialize::OutputArchive& archive) const override;
41
+
42
+ /// Deserializes the `RandomSampler` from the `archive`.
43
+ void load(serialize::InputArchive& archive) override;
44
+
45
+ /// Returns the current index of the `RandomSampler`.
46
+ size_t index() const noexcept;
47
+
48
+ private:
49
+ at::Tensor indices_;
50
+ int64_t index_ = 0;
51
+ };
52
+ } // namespace samplers
53
+ } // namespace data
54
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/csrc/Export.h>
4
+ #include <torch/data/samplers/base.h>
5
+ #include <torch/types.h>
6
+
7
+ #include <cstddef>
8
+ #include <vector>
9
+
10
+ namespace torch {
11
+ namespace serialize {
12
+ class OutputArchive;
13
+ class InputArchive;
14
+ } // namespace serialize
15
+ } // namespace torch
16
+
17
+ namespace torch {
18
+ namespace data {
19
+ namespace samplers {
20
+
21
+ /// A `Sampler` that returns indices sequentially.
22
+ class TORCH_API SequentialSampler : public Sampler<> {
23
+ public:
24
+ /// Creates a `SequentialSampler` that will return indices in the range
25
+ /// `0...size - 1`.
26
+ explicit SequentialSampler(size_t size);
27
+
28
+ /// Resets the `SequentialSampler` to zero.
29
+ void reset(std::optional<size_t> new_size = std::nullopt) override;
30
+
31
+ /// Returns the next batch of indices.
32
+ std::optional<std::vector<size_t>> next(size_t batch_size) override;
33
+
34
+ /// Serializes the `SequentialSampler` to the `archive`.
35
+ void save(serialize::OutputArchive& archive) const override;
36
+
37
+ /// Deserializes the `SequentialSampler` from the `archive`.
38
+ void load(serialize::InputArchive& archive) override;
39
+
40
+ /// Returns the current index of the `SequentialSampler`.
41
+ size_t index() const noexcept;
42
+
43
+ private:
44
+ size_t size_;
45
+ size_t index_{0};
46
+ };
47
+
48
+ } // namespace samplers
49
+ } // namespace data
50
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/samplers/base.h>
4
+ #include <torch/serialize/archive.h>
5
+
6
+ namespace torch {
7
+ namespace data {
8
+ namespace samplers {
9
+ /// Serializes a `Sampler` into an `OutputArchive`.
10
+ template <typename BatchRequest>
11
+ serialize::OutputArchive& operator<<(
12
+ serialize::OutputArchive& archive,
13
+ const Sampler<BatchRequest>& sampler) {
14
+ sampler.save(archive);
15
+ return archive;
16
+ }
17
+
18
+ /// Deserializes a `Sampler` from an `InputArchive`.
19
+ template <typename BatchRequest>
20
+ serialize::InputArchive& operator>>(
21
+ serialize::InputArchive& archive,
22
+ Sampler<BatchRequest>& sampler) {
23
+ sampler.load(archive);
24
+ return archive;
25
+ }
26
+ } // namespace samplers
27
+ } // namespace data
28
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/stream.h ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/csrc/Export.h>
4
+ #include <torch/data/samplers/base.h>
5
+ #include <torch/data/samplers/custom_batch_request.h>
6
+ #include <torch/types.h>
7
+
8
+ #include <cstddef>
9
+
10
+ namespace torch {
11
+ namespace serialize {
12
+ class InputArchive;
13
+ class OutputArchive;
14
+ } // namespace serialize
15
+ } // namespace torch
16
+
17
+ namespace torch {
18
+ namespace data {
19
+ namespace samplers {
20
+
21
+ /// A wrapper around a batch size value, which implements the
22
+ /// `CustomBatchRequest` interface.
23
+ struct TORCH_API BatchSize : public CustomBatchRequest {
24
+ explicit BatchSize(size_t size);
25
+ size_t size() const noexcept override;
26
+ operator size_t() const noexcept;
27
+ size_t size_;
28
+ };
29
+
30
+ /// A sampler for (potentially infinite) streams of data.
31
+ ///
32
+ /// The major feature of the `StreamSampler` is that it does not return
33
+ /// particular indices, but instead only the number of elements to fetch from
34
+ /// the dataset. The dataset has to decide how to produce those elements.
35
+ class TORCH_API StreamSampler : public Sampler<BatchSize> {
36
+ public:
37
+ /// Constructs the `StreamSampler` with the number of individual examples that
38
+ /// should be fetched until the sampler is exhausted.
39
+ explicit StreamSampler(size_t epoch_size);
40
+
41
+ /// Resets the internal state of the sampler.
42
+ void reset(std::optional<size_t> new_size = std::nullopt) override;
43
+
44
+ /// Returns a `BatchSize` object with the number of elements to fetch in the
45
+ /// next batch. This number is the minimum of the supplied `batch_size` and
46
+ /// the difference between the `epoch_size` and the current index. If the
47
+ /// `epoch_size` has been reached, returns an empty optional.
48
+ std::optional<BatchSize> next(size_t batch_size) override;
49
+
50
+ /// Serializes the `StreamSampler` to the `archive`.
51
+ void save(serialize::OutputArchive& archive) const override;
52
+
53
+ /// Deserializes the `StreamSampler` from the `archive`.
54
+ void load(serialize::InputArchive& archive) override;
55
+
56
+ private:
57
+ size_t examples_retrieved_so_far_ = 0;
58
+ size_t epoch_size_;
59
+ };
60
+
61
+ } // namespace samplers
62
+ } // namespace data
63
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/transforms/base.h>
4
+ #include <torch/data/transforms/collate.h>
5
+ #include <torch/data/transforms/lambda.h>
6
+ #include <torch/data/transforms/stack.h>
7
+ #include <torch/data/transforms/tensor.h>
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/base.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/types.h>
4
+
5
+ #include <utility>
6
+ #include <vector>
7
+
8
+ namespace torch {
9
+ namespace data {
10
+ namespace transforms {
11
+
12
+ /// A transformation of a batch to a new batch.
13
+ template <typename InputBatch, typename OutputBatch>
14
+ class BatchTransform {
15
+ public:
16
+ using InputBatchType = InputBatch;
17
+ using OutputBatchType = OutputBatch;
18
+
19
+ virtual ~BatchTransform() = default;
20
+
21
+ /// Applies the transformation to the given `input_batch`.
22
+ virtual OutputBatch apply_batch(InputBatch input_batch) = 0;
23
+ };
24
+
25
+ /// A transformation of individual input examples to individual output examples.
26
+ ///
27
+ /// Just like a `Dataset` is a `BatchDataset`, a `Transform` is a
28
+ /// `BatchTransform` that can operate on the level of individual examples rather
29
+ /// than entire batches. The batch-level transform is implemented (by default)
30
+ /// in terms of the example-level transform, though this can be customized.
31
+ template <typename Input, typename Output>
32
+ class Transform
33
+ : public BatchTransform<std::vector<Input>, std::vector<Output>> {
34
+ public:
35
+ using InputType = Input;
36
+ using OutputType = Output;
37
+
38
+ /// Applies the transformation to the given `input`.
39
+ virtual OutputType apply(InputType input) = 0;
40
+
41
+ /// Applies the `transformation` over the entire `input_batch`.
42
+ std::vector<Output> apply_batch(std::vector<Input> input_batch) override {
43
+ std::vector<Output> output_batch;
44
+ output_batch.reserve(input_batch.size());
45
+ for (auto&& input : input_batch) {
46
+ output_batch.push_back(apply(std::move(input)));
47
+ }
48
+ return output_batch;
49
+ }
50
+ };
51
+ } // namespace transforms
52
+ } // namespace data
53
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/collate.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/example.h>
4
+ #include <torch/data/transforms/lambda.h>
5
+
6
+ #include <vector>
7
+
8
+ namespace torch {
9
+ namespace data {
10
+ namespace transforms {
11
+
12
+ /// A `Collation` is a transform that reduces a batch into a single value.
13
+ /// The result is a `BatchDataset` that has the type of the single value as its
14
+ /// `BatchType`.
15
+ template <typename T, typename BatchType = std::vector<T>>
16
+ using Collation = BatchTransform<BatchType, T>;
17
+
18
+ /// A `Collate` allows passing a custom function to reduce/collate a batch
19
+ /// into a single value. It's effectively the lambda version of `Collation`,
20
+ /// which you could subclass and override `operator()` to achieve the same.
21
+ ///
22
+ /// \rst
23
+ /// .. code-block:: cpp
24
+ /// using namespace torch::data;
25
+ ///
26
+ /// auto dataset = datasets::MNIST("path/to/mnist")
27
+ /// .map(transforms::Collate<Example<>>([](std::vector<Example<>> e) {
28
+ /// return std::move(e.front());
29
+ /// }));
30
+ /// \endrst
31
+ template <typename T, typename BatchType = std::vector<T>>
32
+ using Collate = BatchLambda<BatchType, T>;
33
+ } // namespace transforms
34
+ } // namespace data
35
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/transforms/base.h>
4
+
5
+ #include <functional>
6
+ #include <utility>
7
+ #include <vector>
8
+
9
+ namespace torch {
10
+ namespace data {
11
+ namespace transforms {
12
+
13
+ /// A `BatchTransform` that applies a user-provided functor to a batch.
14
+ template <typename Input, typename Output = Input>
15
+ class BatchLambda : public BatchTransform<Input, Output> {
16
+ public:
17
+ using typename BatchTransform<Input, Output>::InputBatchType;
18
+ using typename BatchTransform<Input, Output>::OutputBatchType;
19
+ using FunctionType = std::function<OutputBatchType(InputBatchType)>;
20
+
21
+ /// Constructs the `BatchLambda` from the given `function` object.
22
+ explicit BatchLambda(FunctionType function)
23
+ : function_(std::move(function)) {}
24
+
25
+ /// Applies the user-provided function object to the `input_batch`.
26
+ OutputBatchType apply_batch(InputBatchType input_batch) override {
27
+ return function_(std::move(input_batch));
28
+ }
29
+
30
+ private:
31
+ FunctionType function_;
32
+ };
33
+
34
+ // A `Transform` that applies a user-provided functor to individual examples.
35
+ template <typename Input, typename Output = Input>
36
+ class Lambda : public Transform<Input, Output> {
37
+ public:
38
+ using typename Transform<Input, Output>::InputType;
39
+ using typename Transform<Input, Output>::OutputType;
40
+ using FunctionType = std::function<Output(Input)>;
41
+
42
+ /// Constructs the `Lambda` from the given `function` object.
43
+ explicit Lambda(FunctionType function) : function_(std::move(function)) {}
44
+
45
+ /// Applies the user-provided function object to the `input`.
46
+ OutputType apply(InputType input) override {
47
+ return function_(std::move(input));
48
+ }
49
+
50
+ private:
51
+ FunctionType function_;
52
+ };
53
+
54
+ } // namespace transforms
55
+ } // namespace data
56
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/stack.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/example.h>
4
+ #include <torch/data/transforms/collate.h>
5
+ #include <torch/types.h>
6
+
7
+ #include <utility>
8
+ #include <vector>
9
+
10
+ namespace torch {
11
+ namespace data {
12
+ namespace transforms {
13
+
14
+ template <typename T = Example<>>
15
+ struct Stack;
16
+
17
+ /// A `Collation` for `Example<Tensor, Tensor>` types that stacks all data
18
+ /// tensors into one tensor, and all target (label) tensors into one tensor.
19
+ template <>
20
+ struct Stack<Example<>> : public Collation<Example<>> {
21
+ Example<> apply_batch(std::vector<Example<>> examples) override {
22
+ std::vector<torch::Tensor> data, targets;
23
+ data.reserve(examples.size());
24
+ targets.reserve(examples.size());
25
+ for (auto& example : examples) {
26
+ data.push_back(std::move(example.data));
27
+ targets.push_back(std::move(example.target));
28
+ }
29
+ return {torch::stack(data), torch::stack(targets)};
30
+ }
31
+ };
32
+
33
+ /// A `Collation` for `Example<Tensor, NoTarget>` types that stacks all data
34
+ /// tensors into one tensor.
35
+ template <>
36
+ struct Stack<TensorExample>
37
+ : public Collation<Example<Tensor, example::NoTarget>> {
38
+ TensorExample apply_batch(std::vector<TensorExample> examples) override {
39
+ std::vector<torch::Tensor> data;
40
+ data.reserve(examples.size());
41
+ for (auto& example : examples) {
42
+ data.push_back(std::move(example.data));
43
+ }
44
+ return torch::stack(data);
45
+ }
46
+ };
47
+ } // namespace transforms
48
+ } // namespace data
49
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/data/example.h>
4
+ #include <torch/data/transforms/base.h>
5
+ #include <torch/types.h>
6
+
7
+ #include <functional>
8
+ #include <utility>
9
+
10
+ namespace torch {
11
+ namespace data {
12
+ namespace transforms {
13
+
14
+ /// A `Transform` that is specialized for the typical `Example<Tensor, Tensor>`
15
+ /// combination. It exposes a single `operator()` interface hook (for
16
+ /// subclasses), and calls this function on input `Example` objects.
17
+ template <typename Target = Tensor>
18
+ class TensorTransform
19
+ : public Transform<Example<Tensor, Target>, Example<Tensor, Target>> {
20
+ public:
21
+ using E = Example<Tensor, Target>;
22
+ using typename Transform<E, E>::InputType;
23
+ using typename Transform<E, E>::OutputType;
24
+
25
+ /// Transforms a single input tensor to an output tensor.
26
+ virtual Tensor operator()(Tensor input) = 0;
27
+
28
+ /// Implementation of `Transform::apply` that calls `operator()`.
29
+ OutputType apply(InputType input) override {
30
+ input.data = (*this)(std::move(input.data));
31
+ return input;
32
+ }
33
+ };
34
+
35
+ /// A `Lambda` specialized for the typical `Example<Tensor, Tensor>` input type.
36
+ template <typename Target = Tensor>
37
+ class TensorLambda : public TensorTransform<Target> {
38
+ public:
39
+ using FunctionType = std::function<Tensor(Tensor)>;
40
+
41
+ /// Creates a `TensorLambda` from the given `function`.
42
+ explicit TensorLambda(FunctionType function)
43
+ : function_(std::move(function)) {}
44
+
45
+ /// Applies the user-provided functor to the input tensor.
46
+ Tensor operator()(Tensor input) override {
47
+ return function_(std::move(input));
48
+ }
49
+
50
+ private:
51
+ FunctionType function_;
52
+ };
53
+
54
+ /// Normalizes input tensors by subtracting the supplied mean and dividing by
55
+ /// the given standard deviation.
56
+ template <typename Target = Tensor>
57
+ struct Normalize : public TensorTransform<Target> {
58
+ /// Constructs a `Normalize` transform. The mean and standard deviation can be
59
+ /// anything that is broadcastable over the input tensors (like single
60
+ /// scalars).
61
+ Normalize(ArrayRef<double> mean, ArrayRef<double> stddev)
62
+ : mean(torch::tensor(mean, torch::kFloat32)
63
+ .unsqueeze(/*dim=*/1)
64
+ .unsqueeze(/*dim=*/2)),
65
+ stddev(torch::tensor(stddev, torch::kFloat32)
66
+ .unsqueeze(/*dim=*/1)
67
+ .unsqueeze(/*dim=*/2)) {}
68
+
69
+ torch::Tensor operator()(Tensor input) override {
70
+ return input.sub(mean).div(stddev);
71
+ }
72
+
73
+ torch::Tensor mean, stddev;
74
+ };
75
+ } // namespace transforms
76
+ } // namespace data
77
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/worker_exception.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <exception>
4
+ #include <string>
5
+ #include <utility>
6
+
7
+ namespace torch {
8
+ namespace data {
9
+
10
+ /// An exception thrown when a DataLoader's worker thread throws an exception,
11
+ /// which is caught. A `WorkerException` stores an `exception_ptr` to the
12
+ /// original exception thrown in the worker thread.
13
+ struct WorkerException : public std::exception {
14
+ /// Constructs a `WorkerException` from an `exception_ptr`.
15
+ explicit WorkerException(std::exception_ptr original)
16
+ : original_exception(std::move(original)),
17
+ message("Caught exception in DataLoader worker thread.") {
18
+ try {
19
+ std::rethrow_exception(original_exception);
20
+ } catch (std::exception& e) {
21
+ message += " Original message: ";
22
+ message += e.what();
23
+ }
24
+ }
25
+
26
+ const char* what() const noexcept override {
27
+ return message.c_str();
28
+ }
29
+
30
+ /// The original exception thrown in the worker thread.
31
+ std::exception_ptr original_exception;
32
+
33
+ /// This exception's message (not the original exception's message).
34
+ std::string message;
35
+ };
36
+
37
+ } // namespace data
38
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Dispatch.h>
4
+ #include <ATen/ScalarOps.h>
5
+ #include <ATen/core/Tensor.h>
6
+ #include <ATen/core/grad_mode.h>
7
+
8
+ #include <c10/util/irange.h>
9
+
10
+ #ifndef AT_PER_OPERATOR_HEADERS
11
+ #include <ATen/Functions.h>
12
+ #else
13
+ #include <ATen/ops/empty.h>
14
+ #include <ATen/ops/tensor.h>
15
+ #endif
16
+
17
+ #include <initializer_list>
18
+
19
+ namespace torch {
20
+
21
+ namespace detail {
22
+
23
+ enum class TensorDataContainerType { Scalar, InitList, Tensor };
24
+
25
+ struct TensorDataContainer;
26
+
27
+ inline std::ostream& operator<<(
28
+ std::ostream& stream,
29
+ const TensorDataContainer& tensor_data_container);
30
+
31
+ inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) {
32
+ if (scalar_type == at::kInt || scalar_type == at::kLong) {
33
+ // C++ `torch::tensor` with an integer type or an `at::ArrayRef` /
34
+ // `std::vector` / (nested) braced-init-list of integer types always
35
+ // produces a tensor of dtype `at::kLong` (aka. int64_t), matching Python
36
+ // `torch.tensor` behavior.
37
+ return at::kLong;
38
+ } else if (scalar_type == at::kFloat || scalar_type == at::kDouble) {
39
+ // C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` /
40
+ // `std::vector` / (nested) braced-init-list of floating-point types always
41
+ // produces a tensor of dtype `torch::get_default_dtype()`, matching Python
42
+ // `torch.tensor` behavior.
43
+ return at::typeMetaToScalarType(at::get_default_dtype());
44
+ } else {
45
+ return scalar_type;
46
+ }
47
+ }
48
+
49
+ // We use `TensorDataContainer` to support converting the following data
50
+ // container types into the equivalent Tensor:
51
+ //
52
+ // 1. Arbitrarily nested braced-init-list (e.g. `{{1, 2}, {3, 4}}`).
53
+ // 2. `at::ArrayRef` of supported tensor data types.
54
+ // 3. `std::vector` of supported tensor data types.
55
+ //
56
+ // At any time, a `TensorDataContainer` object represents one of the following:
57
+ //
58
+ // 1. A scalar with value `scalar()` and type `scalar_type()`.
59
+ // 2. A Tensor represented in `std::initializer_list<TensorDataContainer>` form,
60
+ // with value `init_list()`, Tensor scalar type `scalar_type()`, and Tensor
61
+ // sizes `sizes()`.
62
+ // 3. A Tensor represented in `at::Tensor` form, with value `tensor()`, scalar
63
+ // type `scalar_type()`,
64
+ // and Tensor sizes `sizes()`.
65
+ //
66
+ // All the infrastructure here is mostly to support converting an arbitrarily
67
+ // nested braced-init-list to the equivalent Tensor successfully. Consider the
68
+ // following example:
69
+ //
70
+ // `torch::tensor({{1}, {2}})`
71
+ //
72
+ // this will call into the `torch::tensor` function:
73
+ //
74
+ // `at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const
75
+ // at::TensorOptions& options = {})`
76
+ //
77
+ // the compiler will first try to convert `{{1}, {2}}` to `TensorDataContainer`
78
+ // type:
79
+ //
80
+ // `TensorDataContainer({{1}, {2}})`
81
+ //
82
+ // which matches to the
83
+ // `TensorDataContainer(std::initializer_list<TensorDataContainer>)`
84
+ // constructor, and in an attempt to convert `{1}` and `{2}` to
85
+ // `TensorDataContainer`, it calls the following:
86
+ //
87
+ // `TensorDataContainer({1})` (same call path happens for `{2}`, and we'll just
88
+ // focus on `{1}` here)
89
+ //
90
+ // At this point, theoretically there are two plausible ways for `{1}` to be
91
+ // matched to one of the constructors of `TensorDataContainer`:
92
+ //
93
+ // 1. It can be a list-initialization of a scalar value, thus matching
94
+ // `TensorDataContainer(int value)`.
95
+ // 2. It can be converted to `std::initializer_list<TensorDataContainer>`, thus
96
+ // matching
97
+ // `TensorDataContainer(std::initializer_list<TensorDataContainer>)`.
98
+ //
99
+ // How does the compiler decide which one to choose? According to
100
+ // `https://en.cppreference.com/w/cpp/language/list_initialization`,
101
+ // braced-init-list always prefers the constructor that takes
102
+ // `std::initializer_list`. Hence we happily move forward with constructor #2,
103
+ // and it calls the following:
104
+ //
105
+ // `TensorDataContainer(1)`
106
+ //
107
+ // Now it matches `TensorDataContainer(int value)`, which stores `1` as a scalar
108
+ // value. All is good.
109
+ struct TensorDataContainer {
110
+ // NOTE: For tensors with zero-size dimensions (e.g. `torch::tensor({{},
111
+ // {}})`), the innermost empty braced-init-list `{}` matches the default
112
+ // constructor of the innermost `TensorDataContainer`.
113
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
114
+ TensorDataContainer()
115
+ : sizes_({0}),
116
+ // NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g.
117
+ // `torch.tensor([[], []])`) depends on the value of
118
+ // `torch.get_default_dtype()`, and we should do the same for the C++
119
+ // equivalent.
120
+ scalar_type_(at::typeMetaToScalarType(at::get_default_dtype())),
121
+ type_(TensorDataContainerType::InitList) {}
122
+ #define TENSOR(T, S) \
123
+ TensorDataContainer(T value) \
124
+ : sizes_(), \
125
+ scalar_type_(at::k##S), \
126
+ type_(TensorDataContainerType::Scalar), \
127
+ scalar_(value) {}
128
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
129
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
130
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
131
+ AT_FORALL_COMPLEX_TYPES(TENSOR)
132
+ #undef TENSOR
133
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
134
+ TensorDataContainer(std::initializer_list<TensorDataContainer> init_list)
135
+ : sizes_(),
136
+ scalar_type_(init_list.begin()->scalar_type()),
137
+ type_(TensorDataContainerType::InitList),
138
+ init_list_(init_list) {
139
+ const TensorDataContainer& first_elem = *(init_list.begin());
140
+ for (const auto& elem : init_list) {
141
+ TORCH_CHECK(
142
+ elem.sizes() == first_elem.sizes(),
143
+ "Expected all sub-lists to have sizes: ",
144
+ first_elem.sizes(),
145
+ " (e.g. ",
146
+ first_elem,
147
+ "), ",
148
+ "but got sub-list ",
149
+ elem,
150
+ " with sizes: ",
151
+ elem.sizes());
152
+ TORCH_CHECK(
153
+ elem.scalar_type() == first_elem.scalar_type(),
154
+ "Expected all elements of the tensor to have the same scalar type: ",
155
+ first_elem.scalar_type(),
156
+ ", but got element of scalar type: ",
157
+ elem.scalar_type());
158
+ }
159
+ sizes_.reserve(first_elem.sizes().size() + 1);
160
+ sizes_.push_back(init_list.size());
161
+ sizes_.insert(
162
+ sizes_.end(), first_elem.sizes().begin(), first_elem.sizes().end());
163
+ }
164
+
165
+ #define TENSOR(T, S) \
166
+ TensorDataContainer(at::ArrayRef<T> values) \
167
+ : sizes_({(int64_t)values.size()}), \
168
+ scalar_type_(at::k##S), \
169
+ type_(TensorDataContainerType::Tensor) { \
170
+ at::AutoDispatchBelowAutograd mode; \
171
+ if (scalar_type_ == at::kBool) { \
172
+ tensor_ = at::tensor(values, at::TensorOptions().device(at::kCPU)); \
173
+ } else { \
174
+ tensor_ = at::tensor(values, at::dtype(scalar_type_).device(at::kCPU)); \
175
+ } \
176
+ }
177
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
178
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
179
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
180
+ AT_FORALL_COMPLEX_TYPES(TENSOR)
181
+ #undef TENSOR
182
+
183
+ // NOTE: We need to handle `std::vector` explicitly instead of relying on an
184
+ // implicit conversion to `at::ArrayRef`, otherwise the following error can be
185
+ // thrown when calling `torch::tensor(std::vector<int>({1, 2}))`:
186
+ // ```
187
+ // error: no matching function for call to 'tensor(const std::vector<int>&)'
188
+ // no known conversion for argument 1 from 'const std::vector<int>' to
189
+ // 'torch::detail::TensorDataContainer'
190
+ // ```
191
+ //
192
+ // NOTE: `torch::tensor(std::vector<bool>)` is not supported for now, because
193
+ // ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.
194
+ #define TENSOR(T, S) \
195
+ TensorDataContainer(const std::vector<T>& values) \
196
+ : TensorDataContainer(at::ArrayRef<T>(values)) {}
197
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
198
+ AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR)
199
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
200
+ AT_FORALL_COMPLEX_TYPES(TENSOR)
201
+ #undef TENSOR
202
+
203
+ bool is_scalar() const {
204
+ return type_ == TensorDataContainerType::Scalar;
205
+ }
206
+
207
+ const c10::Scalar& scalar() const {
208
+ TORCH_CHECK(
209
+ is_scalar(),
210
+ "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`");
211
+ return scalar_;
212
+ }
213
+
214
+ bool is_init_list() const {
215
+ return type_ == TensorDataContainerType::InitList;
216
+ }
217
+
218
+ const std::initializer_list<TensorDataContainer>& init_list() const {
219
+ TORCH_CHECK(
220
+ is_init_list(),
221
+ "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`");
222
+ return init_list_;
223
+ }
224
+
225
+ bool is_tensor() const {
226
+ return type_ == TensorDataContainerType::Tensor;
227
+ }
228
+
229
+ const at::Tensor& tensor() const {
230
+ TORCH_CHECK(
231
+ is_tensor(),
232
+ "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`");
233
+ return tensor_;
234
+ }
235
+
236
+ const std::vector<int64_t>& sizes() const {
237
+ return sizes_;
238
+ }
239
+
240
+ const c10::ScalarType& scalar_type() const {
241
+ return scalar_type_;
242
+ }
243
+
244
+ at::Tensor convert_to_tensor(at::TensorOptions options) const {
245
+ if (!options.has_dtype()) {
246
+ options = options.dtype(compute_desired_dtype(scalar_type_));
247
+ }
248
+
249
+ if (is_scalar()) {
250
+ at::AutoDispatchBelowAutograd mode;
251
+ return at::scalar_tensor(scalar_, options);
252
+ } else if (is_init_list()) {
253
+ // NOTE: Here we explicitly choose to initialize the tensor on CPU first,
254
+ // fill each element of the tensor, and then move the tensor to the
255
+ // desired device. For CUDA device, this approach only involves 1 CUDA
256
+ // kernel launch, and is much faster than initializing the tensor on CUDA
257
+ // first and then filling each element of it (which involves `N` CUDA
258
+ // kernel launches where `N` is the number of the elements in the tensor).
259
+ at::Tensor tensor = ([&]() {
260
+ at::AutoDispatchBelowAutograd mode;
261
+ return at::empty(sizes_, options.device(at::kCPU));
262
+ })();
263
+ fill_tensor(tensor);
264
+ return tensor.to(options.device());
265
+ } else if (is_tensor()) {
266
+ auto output = tensor_.to(options);
267
+ TORCH_CHECK(
268
+ !tensor_.is_complex() || output.is_complex(),
269
+ "can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information");
270
+ return output;
271
+ } else {
272
+ TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
273
+ }
274
+ }
275
+
276
+ void pretty_print_recursive(std::ostream& stream) const {
277
+ if (is_scalar()) {
278
+ AT_DISPATCH_ALL_TYPES_AND3(
279
+ at::kBool,
280
+ at::kHalf,
281
+ at::kBFloat16,
282
+ scalar_type_,
283
+ "TensorDataContainer_pretty_print_scalar",
284
+ [&] { stream << scalar_.to<scalar_t>(); });
285
+ } else if (is_init_list()) {
286
+ stream << "{";
287
+ for (const TensorDataContainer* it = init_list_.begin();
288
+ it != init_list_.end();
289
+ it++) {
290
+ stream << *it;
291
+ if (std::next(it) != init_list_.end())
292
+ stream << ", ";
293
+ }
294
+ stream << "}";
295
+ } else if (is_tensor()) {
296
+ stream << "{";
297
+ for (const auto i : c10::irange(tensor_.sizes()[0])) {
298
+ AT_DISPATCH_ALL_TYPES_AND3(
299
+ at::kBool,
300
+ at::kHalf,
301
+ at::kBFloat16,
302
+ scalar_type_,
303
+ "TensorDataContainer_pretty_print_tensor_item",
304
+ [&] { stream << tensor_[i].item<scalar_t>(); });
305
+ if (i != tensor_.sizes()[0] - 1)
306
+ stream << ", ";
307
+ }
308
+ stream << "}";
309
+ } else {
310
+ TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
311
+ }
312
+ }
313
+
314
+ private:
315
+ void fill_tensor(at::Tensor& tensor) const {
316
+ if (is_scalar()) {
317
+ TORCH_INTERNAL_ASSERT(
318
+ tensor.dim() == 0,
319
+ "Expected a 0-dim Tensor, but got Tensor with dimensions: ",
320
+ tensor.dim());
321
+ at::NoGradGuard guard;
322
+ tensor.fill_(scalar_);
323
+ } else if (is_init_list()) {
324
+ TORCH_INTERNAL_ASSERT(
325
+ tensor.sizes()[0] == (int64_t)init_list_.size(),
326
+ "Expected a Tensor with size ",
327
+ init_list_.size(),
328
+ " in its first dimension, but got Tensor with size ",
329
+ tensor.sizes()[0],
330
+ " in its first dimension");
331
+ size_t index = 0;
332
+ for (const auto& elem : init_list_) {
333
+ at::Tensor slice = tensor[index];
334
+ elem.fill_tensor(slice);
335
+ index++;
336
+ }
337
+ } else if (is_tensor()) {
338
+ TORCH_INTERNAL_ASSERT(
339
+ false,
340
+ "TensorDataContainer is already a Tensor type, `fill_tensor` should not be called");
341
+ } else {
342
+ TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
343
+ }
344
+ }
345
+
346
+ std::vector<int64_t> sizes_;
347
+ c10::ScalarType scalar_type_;
348
+ TensorDataContainerType type_;
349
+ c10::Scalar scalar_;
350
+ std::initializer_list<TensorDataContainer> init_list_;
351
+ at::Tensor tensor_;
352
+ };
353
+
354
+ inline std::ostream& operator<<(
355
+ std::ostream& stream,
356
+ const TensorDataContainer& tensor_data_container) {
357
+ tensor_data_container.pretty_print_recursive(stream);
358
+ return stream;
359
+ }
360
+
361
+ } // namespace detail
362
+
363
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/static.h ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/csrc/utils/variadic.h>
4
+ #include <torch/types.h>
5
+
6
+ #include <cstdint>
7
+ #include <type_traits>
8
+
9
+ namespace torch {
10
+ namespace nn {
11
+ class Module;
12
+ } // namespace nn
13
+ } // namespace torch
14
+
15
+ namespace torch {
16
+ namespace detail {
17
+ /// Detects if a type T has a forward() method.
18
+ template <typename T>
19
+ struct has_forward {
20
+ // Declare two types with differing size.
21
+ using yes = int8_t;
22
+ using no = int16_t;
23
+
24
+ // Here we declare two functions. The first is only enabled if `&U::forward`
25
+ // is well-formed and returns the `yes` type. In C++, the ellipsis parameter
26
+ // type (`...`) always puts the function at the bottom of overload resolution.
27
+ // This is specified in the standard as: 1) A standard conversion sequence is
28
+ // always better than a user-defined conversion sequence or an ellipsis
29
+ // conversion sequence. 2) A user-defined conversion sequence is always better
30
+ // than an ellipsis conversion sequence This means that if the first overload
31
+ // is viable, it will be preferred over the second as long as we pass any
32
+ // convertible type. The type of `&U::forward` is a pointer type, so we can
33
+ // pass e.g. 0.
34
+ template <typename U>
35
+ static yes test(decltype(&U::forward));
36
+ template <typename U>
37
+ static no test(...);
38
+
39
+ // Finally we test statically whether the size of the type returned by the
40
+ // selected overload is the size of the `yes` type.
41
+ static constexpr bool value = (sizeof(test<T>(nullptr)) == sizeof(yes));
42
+ };
43
+
44
+ template <typename Head = void, typename... Tail>
45
+ constexpr bool check_not_lvalue_references() {
46
+ return (!std::is_lvalue_reference<Head>::value ||
47
+ std::is_const<typename std::remove_reference<Head>::type>::value) &&
48
+ check_not_lvalue_references<Tail...>();
49
+ }
50
+
51
+ template <>
52
+ inline constexpr bool check_not_lvalue_references<void>() {
53
+ return true;
54
+ }
55
+
56
+ /// A type trait whose `value` member is true if `M` derives from `Module`.
57
+ template <typename M>
58
+ using is_module =
59
+ std::is_base_of<torch::nn::Module, typename std::decay<M>::type>;
60
+
61
+ template <typename M, typename T = void>
62
+ using enable_if_module_t =
63
+ typename std::enable_if<is_module<M>::value, T>::type;
64
+ } // namespace detail
65
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/detail/static.h>
4
+ #include <torch/nn/module.h>
5
+ #include <torch/nn/modules/container/any_module_holder.h>
6
+ #include <torch/nn/modules/container/any_value.h>
7
+ #include <torch/nn/pimpl.h>
8
+ #include <torch/types.h>
9
+
10
+ #include <torch/csrc/autograd/variable.h>
11
+ #include <torch/csrc/utils/variadic.h>
12
+
13
+ #include <ATen/Device.h>
14
+
15
+ #include <memory>
16
+ #include <type_traits>
17
+ #include <typeinfo>
18
+ #include <utility>
19
+ #include <vector>
20
+
21
+ namespace torch {
22
+ namespace nn {
23
+
24
+ /// Stores a type erased `Module`.
25
+ ///
26
+ /// The PyTorch C++ API does not impose an interface on the signature of
27
+ /// `forward()` in `Module` subclasses. This gives you complete freedom to
28
+ /// design your `forward()` methods to your liking. However, this also means
29
+ /// there is no unified base type you could store in order to call `forward()`
30
+ /// polymorphically for any module. This is where the `AnyModule` comes in.
31
+ /// Instead of inheritance, it relies on type erasure for polymorphism.
32
+ ///
33
+ /// An `AnyModule` can store any `nn::Module` subclass that provides a
34
+ /// `forward()` method. This `forward()` may accept any types and return any
35
+ /// type. Once stored in an `AnyModule`, you can invoke the underlying module's
36
+ /// `forward()` by calling `AnyModule::forward()` with the arguments you would
37
+ /// supply to the stored module (though see one important limitation below).
38
+ /// Example:
39
+ ///
40
+ /// \rst
41
+ /// .. code-block:: cpp
42
+ ///
43
+ /// struct GenericTrainer {
44
+ /// torch::nn::AnyModule module;
45
+ ///
46
+ /// void train(torch::Tensor input) {
47
+ /// module.forward(input);
48
+ /// }
49
+ /// };
50
+ ///
51
+ /// GenericTrainer trainer1{torch::nn::Linear(3, 4)};
52
+ /// GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
53
+ /// \endrst
54
+ ///
55
+ /// As `AnyModule` erases the static type of the stored module (and its
56
+ /// `forward()` method) to achieve polymorphism, type checking of arguments is
57
+ /// moved to runtime. That is, passing an argument with an incorrect type to an
58
+ /// `AnyModule` will compile, but throw an exception at runtime:
59
+ ///
60
+ /// \rst
61
+ /// .. code-block:: cpp
62
+ ///
63
+ /// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
64
+ /// // Linear takes a tensor as input, but we are passing an integer.
65
+ /// // This will compile, but throw a `torch::Error` exception at runtime.
66
+ /// module.forward(123);
67
+ /// \endrst
68
+ ///
69
+ /// \rst
70
+ /// .. attention::
71
+ /// One noteworthy limitation of `AnyModule` is that its `forward()` method
72
+ /// does not support implicit conversion of argument types. For example, if
73
+ /// the stored module's `forward()` method accepts a `float` and you call
74
+ /// `any_module.forward(3.4)` (where `3.4` is a `double`), this will throw
75
+ /// an exception.
76
+ /// \endrst
77
+ ///
78
+ /// The return type of the `AnyModule`'s `forward()` method is controlled via
79
+ /// the first template argument to `AnyModule::forward()`. It defaults to
80
+ /// `torch::Tensor`. To change it, you can write `any_module.forward<int>()`,
81
+ /// for example.
82
+ ///
83
+ /// \rst
84
+ /// .. code-block:: cpp
85
+ ///
86
+ /// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
87
+ /// auto output = module.forward(torch::ones({2, 3}));
88
+ ///
89
+ /// struct IntModule {
90
+ /// int forward(int x) { return x; }
91
+ /// };
92
+ /// torch::nn::AnyModule module(IntModule{});
93
+ /// int output = module.forward<int>(5);
94
+ /// \endrst
95
+ ///
96
+ /// The only other method an `AnyModule` provides access to on the stored
97
+ /// module is `clone()`. However, you may acquire a handle on the module via
98
+ /// `.ptr()`, which returns a `shared_ptr<nn::Module>`. Further, if you know
99
+ /// the concrete type of the stored module, you can get a concrete handle to it
100
+ /// using `.get<T>()` where `T` is the concrete module type.
101
+ ///
102
+ /// \rst
103
+ /// .. code-block:: cpp
104
+ ///
105
+ /// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
106
+ /// std::shared_ptr<nn::Module> ptr = module.ptr();
107
+ /// torch::nn::Linear linear(module.get<torch::nn::Linear>());
108
+ /// \endrst
109
+ class AnyModule {
110
+ public:
111
+ /// A default-constructed `AnyModule` is in an empty state.
112
+ AnyModule() = default;
113
+
114
+ /// Constructs an `AnyModule` from a `shared_ptr` to concrete module object.
115
+ template <typename ModuleType>
116
+ explicit AnyModule(std::shared_ptr<ModuleType> module);
117
+
118
+ /// Constructs an `AnyModule` from a concrete module object.
119
+ template <
120
+ typename ModuleType,
121
+ typename = torch::detail::enable_if_module_t<ModuleType>>
122
+ explicit AnyModule(ModuleType&& module);
123
+
124
+ /// Constructs an `AnyModule` from a module holder.
125
+ template <typename ModuleType>
126
+ explicit AnyModule(const ModuleHolder<ModuleType>& module_holder);
127
+
128
+ /// Move construction and assignment is allowed, and follows the default
129
+ /// behavior of move for `std::unique_ptr`.
130
+ AnyModule(AnyModule&&) = default;
131
+ AnyModule& operator=(AnyModule&&) = default;
132
+
133
+ /// Creates a shallow copy of an `AnyModule`.
134
+ AnyModule(const AnyModule& other);
135
+ AnyModule& operator=(const AnyModule& other);
136
+
137
+ /// Creates a deep copy of an `AnyModule` if it contains a module, else an
138
+ /// empty `AnyModule` if it is empty.
139
+ AnyModule clone(std::optional<Device> device = std::nullopt) const;
140
+
141
+ /// Assigns a module to the `AnyModule` (to circumvent the explicit
142
+ /// constructor).
143
+ template <typename ModuleType>
144
+ AnyModule& operator=(std::shared_ptr<ModuleType> module);
145
+
146
+ /// Invokes `forward()` on the contained module with the given arguments, and
147
+ /// returns the return value as an `AnyValue`. Use this method when chaining
148
+ /// `AnyModule`s in a loop.
149
+ template <typename... ArgumentTypes>
150
+ AnyValue any_forward(ArgumentTypes&&... arguments);
151
+
152
+ /// Invokes `forward()` on the contained module with the given arguments, and
153
+ /// casts the returned `AnyValue` to the supplied `ReturnType` (which defaults
154
+ /// to `torch::Tensor`).
155
+ template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
156
+ ReturnType forward(ArgumentTypes&&... arguments);
157
+
158
+ /// Attempts to cast the underlying module to the given module type. Throws an
159
+ /// exception if the types do not match.
160
+ template <typename T, typename = torch::detail::enable_if_module_t<T>>
161
+ T& get();
162
+
163
+ /// Attempts to cast the underlying module to the given module type. Throws an
164
+ /// exception if the types do not match.
165
+ template <typename T, typename = torch::detail::enable_if_module_t<T>>
166
+ const T& get() const;
167
+
168
+ /// Returns the contained module in a `nn::ModuleHolder` subclass if possible
169
+ /// (i.e. if `T` has a constructor for the underlying module type).
170
+ template <typename T, typename ContainedType = typename T::ContainedType>
171
+ T get() const;
172
+
173
+ /// Returns a `std::shared_ptr` whose dynamic type is that of the underlying
174
+ /// module.
175
+ std::shared_ptr<Module> ptr() const;
176
+
177
+ /// Like `ptr()`, but casts the pointer to the given type.
178
+ template <typename T, typename = torch::detail::enable_if_module_t<T>>
179
+ std::shared_ptr<T> ptr() const;
180
+
181
+ /// Returns the `type_info` object of the contained value.
182
+ const std::type_info& type_info() const;
183
+
184
+ /// Returns true if the `AnyModule` does not contain a module.
185
+ bool is_empty() const noexcept;
186
+
187
+ private:
188
+ /// Creates a `unique_ptr<AnyModulePlaceholder>` pointing to a
189
+ /// `AnyModuleHolder` of the correct type. This method is used to deduce the
190
+ /// arguments of the module's `forward()` method.
191
+ template <
192
+ typename ModuleType,
193
+ typename Class,
194
+ typename ReturnType,
195
+ typename... ArgumentTypes>
196
+ std::unique_ptr<AnyModulePlaceholder> make_holder(
197
+ std::shared_ptr<ModuleType>&& module,
198
+ ReturnType (Class::*)(ArgumentTypes...));
199
+
200
+ /// Helper method invoked by const and non-const `get()`.
201
+ template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
202
+ ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const;
203
+
204
+ /// Helper method invoked by const and non-const `get()`.
205
+ template <typename ModuleType>
206
+ ModuleType& get_() const;
207
+
208
+ /// The type erased module.
209
+ std::unique_ptr<AnyModulePlaceholder> content_;
210
+ };
211
+
212
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+
214
+ template <typename ModuleType>
215
+ AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
216
+ : content_(make_holder(
217
+ std::move(module),
218
+ &std::remove_reference<ModuleType>::type::forward)) {
219
+ // `AnyModule` can only store an `nn::Module` subclass object that provides
220
+ // a `forward()` method that has a non-templatized return type.
221
+ // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s
222
+ // `forward()` method has a templatized return type.)
223
+ static_assert(
224
+ torch::detail::is_module<ModuleType>::value,
225
+ "Can only store object derived from nn::Module into AnyModule");
226
+ static_assert(
227
+ torch::detail::has_forward<ModuleType>::value,
228
+ "Can only store module with a forward() method that has a non-templatized"
229
+ " argument type and return type into AnyModule (e.g. we cannot store nn::Sequential"
230
+ "into AnyModule, because its forward() method's argument type and return type are templatized."
231
+ " If you need to use nn::Sequentials inside each other you can subclass "
232
+ "nn::Sequential and write a non-templatized forward function for it. You can checkout "
233
+ "https://github.com/pytorch/vision/blob/2f46070f3cb1ea894d82578f3dc5677f82f34958/torchvision/csrc/models/mnasnet.cpp#L59 "
234
+ "for an example on how to do this.).");
235
+ }
236
+
237
+ template <typename ModuleType, typename>
238
+ AnyModule::AnyModule(ModuleType&& module)
239
+ : AnyModule(
240
+ std::make_shared<ModuleType>(std::forward<ModuleType>(module))) {}
241
+
242
+ template <typename ModuleType>
243
+ AnyModule::AnyModule(const ModuleHolder<ModuleType>& module_holder)
244
+ : AnyModule(module_holder.ptr()) {}
245
+
246
+ inline AnyModule::AnyModule(const AnyModule& other)
247
+ : content_(other.content_ ? other.content_->copy() : nullptr) {}
248
+
249
+ inline AnyModule& AnyModule::operator=(const AnyModule& other) {
250
+ if (this != &other) {
251
+ content_ = other.content_ ? other.content_->copy() : nullptr;
252
+ }
253
+ return *this;
254
+ }
255
+
256
+ inline AnyModule AnyModule::clone(std::optional<Device> device) const {
257
+ AnyModule clone;
258
+ clone.content_ = content_ ? content_->clone_module(device) : nullptr;
259
+ return clone;
260
+ }
261
+
262
+ template <typename ModuleType>
263
+ AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
264
+ // NOLINTNEXTLINE(cppcoreguidelines-c-copy-assignment-signature)
265
+ return (*this = AnyModule(std::move(module)));
266
+ }
267
+
268
+ template <typename... ArgumentTypes>
269
+ AnyValue AnyModule::any_forward(ArgumentTypes&&... arguments) {
270
+ TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
271
+ std::vector<AnyValue> values;
272
+ values.reserve(sizeof...(ArgumentTypes));
273
+ torch::apply(
274
+ [&values](AnyValue&& value) { values.push_back(std::move(value)); },
275
+ AnyValue(std::forward<ArgumentTypes>(arguments))...);
276
+ return content_->forward(std::move(values));
277
+ }
278
+
279
+ template <typename ReturnType, typename... ArgumentTypes>
280
+ ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
281
+ return any_forward(std::forward<ArgumentTypes>(arguments)...)
282
+ .template get<ReturnType>();
283
+ }
284
+
285
+ template <typename T, typename>
286
+ T& AnyModule::get() {
287
+ TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
288
+ return get_<T>();
289
+ }
290
+
291
+ template <typename T, typename>
292
+ const T& AnyModule::get() const {
293
+ TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
294
+ return get_<T>();
295
+ }
296
+
297
+ template <typename T, typename ContainedType>
298
+ T AnyModule::get() const {
299
+ return T(ptr<ContainedType>());
300
+ }
301
+
302
+ inline std::shared_ptr<Module> AnyModule::ptr() const {
303
+ TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
304
+ return content_->ptr();
305
+ }
306
+
307
+ template <typename T, typename>
308
+ std::shared_ptr<T> AnyModule::ptr() const {
309
+ TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
310
+ // Call get() but discard the value, just to do the type checking.
311
+ get_<T>();
312
+ return std::dynamic_pointer_cast<T>(ptr());
313
+ }
314
+
315
+ inline const std::type_info& AnyModule::type_info() const {
316
+ TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
317
+ return content_->type_info;
318
+ }
319
+
320
+ inline bool AnyModule::is_empty() const noexcept {
321
+ return content_ == nullptr;
322
+ }
323
+
324
+ // Private Methods
325
+
326
+ template <
327
+ typename ModuleType,
328
+ typename Class,
329
+ typename ReturnType,
330
+ typename... ArgumentTypes>
331
+ std::unique_ptr<AnyModulePlaceholder> AnyModule::make_holder(
332
+ std::shared_ptr<ModuleType>&& module,
333
+ ReturnType (Class::*)(ArgumentTypes...)) {
334
+ static_assert(
335
+ torch::detail::check_not_lvalue_references<ArgumentTypes...>(),
336
+ "Modules stored inside AnyModule must not take references. "
337
+ "Use pointers instead.");
338
+ static_assert(
339
+ !std::is_void<ReturnType>::value,
340
+ "AnyModule cannot store modules that return void "
341
+ "(you can return a dummy value).");
342
+ return std::make_unique<
343
+ AnyModuleHolder<std::decay_t<ModuleType>, ArgumentTypes...>>(
344
+ std::move(module));
345
+ }
346
+
347
+ template <typename ModuleType>
348
+ ModuleType& AnyModule::get_() const {
349
+ using M = typename std::remove_reference<ModuleType>::type;
350
+ static_assert(
351
+ torch::detail::has_forward<M>::value,
352
+ "Can only call AnyModule::get<T> with a type T that has a forward method");
353
+ return get_(&M::forward);
354
+ }
355
+
356
+ template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
357
+ ModuleType& AnyModule::get_(
358
+ ReturnType (ModuleType::*)(ArgumentTypes...)) const {
359
+ if (typeid(ModuleType).hash_code() == type_info().hash_code()) {
360
+ return *static_cast<AnyModuleHolder<ModuleType, ArgumentTypes...>&>(
361
+ *content_)
362
+ .module;
363
+ }
364
+ AT_ERROR(
365
+ "Attempted to cast module of type ",
366
+ c10::demangle(type_info().name()),
367
+ " to type ",
368
+ c10::demangle(typeid(ModuleType).name()));
369
+ }
370
+
371
+ } // namespace nn
372
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/nn/modules/container/any_value.h>
4
+
5
+ namespace torch {
6
+ namespace nn {
7
+
8
+ class Module;
9
+
10
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModulePlaceholder ~~~~~~~~~~~~~~~~~~~~~~~~~~
11
+
12
+ /// The static type of the object we store in the `AnyModule`, which erases
13
+ /// the actual type, but allows us to call `forward()` on the underlying
14
+ /// module.
15
+ struct AnyModulePlaceholder : public AnyValue::Placeholder {
16
+ using AnyValue::Placeholder::Placeholder;
17
+
18
+ /// The "erased" `forward()` method.
19
+ virtual AnyValue forward(std::vector<AnyValue>&& arguments) = 0;
20
+
21
+ /// Returns std::shared_ptr<Module> pointing to the erased module.
22
+ virtual std::shared_ptr<Module> ptr() = 0;
23
+
24
+ /// Returns a `AnyModulePlaceholder` with a shallow copy of this `AnyModule`.
25
+ virtual std::unique_ptr<AnyModulePlaceholder> copy() const = 0;
26
+
27
+ /// Returns a `AnyModulePlaceholder` with a deep copy of this `AnyModule`.
28
+ virtual std::unique_ptr<AnyModulePlaceholder> clone_module(
29
+ std::optional<Device> device) const = 0;
30
+ };
31
+
32
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModuleHolder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33
+
34
+ /// The dynamic type of the object stored in the `AnyModule`. It contains the
35
+ /// concrete instance to which all calls are forwarded. It is parameterized
36
+ /// over the concrete type of the module, and the types of the arguments the
37
+ /// module takes in its `forward()` method.
38
+ template <typename ModuleType, typename... ArgumentTypes>
39
+ struct AnyModuleHolder : public AnyModulePlaceholder {
40
+ /// \internal
41
+ struct CheckedGetter {
42
+ template <typename T>
43
+ std::decay_t<T>&& operator()(size_t index) {
44
+ AT_ASSERT(index < arguments_.size());
45
+ auto& value = arguments_[index];
46
+ if (auto* maybe_value = value.template try_get<std::decay_t<T>>()) {
47
+ return std::move(*maybe_value);
48
+ }
49
+ AT_ERROR(
50
+ "Expected argument #",
51
+ index,
52
+ " to be of type ",
53
+ c10::demangle(typeid(T).name()),
54
+ ", but received value of type ",
55
+ c10::demangle(value.type_info().name()));
56
+ }
57
+ std::vector<AnyValue>& arguments_;
58
+ };
59
+
60
+ /// \internal
61
+ struct InvokeForward {
62
+ template <typename... Ts>
63
+ AnyValue operator()(Ts&&... ts) {
64
+ return AnyValue(module_->forward(std::forward<Ts>(ts)...));
65
+ }
66
+ std::shared_ptr<ModuleType>& module_;
67
+ };
68
+
69
+ /// Constructs the `AnyModuleHolder` from a concrete module.
70
+ explicit AnyModuleHolder(std::shared_ptr<ModuleType>&& module_)
71
+ : AnyModulePlaceholder(typeid(ModuleType)), module(std::move(module_)) {}
72
+
73
+ /// Calls `forward()` on the underlying module, casting each `AnyValue` in the
74
+ /// argument vector to a concrete value.
75
+ AnyValue forward(std::vector<AnyValue>&& arguments) override {
76
+ if (module->_forward_has_default_args()) {
77
+ TORCH_CHECK(
78
+ arguments.size() >= module->_forward_num_required_args() &&
79
+ arguments.size() <= sizeof...(ArgumentTypes),
80
+ c10::demangle(type_info.name()),
81
+ "'s forward() method expects at least ",
82
+ module->_forward_num_required_args(),
83
+ " argument(s) and at most ",
84
+ sizeof...(ArgumentTypes),
85
+ " argument(s), but received ",
86
+ arguments.size(),
87
+ ".");
88
+ arguments = std::move(
89
+ module->_forward_populate_default_args(std::move(arguments)));
90
+ } else {
91
+ std::string use_default_args_macro_prompt = " If " +
92
+ c10::demangle(type_info.name()) +
93
+ "'s forward() method has default arguments, " +
94
+ "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.";
95
+ TORCH_CHECK(
96
+ arguments.size() == sizeof...(ArgumentTypes),
97
+ c10::demangle(type_info.name()),
98
+ "'s forward() method expects ",
99
+ sizeof...(ArgumentTypes),
100
+ " argument(s), but received ",
101
+ arguments.size(),
102
+ ".",
103
+ (arguments.size() < sizeof...(ArgumentTypes))
104
+ ? use_default_args_macro_prompt
105
+ : "");
106
+ }
107
+
108
+ // FYI: During invocation of a module's `forward()` method, the values live
109
+ // in the `arguments` vector inside this function.
110
+ return torch::unpack<AnyValue, ArgumentTypes...>(
111
+ InvokeForward{module}, CheckedGetter{arguments});
112
+ }
113
+
114
+ std::shared_ptr<Module> ptr() override {
115
+ return module;
116
+ }
117
+
118
+ std::unique_ptr<AnyModulePlaceholder> copy() const override {
119
+ return std::make_unique<AnyModuleHolder>(*this);
120
+ }
121
+
122
+ std::unique_ptr<AnyModulePlaceholder> clone_module(
123
+ std::optional<Device> device) const override {
124
+ return std::make_unique<AnyModuleHolder>(
125
+ std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
126
+ }
127
+
128
+ /// The actual concrete module instance.
129
+ std::shared_ptr<ModuleType> module;
130
+ };
131
+
132
+ } // namespace nn
133
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/detail/static.h>
4
+ #include <torch/nn/module.h>
5
+ #include <torch/nn/pimpl.h>
6
+ #include <torch/types.h>
7
+
8
+ #include <torch/csrc/autograd/variable.h>
9
+ #include <torch/csrc/utils/variadic.h>
10
+
11
+ #include <memory>
12
+ #include <type_traits>
13
+ #include <typeinfo>
14
+ #include <utility>
15
+
16
+ namespace torch {
17
+ namespace nn {
18
+
19
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyValue ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
20
+
21
+ /// An implementation of `std::any` which stores
22
+ /// a type erased object, whose concrete value can be retrieved at runtime by
23
+ /// checking if the `typeid()` of a requested type matches the `typeid()` of
24
+ /// the object stored.
25
+ class AnyValue {
26
+ public:
27
+ /// Move construction and assignment is allowed, and follows the default
28
+ /// behavior of move for `std::unique_ptr`.
29
+ AnyValue(AnyValue&&) = default;
30
+ AnyValue& operator=(AnyValue&&) = default;
31
+
32
+ /// Copy construction and assignment is allowed.
33
+ AnyValue(const AnyValue& other) : content_(other.content_->clone()) {}
34
+ AnyValue& operator=(const AnyValue& other) {
35
+ content_ = other.content_->clone();
36
+ return *this;
37
+ }
38
+
39
+ /// Constructs the `AnyValue` from value type.
40
+ template <typename T>
41
+ // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
42
+ explicit AnyValue(T&& value)
43
+ : content_(
44
+ std::make_unique<Holder<std::decay_t<T>>>(std::forward<T>(value))) {
45
+ }
46
+
47
+ /// Returns a pointer to the value contained in the `AnyValue` if the type
48
+ /// passed as template parameter matches the type of the value stored, and
49
+ /// returns a null pointer otherwise.
50
+ template <typename T>
51
+ T* try_get() {
52
+ static_assert(
53
+ !std::is_reference<T>::value,
54
+ "AnyValue stores decayed types, you cannot cast it to a reference type");
55
+ static_assert(
56
+ !std::is_array<T>::value,
57
+ "AnyValue stores decayed types, you must cast it to T* instead of T[]");
58
+ if (typeid(T).hash_code() == type_info().hash_code()) {
59
+ return &static_cast<Holder<T>&>(*content_).value;
60
+ }
61
+ return nullptr;
62
+ }
63
+
64
+ /// Returns the value contained in the `AnyValue` if the type passed as
65
+ /// template parameter matches the type of the value stored, and throws an
66
+ /// exception otherwise.
67
+ template <typename T>
68
+ T get() {
69
+ if (auto* maybe_value = try_get<T>()) {
70
+ return *maybe_value;
71
+ }
72
+ AT_ERROR(
73
+ "Attempted to cast AnyValue to ",
74
+ c10::demangle(typeid(T).name()),
75
+ ", but its actual type is ",
76
+ c10::demangle(type_info().name()));
77
+ }
78
+
79
+ /// Returns the `type_info` object of the contained value.
80
+ const std::type_info& type_info() const noexcept {
81
+ return content_->type_info;
82
+ }
83
+
84
+ private:
85
+ friend struct AnyModulePlaceholder;
86
+ friend struct TestAnyValue;
87
+
88
+ /// \internal
89
+ /// The static type of the object we store in the `AnyValue`, which erases the
90
+ /// actual object's type, allowing us only to check the `type_info` of the
91
+ /// type stored in the dynamic type.
92
+ struct Placeholder {
93
+ explicit Placeholder(const std::type_info& type_info_) noexcept
94
+ : type_info(type_info_) {}
95
+ Placeholder(const Placeholder&) = default;
96
+ Placeholder(Placeholder&&) = default;
97
+ virtual ~Placeholder() = default;
98
+ virtual std::unique_ptr<Placeholder> clone() const {
99
+ TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`");
100
+ }
101
+ const std::type_info& type_info;
102
+ };
103
+
104
+ /// \internal
105
+ /// The dynamic type of the object we store in the `AnyValue`, which hides the
106
+ /// actual object we have erased in this `AnyValue`.
107
+ template <typename T>
108
+ struct Holder : public Placeholder {
109
+ /// A template because T&& would not be universal reference here.
110
+ template <typename U>
111
+ // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
112
+ explicit Holder(U&& value_) noexcept
113
+ : Placeholder(typeid(T)), value(std::forward<U>(value_)) {}
114
+ std::unique_ptr<Placeholder> clone() const override {
115
+ return std::make_unique<Holder<T>>(value);
116
+ }
117
+ T value;
118
+ };
119
+
120
+ /// The type erased object.
121
+ std::unique_ptr<Placeholder> content_;
122
+ };
123
+
124
+ } // namespace nn
125
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/csrc/Export.h>
4
+ #include <torch/csrc/utils/variadic.h>
5
+ #include <torch/nn/cloneable.h>
6
+ #include <torch/nn/pimpl.h>
7
+ #include <torch/types.h>
8
+
9
+ #include <functional>
10
+ #include <utility>
11
+
12
+ namespace torch {
13
+ namespace nn {
14
+
15
+ /// Wraps a function in a `Module`.
16
+ ///
17
+ /// The `Functional` module allows wrapping an arbitrary function or function
18
+ /// object in an `nn::Module`. This is primarily handy for usage in
19
+ /// `Sequential`.
20
+ ///
21
+ /// \rst
22
+ /// .. code-block:: cpp
23
+ ///
24
+ /// Sequential sequential(
25
+ /// Linear(3, 4),
26
+ /// Functional(torch::relu),
27
+ /// BatchNorm1d(3),
28
+ /// Functional(torch::elu, /*alpha=*/1));
29
+ /// \endrst
30
+ ///
31
+ /// While a `Functional` module only accepts a single `Tensor` as input, it is
32
+ /// possible for the wrapped function to accept further arguments. However,
33
+ /// these have to be bound *at construction time*. For example, if
34
+ /// you want to wrap `torch::leaky_relu`, which accepts a `slope` scalar as its
35
+ /// second argument, with a particular value for its `slope` in a `Functional`
36
+ /// module, you could write
37
+ ///
38
+ /// \rst
39
+ /// .. code-block:: cpp
40
+ ///
41
+ /// Functional(torch::leaky_relu, /*slope=*/0.5)
42
+ /// \endrst
43
+ ///
44
+ /// The value of `0.5` is then stored within the `Functional` object and
45
+ /// supplied to the function call at invocation time. Note that such bound
46
+ /// values are evaluated eagerly and stored a single time. See the documentation
47
+ /// of [std::bind](https://en.cppreference.com/w/cpp/utility/functional/bind)
48
+ /// for more information on the semantics of argument binding.
49
+ ///
50
+ /// \rst
51
+ /// .. attention::
52
+ /// After passing any bound arguments, the function must accept a single
53
+ /// tensor and return a single tensor.
54
+ /// \endrst
55
+ ///
56
+ /// Note that `Functional` overloads the call operator (`operator()`) such that
57
+ /// you can invoke it with `my_func(...)`.
58
+ class TORCH_API FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl> {
59
+ public:
60
+ using Function = std::function<Tensor(Tensor)>;
61
+
62
+ /// Constructs a `Functional` from a function object.
63
+ explicit FunctionalImpl(Function function);
64
+
65
+ template <
66
+ typename SomeFunction,
67
+ typename... Args,
68
+ typename = std::enable_if_t<(sizeof...(Args) > 0)>>
69
+ explicit FunctionalImpl(SomeFunction original_function, Args&&... args)
70
+ // NOLINTNEXTLINE(modernize-avoid-bind)
71
+ : function_(std::bind(
72
+ original_function,
73
+ /*input=*/std::placeholders::_1,
74
+ std::forward<Args>(args)...)) {
75
+ // std::bind is normally evil, but (1) gcc is broken w.r.t. handling
76
+ // parameter pack expansion in lambdas and (2) moving parameter packs into
77
+ // a lambda only works with C++14, so std::bind is the more move-aware
78
+ // solution here.
79
+ }
80
+
81
+ void reset() override;
82
+
83
+ /// Pretty prints the `Functional` module into the given `stream`.
84
+ void pretty_print(std::ostream& stream) const override;
85
+
86
+ /// Forwards the `input` tensor to the underlying (bound) function object.
87
+ Tensor forward(Tensor input);
88
+
89
+ /// Calls forward(input).
90
+ Tensor operator()(Tensor input);
91
+
92
+ bool is_serializable() const override;
93
+
94
+ private:
95
+ Function function_;
96
+ };
97
+
98
+ /// A `ModuleHolder` subclass for `FunctionalImpl`.
99
+ /// See the documentation for `FunctionalImpl` class to learn what methods it
100
+ /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
101
+ /// module storage semantics.
102
+ TORCH_MODULE(Functional);
103
+
104
+ } // namespace nn
105
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/nn/cloneable.h>
4
+ #include <torch/nn/module.h>
5
+ #include <torch/ordered_dict.h>
6
+ #include <vector>
7
+
8
+ namespace torch {
9
+ namespace nn {
10
+
11
+ /// An OrderedDict of `Module`s that registers its elements by their `key`s.
12
+ ///
13
+ /// \rst
14
+ /// .. code-block:: cpp
15
+ ///
16
+ /// torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
17
+ /// {"linear", Linear(10, 3).ptr()},
18
+ /// {"conv", Conv2d(1, 2, 3).ptr()},
19
+ /// {"dropout", Dropout(0.5).ptr()},
20
+ /// };
21
+ /// torch::nn::ModuleDict dict1(ordereddict);
22
+ ///
23
+ /// for (const auto &module : *dict1) {
24
+ /// module->pretty_print(std::cout);
25
+ /// }
26
+ ///
27
+ /// std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
28
+ /// {"linear", Linear(10, 3).ptr()},
29
+ /// {"conv", Conv2d(1, 2, 3).ptr()},
30
+ /// {"dropout", Dropout(0.5).ptr()},
31
+ /// };
32
+ /// torch::nn::ModuleDict dict2(list);
33
+ ///
34
+ /// for (const auto &module : *dict2) {
35
+ /// module->pretty_print(std::cout);
36
+ /// }
37
+ ///
38
+ /// \endrst
39
+ ///
40
+ /// Why should you use `ModuleDict` instead of a simple `map` or `OrderedDict`?
41
+ /// The value a `ModuleDict` provides over manually calling an ordered map of
42
+ /// modules is that it allows treating the whole container *as a single module*,
43
+ /// such that performing a transformation on the `ModuleDict` applies to each of
44
+ /// the modules it stores (which are each a registered submodule of the
45
+ /// `ModuleDict`). For example, calling `.to(torch::kCUDA)` on a `ModuleDict`
46
+ /// will move each module in the map to CUDA memory. For example:
47
+ ///
48
+ /// \rst
49
+ /// .. code-block:: cpp
50
+ ///
51
+ /// torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
52
+ /// {"linear", Linear(10, 3).ptr()},
53
+ /// {"conv", Conv2d(1, 2, 3).ptr()},
54
+ /// {"dropout", Dropout(0.5).ptr()},
55
+ /// };
56
+ /// torch::nn::ModuleDict dict(ordereddict);
57
+ ///
58
+ /// // Convert all modules to CUDA.
59
+ /// dict->to(torch::kCUDA);
60
+ ///
61
+ /// \endrst
62
+ ///
63
+ /// Finally, `ModuleDict` provides a lightweight container API, such as allowing
64
+ /// iteration over submodules, positional access, adding new modules from a
65
+ /// vector of key-module pairs or an `OrderedDict` or another `ModuleDict` after
66
+ /// construction via `update`.
67
+ class ModuleDictImpl : public Cloneable<ModuleDictImpl> {
68
+ public:
69
+ using Iterator =
70
+ torch::OrderedDict<std::string, std::shared_ptr<Module>>::Iterator;
71
+ using ConstIterator =
72
+ torch::OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator;
73
+
74
+ ModuleDictImpl() = default;
75
+
76
+ /// Constructs the `ModuleDict` from a list of string-Module pairs.
77
+ explicit ModuleDictImpl(
78
+ const std::vector<std::pair<std::string, std::shared_ptr<Module>>>&
79
+ modules) {
80
+ update(modules);
81
+ }
82
+
83
+ /// Constructs the `ModuleDict` from an `OrderedDict`.
84
+ explicit ModuleDictImpl(
85
+ const torch::OrderedDict<std::string, std::shared_ptr<Module>>& modules) {
86
+ update(modules);
87
+ }
88
+
89
+ /// Return the items in the `ModuleDict`.
90
+ std::vector<std::pair<std::string, std::shared_ptr<Module>>> items() const {
91
+ return modules_.pairs();
92
+ }
93
+
94
+ /// Return the keys in the `ModuleDict`.
95
+ std::vector<std::string> keys() const {
96
+ return modules_.keys();
97
+ }
98
+
99
+ /// Return the values in the `ModuleDict`.
100
+ std::vector<std::shared_ptr<Module>> values() const {
101
+ return modules_.values();
102
+ }
103
+
104
+ /// Return an iterator to the start of `ModuleDict`.
105
+ Iterator begin() {
106
+ return modules_.begin();
107
+ }
108
+
109
+ /// Return a const iterator to the start of `ModuleDict`.
110
+ ConstIterator begin() const {
111
+ return modules_.begin();
112
+ }
113
+
114
+ /// Return an iterator to the end of `ModuleDict`.
115
+ Iterator end() {
116
+ return modules_.end();
117
+ }
118
+
119
+ /// Return a const iterator to the end of `ModuleDict`.
120
+ ConstIterator end() const {
121
+ return modules_.end();
122
+ }
123
+
124
+ /// Return the number of items currently stored in the `ModuleDict`.
125
+ size_t size() const noexcept {
126
+ return modules_.size();
127
+ }
128
+
129
+ /// Return true if the `ModuleDict` is empty, otherwise return false.
130
+ bool empty() const noexcept {
131
+ return modules_.is_empty();
132
+ }
133
+
134
+ /// Check if the centain parameter with the key in the `ModuleDict`.
135
+ bool contains(const std::string& key) const noexcept {
136
+ return modules_.contains(key);
137
+ }
138
+
139
+ /// Remove all items from the `ModuleDict`.
140
+ void clear() {
141
+ // Not remove the registration of modules to make it consistent with python
142
+ // version.
143
+ modules_.clear();
144
+ }
145
+
146
+ /// Special cloning function for `ModuleDict` because it does not use
147
+ /// `reset()`.
148
+ std::shared_ptr<Module> clone(
149
+ const std::optional<Device>& device = std::nullopt) const override {
150
+ auto clone = std::make_shared<ModuleDictImpl>();
151
+ for (const auto& module : modules_) {
152
+ clone->insert(module.key(), module.value()->clone(device));
153
+ }
154
+ return clone;
155
+ }
156
+
157
+ /// `reset()` is empty for `ModuleDict`, since it does not have parameters of
158
+ /// its own.
159
+ void reset() override {}
160
+
161
+ /// Pretty prints the `ModuleDict` into the given `stream`.
162
+ void pretty_print(std::ostream& stream) const override {
163
+ stream << "torch::nn::ModuleDict";
164
+ }
165
+
166
+ /// Attempts to returns the `Module` associated with the given `key`. Throws
167
+ /// an exception if no such `key` is stored in the `ModuleDict`. Check
168
+ /// contains(key) before for a non-throwing way of access.
169
+ std::shared_ptr<Module> operator[](const std::string& key) const {
170
+ return modules_[key];
171
+ }
172
+
173
+ /// Attempts to return the module at the given key as the requested type.
174
+ /// Throws an exception if no such `key` is stored in the `ModuleDict`.
175
+ /// Check contains(key) before for a non-throwing way of access.
176
+ template <typename T>
177
+ T& at(const std::string& key) {
178
+ static_assert(
179
+ torch::detail::is_module<T>::value,
180
+ "Can only call ModuleList::at with an nn::Module type");
181
+ auto module = modules_[key]->as<T>();
182
+ TORCH_CHECK(
183
+ module,
184
+ "Unable to cast module[",
185
+ key,
186
+ "] to ",
187
+ c10::demangle(typeid(T).name()));
188
+ return *module;
189
+ }
190
+
191
+ /// Attempts to return the module at the given key as the requested type.
192
+ /// Throws an exception if no such `key` is stored in the `ModuleDict`.
193
+ /// Check contains(key) before for a non-throwing way of access.
194
+ template <typename T>
195
+ const T& at(const std::string& key) const {
196
+ static_assert(
197
+ torch::detail::is_module<T>::value,
198
+ "Can only call ModuleList::at with an nn::Module type");
199
+ const auto module = modules_[key]->as<T>();
200
+ TORCH_CHECK(
201
+ module,
202
+ "Unable to cast module[",
203
+ key,
204
+ "] to ",
205
+ c10::demangle(typeid(T).name()));
206
+ return *module;
207
+ }
208
+
209
+ /// Removes and returns the `Module` associated with the given `key`.
210
+ /// Throws an exception if no such `key` is stored in the `ModuleDict`.
211
+ /// Check contains(key) before for a non-throwing way of access.
212
+ std::shared_ptr<Module> pop(const std::string& key) {
213
+ auto module = modules_[key];
214
+ modules_.erase(key);
215
+ // Not remove the registration of the module to make it consistent with
216
+ // python version.
217
+ return module;
218
+ }
219
+
220
+ /// Updated the `ModuleDict` with a vector of key-module pairs.
221
+ void update(
222
+ const std::vector<std::pair<std::string, std::shared_ptr<Module>>>&
223
+ modules) {
224
+ for (auto& item : modules) {
225
+ insert(item.first, item.second);
226
+ }
227
+ }
228
+
229
+ /// Updated the `ModuleDict` with key-value pairs from `OrderedDict` or
230
+ /// `ModuleDict`.
231
+ template <typename Container>
232
+ void update(const Container& container) {
233
+ for (auto& item : container) {
234
+ insert(item.key(), item.value());
235
+ }
236
+ }
237
+
238
+ private:
239
+ /// Private `OrderedDict` holding the key-Module pairs.
240
+ torch::OrderedDict<std::string, std::shared_ptr<Module>> modules_;
241
+
242
+ /// Insert a key-module pair by overwriting existing keys,
243
+ /// and register or replace the `Module`.
244
+ void insert(const std::string& key, std::shared_ptr<Module> module) {
245
+ if (contains(key)) {
246
+ modules_[key] = std::move(module);
247
+ replace_module(key, modules_[key]);
248
+ } else {
249
+ modules_.insert(key, std::move(module));
250
+ register_module(key, modules_.back().value());
251
+ }
252
+ }
253
+ };
254
+
255
+ /// A `ModuleHolder` subclass for `ModuleDictImpl`.
256
+ /// See the documentation for `ModuleDictImpl` class to learn what methods it
257
+ /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
258
+ /// module storage semantics.
259
+ TORCH_MODULE(ModuleDict);
260
+
261
+ } // namespace nn
262
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/irange.h>
4
+ #include <torch/nn/cloneable.h>
5
+ #include <torch/nn/module.h>
6
+
7
+ #include <utility>
8
+ #include <vector>
9
+
10
+ namespace torch {
11
+ namespace nn {
12
+
13
+ /// A list of `Module`s that registers its elements.
14
+ ///
15
+ /// \rst
16
+ /// .. code-block:: cpp
17
+ ///
18
+ /// torch::nn::ModuleList mlist(
19
+ /// torch::nn::Linear(3, 4),
20
+ /// torch::nn::BatchNorm1d(4),
21
+ /// torch::nn::Dropout(0.5)
22
+ /// );
23
+ ///
24
+ /// for (const auto &module : *mlist) {
25
+ /// module->pretty_print(std::cout);
26
+ /// }
27
+ ///
28
+ /// \endrst
29
+ ///
30
+ /// Why should you use `ModuleList` instead of a simple `std::vector`? The value
31
+ /// a `ModuleList` provides over manually calling a sequence of modules is that
32
+ /// it allows treating the whole container *as a single module*, such that
33
+ /// performing a transformation on the `ModuleList` applies to each of the
34
+ /// modules it stores (which are each a registered submodule of the
35
+ /// `ModuleList`). For example, calling
36
+ /// `.to(torch::kCUDA)` on a `ModuleList` will move each module in the list to
37
+ /// CUDA memory. For example:
38
+ ///
39
+ /// \rst
40
+ /// .. code-block:: cpp
41
+ ///
42
+ /// torch::nn::ModuleList mlist(
43
+ /// torch::nn::Linear(3, 4),
44
+ /// torch::nn::BatchNorm1d(4),
45
+ /// torch::nn::Dropout(0.5)
46
+ /// );
47
+ ///
48
+ /// // Convert all modules to CUDA.
49
+ /// mlist->to(torch::kCUDA);
50
+ ///
51
+ /// \endrst
52
+ ///
53
+ /// Finally, `ModuleList` provides a lightweight container API, such as allowing
54
+ /// iteration over submodules, positional access, adding a new module after
55
+ /// construction via `push_back`, as well as joining two `ModuleList`s via
56
+ /// `extend`.
57
+ class ModuleListImpl : public Cloneable<ModuleListImpl> {
58
+ public:
59
+ using Iterator = std::vector<std::shared_ptr<Module>>::iterator;
60
+ using ConstIterator = std::vector<std::shared_ptr<Module>>::const_iterator;
61
+
62
+ ModuleListImpl() = default;
63
+
64
+ /// Constructs the `ModuleList` from a variadic list of modules.
65
+ template <typename... Modules>
66
+ explicit ModuleListImpl(Modules&&... modules) {
67
+ modules_.reserve(sizeof...(Modules));
68
+ push_back_var(std::forward<Modules>(modules)...);
69
+ }
70
+
71
+ /// Special cloning function for `ModuleList` because it does not use
72
+ /// `reset()`.
73
+ std::shared_ptr<Module> clone(
74
+ const std::optional<Device>& device = std::nullopt) const override {
75
+ auto clone = std::make_shared<ModuleListImpl>();
76
+ for (const auto& module : modules_) {
77
+ clone->push_back(module->clone(device));
78
+ }
79
+ return clone;
80
+ }
81
+
82
+ /// `reset()` is empty for `ModuleList`, since it does not have parameters of
83
+ /// its own.
84
+ void reset() override {}
85
+
86
+ /// Pretty prints the `ModuleList` module into the given `stream`.
87
+ void pretty_print(std::ostream& stream) const override {
88
+ stream << "torch::nn::ModuleList";
89
+ }
90
+
91
+ void push_back(std::shared_ptr<Module> module) {
92
+ modules_.push_back(std::move(module));
93
+ const auto index = modules_.size() - 1;
94
+ register_module(std::to_string(index), modules_[index]);
95
+ }
96
+
97
+ /// Adds a new `Module` to the `ModuleList` container, moving or copying
98
+ /// it into a `shared_ptr` internally. This method allows passing value types,
99
+ /// and letting the container deal with the boxing.
100
+ template <typename M, typename = torch::detail::enable_if_module_t<M>>
101
+ void push_back(M&& module) {
102
+ using Type = typename std::remove_reference<M>::type;
103
+ push_back(std::make_shared<Type>(std::forward<M>(module)));
104
+ }
105
+
106
+ /// Unwraps the contained module of a `ModuleHolder` and adds it to the
107
+ /// `ModuleList`.
108
+ template <typename M>
109
+ void push_back(const ModuleHolder<M>& module_holder) {
110
+ push_back(module_holder.ptr());
111
+ }
112
+
113
+ /// Iterates over the container and calls `push_back()` on each value.
114
+ template <typename Container>
115
+ void extend(const Container& container) {
116
+ for (const auto& module : container) {
117
+ push_back(module);
118
+ }
119
+ }
120
+
121
+ /// Returns an iterator to the start of the `ModuleList`.
122
+ Iterator begin() {
123
+ return modules_.begin();
124
+ }
125
+
126
+ /// Returns a const iterator to the start of the `ModuleList`.
127
+ ConstIterator begin() const {
128
+ return modules_.begin();
129
+ }
130
+
131
+ /// Returns an iterator to the end of the `ModuleList`.
132
+ Iterator end() {
133
+ return modules_.end();
134
+ }
135
+
136
+ /// Returns a const iterator to the end of the `ModuleList`.
137
+ ConstIterator end() const {
138
+ return modules_.end();
139
+ }
140
+
141
+ /// Attempts to return the module at the given index as the requested type.
142
+ /// Throws an exception if the index is out of bounds or the types do not
143
+ /// match.
144
+ template <typename T>
145
+ T& at(size_t index) {
146
+ static_assert(
147
+ torch::detail::is_module<T>::value,
148
+ "Can only call ModuleList::at with an nn::Module type");
149
+ TORCH_CHECK(index < size(), "Index out of range");
150
+ auto module = modules_[index]->as<T>();
151
+ TORCH_CHECK(
152
+ module,
153
+ "Unable to cast module[",
154
+ index,
155
+ "] to ",
156
+ c10::demangle(typeid(T).name()));
157
+ return *module;
158
+ }
159
+
160
+ /// Attempts to return the module at the given index as the requested type.
161
+ /// Throws an exception if the index is out of bounds or the types do not
162
+ /// match.
163
+ template <typename T>
164
+ const T& at(size_t index) const {
165
+ static_assert(
166
+ torch::detail::is_module<T>::value,
167
+ "Can only call ModuleList::at with an nn::Module type");
168
+ TORCH_CHECK(index < size(), "Index out of range");
169
+ const auto module = modules_[index]->as<T>();
170
+ TORCH_CHECK(
171
+ module,
172
+ "Unable to cast module[",
173
+ index,
174
+ "] to ",
175
+ c10::demangle(typeid(T).name()));
176
+ return *module;
177
+ }
178
+
179
+ /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the
180
+ /// underlying module at the given index. Throws an exception if the index is
181
+ /// out of bounds.
182
+ std::shared_ptr<Module> ptr(size_t index) const {
183
+ TORCH_CHECK(index < size(), "Index out of range");
184
+ return modules_[index];
185
+ }
186
+
187
+ /// Attempts to return a `std::shared_ptr` whose type is the one provided.
188
+ /// Throws an exception if the index is out of bounds or the types do not
189
+ /// match.
190
+ template <typename T>
191
+ std::shared_ptr<T> ptr(size_t index) const {
192
+ static_assert(
193
+ torch::detail::is_module<T>::value,
194
+ "Can only call ModuleList::ptr with an nn::Module type");
195
+ TORCH_CHECK(index < size(), "Index out of range");
196
+ return std::dynamic_pointer_cast<T>(modules_[index]);
197
+ }
198
+
199
+ /// Like `ptr(index)`.
200
+ std::shared_ptr<Module> operator[](size_t index) const {
201
+ // This is the only method we can call without a type.
202
+ return ptr(index);
203
+ }
204
+
205
+ /// The current size of the `ModuleList` container.
206
+ size_t size() const noexcept {
207
+ return modules_.size();
208
+ }
209
+
210
+ /// True if there are no modules in the `ModuleList`.
211
+ bool is_empty() const noexcept {
212
+ return size() == 0;
213
+ }
214
+
215
+ void insert(size_t index, std::shared_ptr<Module> module) {
216
+ TORCH_CHECK(index <= size(), "Index out of range");
217
+
218
+ if (index == size())
219
+ push_back(std::move(module));
220
+ else {
221
+ modules_.insert(
222
+ modules_.begin() + Iterator::difference_type(index),
223
+ std::move(module));
224
+
225
+ for (const auto i : c10::irange(index, size() - 1)) {
226
+ (void)i; // Suppress unused variable warning
227
+ replace_module(std::to_string(index), modules_[index]);
228
+ }
229
+ register_module(std::to_string(size() - 1), modules_.back());
230
+ }
231
+ }
232
+
233
+ /// Unwraps the contained module of a `ModuleHolder` and inserts it in the
234
+ /// `ModuleList`.
235
+ template <typename M>
236
+ void insert(size_t index, const ModuleHolder<M>& module_holder) {
237
+ insert(index, module_holder.ptr());
238
+ }
239
+
240
+ /// inserts a new `Module` to the `ModuleList` container, moving or copying
241
+ /// it into a `shared_ptr` internally. This method allows passing value types,
242
+ /// and letting the container deal with the boxing.
243
+ template <typename M, typename = torch::detail::enable_if_module_t<M>>
244
+ void insert(size_t index, M&& module) {
245
+ using Type = typename std::remove_reference<M>::type;
246
+ insert(index, std::make_shared<Type>(std::forward<M>(module)));
247
+ }
248
+
249
+ private:
250
+ template <typename Head, typename... Tail>
251
+ void push_back_var(Head&& head, Tail&&... tail) {
252
+ push_back(std::forward<Head>(head));
253
+ // Recursively calls this method, until the parameter pack only thas this
254
+ // entry left. Then calls `push_back()` a final time (above).
255
+ push_back_var(std::forward<Tail>(tail)...);
256
+ }
257
+
258
+ /// The base case, when the list of modules is empty.
259
+ void push_back_var() {}
260
+
261
+ // Box the AnyModules to give ModuleList reference semantics, like the rest of
262
+ // the API. Note that this is not required otherwise, this could just be a
263
+ // `vector<AnyModule>`.
264
+ std::vector<std::shared_ptr<Module>> modules_;
265
+ };
266
+
267
+ /// A `ModuleHolder` subclass for `ModuleListImpl`.
268
+ /// See the documentation for `ModuleListImpl` class to learn what methods it
269
+ /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
270
+ /// module storage semantics.
271
+ TORCH_MODULE(ModuleList);
272
+
273
+ } // namespace nn
274
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/detail/static.h>
4
+ #include <torch/nn/module.h>
5
+ #include <torch/nn/modules/container/any.h>
6
+ #include <torch/nn/pimpl.h>
7
+ #include <torch/types.h>
8
+
9
+ #include <torch/csrc/autograd/variable.h>
10
+ #include <torch/csrc/utils/variadic.h>
11
+
12
+ #include <ATen/Device.h>
13
+
14
+ #include <initializer_list>
15
+ #include <memory>
16
+ #include <type_traits>
17
+ #include <typeinfo>
18
+ #include <utility>
19
+ #include <vector>
20
+
21
+ namespace torch {
22
+ namespace nn {
23
+
24
+ /// Stores a type erased `Module` with name.
25
+ ///
26
+ /// The `NamedAnyModule` class enables the following API for constructing
27
+ /// `nn::Sequential` with named submodules:
28
+ /// \rst
29
+ /// .. code-block:: cpp
30
+ ///
31
+ /// struct M : torch::nn::Module {
32
+ /// explicit M(int value_) : value(value_) {}
33
+ /// int value;
34
+ /// int forward() {
35
+ /// return value;
36
+ /// }
37
+ /// };
38
+ ///
39
+ /// Sequential sequential({
40
+ /// {"m1", std::make_shared<M>(1)}, // shared pointer to `Module` is
41
+ /// supported {std::string("m2"), M(2)}, // `Module` is supported
42
+ /// {"linear1", Linear(10, 3)} // `ModuleHolder` is supported
43
+ /// });
44
+ /// \endrst
45
+ class NamedAnyModule {
46
+ public:
47
+ /// Creates a `NamedAnyModule` from a (boxed) `Module`.
48
+ template <typename ModuleType>
49
+ NamedAnyModule(std::string name, std::shared_ptr<ModuleType> module_ptr)
50
+ : NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {}
51
+
52
+ /// Creates a `NamedAnyModule` from a `Module`, moving or copying it
53
+ /// into a `shared_ptr` internally.
54
+ // NOTE: We need to use `std::remove_reference<M>::type` to get rid of
55
+ // any reference components for make_unique.
56
+ template <typename M, typename = torch::detail::enable_if_module_t<M>>
57
+ NamedAnyModule(std::string name, M&& module)
58
+ : NamedAnyModule(
59
+ std::move(name),
60
+ std::make_shared<typename std::remove_reference<M>::type>(
61
+ std::forward<M>(module))) {}
62
+
63
+ /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from
64
+ /// a `ModuleHolder`.
65
+ template <typename M>
66
+ NamedAnyModule(std::string name, const ModuleHolder<M>& module_holder)
67
+ : NamedAnyModule(std::move(name), module_holder.ptr()) {}
68
+
69
+ /// Creates a `NamedAnyModule` from a type-erased `AnyModule`.
70
+ NamedAnyModule(std::string name, AnyModule any_module)
71
+ : name_(std::move(name)), module_(std::move(any_module)) {}
72
+
73
+ /// Returns a reference to the name.
74
+ const std::string& name() const noexcept {
75
+ return name_;
76
+ }
77
+
78
+ /// Returns a reference to the module.
79
+ AnyModule& module() noexcept {
80
+ return module_;
81
+ }
82
+
83
+ /// Returns a const reference to the module.
84
+ const AnyModule& module() const noexcept {
85
+ return module_;
86
+ }
87
+
88
+ private:
89
+ std::string name_;
90
+ AnyModule module_;
91
+ };
92
+
93
+ } // namespace nn
94
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/nn/cloneable.h>
4
+ #include <torch/nn/pimpl.h>
5
+ #include <torch/ordered_dict.h>
6
+ #include <utility>
7
+ #include <vector>
8
+
9
+ namespace torch {
10
+ namespace nn {
11
+
12
+ class ParameterDictImpl : public Cloneable<ParameterDictImpl> {
13
+ public:
14
+ using Iterator = OrderedDict<std::string, Tensor>::Iterator;
15
+ using ConstIterator = OrderedDict<std::string, Tensor>::ConstIterator;
16
+
17
+ ParameterDictImpl() = default;
18
+
19
+ explicit ParameterDictImpl(
20
+ const torch::OrderedDict<std::string, torch::Tensor>& params) {
21
+ parameters_ = params;
22
+ }
23
+
24
+ /// `reset()` is empty for `ParameterDict`, since it does not have
25
+ /// parameters of its own.
26
+ void reset() override {}
27
+
28
+ /// Pretty prints the `ParameterDict` module into the given `stream`.
29
+ void pretty_print(std::ostream& stream) const override {
30
+ stream << "torch::nn::ParameterDict(" << std::endl;
31
+ for (const auto& pair : parameters_) {
32
+ stream << "(" << pair.key() << ")"
33
+ << ": Parameter containing: [" << pair.value().scalar_type()
34
+ << " of size " << pair.value().sizes() << "]";
35
+ ;
36
+ stream << std::endl;
37
+ }
38
+ stream << ")";
39
+ }
40
+
41
+ /// Insert the parameter along with the key into ParameterDict
42
+ /// The parameter is set to be require grad by default
43
+ Tensor& insert(std::string key, Tensor param) {
44
+ bool requires_grad = param.requires_grad();
45
+ return register_parameter(std::move(key), std::move(param), requires_grad);
46
+ }
47
+
48
+ /// Remove key from the ParameterDict and return its value, throw exception
49
+ /// if the key is not contained. Please check contains(key) before for a
50
+ /// non-throwing access.
51
+ Tensor pop(const std::string& key) {
52
+ torch::Tensor v = parameters_[key];
53
+ parameters_.erase(key);
54
+ return v;
55
+ }
56
+
57
+ /// Return the keys in the dict
58
+ ::std::vector<std::string> keys() const {
59
+ return parameters_.keys();
60
+ }
61
+
62
+ /// Return the Values in the dict
63
+ ::std::vector<torch::Tensor> values() const {
64
+ return parameters_.values();
65
+ }
66
+
67
+ /// Return an iterator to the start of ParameterDict
68
+ Iterator begin() {
69
+ return parameters_.begin();
70
+ }
71
+
72
+ /// Return a const iterator to the start of ParameterDict
73
+ ConstIterator begin() const {
74
+ return parameters_.begin();
75
+ }
76
+
77
+ /// Return an iterator to the end of ParameterDict
78
+ Iterator end() {
79
+ return parameters_.end();
80
+ }
81
+
82
+ /// Return a const iterator to the end of ParameterDict
83
+ ConstIterator end() const {
84
+ return parameters_.end();
85
+ }
86
+
87
+ /// Return the number of items currently stored in the ParameterDict
88
+ size_t size() const noexcept {
89
+ return parameters_.size();
90
+ }
91
+
92
+ /// Return true if the ParameterDict is empty, otherwise return false
93
+ bool empty() const noexcept {
94
+ return parameters_.is_empty();
95
+ }
96
+
97
+ /// Update the ParameterDict with the key-value pairs from
98
+ /// another ParameterDict, overwriting existing key
99
+ template <typename Container>
100
+ void update(const Container& container) {
101
+ for (auto& item : container) {
102
+ parameters_[item.key()] = item.value();
103
+ }
104
+ }
105
+
106
+ /// Remove all parameters in the ParameterDict
107
+ void clear() {
108
+ parameters_.clear();
109
+ }
110
+
111
+ /// Check if the centain parameter with the key in the ParameterDict
112
+ bool contains(const std::string& key) const noexcept {
113
+ return parameters_.contains(key);
114
+ }
115
+
116
+ /// Returns the value associated with the given `key`. Throws an exception if
117
+ /// no such key is stored in the `ParameterDict`. Check contains(key) before
118
+ /// for a non-throwing way of access
119
+ const Tensor& get(const std::string& key) const {
120
+ return parameters_[key];
121
+ }
122
+
123
+ /// Returns the value associated with the given `key`. Throws an exception if
124
+ /// no such key is stored in the `ParameterDict`. Check contains(key) before
125
+ /// for a non-throwing way of access
126
+ Tensor& get(const std::string& key) {
127
+ return parameters_[key];
128
+ }
129
+
130
+ /// Returns the value associated with the given `key`. Throws an exception if
131
+ /// no such key is stored in the `ParameterDict`. Check contains(key) before
132
+ /// for a non-throwing way of access
133
+ Tensor& operator[](const std::string& key) {
134
+ return parameters_[key];
135
+ }
136
+
137
+ /// Returns the value associated with the given `key`. Throws an exception if
138
+ /// no such key is stored in the `ParameterDict`. Check contains(key) before
139
+ /// for a non-throwing way of access
140
+ const Tensor& operator[](const std::string& key) const {
141
+ return parameters_[key];
142
+ }
143
+ };
144
+
145
+ TORCH_MODULE(ParameterDict);
146
+
147
+ } // namespace nn
148
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/nn/cloneable.h>
4
+ #include <torch/nn/module.h>
5
+
6
+ #include <vector>
7
+
8
+ namespace torch {
9
+ namespace nn {
10
+ class ParameterListImpl : public Cloneable<ParameterListImpl> {
11
+ public:
12
+ using Iterator = typename std::vector<
13
+ OrderedDict<std::string, torch::Tensor>::Item>::iterator;
14
+ using ConstIterator = typename std::vector<
15
+ OrderedDict<std::string, torch::Tensor>::Item>::const_iterator;
16
+
17
+ ParameterListImpl() = default;
18
+
19
+ /// Constructs the `ParameterList` from a variadic list of ParameterList.
20
+ template <typename... Tensors>
21
+ explicit ParameterListImpl(Tensors&&... params) {
22
+ parameters_.reserve(sizeof...(Tensors));
23
+ push_back_var(std::forward<Tensors>(params)...);
24
+ }
25
+
26
+ template <typename... Tensors>
27
+ explicit ParameterListImpl(const Tensors&... params) {
28
+ parameters_.reserve(sizeof...(Tensors));
29
+ push_back_var(std::forward<Tensors>(params)...);
30
+ }
31
+
32
+ /// `reset()` is empty for `ParameterList`, since it does not have parameters
33
+ /// of its own.
34
+ void reset() override {}
35
+
36
+ /// Pretty prints the `ParameterList` module into the given `stream`.
37
+ void pretty_print(std::ostream& stream) const override {
38
+ stream << "torch::nn::ParameterList(" << std::endl;
39
+ for (const auto& pair : parameters_) {
40
+ stream << "(" << pair.key() << ")"
41
+ << ": Parameter containing: [" << pair.value().scalar_type()
42
+ << " of size " << pair.value().sizes() << "]";
43
+ ;
44
+ stream << std::endl;
45
+ }
46
+ stream << ")";
47
+ }
48
+
49
+ /// push the a given parameter at the end of the list
50
+ void append(torch::Tensor&& param) {
51
+ bool requires_grad = param.requires_grad();
52
+ register_parameter(
53
+ std::to_string(parameters_.size()), std::move(param), requires_grad);
54
+ }
55
+
56
+ /// push the a given parameter at the end of the list
57
+ void append(const torch::Tensor& param) {
58
+ bool requires_grad = param.requires_grad();
59
+ register_parameter(
60
+ std::to_string(parameters_.size()), param, requires_grad);
61
+ }
62
+
63
+ /// push the a given parameter at the end of the list
64
+ /// And the key of the pair will be discarded, only the value
65
+ /// will be added into the `ParameterList`
66
+ void append(const OrderedDict<std::string, torch::Tensor>::Item& pair) {
67
+ register_parameter(
68
+ std::to_string(parameters_.size()),
69
+ pair.value(),
70
+ pair.value().requires_grad());
71
+ }
72
+
73
+ /// extend parameters from a container to the end of the list
74
+ template <typename Container>
75
+ void extend(const Container& container) {
76
+ for (const auto& param : container) {
77
+ append(param);
78
+ }
79
+ }
80
+
81
+ /// Returns an iterator to the start of the ParameterList
82
+ /// the iterator returned will be type of `OrderedDict<std::string,
83
+ /// torch::Tensor>::Item`
84
+ Iterator begin() {
85
+ return parameters_.begin();
86
+ }
87
+
88
+ /// Returns a const iterator to the start of the ParameterList
89
+ /// the iterator returned will be type of `OrderedDict<std::string,
90
+ /// torch::Tensor>::Item`
91
+ ConstIterator begin() const {
92
+ return parameters_.begin();
93
+ }
94
+
95
+ /// Returns an iterator to the end of the ParameterList
96
+ /// the iterator returned will be type of `OrderedDict<std::string,
97
+ /// torch::Tensor>::Item`
98
+ Iterator end() {
99
+ return parameters_.end();
100
+ }
101
+
102
+ /// Returns a const iterator to the end of the ParameterList
103
+ /// the iterator returned will be type of `OrderedDict<std::string,
104
+ /// torch::Tensor>::Item`
105
+ ConstIterator end() const {
106
+ return parameters_.end();
107
+ }
108
+
109
+ /// Returns the value associated with the given `key`. Throws an exception if
110
+ /// no such key is stored in the `ParameterList`. Check contains(key) before
111
+ /// for a non-throwing way of access
112
+ at::Tensor& at(size_t idx) {
113
+ TORCH_CHECK(idx < size(), "Index out of range");
114
+ return parameters_[std::to_string(idx)];
115
+ }
116
+
117
+ /// Returns the value associated with the given `key`. Throws an exception if
118
+ /// no such key is stored in the `ParameterList`. Check contains(key) before
119
+ /// for a non-throwing way of access
120
+ const at::Tensor& at(size_t idx) const {
121
+ TORCH_CHECK(idx < size(), "Index out of range");
122
+ return parameters_[std::to_string(idx)];
123
+ }
124
+
125
+ /// Returns the value associated with the given `key`. Throws an exception if
126
+ /// no such key is stored in the `ParameterList`. Check contains(key) before
127
+ /// for a non-throwing way of access
128
+ at::Tensor& operator[](size_t idx) {
129
+ return at(idx);
130
+ }
131
+
132
+ /// Returns the value associated with the given `key`. Throws an exception if
133
+ /// no such key is stored in the `ParameterList`. Check contains(key) before
134
+ /// for a non-throwing way of access
135
+ const at::Tensor& operator[](size_t idx) const {
136
+ return at(idx);
137
+ }
138
+
139
+ /// Return the size of the ParameterList
140
+ size_t size() const noexcept {
141
+ return parameters_.size();
142
+ }
143
+ /// True if the ParameterList is empty
144
+ bool is_empty() const noexcept {
145
+ return parameters_.is_empty();
146
+ }
147
+
148
+ /// Overload the +=, so that two ParameterList could be incrementally added
149
+ template <typename Container>
150
+ Container& operator+=(const Container& other) {
151
+ extend(other);
152
+ return *this;
153
+ }
154
+
155
+ private:
156
+ template <typename Head, typename... Tail>
157
+ void push_back_var(Head&& head, Tail&&... tail) {
158
+ append(std::forward<Head>(head));
159
+ // Recursively calls this method, until the parameter pack only thas this
160
+ // entry left. Then calls `push_back()` a final time (above).
161
+ push_back_var(std::forward<Tail>(tail)...);
162
+ }
163
+
164
+ /// The base case, when the list of modules is empty.
165
+ void push_back_var() {}
166
+ };
167
+ TORCH_MODULE(ParameterList);
168
+ } // namespace nn
169
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/detail/static.h>
4
+ #include <torch/nn/cloneable.h>
5
+ #include <torch/nn/module.h>
6
+ #include <torch/nn/modules/container/any.h>
7
+ #include <torch/nn/modules/container/named_any.h>
8
+ #include <torch/nn/pimpl.h>
9
+ #include <torch/types.h>
10
+
11
+ #include <c10/util/Exception.h>
12
+
13
+ #include <cstdint>
14
+ #include <memory>
15
+ #include <ostream>
16
+ #include <string>
17
+ #include <type_traits>
18
+ #include <utility>
19
+ #include <vector>
20
+
21
+ namespace torch {
22
+ namespace nn {
23
+
24
+ /// A list of `Module`s that acts as a `Module` itself.
25
+ ///
26
+ /// A `Sequential` is fundamentally a list of `Module`s, each with a `forward()`
27
+ /// method. `Sequential` provides a `forward()` method of its own, which accepts
28
+ /// any input and forwards it to the first module it stores. It then "chains"
29
+ /// outputs to inputs sequentially for each subsequent module, finally returning
30
+ /// the output of the last module. For example:
31
+ ///
32
+ /// \rst
33
+ /// .. code-block:: cpp
34
+ ///
35
+ /// torch::nn::Sequential seq(
36
+ /// torch::nn::Linear(3, 4),
37
+ /// torch::nn::BatchNorm1d(4),
38
+ /// torch::nn::Dropout(0.5)
39
+ /// );
40
+ ///
41
+ /// auto output = seq->forward(torch::ones(3));
42
+ ///
43
+ /// \endrst
44
+ ///
45
+ /// This can conceptually be thought of as the following loop (using Python as
46
+ /// pseudocode):
47
+ ///
48
+ /// \rst
49
+ /// .. code-block:: python
50
+ ///
51
+ /// def forward(sequential, input):
52
+ /// for module in sequential:
53
+ /// input = module(input)
54
+ /// return input
55
+ ///
56
+ /// \endrst
57
+ ///
58
+ /// Why should you use `Sequential` instead of a simple `std::vector`? The value
59
+ /// a `Sequential` provides over manually calling a sequence of modules is that
60
+ /// it allows treating the whole container *as a single module*, such that
61
+ /// performing a transformation on the `Sequential` applies to each of the
62
+ /// modules it stores (which are each a registered submodule of the
63
+ /// `Sequential`). For example, calling
64
+ /// `.to(torch::kCUDA)` on a `Sequential` will move each module in the list to
65
+ /// CUDA memory. For example:
66
+ ///
67
+ /// \rst
68
+ /// .. code-block:: cpp
69
+ ///
70
+ /// torch::nn::Sequential seq(
71
+ /// torch::nn::Linear(3, 4),
72
+ /// torch::nn::BatchNorm1d(4),
73
+ /// torch::nn::Dropout(0.5)
74
+ /// );
75
+ ///
76
+ /// // Convert all modules to CUDA.
77
+ /// seq->to(torch::kCUDA);
78
+ ///
79
+ /// \endrst
80
+ ///
81
+ /// Finally, `Sequential` provides a lightweight container API, such as allowing
82
+ /// iteration over submodules, positional access, adding a new module after
83
+ /// construction via `push_back`, as well as joining two `Sequential`s via
84
+ /// `extend`.
85
+ ///
86
+ /// \rst
87
+ /// .. attention::
88
+ /// One current limitation of `Sequential` is that all except the first module
89
+ /// must accept a single argument. If your modules need to take multiple
90
+ /// arguments, you should define them to take and return tuples.
91
+ /// \endrst
92
+ class SequentialImpl : public Cloneable<SequentialImpl> {
93
+ public:
94
+ using Iterator = std::vector<AnyModule>::iterator;
95
+ using ConstIterator = std::vector<AnyModule>::const_iterator;
96
+
97
+ SequentialImpl() = default;
98
+
99
+ /// Constructs the `Sequential` from a variadic list of modules.
100
+ template <typename... Modules>
101
+ explicit SequentialImpl(Modules&&... modules) {
102
+ modules_.reserve(sizeof...(Modules));
103
+ push_back(std::forward<Modules>(modules)...);
104
+ }
105
+
106
+ /// Constructs the `Sequential` from an `OrderedDict` of named `AnyModule`s.
107
+ explicit SequentialImpl(
108
+ torch::OrderedDict<std::string, AnyModule>&& ordered_dict) {
109
+ modules_.reserve(ordered_dict.size());
110
+ for (auto& item : ordered_dict) {
111
+ push_back(item.key(), std::move(item.value()));
112
+ }
113
+ }
114
+
115
+ /// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s.
116
+ /// It enables the following use case:
117
+ /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})`
118
+ explicit SequentialImpl(std::initializer_list<NamedAnyModule> named_modules) {
119
+ modules_.reserve(named_modules.size());
120
+ for (const auto& named_module : named_modules) {
121
+ push_back(named_module.name(), named_module.module());
122
+ }
123
+ }
124
+
125
+ /// Special cloning function for `Sequential` because it does not use
126
+ /// `reset()`.
127
+ std::shared_ptr<Module> clone(
128
+ const std::optional<Device>& device = std::nullopt) const override {
129
+ auto clone = std::make_shared<SequentialImpl>();
130
+ for (const auto& module : modules_) {
131
+ clone->push_back(module.clone(device));
132
+ }
133
+ return clone;
134
+ }
135
+
136
+ /// `reset()` is empty for `Sequential`, since it does not have parameters of
137
+ /// its own.
138
+ void reset() override {}
139
+
140
+ /// Pretty prints the `Sequential` module into the given `stream`.
141
+ void pretty_print(std::ostream& stream) const override {
142
+ stream << "torch::nn::Sequential";
143
+ }
144
+
145
+ /// Feeds `inputs` to the first module and then chains outputs to inputs,
146
+ /// returning the last output.
147
+ ///
148
+ /// Conceptually the following loop in Python:
149
+ ///
150
+ /// \rst
151
+ /// .. code-block:: python
152
+ ///
153
+ /// def forward(sequential, input):
154
+ /// for module in sequential:
155
+ /// input = module(input)
156
+ /// return input
157
+ ///
158
+ /// \endrst
159
+ ///
160
+ /// The return type is taken as the first template parameter. It defaults to
161
+ /// `Tensor`. If the last module in the `Sequential` returns another type `T`,
162
+ /// you should call `forward<T>(inputs)` instead of just `forward(inputs)`:
163
+ ///
164
+ /// \rst
165
+ /// .. code-block:: cpp
166
+ ///
167
+ /// torch::Tensor tensor = sequential1->forward(inputs);
168
+ /// int integer = sequential2->forward<int>(inputs);
169
+ /// float value = sequential3->forward<float>(inputs);
170
+ ///
171
+ /// \endrst
172
+ template <typename ReturnType = Tensor, typename... InputTypes>
173
+ ReturnType forward(InputTypes&&... inputs) {
174
+ TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");
175
+
176
+ auto iterator = modules_.begin();
177
+ auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...);
178
+
179
+ for (++iterator; iterator != modules_.end(); ++iterator) {
180
+ input = iterator->any_forward(std::move(input));
181
+ }
182
+
183
+ // Check the return value and give a nice error message if the requested
184
+ // return type was incorrect.
185
+ if (auto* return_value = input.template try_get<ReturnType>()) {
186
+ return std::move(*return_value);
187
+ }
188
+ AT_ERROR(
189
+ "The type of the return value is ",
190
+ c10::demangle(input.type_info().name()),
191
+ ", but you asked for type ",
192
+ c10::demangle(typeid(ReturnType).name()));
193
+ }
194
+
195
+ /// Adds a new (boxed) `Module` to the `Sequential` container.
196
+ template <typename ModuleType>
197
+ void push_back(std::shared_ptr<ModuleType> module_ptr) {
198
+ push_back(std::to_string(modules_.size()), std::move(module_ptr));
199
+ }
200
+
201
+ /// Adds a new named (boxed) `Module` to the `Sequential` container.
202
+ template <typename ModuleType>
203
+ void push_back(std::string name, std::shared_ptr<ModuleType> module_ptr) {
204
+ push_back(std::move(name), AnyModule(std::move(module_ptr)));
205
+ }
206
+
207
+ /// Adds a new `Module` to the `Sequential` container, moving or copying it
208
+ /// into a `shared_ptr` internally. This method allows passing value types,
209
+ /// and letting the container deal with the boxing. This means you can write
210
+ /// `Sequential(Module(3, 4))` instead of
211
+ /// `Sequential(std::make_shared<Module>(3, 4))`.
212
+ template <typename M, typename = torch::detail::enable_if_module_t<M>>
213
+ void push_back(M&& module) {
214
+ push_back(std::to_string(modules_.size()), std::forward<M>(module));
215
+ }
216
+
217
+ /// Adds a new named `Module` to the `Sequential` container, moving or copying
218
+ /// it into a `shared_ptr` internally. This method allows passing value types,
219
+ /// and letting the container deal with the boxing.
220
+ template <typename M, typename = torch::detail::enable_if_module_t<M>>
221
+ void push_back(std::string name, M&& module) {
222
+ using Type = typename std::remove_reference_t<M>;
223
+ push_back(std::move(name), std::make_shared<Type>(std::forward<M>(module)));
224
+ }
225
+
226
+ /// Unwraps the contained module of a `ModuleHolder` and adds it to the
227
+ /// `Sequential`.
228
+ template <typename M>
229
+ void push_back(const ModuleHolder<M>& module_holder) {
230
+ push_back(std::to_string(modules_.size()), module_holder);
231
+ }
232
+
233
+ /// Unwraps the contained named module of a `ModuleHolder` and adds it to the
234
+ /// `Sequential`.
235
+ template <typename M>
236
+ void push_back(std::string name, const ModuleHolder<M>& module_holder) {
237
+ push_back(std::move(name), module_holder.ptr());
238
+ }
239
+
240
+ /// Iterates over the container and calls `push_back()` on each value.
241
+ template <typename Container>
242
+ void extend(const Container& container) {
243
+ for (const auto& module : container) {
244
+ push_back(module);
245
+ }
246
+ }
247
+
248
+ /// Adds a type-erased `AnyModule` to the `Sequential`.
249
+ void push_back(AnyModule any_module) {
250
+ push_back(std::to_string(modules_.size()), std::move(any_module));
251
+ }
252
+
253
+ void push_back(std::string name, AnyModule any_module) {
254
+ modules_.push_back(std::move(any_module));
255
+ const auto index = modules_.size() - 1;
256
+ register_module(std::move(name), modules_[index].ptr());
257
+ }
258
+
259
+ /// Returns an iterator to the start of the `Sequential`.
260
+ Iterator begin() {
261
+ return modules_.begin();
262
+ }
263
+
264
+ /// Returns a const iterator to the start of the `Sequential`.
265
+ ConstIterator begin() const {
266
+ return modules_.begin();
267
+ }
268
+
269
+ /// Returns an iterator to the end of the `Sequential`.
270
+ Iterator end() {
271
+ return modules_.end();
272
+ }
273
+
274
+ /// Returns a const iterator to the end of the `Sequential`.
275
+ ConstIterator end() const {
276
+ return modules_.end();
277
+ }
278
+
279
+ /// Attempts to return the module at the given index as the requested type.
280
+ /// Throws an exception if the index is out of bounds or the types do not
281
+ /// match.
282
+ template <typename T>
283
+ T& at(size_t index) {
284
+ static_assert(
285
+ torch::detail::is_module<T>::value,
286
+ "Can only call Sequential::at with an nn::Module type");
287
+ TORCH_CHECK(index < size(), "Index out of range");
288
+ return modules_[index].get<T>();
289
+ }
290
+
291
+ /// Attempts to return the module at the given index as the requested type.
292
+ /// Throws an exception if the index is out of bounds or the types do not
293
+ /// match.
294
+ template <typename T>
295
+ const T& at(size_t index) const {
296
+ static_assert(
297
+ torch::detail::is_module<T>::value,
298
+ "Can only call Sequential::at with an nn::Module type");
299
+ TORCH_CHECK(index < size(), "Index out of range");
300
+ return modules_[index].get<T>();
301
+ }
302
+
303
+ /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the
304
+ /// underlying module at the given index. Throws an exception if the index is
305
+ /// out of bounds.
306
+ std::shared_ptr<Module> ptr(size_t index) const {
307
+ TORCH_CHECK(index < size(), "Index out of range");
308
+ return modules_[index].ptr();
309
+ }
310
+
311
+ /// Attempts to return a `std::shared_ptr` whose type is the one provided.
312
+ /// Throws an exception if the index is out of bounds or the types do not
313
+ /// match.
314
+ template <typename T>
315
+ std::shared_ptr<T> ptr(size_t index) const {
316
+ static_assert(
317
+ torch::detail::is_module<T>::value,
318
+ "Can only call Sequential::ptr with an nn::Module type");
319
+ TORCH_CHECK(index < size(), "Index out of range");
320
+ return modules_[index].ptr<T>();
321
+ }
322
+
323
+ /// Like `ptr(index)`.
324
+ std::shared_ptr<Module> operator[](size_t index) const {
325
+ // This is the only method we can call without a type.
326
+ return ptr(index);
327
+ }
328
+
329
+ /// The current size of the `Sequential` container.
330
+ size_t size() const noexcept {
331
+ return modules_.size();
332
+ }
333
+
334
+ /// True if there are no modules in the `Sequential`.
335
+ bool is_empty() const noexcept {
336
+ return size() == 0;
337
+ }
338
+
339
+ private:
340
+ /// Takes a First *and* Second parameter, to avoid ambiguity when a parameter
341
+ /// pack has only one type, in which case the template would be preferred,
342
+ /// even if the other `push_back` functions are better fits (e.g. `unique_ptr`
343
+ /// -> `shared_ptr` overload).
344
+ /// NOTE: We explicitly avoid matching this template with
345
+ /// `push_back(std::string("name"), module)` or `push_back("name", module)`,
346
+ /// since they should be handled by their respective `push_back` functions.
347
+ template <
348
+ typename First,
349
+ typename Second,
350
+ typename... Rest,
351
+ typename = std::enable_if_t<
352
+ !std::is_same_v<First, std::string> &&
353
+ // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
354
+ !std::is_same_v<std::decay_t<First>, std::decay_t<const char (&)[]>>>>
355
+ void push_back(First&& first, Second&& second, Rest&&... rest) {
356
+ push_back(std::forward<First>(first));
357
+ // Recursively calls this method, until the parameter pack only thas this
358
+ // entry left. Then calls `push_back()` a final time (above).
359
+ push_back(std::forward<Second>(second), std::forward<Rest>(rest)...);
360
+ }
361
+
362
+ /// The base case, when the list of modules is empty.
363
+ void push_back() {}
364
+
365
+ // Box the AnyModules to give Sequential reference semantics, like the rest of
366
+ // the API. Note that this is not required otherwise, this could just be a
367
+ // `vector<AnyModule>`.
368
+ std::vector<AnyModule> modules_;
369
+ };
370
+
371
+ /// A `ModuleHolder` subclass for `SequentialImpl`.
372
+ /// See the documentation for `SequentialImpl` class to learn what methods it
373
+ /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
374
+ /// module storage semantics.
375
+ class Sequential : public torch::nn::ModuleHolder<SequentialImpl> {
376
+ public:
377
+ using torch::nn::ModuleHolder<SequentialImpl>::ModuleHolder;
378
+
379
+ Sequential() : ModuleHolder() {}
380
+
381
+ /// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s.
382
+ /// It enables the following use case:
383
+ /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})`
384
+ Sequential(std::initializer_list<NamedAnyModule> named_modules)
385
+ : ModuleHolder(std::make_shared<SequentialImpl>(named_modules)) {}
386
+ };
387
+ } // namespace nn
388
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/nn/cloneable.h>
4
+ #include <torch/nn/options/dropout.h>
5
+ #include <torch/nn/pimpl.h>
6
+ #include <torch/types.h>
7
+
8
+ #include <torch/csrc/Export.h>
9
+
10
+ #include <cstddef>
11
+ #include <vector>
12
+
13
+ namespace torch {
14
+ namespace nn {
15
+
16
+ namespace detail {
17
+
18
+ template <typename Derived>
19
+ class _DropoutNd : public torch::nn::Cloneable<Derived> {
20
+ public:
21
+ _DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)){};
22
+
23
+ explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) {
24
+ // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
25
+ reset();
26
+ }
27
+
28
+ void reset() override {
29
+ TORCH_CHECK(
30
+ options.p() >= 0. && options.p() <= 1.,
31
+ "dropout probability has to be between 0 and 1, but got ",
32
+ options.p());
33
+ }
34
+
35
+ /// The options with which this `Module` was constructed.
36
+ DropoutOptions options;
37
+ };
38
+
39
+ } // namespace detail
40
+
41
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
42
+
43
+ /// Applies dropout over a 1-D input.
44
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout to learn
45
+ /// about the exact behavior of this module.
46
+ ///
47
+ /// See the documentation for `torch::nn::DropoutOptions` class to learn what
48
+ /// constructor arguments are supported for this module.
49
+ ///
50
+ /// Example:
51
+ /// ```
52
+ /// Dropout model(DropoutOptions().p(0.42).inplace(true));
53
+ /// ```
54
+ class TORCH_API DropoutImpl : public detail::_DropoutNd<DropoutImpl> {
55
+ public:
56
+ using detail::_DropoutNd<DropoutImpl>::_DropoutNd;
57
+
58
+ Tensor forward(Tensor input);
59
+
60
+ /// Pretty prints the `Dropout` module into the given `stream`.
61
+ void pretty_print(std::ostream& stream) const override;
62
+ };
63
+
64
+ /// A `ModuleHolder` subclass for `DropoutImpl`.
65
+ /// See the documentation for `DropoutImpl` class to learn what methods it
66
+ /// provides, and examples of how to use `Dropout` with
67
+ /// `torch::nn::DropoutOptions`. See the documentation for `ModuleHolder` to
68
+ /// learn about PyTorch's module storage semantics.
69
+ TORCH_MODULE(Dropout);
70
+
71
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72
+
73
+ /// Applies dropout over a 2-D input.
74
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout2d to learn
75
+ /// about the exact behavior of this module.
76
+ ///
77
+ /// See the documentation for `torch::nn::Dropout2dOptions` class to learn what
78
+ /// constructor arguments are supported for this module.
79
+ ///
80
+ /// Example:
81
+ /// ```
82
+ /// Dropout2d model(Dropout2dOptions().p(0.42).inplace(true));
83
+ /// ```
84
+ class TORCH_API Dropout2dImpl : public detail::_DropoutNd<Dropout2dImpl> {
85
+ public:
86
+ using detail::_DropoutNd<Dropout2dImpl>::_DropoutNd;
87
+
88
+ Tensor forward(Tensor input);
89
+
90
+ /// Pretty prints the `Dropout2d` module into the given `stream`.
91
+ void pretty_print(std::ostream& stream) const override;
92
+ };
93
+
94
+ /// A `ModuleHolder` subclass for `Dropout2dImpl`.
95
+ /// See the documentation for `Dropout2dImpl` class to learn what methods it
96
+ /// provides, and examples of how to use `Dropout2d` with
97
+ /// `torch::nn::Dropout2dOptions`. See the documentation for `ModuleHolder` to
98
+ /// learn about PyTorch's module storage semantics.
99
+ TORCH_MODULE(Dropout2d);
100
+
101
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102
+
103
+ /// Applies dropout over a 3-D input.
104
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout3d to learn
105
+ /// about the exact behavior of this module.
106
+ ///
107
+ /// See the documentation for `torch::nn::Dropout3dOptions` class to learn what
108
+ /// constructor arguments are supported for this module.
109
+ ///
110
+ /// Example:
111
+ /// ```
112
+ /// Dropout3d model(Dropout3dOptions().p(0.42).inplace(true));
113
+ /// ```
114
+ class TORCH_API Dropout3dImpl : public detail::_DropoutNd<Dropout3dImpl> {
115
+ public:
116
+ using detail::_DropoutNd<Dropout3dImpl>::_DropoutNd;
117
+
118
+ Tensor forward(Tensor input);
119
+
120
+ /// Pretty prints the `Dropout3d` module into the given `stream`.
121
+ void pretty_print(std::ostream& stream) const override;
122
+ };
123
+
124
+ /// A `ModuleHolder` subclass for `Dropout3dImpl`.
125
+ /// See the documentation for `Dropout3dImpl` class to learn what methods it
126
+ /// provides, and examples of how to use `Dropout3d` with
127
+ /// `torch::nn::Dropout3dOptions`. See the documentation for `ModuleHolder` to
128
+ /// learn about PyTorch's module storage semantics.
129
+ TORCH_MODULE(Dropout3d);
130
+
131
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
132
+
133
+ /// Applies Alpha Dropout over the input.
134
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.AlphaDropout to learn
135
+ /// about the exact behavior of this module.
136
+ ///
137
+ /// See the documentation for `torch::nn::AlphaDropoutOptions` class to learn
138
+ /// what constructor arguments are supported for this module.
139
+ ///
140
+ /// Example:
141
+ /// ```
142
+ /// AlphaDropout model(AlphaDropoutOptions(0.2).inplace(true));
143
+ /// ```
144
+ class TORCH_API AlphaDropoutImpl : public detail::_DropoutNd<AlphaDropoutImpl> {
145
+ public:
146
+ using detail::_DropoutNd<AlphaDropoutImpl>::_DropoutNd;
147
+
148
+ Tensor forward(const Tensor& input);
149
+
150
+ /// Pretty prints the `AlphaDropout` module into the given `stream`.
151
+ void pretty_print(std::ostream& stream) const override;
152
+ };
153
+
154
+ /// A `ModuleHolder` subclass for `AlphaDropoutImpl`.
155
+ /// See the documentation for `AlphaDropoutImpl` class to learn what methods it
156
+ /// provides, and examples of how to use `AlphaDropout` with
157
+ /// `torch::nn::AlphaDropoutOptions`. See the documentation for `ModuleHolder`
158
+ /// to learn about PyTorch's module storage semantics.
159
+ TORCH_MODULE(AlphaDropout);
160
+
161
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout
162
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
163
+
164
+ /// See the documentation for `torch::nn::FeatureAlphaDropoutOptions` class to
165
+ /// learn what constructor arguments are supported for this module.
166
+ ///
167
+ /// Example:
168
+ /// ```
169
+ /// FeatureAlphaDropout model(FeatureAlphaDropoutOptions(0.2).inplace(true));
170
+ /// ```
171
+ class TORCH_API FeatureAlphaDropoutImpl
172
+ : public detail::_DropoutNd<FeatureAlphaDropoutImpl> {
173
+ public:
174
+ using detail::_DropoutNd<FeatureAlphaDropoutImpl>::_DropoutNd;
175
+
176
+ Tensor forward(const Tensor& input);
177
+
178
+ /// Pretty prints the `FeatureAlphaDropout` module into the given `stream`.
179
+ void pretty_print(std::ostream& stream) const override;
180
+ };
181
+
182
+ /// A `ModuleHolder` subclass for `FeatureAlphaDropoutImpl`.
183
+ /// See the documentation for `FeatureAlphaDropoutImpl` class to learn what
184
+ /// methods it provides, and examples of how to use `FeatureAlphaDropout` with
185
+ /// `torch::nn::FeatureAlphaDropoutOptions`. See the documentation for
186
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
187
+ TORCH_MODULE(FeatureAlphaDropout);
188
+
189
+ } // namespace nn
190
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/fold.h ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/expanding_array.h>
4
+ #include <torch/nn/cloneable.h>
5
+ #include <torch/nn/functional/fold.h>
6
+ #include <torch/nn/options/fold.h>
7
+ #include <torch/nn/pimpl.h>
8
+ #include <torch/types.h>
9
+
10
+ namespace torch {
11
+ namespace nn {
12
+
13
+ /// Applies fold over a 3-D input.
14
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.Fold to learn about
15
+ /// the exact behavior of this module.
16
+ ///
17
+ /// See the documentation for `torch::nn::FoldOptions` class to learn what
18
+ /// constructor arguments are supported for this module.
19
+ ///
20
+ /// Example:
21
+ /// ```
22
+ /// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2,
23
+ /// 1}).stride(2));
24
+ /// ```
25
+ class TORCH_API FoldImpl : public torch::nn::Cloneable<FoldImpl> {
26
+ public:
27
+ FoldImpl(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size)
28
+ : FoldImpl(FoldOptions(output_size, kernel_size)) {}
29
+ explicit FoldImpl(const FoldOptions& options_);
30
+
31
+ void reset() override;
32
+
33
+ /// Pretty prints the `Fold` module into the given `stream`.
34
+ void pretty_print(std::ostream& stream) const override;
35
+
36
+ Tensor forward(const Tensor& input);
37
+
38
+ /// The options with which this `Module` was constructed.
39
+ FoldOptions options;
40
+ };
41
+
42
+ /// A `ModuleHolder` subclass for `FoldImpl`.
43
+ /// See the documentation for `FoldImpl` class to learn what methods it
44
+ /// provides, and examples of how to use `Fold` with `torch::nn::FoldOptions`.
45
+ /// See the documentation for `ModuleHolder` to learn about PyTorch's
46
+ /// module storage semantics.
47
+ TORCH_MODULE(Fold);
48
+
49
+ // ============================================================================
50
+
51
+ /// Applies unfold over a 4-D input.
52
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.Unfold to learn about
53
+ /// the exact behavior of this module.
54
+ ///
55
+ /// See the documentation for `torch::nn::UnfoldOptions` class to learn what
56
+ /// constructor arguments are supported for this module.
57
+ ///
58
+ /// Example:
59
+ /// ```
60
+ /// Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2));
61
+ /// ```
62
+ class TORCH_API UnfoldImpl : public Cloneable<UnfoldImpl> {
63
+ public:
64
+ UnfoldImpl(ExpandingArray<2> kernel_size)
65
+ : UnfoldImpl(UnfoldOptions(kernel_size)) {}
66
+ explicit UnfoldImpl(const UnfoldOptions& options_);
67
+
68
+ void reset() override;
69
+
70
+ /// Pretty prints the `Unfold` module into the given `stream`.
71
+ void pretty_print(std::ostream& stream) const override;
72
+
73
+ Tensor forward(const Tensor& input);
74
+
75
+ /// The options with which this `Module` was constructed.
76
+ UnfoldOptions options;
77
+ };
78
+
79
+ /// A `ModuleHolder` subclass for `UnfoldImpl`.
80
+ /// See the documentation for `UnfoldImpl` class to learn what methods it
81
+ /// provides, and examples of how to use `Unfold` with
82
+ /// `torch::nn::UnfoldOptions`. See the documentation for `ModuleHolder` to
83
+ /// learn about PyTorch's module storage semantics.
84
+ TORCH_MODULE(Unfold);
85
+
86
+ } // namespace nn
87
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/nn/modules/batchnorm.h>
4
+ #include <torch/nn/options/instancenorm.h>
5
+
6
+ namespace torch {
7
+ namespace nn {
8
+
9
+ /// Base class for all (dimension-specialized) instance norm modules
10
+ template <size_t D, typename Derived>
11
+ class InstanceNormImpl
12
+ : public torch::nn::NormImplBase<D, Derived, InstanceNormOptions> {
13
+ private:
14
+ inline Tensor apply_instance_norm(const Tensor& input) {
15
+ return torch::nn::functional::detail::instance_norm(
16
+ input,
17
+ this->running_mean,
18
+ this->running_var,
19
+ this->weight,
20
+ this->bias,
21
+ this->is_training() || !this->options.track_running_stats(),
22
+ this->options.momentum(),
23
+ this->options.eps());
24
+ }
25
+
26
+ inline Tensor handle_no_batch_input(const Tensor& input) {
27
+ return this->apply_instance_norm(input.unsqueeze(0)).squeeze(0);
28
+ }
29
+
30
+ public:
31
+ using torch::nn::NormImplBase<D, Derived, InstanceNormOptions>::NormImplBase;
32
+
33
+ Tensor forward(const Tensor& input) {
34
+ this->_check_input_dim(input);
35
+
36
+ // For InstanceNorm1D, 2D is unbatched and 3D is batched
37
+ // For InstanceNorm2D, 3D is unbatched and 4D is batched
38
+ // For InstanceNorm3D, 4D is unbatched and 5D is batched
39
+ // check if input does not have a batch-dim
40
+ if (input.dim() == D + 1) {
41
+ return this->handle_no_batch_input(input);
42
+ }
43
+
44
+ return this->apply_instance_norm(input);
45
+ }
46
+
47
+ /// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`.
48
+ void pretty_print(std::ostream& stream) const override {
49
+ stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d("
50
+ << this->options.num_features() << ", "
51
+ << "eps=" << this->options.eps() << ", "
52
+ << "momentum=" << this->options.momentum() << ", "
53
+ << "affine=" << this->options.affine() << ", "
54
+ << "track_running_stats=" << this->options.track_running_stats()
55
+ << ")";
56
+ }
57
+ };
58
+
59
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d
60
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
61
+
62
+ /// Applies the InstanceNorm1d function.
63
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm1d to learn
64
+ /// about the exact behavior of this module.
65
+ ///
66
+ /// See the documentation for `torch::nn::InstanceNorm1dOptions` class to learn
67
+ /// what constructor arguments are supported for this module.
68
+ ///
69
+ /// Example:
70
+ /// ```
71
+ /// InstanceNorm1d
72
+ /// model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
73
+ /// ```
74
+ class TORCH_API InstanceNorm1dImpl
75
+ : public InstanceNormImpl<1, InstanceNorm1dImpl> {
76
+ protected:
77
+ void _check_input_dim(const Tensor& input) override;
78
+
79
+ public:
80
+ using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl;
81
+ };
82
+
83
+ /// A `ModuleHolder` subclass for `InstanceNorm1dImpl`.
84
+ /// See the documentation for `InstanceNorm1dImpl` class to learn what methods
85
+ /// it provides, and examples of how to use `InstanceNorm1d` with
86
+ /// `torch::nn::InstanceNorm1dOptions`. See the documentation for `ModuleHolder`
87
+ /// to learn about PyTorch's module storage semantics.
88
+ TORCH_MODULE(InstanceNorm1d);
89
+
90
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d
91
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
92
+
93
+ /// Applies the InstanceNorm2d function.
94
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm2d to learn
95
+ /// about the exact behavior of this module.
96
+ ///
97
+ /// See the documentation for `torch::nn::InstanceNorm2dOptions` class to learn
98
+ /// what constructor arguments are supported for this module.
99
+ ///
100
+ /// Example:
101
+ /// ```
102
+ /// InstanceNorm2d
103
+ /// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
104
+ /// ```
105
+ class TORCH_API InstanceNorm2dImpl
106
+ : public InstanceNormImpl<2, InstanceNorm2dImpl> {
107
+ protected:
108
+ void _check_input_dim(const Tensor& input) override;
109
+
110
+ public:
111
+ using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl;
112
+ };
113
+
114
+ /// A `ModuleHolder` subclass for `InstanceNorm2dImpl`.
115
+ /// See the documentation for `InstanceNorm2dImpl` class to learn what methods
116
+ /// it provides, and examples of how to use `InstanceNorm2d` with
117
+ /// `torch::nn::InstanceNorm2dOptions`. See the documentation for `ModuleHolder`
118
+ /// to learn about PyTorch's module storage semantics.
119
+ TORCH_MODULE(InstanceNorm2d);
120
+
121
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d
122
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
123
+
124
+ /// Applies the InstanceNorm3d function.
125
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm3d to learn
126
+ /// about the exact behavior of this module.
127
+ ///
128
+ /// See the documentation for `torch::nn::InstanceNorm3dOptions` class to learn
129
+ /// what constructor arguments are supported for this module.
130
+ ///
131
+ /// Example:
132
+ /// ```
133
+ /// InstanceNorm3d
134
+ /// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
135
+ /// ```
136
+ class TORCH_API InstanceNorm3dImpl
137
+ : public InstanceNormImpl<3, InstanceNorm3dImpl> {
138
+ protected:
139
+ void _check_input_dim(const Tensor& input) override;
140
+
141
+ public:
142
+ using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl;
143
+ };
144
+
145
+ /// A `ModuleHolder` subclass for `InstanceNorm3dImpl`.
146
+ /// See the documentation for `InstanceNorm3dImpl` class to learn what methods
147
+ /// it provides, and examples of how to use `InstanceNorm3d` with
148
+ /// `torch::nn::InstanceNorm3dOptions`. See the documentation for `ModuleHolder`
149
+ /// to learn about PyTorch's module storage semantics.
150
+ TORCH_MODULE(InstanceNorm3d);
151
+
152
+ } // namespace nn
153
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/loss.h ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/expanding_array.h>
4
+ #include <torch/nn/cloneable.h>
5
+ #include <torch/nn/functional/loss.h>
6
+ #include <torch/nn/options/loss.h>
7
+ #include <torch/nn/pimpl.h>
8
+ #include <torch/types.h>
9
+
10
+ #include <torch/csrc/Export.h>
11
+
12
+ #include <cstddef>
13
+ #include <vector>
14
+
15
+ namespace torch {
16
+ namespace nn {
17
+
18
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ L1Loss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
19
+
20
+ /// Creates a criterion that measures the mean absolute error (MAE) between each
21
+ /// element in the input : math :`x` and target : `y`.
22
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.L1Loss to learn
23
+ /// about the exact behavior of this module.
24
+ ///
25
+ /// See the documentation for `torch::nn::L1LossOptions` class to learn what
26
+ /// constructor arguments are supported for this module.
27
+ ///
28
+ /// Example:
29
+ /// ```
30
+ /// L1Loss model(L1LossOptions(torch::kNone));
31
+ /// ```
32
+ struct TORCH_API L1LossImpl : Cloneable<L1LossImpl> {
33
+ explicit L1LossImpl(L1LossOptions options_ = {});
34
+
35
+ void reset() override;
36
+
37
+ /// Pretty prints the `L1Loss` module into the given `stream`.
38
+ void pretty_print(std::ostream& stream) const override;
39
+
40
+ Tensor forward(const Tensor& input, const Tensor& target);
41
+
42
+ /// The options with which this `Module` was constructed.
43
+ L1LossOptions options;
44
+ };
45
+
46
+ /// A `ModuleHolder` subclass for `L1LossImpl`.
47
+ /// See the documentation for `L1LossImpl` class to learn what methods it
48
+ /// provides, and examples of how to use `L1Loss` with
49
+ /// `torch::nn::L1LossOptions`. See the documentation for `ModuleHolder` to
50
+ /// learn about PyTorch's module storage semantics.
51
+ TORCH_MODULE(L1Loss);
52
+
53
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ KLDivLoss
54
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55
+
56
+ /// The Kullback-Leibler divergence loss measure
57
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.KLDivLoss to learn
58
+ /// about the exact behavior of this module.
59
+ ///
60
+ /// See the documentation for `torch::nn::KLDivLossOptions` class to learn what
61
+ /// constructor arguments are supported for this module.
62
+ ///
63
+ /// Example:
64
+ /// ```
65
+ /// KLDivLoss model(KLDivLossOptions().reduction(torch::kNone));
66
+ /// ```
67
+ struct TORCH_API KLDivLossImpl : Cloneable<KLDivLossImpl> {
68
+ explicit KLDivLossImpl(KLDivLossOptions options_ = {});
69
+
70
+ void reset() override;
71
+
72
+ /// Pretty prints the `KLDivLoss` module into the given `stream`.
73
+ void pretty_print(std::ostream& stream) const override;
74
+
75
+ Tensor forward(const Tensor& input, const Tensor& target);
76
+
77
+ /// The options with which this `Module` was constructed.
78
+ KLDivLossOptions options;
79
+ };
80
+
81
+ /// A `ModuleHolder` subclass for `KLDivLossImpl`.
82
+ /// See the documentation for `KLDivLossImpl` class to learn what methods it
83
+ /// provides, and examples of how to use `KLDivLoss` with
84
+ /// `torch::nn::KLDivLossOptions`. See the documentation for `ModuleHolder` to
85
+ /// learn about PyTorch's module storage semantics.
86
+ TORCH_MODULE(KLDivLoss);
87
+
88
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MSELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
89
+
90
+ /// Creates a criterion that measures the mean squared error (squared L2 norm)
91
+ /// between each element in the input :math:`x` and target :math:`y`.
92
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.MSELoss to learn
93
+ /// about the exact behavior of this module.
94
+ ///
95
+ /// See the documentation for `torch::nn::MSELossOptions` class to learn what
96
+ /// constructor arguments are supported for this module.
97
+ ///
98
+ /// Example:
99
+ /// ```
100
+ /// MSELoss model(MSELossOptions(torch::kNone));
101
+ /// ```
102
+ struct TORCH_API MSELossImpl : Cloneable<MSELossImpl> {
103
+ explicit MSELossImpl(MSELossOptions options_ = {});
104
+
105
+ void reset() override;
106
+
107
+ /// Pretty prints the `MSELoss` module into the given `stream`.
108
+ void pretty_print(std::ostream& stream) const override;
109
+
110
+ Tensor forward(const Tensor& input, const Tensor& target);
111
+
112
+ /// The options with which this `Module` was constructed.
113
+ MSELossOptions options;
114
+ };
115
+
116
+ /// A `ModuleHolder` subclass for `MSELossImpl`.
117
+ /// See the documentation for `MSELossImpl` class to learn what methods it
118
+ /// provides, and examples of how to use `MSELoss` with
119
+ /// `torch::nn::MSELossOptions`. See the documentation for `ModuleHolder` to
120
+ /// learn about PyTorch's module storage semantics.
121
+ TORCH_MODULE(MSELoss);
122
+
123
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
124
+
125
+ /// Creates a criterion that measures the Binary Cross Entropy
126
+ /// between the target and the output.
127
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.BCELoss to learn
128
+ /// about the exact behavior of this module.
129
+ ///
130
+ /// See the documentation for `torch::nn::BCELossOptions` class to learn what
131
+ /// constructor arguments are supported for this module.
132
+ ///
133
+ /// Example:
134
+ /// ```
135
+ /// BCELoss model(BCELossOptions().reduction(torch::kNone).weight(weight));
136
+ /// ```
137
+ struct TORCH_API BCELossImpl : Cloneable<BCELossImpl> {
138
+ explicit BCELossImpl(BCELossOptions options_ = {});
139
+
140
+ void reset() override;
141
+
142
+ /// Pretty prints the `BCELoss` module into the given `stream`.
143
+ void pretty_print(std::ostream& stream) const override;
144
+
145
+ Tensor forward(const Tensor& input, const Tensor& target);
146
+
147
+ /// The options with which this `Module` was constructed.
148
+ BCELossOptions options;
149
+ };
150
+
151
+ /// A `ModuleHolder` subclass for `BCELossImpl`.
152
+ /// See the documentation for `BCELossImpl` class to learn what methods it
153
+ /// provides, and examples of how to use `BCELoss` with
154
+ /// `torch::nn::BCELossOptions`. See the documentation for `ModuleHolder` to
155
+ /// learn about PyTorch's module storage semantics.
156
+ TORCH_MODULE(BCELoss);
157
+
158
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HingeEmbeddingLoss
159
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
160
+
161
+ /// Creates a criterion that measures the loss given an input tensor :math:`x`
162
+ /// and a labels tensor :math:`y` (containing 1 or -1).
163
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.HingeEmbeddingLoss to
164
+ /// learn about the exact behavior of this module.
165
+ ///
166
+ /// See the documentation for `torch::nn::HingeEmbeddingLossOptions` class to
167
+ /// learn what constructor arguments are supported for this module.
168
+ ///
169
+ /// Example:
170
+ /// ```
171
+ /// HingeEmbeddingLoss
172
+ /// model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone));
173
+ /// ```
174
+ struct TORCH_API HingeEmbeddingLossImpl : Cloneable<HingeEmbeddingLossImpl> {
175
+ explicit HingeEmbeddingLossImpl(HingeEmbeddingLossOptions options_ = {});
176
+
177
+ void reset() override;
178
+
179
+ /// Pretty prints the `HingeEmbeddingLoss` module into the given `stream`.
180
+ void pretty_print(std::ostream& stream) const override;
181
+
182
+ Tensor forward(const Tensor& input, const Tensor& target);
183
+
184
+ /// The options with which this `Module` was constructed.
185
+ HingeEmbeddingLossOptions options;
186
+ };
187
+
188
+ /// A `ModuleHolder` subclass for `HingeEmbeddingLossImpl`.
189
+ /// See the documentation for `HingeEmbeddingLossImpl` class to learn what
190
+ /// methods it provides, and examples of how to use `HingeEmbeddingLoss` with
191
+ /// `torch::nn::HingeEmbeddingLossOptions`. See the documentation for
192
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
193
+ TORCH_MODULE(HingeEmbeddingLoss);
194
+
195
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiMarginLoss
196
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
197
+
198
+ /// Creates a criterion that optimizes a multi-class classification hinge
199
+ /// loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
200
+ /// and output :math:`y` (which is a 1D tensor of target class indices, :math:`0
201
+ /// \leq y \leq \text{x.size}(1)-1`). See
202
+ /// https://pytorch.org/docs/main/nn.html#torch.nn.MultiMarginLoss to learn
203
+ /// about the exact behavior of this module.
204
+ ///
205
+ /// See the documentation for `torch::nn::MultiMarginLossOptions` class to learn
206
+ /// what constructor arguments are supported for this module.
207
+ ///
208
+ /// Example:
209
+ /// ```
210
+ /// MultiMarginLoss model(MultiMarginLossOptions().margin(2).weight(weight));
211
+ /// ```
212
+ struct TORCH_API MultiMarginLossImpl : public Cloneable<MultiMarginLossImpl> {
213
+ explicit MultiMarginLossImpl(MultiMarginLossOptions options_ = {});
214
+
215
+ void reset() override;
216
+
217
+ /// Pretty prints the `MultiMarginLoss` module into the given `stream`.
218
+ void pretty_print(std::ostream& stream) const override;
219
+
220
+ Tensor forward(const Tensor& input, const Tensor& target);
221
+
222
+ /// The options with which this `Module` was constructed.
223
+ MultiMarginLossOptions options;
224
+ };
225
+
226
+ /// A `ModuleHolder` subclass for `MultiMarginLossImpl`.
227
+ /// See the documentation for `MultiMarginLossImpl` class to learn what methods
228
+ /// it provides, and examples of how to use `MultiMarginLoss` with
229
+ /// `torch::nn::MultiMarginLossOptions`. See the documentation for
230
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
231
+ TORCH_MODULE(MultiMarginLoss);
232
+
233
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CosineEmbeddingLoss
234
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
235
+
236
+ /// Creates a criterion that measures the loss given input tensors
237
+ /// `input1`, `input2`, and a `Tensor` label `target` with values 1 or
238
+ /// -1. This is used for measuring whether two inputs are similar or
239
+ /// dissimilar, using the cosine distance, and is typically used for learning
240
+ /// nonlinear embeddings or semi-supervised learning.
241
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.CosineEmbeddingLoss to
242
+ /// learn about the exact behavior of this module.
243
+ ///
244
+ /// See the documentation for `torch::nn::CosineEmbeddingLossOptions` class to
245
+ /// learn what constructor arguments are supported for this module.
246
+ ///
247
+ /// Example:
248
+ /// ```
249
+ /// CosineEmbeddingLoss model(CosineEmbeddingLossOptions().margin(0.5));
250
+ /// ```
251
+ struct TORCH_API CosineEmbeddingLossImpl
252
+ : public Cloneable<CosineEmbeddingLossImpl> {
253
+ explicit CosineEmbeddingLossImpl(CosineEmbeddingLossOptions options_ = {});
254
+
255
+ void reset() override;
256
+
257
+ /// Pretty prints the `CosineEmbeddingLoss` module into the given `stream`.
258
+ void pretty_print(std::ostream& stream) const override;
259
+
260
+ Tensor forward(
261
+ const Tensor& input1,
262
+ const Tensor& input2,
263
+ const Tensor& target);
264
+
265
+ /// The options with which this `Module` was constructed.
266
+ CosineEmbeddingLossOptions options;
267
+ };
268
+
269
+ /// A `ModuleHolder` subclass for `CosineEmbeddingLossImpl`.
270
+ /// See the documentation for `CosineEmbeddingLossImpl` class to learn what
271
+ /// methods it provides, and examples of how to use `CosineEmbeddingLoss` with
272
+ /// `torch::nn::CosineEmbeddingLossOptions`. See the documentation for
273
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
274
+ TORCH_MODULE(CosineEmbeddingLoss);
275
+
276
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SmoothL1Loss
277
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
278
+
279
+ /// Creates a criterion that uses a squared term if the absolute
280
+ /// element-wise error falls below beta and an L1 term otherwise.
281
+ /// It is less sensitive to outliers than the `MSELoss` and in some cases
282
+ /// prevents exploding gradients (e.g. see the paper `Fast R-CNN` by Ross
283
+ /// Girshick). See https://pytorch.org/docs/main/nn.html#torch.nn.SmoothL1Loss
284
+ /// to learn about the exact behavior of this module.
285
+ ///
286
+ /// See the documentation for `torch::nn::SmoothL1LossOptions` class to learn
287
+ /// what constructor arguments are supported for this module.
288
+ ///
289
+ /// Example:
290
+ /// ```
291
+ /// SmoothL1Loss model(SmoothL1LossOptions().reduction(torch::kNone).beta(0.5));
292
+ /// ```
293
+ struct TORCH_API SmoothL1LossImpl : public Cloneable<SmoothL1LossImpl> {
294
+ explicit SmoothL1LossImpl(SmoothL1LossOptions options = {});
295
+
296
+ void reset() override;
297
+
298
+ /// Pretty prints the `L1Loss` module into the given `stream`.
299
+ void pretty_print(std::ostream& stream) const override;
300
+
301
+ Tensor forward(const Tensor& input, const Tensor& target);
302
+
303
+ /// The options with which this `Module` was constructed.
304
+ SmoothL1LossOptions options;
305
+ };
306
+
307
+ /// A `ModuleHolder` subclass for `SmoothL1LossImpl`.
308
+ /// See the documentation for `SmoothL1LossImpl` class to learn what methods it
309
+ /// provides, and examples of how to use `SmoothL1Loss` with
310
+ /// `torch::nn::SmoothL1LossOptions`. See the documentation for `ModuleHolder`
311
+ /// to learn about PyTorch's module storage semantics.
312
+ TORCH_MODULE(SmoothL1Loss);
313
+
314
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HuberLoss
315
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
316
+
317
+ /// Creates a criterion that uses a squared term if the absolute
318
+ /// element-wise error falls below delta and a delta-scaled L1 term otherwise.
319
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.HuberLoss to learn
320
+ /// about the exact behavior of this module.
321
+ ///
322
+ /// See the documentation for `torch::nn::HuberLossOptions` class to learn what
323
+ /// constructor arguments are supported for this module.
324
+ ///
325
+ /// Example:
326
+ /// ```
327
+ /// HuberLoss model(HuberLossOptions().reduction(torch::kNone).delta(0.5));
328
+ /// ```
329
+ struct TORCH_API HuberLossImpl : public Cloneable<HuberLossImpl> {
330
+ explicit HuberLossImpl(HuberLossOptions options_ = {});
331
+
332
+ void reset() override;
333
+
334
+ /// Pretty prints the `HuberLoss` module into the given `stream`.
335
+ void pretty_print(std::ostream& stream) const override;
336
+
337
+ Tensor forward(const Tensor& input, const Tensor& target);
338
+
339
+ /// The options with which this `Module` was constructed.
340
+ HuberLossOptions options;
341
+ };
342
+
343
+ /// A `ModuleHolder` subclass for `HuberLossImpl`.
344
+ /// See the documentation for `HuberLossImpl` class to learn what methods it
345
+ /// provides, and examples of how to use `HuberLoss` with
346
+ /// `torch::nn::HuberLossOptions`. See the documentation for `ModuleHolder` to
347
+ /// learn about PyTorch's module storage semantics.
348
+ TORCH_MODULE(HuberLoss);
349
+
350
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss
351
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
352
+
353
+ /// Creates a criterion that optimizes a multi-class multi-classification
354
+ /// hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch
355
+ /// `Tensor`) and output :math:`y` (which is a 2D `Tensor` of target class
356
+ /// indices). See
357
+ /// https://pytorch.org/docs/main/nn.html#torch.nn.MultiLabelMarginLoss to
358
+ /// learn about the exact behavior of this module.
359
+ ///
360
+ /// See the documentation for `torch::nn::MultiLabelMarginLossOptions` class to
361
+ /// learn what constructor arguments are supported for this module.
362
+ ///
363
+ /// Example:
364
+ /// ```
365
+ /// MultiLabelMarginLoss model(MultiLabelMarginLossOptions(torch::kNone));
366
+ /// ```
367
+ struct TORCH_API MultiLabelMarginLossImpl
368
+ : public Cloneable<MultiLabelMarginLossImpl> {
369
+ explicit MultiLabelMarginLossImpl(MultiLabelMarginLossOptions options_ = {});
370
+
371
+ void reset() override;
372
+
373
+ /// Pretty prints the `L1Loss` module into the given `stream`.
374
+ void pretty_print(std::ostream& stream) const override;
375
+
376
+ Tensor forward(const Tensor& input, const Tensor& target);
377
+
378
+ /// The options with which this `Module` was constructed.
379
+ MultiLabelMarginLossOptions options;
380
+ };
381
+
382
+ /// A `ModuleHolder` subclass for `MultiLabelMarginLossImpl`.
383
+ /// See the documentation for `MultiLabelMarginLossImpl` class to learn what
384
+ /// methods it provides, and examples of how to use `MultiLabelMarginLoss` with
385
+ /// `torch::nn::MultiLabelMarginLossOptions`. See the documentation for
386
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
387
+ TORCH_MODULE(MultiLabelMarginLoss);
388
+
389
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SoftMarginLoss
390
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
391
+
392
+ /// Creates a criterion that optimizes a two-class classification
393
+ /// logistic loss between input tensor :math:`x` and target tensor :math:`y`
394
+ /// (containing 1 or -1).
395
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.SoftMarginLoss to learn
396
+ /// about the exact behavior of this module.
397
+ ///
398
+ /// See the documentation for `torch::nn::SoftMarginLossOptions` class to learn
399
+ /// what constructor arguments are supported for this module.
400
+ ///
401
+ /// Example:
402
+ /// ```
403
+ /// SoftMarginLoss model(SoftMarginLossOptions(torch::kNone));
404
+ /// ```
405
+ struct TORCH_API SoftMarginLossImpl : public Cloneable<SoftMarginLossImpl> {
406
+ explicit SoftMarginLossImpl(SoftMarginLossOptions options_ = {});
407
+
408
+ /// Pretty prints the `SoftMarginLoss` module into the given `stream`.
409
+ void pretty_print(std::ostream& stream) const override;
410
+
411
+ void reset() override;
412
+
413
+ Tensor forward(const Tensor& input, const Tensor& target);
414
+
415
+ /// The options with which this `Module` was constructed.
416
+ SoftMarginLossOptions options;
417
+ };
418
+
419
+ /// A `ModuleHolder` subclass for `SoftMarginLossImpl`.
420
+ /// See the documentation for `SoftMarginLossImpl` class to learn what methods
421
+ /// it provides, and examples of how to use `SoftMarginLoss` with
422
+ /// `torch::nn::SoftMarginLossOptions`. See the documentation for `ModuleHolder`
423
+ /// to learn about PyTorch's module storage semantics.
424
+ TORCH_MODULE(SoftMarginLoss);
425
+
426
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelSoftMarginLoss
427
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
428
+
429
+ /// Creates a criterion that optimizes a multi-label one-versus-all
430
+ /// loss based on max-entropy, between input :math:`x` and target :math:`y` of
431
+ /// size :math:`(N, C)`. See
432
+ /// https://pytorch.org/docs/main/nn.html#torch.nn.MultiLabelSoftMarginLoss to
433
+ /// learn about the exact behavior of this module.
434
+ ///
435
+ /// See the documentation for `torch::nn::MultiLabelSoftMarginLossOptions` class
436
+ /// to learn what constructor arguments are supported for this module.
437
+ ///
438
+ /// Example:
439
+ /// ```
440
+ /// MultiLabelSoftMarginLoss
441
+ /// model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight));
442
+ /// ```
443
+ struct TORCH_API MultiLabelSoftMarginLossImpl
444
+ : public Cloneable<MultiLabelSoftMarginLossImpl> {
445
+ explicit MultiLabelSoftMarginLossImpl(
446
+ MultiLabelSoftMarginLossOptions options_ = {});
447
+
448
+ /// Pretty prints the `MultiLabelSoftMarginLoss` module into the given
449
+ /// `stream`.
450
+ void pretty_print(std::ostream& stream) const override;
451
+
452
+ void reset() override;
453
+
454
+ Tensor forward(const Tensor& input, const Tensor& target);
455
+
456
+ /// The options with which this `Module` was constructed.
457
+ MultiLabelSoftMarginLossOptions options;
458
+ };
459
+
460
+ /// A `ModuleHolder` subclass for `MultiLabelSoftMarginLossImpl`.
461
+ /// See the documentation for `MultiLabelSoftMarginLossImpl` class to learn what
462
+ /// methods it provides, and examples of how to use `MultiLabelSoftMarginLoss`
463
+ /// with `torch::nn::MultiLabelSoftMarginLossOptions`. See the documentation for
464
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
465
+ TORCH_MODULE(MultiLabelSoftMarginLoss);
466
+
467
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginLoss
468
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
469
+
470
+ /// Creates a criterion that measures the triplet loss given an input
471
+ /// tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater
472
+ /// than :math:`0`. This is used for measuring a relative similarity between
473
+ /// samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`,
474
+ /// `positive examples` and `negative examples` respectively). The
475
+ /// shapes of all input tensors should be :math:`(N, D)`.
476
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.TripletMarginLoss to
477
+ /// learn about the exact behavior of this module.
478
+ ///
479
+ /// See the documentation for `torch::nn::TripletMarginLossOptions` class to
480
+ /// learn what constructor arguments are supported for this module.
481
+ ///
482
+ /// Example:
483
+ /// ```
484
+ /// TripletMarginLoss
485
+ /// model(TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false));
486
+ /// ```
487
+ struct TORCH_API TripletMarginLossImpl
488
+ : public Cloneable<TripletMarginLossImpl> {
489
+ explicit TripletMarginLossImpl(TripletMarginLossOptions options_ = {});
490
+
491
+ void reset() override;
492
+
493
+ /// Pretty prints the `TripletMarginLoss` module into the given `stream`.
494
+ void pretty_print(std::ostream& stream) const override;
495
+
496
+ Tensor forward(
497
+ const Tensor& anchor,
498
+ const Tensor& positive,
499
+ const Tensor& negative);
500
+
501
+ /// The options with which this `Module` was constructed.
502
+ TripletMarginLossOptions options;
503
+ };
504
+
505
+ /// A `ModuleHolder` subclass for `TripletMarginLossImpl`.
506
+ /// See the documentation for `TripletMarginLossImpl` class to learn what
507
+ /// methods it provides, and examples of how to use `TripletMarginLoss` with
508
+ /// `torch::nn::TripletMarginLossOptions`. See the documentation for
509
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
510
+ TORCH_MODULE(TripletMarginLoss);
511
+
512
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss
513
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
514
+
515
+ /// Creates a criterion that measures the triplet loss given input
516
+ /// tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
517
+ /// positive, and negative examples, respectively); and a nonnegative,
518
+ /// real-valued function
519
+ /// ("distance function") used to compute the relationships between the anchor
520
+ /// and positive example ("positive distance") and the anchor and negative
521
+ /// example ("negative distance").
522
+ /// See
523
+ /// https://pytorch.org/docs/main/nn.html#torch.nn.TripletMarginWithDistanceLoss
524
+ /// to learn about the exact behavior of this module.
525
+ ///
526
+ /// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions`
527
+ /// class to learn what constructor arguments are supported for this module.
528
+ ///
529
+ /// Example:
530
+ /// ```
531
+ /// TripletMarginWithDistanceLoss
532
+ /// model(TripletMarginWithDistanceLossOptions().margin(3).swap(false));
533
+ /// ```
534
+ struct TORCH_API TripletMarginWithDistanceLossImpl
535
+ : public Cloneable<TripletMarginWithDistanceLossImpl> {
536
+ explicit TripletMarginWithDistanceLossImpl(
537
+ TripletMarginWithDistanceLossOptions options_ = {});
538
+
539
+ void reset() override;
540
+
541
+ /// Pretty prints the `TripletMarginWithDistanceLoss` module into the given
542
+ /// `stream`.
543
+ void pretty_print(std::ostream& stream) const override;
544
+
545
+ Tensor forward(
546
+ const Tensor& anchor,
547
+ const Tensor& positive,
548
+ const Tensor& negative);
549
+
550
+ /// The options with which this `Module` was constructed.
551
+ TripletMarginWithDistanceLossOptions options;
552
+ };
553
+
554
+ /// A `ModuleHolder` subclass for `TripletMarginWithDistanceLossImpl`.
555
+ /// See the documentation for `TripletMarginWithDistanceLossImpl` class to learn
556
+ /// what methods it provides, and examples of how to use
557
+ /// `TripletMarginWithDistanceLoss` with
558
+ /// `torch::nn::TripletMarginWithDistanceLossOptions`.
559
+ /// See the documentation for `ModuleHolder` to learn about PyTorch's
560
+ /// module storage semantics.
561
+ TORCH_MODULE(TripletMarginWithDistanceLoss);
562
+
563
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CTCLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
564
+
565
+ /// The Connectionist Temporal Classification loss.
566
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.CTCLoss to learn
567
+ /// about the exact behavior of this module.
568
+ ///
569
+ /// See the documentation for `torch::nn::CTCLossOptions` class to learn what
570
+ /// constructor arguments are supported for this module.
571
+ ///
572
+ /// Example:
573
+ /// ```
574
+ /// CTCLoss
575
+ /// model(CTCLossOptions().blank(42).zero_infinity(false).reduction(torch::kSum));
576
+ /// ```
577
+ struct TORCH_API CTCLossImpl : public Cloneable<CTCLossImpl> {
578
+ explicit CTCLossImpl(CTCLossOptions options_ = {});
579
+
580
+ void reset() override;
581
+
582
+ /// Pretty prints the `CTCLoss` module into the given `stream`.
583
+ void pretty_print(std::ostream& stream) const override;
584
+
585
+ Tensor forward(
586
+ const Tensor& log_probs,
587
+ const Tensor& targets,
588
+ const Tensor& input_lengths,
589
+ const Tensor& target_lengths);
590
+
591
+ /// The options with which this `Module` was constructed.
592
+ CTCLossOptions options;
593
+ };
594
+
595
+ /// A `ModuleHolder` subclass for `CTCLossImpl`.
596
+ /// See the documentation for `CTCLossImpl` class to learn what methods it
597
+ /// provides, and examples of how to use `CTCLoss` with
598
+ /// `torch::nn::CTCLossOptions`. See the documentation for `ModuleHolder` to
599
+ /// learn about PyTorch's module storage semantics.
600
+ TORCH_MODULE(CTCLoss);
601
+
602
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PoissonNLLLoss
603
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
604
+
605
+ /// Negative log likelihood loss with Poisson distribution of target.
606
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.PoissonNLLLoss to learn
607
+ /// about the exact behavior of this module.
608
+ ///
609
+ /// See the documentation for `torch::nn::PoissonNLLLossOptions` class to learn
610
+ /// what constructor arguments are supported for this module.
611
+ ///
612
+ /// Example:
613
+ /// ```
614
+ /// PoissonNLLLoss
615
+ /// model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum));
616
+ /// ```
617
+ struct TORCH_API PoissonNLLLossImpl : public Cloneable<PoissonNLLLossImpl> {
618
+ explicit PoissonNLLLossImpl(PoissonNLLLossOptions options_ = {});
619
+
620
+ void reset() override;
621
+
622
+ /// Pretty prints the `PoissonNLLLoss` module into the given `stream`.
623
+ void pretty_print(std::ostream& stream) const override;
624
+
625
+ Tensor forward(const Tensor& log_input, const Tensor& targets);
626
+
627
+ /// The options with which this `Module` was constructed.
628
+ PoissonNLLLossOptions options;
629
+ };
630
+
631
+ /// A `ModuleHolder` subclass for `PoissonNLLLossImpl`.
632
+ /// See the documentation for `PoissonNLLLossImpl` class to learn what methods
633
+ /// it provides, and examples of how to use `PoissonNLLLoss` with
634
+ /// `torch::nn::PoissonNLLLossOptions`. See the documentation for `ModuleHolder`
635
+ /// to learn about PyTorch's module storage semantics.
636
+ TORCH_MODULE(PoissonNLLLoss);
637
+
638
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MarginRankingLoss
639
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
640
+
641
+ /// Creates a criterion that measures the loss given
642
+ /// inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`,
643
+ /// and a label 1D mini-batch tensor :math:`y` (containing 1 or -1).
644
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.MarginRankingLoss to
645
+ /// learn about the exact behavior of this module.
646
+ ///
647
+ /// See the documentation for `torch::nn::MarginRankingLossOptions` class to
648
+ /// learn what constructor arguments are supported for this module.
649
+ ///
650
+ /// Example:
651
+ /// ```
652
+ /// MarginRankingLoss
653
+ /// model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum));
654
+ /// ```
655
+ struct TORCH_API MarginRankingLossImpl
656
+ : public Cloneable<MarginRankingLossImpl> {
657
+ explicit MarginRankingLossImpl(MarginRankingLossOptions options_ = {});
658
+
659
+ void reset() override;
660
+
661
+ /// Pretty prints the `MarginRankingLoss` module into the given `stream`.
662
+ void pretty_print(std::ostream& stream) const override;
663
+
664
+ Tensor forward(
665
+ const Tensor& input1,
666
+ const Tensor& input2,
667
+ const Tensor& targets);
668
+
669
+ /// The options with which this `Module` was constructed.
670
+ MarginRankingLossOptions options;
671
+ };
672
+
673
+ /// A `ModuleHolder` subclass for `MarginRankingLossImpl`.
674
+ /// See the documentation for `MarginRankingLossImpl` class to learn what
675
+ /// methods it provides, and examples of how to use `MarginRankingLoss` with
676
+ /// `torch::nn::MarginRankingLossOptions`. See the documentation for
677
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
678
+ TORCH_MODULE(MarginRankingLoss);
679
+
680
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ NLLLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
681
+
682
+ /// The negative log likelihood loss. It is useful to train a classification
683
+ /// problem with `C` classes.
684
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss to learn
685
+ /// about the exact behavior of this module.
686
+ ///
687
+ /// See the documentation for `torch::nn::NLLLossOptions` class to learn what
688
+ /// constructor arguments are supported for this module.
689
+ ///
690
+ /// Example:
691
+ /// ```
692
+ /// NLLLoss model(NLLLossOptions().ignore_index(-100).reduction(torch::kMean));
693
+ /// ```
694
+ struct TORCH_API NLLLossImpl : public Cloneable<NLLLossImpl> {
695
+ explicit NLLLossImpl(NLLLossOptions options_ = {});
696
+
697
+ /// Pretty prints the `NLLLoss` module into the given `stream`.
698
+ void pretty_print(std::ostream& stream) const override;
699
+
700
+ void reset() override;
701
+
702
+ Tensor forward(const Tensor& input, const Tensor& target);
703
+
704
+ /// The options with which this `Module` was constructed.
705
+ NLLLossOptions options;
706
+
707
+ /// A manual rescaling weight given to to each class.
708
+ Tensor weight;
709
+ };
710
+
711
+ /// A `ModuleHolder` subclass for `NLLLossImpl`.
712
+ /// See the documentation for `NLLLossImpl` class to learn what methods it
713
+ /// provides, and examples of how to use `NLLLoss` with
714
+ /// `torch::nn::NLLLossOptions`. See the documentation for `ModuleHolder` to
715
+ /// learn about PyTorch's module storage semantics.
716
+ TORCH_MODULE(NLLLoss);
717
+
718
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossEntropyLoss
719
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
720
+
721
+ /// Creates a criterion that computes cross entropy loss between input and
722
+ /// target. See
723
+ /// https://pytorch.org/docs/main/nn.html#torch.nn.CrossEntropyLoss to learn
724
+ /// about the exact behavior of this module.
725
+ ///
726
+ /// See the documentation for `torch::nn::CrossEntropyLossOptions` class to
727
+ /// learn what constructor arguments are supported for this module.
728
+ ///
729
+ /// Example:
730
+ /// ```
731
+ /// CrossEntropyLoss
732
+ /// model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean));
733
+ /// ```
734
+ struct TORCH_API CrossEntropyLossImpl : public Cloneable<CrossEntropyLossImpl> {
735
+ explicit CrossEntropyLossImpl(CrossEntropyLossOptions options_ = {});
736
+
737
+ void reset() override;
738
+
739
+ /// Pretty prints the `CrossEntropyLoss` module into the given `stream`.
740
+ void pretty_print(std::ostream& stream) const override;
741
+
742
+ Tensor forward(const Tensor& input, const Tensor& target);
743
+
744
+ /// The options with which this `Module` was constructed.
745
+ CrossEntropyLossOptions options;
746
+
747
+ /// A manual rescaling weight given to to each class.
748
+ Tensor weight;
749
+ };
750
+
751
+ /// A `ModuleHolder` subclass for `CrossEntropyLossImpl`.
752
+ /// See the documentation for `CrossEntropyLossImpl` class to learn what methods
753
+ /// it provides, and examples of how to use `CrossEntropyLoss` with
754
+ /// `torch::nn::CrossEntropyLossOptions`. See the documentation for
755
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
756
+ TORCH_MODULE(CrossEntropyLoss);
757
+
758
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCEWithLogitsLoss
759
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
760
+
761
+ /// This loss combines a `Sigmoid` layer and the `BCELoss` in one single
762
+ /// class. This version is more numerically stable than using a plain `Sigmoid`
763
+ /// followed by a `BCELoss` as, by combining the operations into one layer,
764
+ /// we take advantage of the log-sum-exp trick for numerical stability.
765
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.BCEWithLogitsLoss to
766
+ /// learn about the exact behavior of this module.
767
+ ///
768
+ /// See the documentation for `torch::nn::BCEWithLogitsLossOptions` class to
769
+ /// learn what constructor arguments are supported for this module.
770
+ ///
771
+ /// Example:
772
+ /// ```
773
+ /// BCEWithLogitsLoss
774
+ /// model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight));
775
+ /// ```
776
+ struct TORCH_API BCEWithLogitsLossImpl
777
+ : public Cloneable<BCEWithLogitsLossImpl> {
778
+ explicit BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_ = {});
779
+
780
+ void reset() override;
781
+
782
+ /// Pretty prints the `BCEWithLogitsLoss` module into the given `stream`.
783
+ void pretty_print(std::ostream& stream) const override;
784
+
785
+ Tensor forward(const Tensor& input, const Tensor& target);
786
+
787
+ /// The options with which this `Module` was constructed.
788
+ BCEWithLogitsLossOptions options;
789
+
790
+ /// A manual rescaling weight given to the loss of each batch element.
791
+ Tensor weight;
792
+
793
+ /// A weight of positive examples.
794
+ Tensor pos_weight;
795
+ };
796
+
797
+ /// A `ModuleHolder` subclass for `BCEWithLogitsLossImpl`.
798
+ /// See the documentation for `BCEWithLogitsLossImpl` class to learn what
799
+ /// methods it provides, and examples of how to use `BCEWithLogitsLoss` with
800
+ /// `torch::nn::BCEWithLogitsLossOptions`. See the documentation for
801
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
802
+ TORCH_MODULE(BCEWithLogitsLoss);
803
+
804
+ } // namespace nn
805
+ } // namespace torch
.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/nn/cloneable.h>
4
+ #include <torch/nn/functional/normalization.h>
5
+ #include <torch/nn/modules/_functions.h>
6
+ #include <torch/nn/options/normalization.h>
7
+ #include <torch/nn/pimpl.h>
8
+ #include <torch/types.h>
9
+
10
+ #include <cstddef>
11
+ #include <vector>
12
+
13
+ namespace torch {
14
+ namespace nn {
15
+
16
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LayerNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
17
+
18
+ /// Applies Layer Normalization over a mini-batch of inputs as described in
19
+ /// the paper `Layer Normalization`_ .
20
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.LayerNorm to learn
21
+ /// about the exact behavior of this module.
22
+ ///
23
+ /// See the documentation for `torch::nn::LayerNormOptions` class to learn what
24
+ /// constructor arguments are supported for this module.
25
+ ///
26
+ /// Example:
27
+ /// ```
28
+ /// LayerNorm model(LayerNormOptions({2,
29
+ /// 2}).elementwise_affine(false).eps(2e-5));
30
+ /// ```
31
+ class TORCH_API LayerNormImpl : public torch::nn::Cloneable<LayerNormImpl> {
32
+ public:
33
+ LayerNormImpl(std::vector<int64_t> normalized_shape)
34
+ : LayerNormImpl(LayerNormOptions(normalized_shape)) {}
35
+ explicit LayerNormImpl(LayerNormOptions options_);
36
+
37
+ void reset() override;
38
+
39
+ void reset_parameters();
40
+
41
+ /// Pretty prints the `LayerNorm` module into the given `stream`.
42
+ void pretty_print(std::ostream& stream) const override;
43
+
44
+ /// Applies layer normalization over a mini-batch of inputs as described in
45
+ /// the paper `Layer Normalization`_ .
46
+ ///
47
+ /// The mean and standard-deviation are calculated separately over the last
48
+ /// certain number dimensions which have to be of the shape specified by
49
+ /// input `normalized_shape`.
50
+ ///
51
+ /// `Layer Normalization`: https://arxiv.org/abs/1607.06450
52
+ Tensor forward(const Tensor& input);
53
+
54
+ /// The options with which this module was constructed.
55
+ LayerNormOptions options;
56
+
57
+ /// The learned weight.
58
+ /// Initialized to ones if the `elementwise_affine` option is set to `true`
59
+ /// upon construction.
60
+ Tensor weight;
61
+
62
+ /// The learned bias.
63
+ /// Initialized to zeros `elementwise_affine` option is set to `true` upon
64
+ /// construction.
65
+ Tensor bias;
66
+ };
67
+
68
+ /// A `ModuleHolder` subclass for `LayerNormImpl`.
69
+ /// See the documentation for `LayerNormImpl` class to learn what methods it
70
+ /// provides, and examples of how to use `LayerNorm` with
71
+ /// `torch::nn::LayerNormOptions`. See the documentation for `ModuleHolder` to
72
+ /// learn about PyTorch's module storage semantics.
73
+ TORCH_MODULE(LayerNorm);
74
+
75
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LocalResponseNorm
76
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
77
+
78
+ /// Applies local response normalization over an input signal composed
79
+ /// of several input planes, where channels occupy the second dimension.
80
+ /// Applies normalization across channels.
81
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.LocalResponseNorm to
82
+ /// learn about the exact behavior of this module.
83
+ ///
84
+ /// See the documentation for `torch::nn::LocalResponseNormOptions` class to
85
+ /// learn what constructor arguments are supported for this module.
86
+ ///
87
+ /// Example:
88
+ /// ```
89
+ /// LocalResponseNorm
90
+ /// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.));
91
+ /// ```
92
+ class TORCH_API LocalResponseNormImpl
93
+ : public Cloneable<LocalResponseNormImpl> {
94
+ public:
95
+ LocalResponseNormImpl(int64_t size)
96
+ : LocalResponseNormImpl(LocalResponseNormOptions(size)) {}
97
+ explicit LocalResponseNormImpl(const LocalResponseNormOptions& options_);
98
+
99
+ Tensor forward(const Tensor& input);
100
+
101
+ void reset() override;
102
+
103
+ /// Pretty prints the `LocalResponseNormImpl` module into the given `stream`.
104
+ void pretty_print(std::ostream& stream) const override;
105
+
106
+ /// The options with which this `Module` was constructed.
107
+ LocalResponseNormOptions options;
108
+ };
109
+
110
+ /// A `ModuleHolder` subclass for `LocalResponseNormImpl`.
111
+ /// See the documentation for `LocalResponseNormImpl` class to learn what
112
+ /// methods it provides, and examples of how to use `LocalResponseNorm` with
113
+ /// `torch::nn::LocalResponseNormOptions`. See the documentation for
114
+ /// `ModuleHolder` to learn about PyTorch's module storage semantics.
115
+ TORCH_MODULE(LocalResponseNorm);
116
+
117
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossMapLRN2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
118
+
119
+ /// See the documentation for `torch::nn::CrossMapLRN2dOptions` class to learn
120
+ /// what constructor arguments are supported for this module.
121
+ ///
122
+ /// Example:
123
+ /// ```
124
+ /// CrossMapLRN2d model(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10));
125
+ /// ```
126
+ class TORCH_API CrossMapLRN2dImpl
127
+ : public torch::nn::Cloneable<CrossMapLRN2dImpl> {
128
+ public:
129
+ CrossMapLRN2dImpl(int64_t size)
130
+ : CrossMapLRN2dImpl(CrossMapLRN2dOptions(size)) {}
131
+ explicit CrossMapLRN2dImpl(const CrossMapLRN2dOptions& options_)
132
+ : options(options_) {}
133
+
134
+ void reset() override;
135
+
136
+ /// Pretty prints the `CrossMapLRN2d` module into the given `stream`.
137
+ void pretty_print(std::ostream& stream) const override;
138
+
139
+ torch::Tensor forward(const torch::Tensor& input);
140
+
141
+ CrossMapLRN2dOptions options;
142
+ };
143
+
144
+ /// A `ModuleHolder` subclass for `CrossMapLRN2dImpl`.
145
+ /// See the documentation for `CrossMapLRN2dImpl` class to learn what methods it
146
+ /// provides, and examples of how to use `CrossMapLRN2d` with
147
+ /// `torch::nn::CrossMapLRN2dOptions`. See the documentation for `ModuleHolder`
148
+ /// to learn about PyTorch's module storage semantics.
149
+ TORCH_MODULE(CrossMapLRN2d);
150
+
151
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GroupNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
152
+
153
+ /// Applies Group Normalization over a mini-batch of inputs as described in
154
+ /// the paper `Group Normalization`_ .
155
+ /// See https://pytorch.org/docs/main/nn.html#torch.nn.GroupNorm to learn
156
+ /// about the exact behavior of this module.
157
+ ///
158
+ /// See the documentation for `torch::nn::GroupNormOptions` class to learn what
159
+ /// constructor arguments are supported for this module.
160
+ ///
161
+ /// Example:
162
+ /// ```
163
+ /// GroupNorm model(GroupNormOptions(2, 2).eps(2e-5).affine(false));
164
+ /// ```
165
+ class TORCH_API GroupNormImpl : public torch::nn::Cloneable<GroupNormImpl> {
166
+ public:
167
+ GroupNormImpl(int64_t num_groups, int64_t num_channels)
168
+ : GroupNormImpl(GroupNormOptions(num_groups, num_channels)) {}
169
+ explicit GroupNormImpl(const GroupNormOptions& options_);
170
+
171
+ void reset() override;
172
+
173
+ void reset_parameters();
174
+
175
+ /// Pretty prints the `GroupNorm` module into the given `stream`.
176
+ void pretty_print(std::ostream& stream) const override;
177
+
178
+ Tensor forward(const Tensor& input);
179
+
180
+ /// The options with which this module was constructed.
181
+ GroupNormOptions options;
182
+
183
+ /// The learned weight.
184
+ Tensor weight;
185
+
186
+ /// The learned bias.
187
+ Tensor bias;
188
+ };
189
+
190
+ /// A `ModuleHolder` subclass for `GroupNormImpl`.
191
+ /// See the documentation for `GroupNormImpl` class to learn what methods it
192
+ /// provides, and examples of how to use `GroupNorm` with
193
+ /// `torch::nn::GroupNormOptions`. See the documentation for `ModuleHolder` to
194
+ /// learn about PyTorch's module storage semantics.
195
+ TORCH_MODULE(GroupNorm);
196
+
197
+ } // namespace nn
198
+ } // namespace torch