|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from ..octree import DfsOctree as Octree |
|
|
|
|
|
class Strivec(Octree): |
|
def __init__( |
|
self, |
|
resolution: int, |
|
aabb: list, |
|
sh_degree: int = 0, |
|
rank: int = 8, |
|
dim: int = 8, |
|
device: str = "cuda", |
|
): |
|
assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2" |
|
self.resolution = resolution |
|
depth = int(np.round(np.log2(resolution))) |
|
super().__init__( |
|
depth=depth, |
|
aabb=aabb, |
|
sh_degree=sh_degree, |
|
primitive="trivec", |
|
primitive_config={"rank": rank, "dim": dim}, |
|
device=device, |
|
) |
|
|