#pragma once #include #include #include #include #if defined(USE_ROCM) #include #endif namespace at::cuda::sparse { template struct CuSparseDescriptorDeleter { void operator()(T* x) { if (x != nullptr) { TORCH_CUDASPARSE_CHECK(destructor(x)); } } }; template class CuSparseDescriptor { public: T* descriptor() const { return descriptor_.get(); } T* descriptor() { return descriptor_.get(); } protected: std::unique_ptr> descriptor_; }; #if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() template struct ConstCuSparseDescriptorDeleter { void operator()(T* x) { if (x != nullptr) { TORCH_CUDASPARSE_CHECK(destructor(x)); } } }; template class ConstCuSparseDescriptor { public: T* descriptor() const { return descriptor_.get(); } T* descriptor() { return descriptor_.get(); } protected: std::unique_ptr> descriptor_; }; #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS #if defined(USE_ROCM) using cusparseMatDescr = std::remove_pointer::type; using cusparseDnMatDescr = std::remove_pointer::type; using cusparseDnVecDescr = std::remove_pointer::type; using cusparseSpMatDescr = std::remove_pointer::type; using cusparseSpMatDescr = std::remove_pointer::type; using cusparseSpGEMMDescr = std::remove_pointer::type; #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() using bsrsv2Info = std::remove_pointer::type; using bsrsm2Info = std::remove_pointer::type; #endif #endif class TORCH_CUDA_CPP_API CuSparseMatDescriptor : public CuSparseDescriptor { public: CuSparseMatDescriptor() { cusparseMatDescr_t raw_descriptor; TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor)); descriptor_.reset(raw_descriptor); } CuSparseMatDescriptor(bool upper, bool unit) { cusparseFillMode_t fill_mode = upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER; cusparseDiagType_t diag_type = unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT; cusparseMatDescr_t raw_descriptor; TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor)); TORCH_CUDASPARSE_CHECK(cusparseSetMatFillMode(raw_descriptor, fill_mode)); TORCH_CUDASPARSE_CHECK(cusparseSetMatDiagType(raw_descriptor, diag_type)); descriptor_.reset(raw_descriptor); } }; #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() class TORCH_CUDA_CPP_API CuSparseBsrsv2Info : public CuSparseDescriptor { public: CuSparseBsrsv2Info() { bsrsv2Info_t raw_descriptor; TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsv2Info(&raw_descriptor)); descriptor_.reset(raw_descriptor); } }; class TORCH_CUDA_CPP_API CuSparseBsrsm2Info : public CuSparseDescriptor { public: CuSparseBsrsm2Info() { bsrsm2Info_t raw_descriptor; TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsm2Info(&raw_descriptor)); descriptor_.reset(raw_descriptor); } }; #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE #if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API() cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type); #if AT_USE_HIPSPARSE_GENERIC_52_API() || \ (AT_USE_CUSPARSE_GENERIC_API() && AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS()) class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor : public CuSparseDescriptor { public: explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1); }; class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor : public CuSparseDescriptor { public: explicit CuSparseDnVecDescriptor(const Tensor& input); }; class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor : public CuSparseDescriptor {}; //AT_USE_HIPSPARSE_GENERIC_52_API() || (AT_USE_CUSPARSE_GENERIC_API() && AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS()) #elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor : public ConstCuSparseDescriptor< cusparseDnMatDescr, &cusparseDestroyDnMat> { public: explicit CuSparseDnMatDescriptor( const Tensor& input, int64_t batch_offset = -1); }; class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor : public ConstCuSparseDescriptor< cusparseDnVecDescr, &cusparseDestroyDnVec> { public: explicit CuSparseDnVecDescriptor(const Tensor& input); }; class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor : public ConstCuSparseDescriptor< cusparseSpMatDescr, &cusparseDestroySpMat> {}; #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS() class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor : public CuSparseSpMatDescriptor { public: explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1); std::tuple get_size() { int64_t rows, cols, nnz; TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize( this->descriptor(), &rows, &cols, &nnz)); return std::make_tuple(rows, cols, nnz); } void set_tensor(const Tensor& input) { auto crow_indices = input.crow_indices(); auto col_indices = input.col_indices(); auto values = input.values(); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous()); TORCH_CUDASPARSE_CHECK(cusparseCsrSetPointers( this->descriptor(), crow_indices.data_ptr(), col_indices.data_ptr(), values.data_ptr())); } #if AT_USE_CUSPARSE_GENERIC_SPSV() void set_mat_fill_mode(bool upper) { cusparseFillMode_t fill_mode = upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER; TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute( this->descriptor(), CUSPARSE_SPMAT_FILL_MODE, &fill_mode, sizeof(fill_mode))); } void set_mat_diag_type(bool unit) { cusparseDiagType_t diag_type = unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT; TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute( this->descriptor(), CUSPARSE_SPMAT_DIAG_TYPE, &diag_type, sizeof(diag_type))); } #endif }; #if AT_USE_CUSPARSE_GENERIC_SPSV() class TORCH_CUDA_CPP_API CuSparseSpSVDescriptor : public CuSparseDescriptor { public: CuSparseSpSVDescriptor() { cusparseSpSVDescr_t raw_descriptor; TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor)); descriptor_.reset(raw_descriptor); } }; #endif #if AT_USE_CUSPARSE_GENERIC_SPSM() class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor : public CuSparseDescriptor { public: CuSparseSpSMDescriptor() { cusparseSpSMDescr_t raw_descriptor; TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor)); descriptor_.reset(raw_descriptor); } }; #endif #if (defined(USE_ROCM) && ROCM_VERSION >= 50200) || !defined(USE_ROCM) class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor : public CuSparseDescriptor { public: CuSparseSpGEMMDescriptor() { cusparseSpGEMMDescr_t raw_descriptor; TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&raw_descriptor)); descriptor_.reset(raw_descriptor); } }; #endif #endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API() } // namespace at::cuda::sparse