|
#include <ATen/ATen.h> |
|
#include <cstdint> |
|
|
|
#ifdef USE_FBGEMM |
|
#include <fbgemm/FbgemmEmbedding.h> |
|
#endif |
|
|
|
namespace at { |
|
namespace native { |
|
|
|
void check_arguments( |
|
const Tensor& weight, |
|
const Tensor& indices, |
|
const Tensor& offsets, |
|
const int64_t mode, |
|
const c10::optional<Tensor>& per_sample_weights, |
|
bool include_last_offset); |
|
|
|
void make_bag_size_out( |
|
Tensor& bag_size_out, |
|
const Tensor& offsets, |
|
const Tensor& indices, |
|
const int64_t mode, |
|
const bool include_last_offset, |
|
const bool requires_grad); |
|
|
|
void make_max_indices_out( |
|
Tensor& max_indices_out, |
|
const Tensor& weight, |
|
const Tensor& indices, |
|
const Tensor& offsets, |
|
const Tensor& bag_size, |
|
const int64_t mode, |
|
bool include_last_offset); |
|
|
|
void make_offset2bag_out( |
|
Tensor& offset2bag, |
|
Tensor& output, |
|
const Tensor& weight, |
|
const Tensor& indices, |
|
const Tensor& offsets, |
|
const int64_t mode, |
|
const c10::optional<Tensor>& per_sample_weights, |
|
const int64_t padding_idx = -1); |
|
|
|
#ifdef USE_FBGEMM |
|
|
|
template<bool has_weight, typename TIndex, typename TData> |
|
struct _CallbackAndBlockSize { |
|
using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type; |
|
|
|
int64_t blockSize = -1; |
|
TCallback callback = nullptr; |
|
|
|
static TCallback generateCallback(int64_t block_size) { |
|
return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>( |
|
block_size, |
|
has_weight, |
|
false, |
|
16, |
|
false, |
|
true); |
|
} |
|
|
|
_CallbackAndBlockSize() = default; |
|
|
|
explicit _CallbackAndBlockSize(c10::optional<int64_t> maybe_block_size) |
|
: blockSize(maybe_block_size.value_or(-1)) |
|
, callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr) |
|
{} |
|
}; |
|
|
|
template<typename... StorageMixins> |
|
struct _EmbeddingBagKernelCacheImpl : private StorageMixins... { |
|
|
|
_EmbeddingBagKernelCacheImpl() = default; |
|
|
|
explicit _EmbeddingBagKernelCacheImpl(c10::optional<int64_t> maybe_block_size) |
|
: StorageMixins(maybe_block_size)... |
|
{} |
|
|
|
|
|
template<bool has_weight, typename TIndex, typename TData> |
|
typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback |
|
getCallback(int64_t block_size) const { |
|
|
|
|
|
|
|
if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) { |
|
return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size); |
|
} |
|
|
|
return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback; |
|
} |
|
}; |
|
|
|
|
|
|
|
using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl< |
|
_CallbackAndBlockSize<true, int32_t, float>, |
|
_CallbackAndBlockSize<false, int32_t, float>, |
|
_CallbackAndBlockSize<true, int64_t, float>, |
|
_CallbackAndBlockSize<false, int64_t, float>, |
|
_CallbackAndBlockSize<true, int32_t, unsigned short>, |
|
_CallbackAndBlockSize<false, int32_t, unsigned short>, |
|
_CallbackAndBlockSize<true, int64_t, unsigned short>, |
|
_CallbackAndBlockSize<false, int64_t, unsigned short>>; |
|
#else |
|
struct _EmbeddingBagKernelCache { |
|
explicit _EmbeddingBagKernelCache(c10::optional<int64_t> ) {} |
|
}; |
|
#endif |
|
|
|
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag, |
|
Tensor& bag_size, Tensor* max_indices, |
|
const Tensor &weight, const Tensor &indices, |
|
const Tensor &offsets, const int64_t mode = 0, |
|
const c10::optional<Tensor>& per_sample_weights = c10::nullopt, |
|
bool include_last_offset = false, |
|
int64_t padding_idx = -1, |
|
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr); |
|
|
|
void _embedding_bag_cpu_out( |
|
at::Tensor& output, |
|
at::Tensor& offset2bag, |
|
at::Tensor& bag_size, |
|
at::Tensor* p_max_indices, |
|
const at::Tensor& weight, |
|
const at::Tensor& indices, |
|
const at::Tensor& offsets, |
|
const bool scale_grad_by_freq, |
|
const int64_t mode, |
|
const bool sparse, |
|
const c10::optional<at::Tensor>& per_sample_weights, |
|
const bool include_last_offset, |
|
const c10::optional<int64_t>& padding_idx, |
|
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr); |
|
|
|
} |
|
} |
|
|