from typing import List from pydantic import validator from my.config import BaseConf, SingleOrList, dispatch from my.utils.seed import seed_everything import numpy as np from voxnerf.vox import VOXRF_REGISTRY from voxnerf.pipelines import train class VoxConfig(BaseConf): model_type: str = "VoxRF" bbox_len: float = 1.5 grid_size: SingleOrList(int) = [128, 128, 128] step_ratio: float = 0.5 density_shift: float = -10. ray_march_weight_thres: float = 0.0001 c: int = 3 blend_bg_texture: bool = False bg_texture_hw: int = 64 @validator("grid_size") def check_gsize(cls, grid_size): if isinstance(grid_size, int): return [grid_size, ] * 3 else: assert len(grid_size) == 3 return grid_size def make(self): params = self.dict() m_type = params.pop("model_type") model_fn = VOXRF_REGISTRY.get(m_type) radius = params.pop('bbox_len') aabb = radius * np.array([ [-1, -1, -1], [1, 1, 1] ]) model = model_fn(aabb=aabb, **params) return model class TrainerConfig(BaseConf): model: VoxConfig = VoxConfig() scene: str = "lego" n_epoch: int = 2 bs: int = 4096 lr: float = 0.02 def run(self): args = self.dict() args.pop("model") model = self.model.make() train(model, **args) if __name__ == "__main__": seed_everything(0) dispatch(TrainerConfig)