File size: 767 Bytes
138f509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..octree import DfsOctree as Octree


class Strivec(Octree):
    def __init__(

        self,

        resolution: int,

        aabb: list,

        sh_degree: int = 0,

        rank: int = 8,

        dim: int = 8,

        device: str = "cuda",

    ):
        assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2"
        self.resolution = resolution
        depth = int(np.round(np.log2(resolution)))
        super().__init__(
            depth=depth,
            aabb=aabb,
            sh_degree=sh_degree,
            primitive="trivec",
            primitive_config={"rank": rank, "dim": dim},
            device=device,
        )