hzxie's picture
feat: citydreamer inference (bugs to be fixed).
79df973 verified
// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, check out LICENSE.md
#ifndef VOXLIB_COMMON_H
#define VOXLIB_COMMON_H
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_CPU(x) \
TORCH_CHECK(x.device().is_cpu(), #x " must be a CPU tensor")
#include <cuda.h>
#include <cuda_runtime.h>
// CUDA vector math functions
__host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
template <typename scalar_t>
__host__ __forceinline__ void cross(scalar_t *r, const scalar_t *a,
const scalar_t *b) {
r[0] = a[1] * b[2] - a[2] * b[1];
r[1] = a[2] * b[0] - a[0] * b[2];
r[2] = a[0] * b[1] - a[1] * b[0];
}
__device__ __host__ __forceinline__ float dot(const float *a, const float *b) {
return a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
}
template <typename scalar_t, int ndim>
__device__ __host__ __forceinline__ void copyarr(scalar_t *r,
const scalar_t *a) {
#pragma unroll
for (int i = 0; i < ndim; i++) {
r[i] = a[i];
}
}
// TODO: use rsqrt to speed up
// inplace version
template <typename scalar_t, int ndim>
__device__ __host__ __forceinline__ void normalize(scalar_t *a) {
scalar_t vec_len = 0.0f;
#pragma unroll
for (int i = 0; i < ndim; i++) {
vec_len += a[i] * a[i];
}
vec_len = sqrtf(vec_len);
#pragma unroll
for (int i = 0; i < ndim; i++) {
a[i] /= vec_len;
}
}
// normalize + copy
template <typename scalar_t, int ndim>
__device__ __host__ __forceinline__ void normalize(scalar_t *r,
const scalar_t *a) {
scalar_t vec_len = 0.0f;
#pragma unroll
for (int i = 0; i < ndim; i++) {
vec_len += a[i] * a[i];
}
vec_len = sqrtf(vec_len);
#pragma unroll
for (int i = 0; i < ndim; i++) {
r[i] = a[i] / vec_len;
}
}
#endif // VOXLIB_COMMON_H