/* * Copyright (C) 2023, Inria * GRAPHDECO research group, https://team.inria.fr/graphdeco * All rights reserved. * * This software is free for non-commercial, research and evaluation use * under the terms of the LICENSE.md file. * * For inquiries contact george.drettakis@inria.fr */ #include #include #include #include #include #include #include #include #include #include "cuda_rasterizer/config.h" #include "cuda_rasterizer/rasterizer.h" #include #include #include std::function resizeFunctional(torch::Tensor& t) { auto lambda = [&t](size_t N) { t.resize_({(long long)N}); return reinterpret_cast(t.contiguous().data_ptr()); }; return lambda; } std::tuple RasterizeGaussiansCUDA( const torch::Tensor& background, const torch::Tensor& means3D, const torch::Tensor& colors, const torch::Tensor& opacity, const torch::Tensor& scales, const torch::Tensor& rotations, const float scale_modifier, const torch::Tensor& cov3D_precomp, const torch::Tensor& viewmatrix, const torch::Tensor& projmatrix, const float tan_fovx, const float tan_fovy, const int image_height, const int image_width, const torch::Tensor& sh, const int degree, const torch::Tensor& campos, const bool prefiltered, const bool debug) { if (means3D.ndimension() != 2 || means3D.size(1) != 3) { AT_ERROR("means3D must have dimensions (num_points, 3)"); } const int P = means3D.size(0); const int H = image_height; const int W = image_width; auto int_opts = means3D.options().dtype(torch::kInt32); auto float_opts = means3D.options().dtype(torch::kFloat32); torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts); torch::Tensor out_alpha = torch::full({1, H, W}, 0.0, float_opts); torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); torch::Device device(torch::kCUDA); torch::TensorOptions options(torch::kByte); torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); std::function geomFunc = resizeFunctional(geomBuffer); std::function binningFunc = resizeFunctional(binningBuffer); std::function imgFunc = resizeFunctional(imgBuffer); int rendered = 0; if(P != 0) { int M = 0; if(sh.size(0) != 0) { M = sh.size(1); } rendered = CudaRasterizer::Rasterizer::forward( geomFunc, binningFunc, imgFunc, P, degree, M, background.contiguous().data(), W, H, means3D.contiguous().data(), sh.contiguous().data_ptr(), colors.contiguous().data(), opacity.contiguous().data(), scales.contiguous().data_ptr(), scale_modifier, rotations.contiguous().data_ptr(), cov3D_precomp.contiguous().data(), viewmatrix.contiguous().data(), projmatrix.contiguous().data(), campos.contiguous().data(), tan_fovx, tan_fovy, prefiltered, out_color.contiguous().data(), out_depth.contiguous().data(), out_alpha.contiguous().data(), radii.contiguous().data(), debug); } return std::make_tuple(rendered, out_color, out_depth, out_alpha, radii, geomBuffer, binningBuffer, imgBuffer); } std::tuple RasterizeGaussiansBackwardCUDA( const torch::Tensor& background, const torch::Tensor& means3D, const torch::Tensor& radii, const torch::Tensor& colors, const torch::Tensor& scales, const torch::Tensor& rotations, const float scale_modifier, const torch::Tensor& cov3D_precomp, const torch::Tensor& viewmatrix, const torch::Tensor& projmatrix, const float tan_fovx, const float tan_fovy, const torch::Tensor& dL_dout_color, const torch::Tensor& dL_dout_depth, const torch::Tensor& dL_dout_alpha, const torch::Tensor& sh, const int degree, const torch::Tensor& campos, const torch::Tensor& geomBuffer, const int R, const torch::Tensor& binningBuffer, const torch::Tensor& imageBuffer, const torch::Tensor& alphas, const bool debug) { const int P = means3D.size(0); const int H = dL_dout_color.size(1); const int W = dL_dout_color.size(2); int M = 0; if(sh.size(0) != 0) { M = sh.size(1); } torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); torch::Tensor dL_ddepths = torch::zeros({P, 1}, means3D.options()); torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options()); torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options()); torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options()); torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); if(P != 0) { CudaRasterizer::Rasterizer::backward(P, degree, M, R, background.contiguous().data(), W, H, means3D.contiguous().data(), sh.contiguous().data(), colors.contiguous().data(), alphas.contiguous().data(), scales.data_ptr(), scale_modifier, rotations.data_ptr(), cov3D_precomp.contiguous().data(), viewmatrix.contiguous().data(), projmatrix.contiguous().data(), campos.contiguous().data(), tan_fovx, tan_fovy, radii.contiguous().data(), reinterpret_cast(geomBuffer.contiguous().data_ptr()), reinterpret_cast(binningBuffer.contiguous().data_ptr()), reinterpret_cast(imageBuffer.contiguous().data_ptr()), dL_dout_color.contiguous().data(), dL_dout_depth.contiguous().data(), dL_dout_alpha.contiguous().data(), dL_dmeans2D.contiguous().data(), dL_dconic.contiguous().data(), dL_dopacity.contiguous().data(), dL_dcolors.contiguous().data(), dL_ddepths.contiguous().data(), dL_dmeans3D.contiguous().data(), dL_dcov3D.contiguous().data(), dL_dsh.contiguous().data(), dL_dscales.contiguous().data(), dL_drotations.contiguous().data(), debug); } return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations); } torch::Tensor markVisible( torch::Tensor& means3D, torch::Tensor& viewmatrix, torch::Tensor& projmatrix) { const int P = means3D.size(0); torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); if(P != 0) { CudaRasterizer::Rasterizer::markVisible(P, means3D.contiguous().data(), viewmatrix.contiguous().data(), projmatrix.contiguous().data(), present.contiguous().data()); } return present; }