|
#pragma once |
|
|
|
#include <ATen/core/Tensor.h> |
|
#include <ATen/Parallel.h> |
|
#include <ATen/native/DispatchStub.h> |
|
|
|
namespace at { |
|
namespace native { |
|
|
|
|
|
struct PoolingParams1D { |
|
int64_t NB; |
|
int64_t NC; |
|
int64_t IW; |
|
int64_t OW; |
|
int64_t KW; |
|
int64_t SJ; |
|
int64_t PJ; |
|
int64_t DJ; |
|
|
|
|
|
inline int64_t index(int64_t kj, int64_t oj) const { |
|
return oj * SJ + kj * DJ - PJ; |
|
} |
|
|
|
|
|
inline int64_t valid_output_start(int64_t kj) const { |
|
int64_t ij = index(kj, 0);; |
|
return ij < 0 ? at::divup(-ij, SJ) : 0; |
|
} |
|
|
|
|
|
inline int64_t valid_output_end(int64_t kj) const { |
|
int64_t ij = index(kj, OW - 1); |
|
return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW; |
|
} |
|
}; |
|
|
|
using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&); |
|
|
|
DECLARE_DISPATCH(pooling_fn, max_pool1d_stub); |
|
|
|
} |
|
} |
|
|