|
#include <ATen/ATen.h> |
|
#include <ATen/core/dispatch/Dispatcher.h> |
|
#include <ATen/core/op_registration/op_registration.h> |
|
#include <ATen/native/UnaryOps.h> |
|
#include <ATen/NativeFunctions.h> |
|
#include <ATen/native/Resize.h> |
|
#include <c10/util/irange.h> |
|
#include <torch/library.h> |
|
|
|
namespace at { |
|
namespace native { |
|
|
|
|
|
|
|
|
|
|
|
|
|
struct MathOpFallback { |
|
MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(op_name_) {} |
|
virtual bool is_bit_set(const Tensor&) = 0; |
|
void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const auto& arguments = op.schema().arguments(); |
|
const auto num_arguments = arguments.size(); |
|
const auto stack_start = stack->size() - num_arguments; |
|
|
|
c10::optional<bool> is_write; |
|
for (const auto i : c10::irange(num_arguments)) { |
|
|
|
|
|
|
|
|
|
const AliasInfo* alias_info = arguments[i].alias_info(); |
|
if (alias_info != nullptr) { |
|
if (is_write.has_value()) { |
|
TORCH_CHECK(*is_write == alias_info->isWrite(), |
|
"Unsupported operator for ", op_name, " fallback: ", op.schema().name(), |
|
op_name, " fallback doesn't work for operators with a mix " |
|
"mutable and non-mutable inputs that alias with outputs, " |
|
"this must be implemented manually. " |
|
"If you got this error on a core op, please report a bug to PyTorch."); |
|
} else { |
|
is_write = alias_info->isWrite(); |
|
} |
|
} |
|
} |
|
|
|
if (is_write.has_value() && !*is_write) { |
|
|
|
|
|
|
|
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack); |
|
return; |
|
} |
|
|
|
|
|
std::vector<std::pair<Tensor, Tensor>> mutable_inputs_with_their_clones; |
|
for (const auto i : c10::irange(num_arguments)) { |
|
auto& ivalue = (*stack)[stack_start + i]; |
|
if (!(ivalue.isTensor() || ivalue.isTensorList())) { |
|
continue; |
|
} |
|
const auto& argument = arguments[i]; |
|
bool mut_arg = false; |
|
if (argument.alias_info()) { |
|
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite()); |
|
mut_arg = true; |
|
} |
|
if (ivalue.isTensor()) { |
|
if (!is_bit_set(ivalue.toTensor())) { |
|
continue; |
|
} |
|
auto tensor = std::move(ivalue).toTensor(); |
|
auto resolved_tensor = at::clone(tensor); |
|
if (mut_arg) { |
|
TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ", |
|
op_name, "bit set to true."); |
|
mutable_inputs_with_their_clones.emplace_back(std::make_pair(std::move(tensor), resolved_tensor)); |
|
} |
|
(*stack)[stack_start + i] = std::move(resolved_tensor); |
|
} else if (ivalue.isTensorList()) { |
|
auto tensors = std::move(ivalue).toTensorList(); |
|
for(const auto j : c10::irange(tensors.size())) { |
|
const auto& tensor = tensors[j]; |
|
if (!is_bit_set(tensor)) { |
|
continue; |
|
} |
|
TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ", |
|
op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ", |
|
op.schema().name()); |
|
tensors[j] = at::clone(tensor); |
|
} |
|
(*stack)[stack_start + i] = std::move(tensors); |
|
} |
|
} |
|
|
|
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack); |
|
|
|
TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1); |
|
|
|
for (std::pair<Tensor, Tensor> mut_tensors: mutable_inputs_with_their_clones) { |
|
auto& mutable_input = mut_tensors.first; |
|
auto& cloned_mutable_input = mut_tensors.second; |
|
auto& ivalue = (*stack)[stack_start]; |
|
auto returned_output = std::move(ivalue).toTensor(); |
|
|
|
|
|
TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output)); |
|
|
|
|
|
at::native::resize_output(mutable_input, returned_output.sizes()); |
|
|
|
mutable_input.copy_(returned_output); |
|
(*stack)[stack_start] = std::move(mutable_input); |
|
} |
|
} |
|
|
|
virtual ~MathOpFallback() = default; |
|
|
|
DispatchKey key; |
|
string op_name; |
|
}; |
|
} |
|
} |
|
|