Spaces:
Build error
Build error
from typing import List | |
from pydantic import validator | |
from my.config import BaseConf, SingleOrList, dispatch | |
from my.utils.seed import seed_everything | |
import numpy as np | |
from voxnerf.vox import VOXRF_REGISTRY | |
from voxnerf.pipelines import train | |
class VoxConfig(BaseConf): | |
model_type: str = "VoxRF" | |
bbox_len: float = 1.5 | |
grid_size: SingleOrList(int) = [128, 128, 128] | |
step_ratio: float = 0.5 | |
density_shift: float = -10. | |
ray_march_weight_thres: float = 0.0001 | |
c: int = 3 | |
blend_bg_texture: bool = False | |
bg_texture_hw: int = 64 | |
def check_gsize(cls, grid_size): | |
if isinstance(grid_size, int): | |
return [grid_size, ] * 3 | |
else: | |
assert len(grid_size) == 3 | |
return grid_size | |
def make(self): | |
params = self.dict() | |
m_type = params.pop("model_type") | |
model_fn = VOXRF_REGISTRY.get(m_type) | |
radius = params.pop('bbox_len') | |
aabb = radius * np.array([ | |
[-1, -1, -1], | |
[1, 1, 1] | |
]) | |
model = model_fn(aabb=aabb, **params) | |
return model | |
class TrainerConfig(BaseConf): | |
model: VoxConfig = VoxConfig() | |
scene: str = "lego" | |
n_epoch: int = 2 | |
bs: int = 4096 | |
lr: float = 0.02 | |
def run(self): | |
args = self.dict() | |
args.pop("model") | |
model = self.model.make() | |
train(model, **args) | |
if __name__ == "__main__": | |
seed_everything(0) | |
dispatch(TrainerConfig) | |