|
#pragma once |
|
|
|
#include <ATen/EmptyTensor.h> |
|
#include <ATen/native/ResizeCommon.h> |
|
|
|
#include <c10/cuda/CUDAGuard.h> |
|
|
|
namespace at { namespace native { |
|
|
|
TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes); |
|
|
|
static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) { |
|
|
|
|
|
|
|
|
|
|
|
if (self->numel() == 0) { |
|
return; |
|
} |
|
|
|
const Storage &storage = self->unsafe_storage(); |
|
TORCH_CHECK(storage, "Tensor: invalid null storage"); |
|
if (new_size_bytes > storage.nbytes()) { |
|
resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes); |
|
} |
|
} |
|
|
|
inline TensorImpl* resize_impl_cuda_( |
|
TensorImpl* self, |
|
IntArrayRef size, |
|
at::OptionalIntArrayRef stride, |
|
bool device_guard = true) { |
|
if (self->sizes() == size && (!stride || self->strides() == stride)) { |
|
return self; |
|
} |
|
|
|
|
|
cuda::OptionalCUDAGuard guard; |
|
if (device_guard) { |
|
guard.set_index(self->storage().device().index()); |
|
} |
|
|
|
const auto itemsize = self->dtype().itemsize(); |
|
const auto storage_offset = self->storage_offset(); |
|
size_t storage_size = 1; |
|
if (stride) { |
|
self->set_sizes_and_strides(size, *stride); |
|
storage_size = at::detail::computeStorageNbytes( |
|
size, *stride, itemsize, storage_offset); |
|
} else { |
|
self->set_sizes_contiguous(size); |
|
storage_size = at::detail::computeStorageNbytesContiguous( |
|
size, itemsize, storage_offset); |
|
} |
|
maybe_resize_storage_cuda(self, storage_size); |
|
|
|
return self; |
|
} |
|
|
|
}} |
|
|