|
#pragma once |
|
|
|
#include <ATen/core/Tensor.h> |
|
#include <c10/util/irange.h> |
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS |
|
#include <ATen/NativeFunctions.h> |
|
#else |
|
#include <ATen/ops/result_type_native.h> |
|
#endif |
|
|
|
namespace at { |
|
namespace native { |
|
namespace { |
|
|
|
bool has_integral_tensor(TensorList tensors, const bool includeBool) { |
|
return std::any_of(tensors.begin(), tensors.end(), |
|
[&includeBool](const auto & t) { return at::isIntegralType(t.scalar_type(), includeBool); }); |
|
} |
|
|
|
bool has_bool_tensor(TensorList tensors) { |
|
return std::any_of(tensors.begin(), tensors.end(), |
|
[](const auto & t) -> bool { return t.scalar_type() == ScalarType::Bool; }); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
void check_foreach_api_restrictions(TensorList tensors) { |
|
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor."); |
|
} |
|
|
|
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<Scalar> scalars) { |
|
check_foreach_api_restrictions(tensors); |
|
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list."); |
|
} |
|
|
|
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) { |
|
TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor."); |
|
TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor."); |
|
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size()); |
|
} |
|
|
|
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3) { |
|
TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor."); |
|
TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor."); |
|
TORCH_CHECK(tensors3.size() > 0, "Tensor list must have at least one tensor."); |
|
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size()); |
|
TORCH_CHECK(tensors1.size() == tensors3.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors3.size()); |
|
} |
|
|
|
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar> scalars) { |
|
check_foreach_api_restrictions(tensors1, tensors2, tensors3); |
|
TORCH_CHECK(tensors1.size() == scalars.size(), "Tensor list must have same number of elements as scalar list, got ", tensors1.size(), " and ", scalars.size()); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool check_fast_path_restrictions( |
|
ArrayRef<TensorList> tensorLists, |
|
ArrayRef<Scalar> scalarList = {}, |
|
bool does_op_promote_integer_inputs_to_float = false) { |
|
const auto expected_dtype = tensorLists[0][0].dtype(); |
|
const auto expected_device = tensorLists[0][0].device(); |
|
|
|
auto is_tensor_okay = [&](const Tensor& tensor) { |
|
return tensor.dtype() == expected_dtype && |
|
tensor.device() == expected_device && |
|
tensor.layout() == at::kStrided && |
|
tensor.is_non_overlapping_and_dense(); |
|
}; |
|
|
|
for (const auto& tensorList : tensorLists) { |
|
for (const auto& tensor : tensorList) { |
|
if (!is_tensor_okay(tensor)) { |
|
return false; |
|
} |
|
} |
|
} |
|
|
|
|
|
for (const auto& tensor_list : tensorLists) { |
|
for (const auto j : c10::irange(tensorLists[0].size())) { |
|
if (tensorLists[0][j].sizes() != tensor_list[j].sizes()) { |
|
return false; |
|
} |
|
if (tensorLists[0][j].strides() != tensor_list[j].strides()) { |
|
return false; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (const auto i : c10::irange(tensorLists[0].size())) { |
|
|
|
if (does_op_promote_integer_inputs_to_float) { |
|
if (at::isIntegralType(tensorLists[0][i].scalar_type(), true)) { |
|
return false; |
|
} |
|
} |
|
if (scalarList.size() > 0) { |
|
const auto& scalar = scalarList.size() == 1 ? scalarList[0] : scalarList[i]; |
|
const auto& tensor = tensorLists[0][i]; |
|
|
|
|
|
if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) { |
|
return false; |
|
} |
|
} |
|
} |
|
|
|
return true; |
|
} |
|
|
|
bool can_use_fast_route(ArrayRef<TensorList> tensorLists, |
|
ArrayRef<Scalar> scalarList = {}, |
|
bool does_op_promote_integer_inputs_to_float = false) { |
|
return check_fast_path_restrictions(tensorLists, scalarList, does_op_promote_integer_inputs_to_float); |
|
} |
|
|
|
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, bool does_op_promote_integer_inputs_to_float = false) { |
|
return can_use_fast_route({tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float); |
|
} |
|
|
|
} |
|
}} |
|
|