xiexh20's picture
add hdm demo v1
2fd6166
raw
history blame
1.19 kB
from configs.structured import ProjectConfig
from .model import ConditionalPointCloudDiffusionModel
from .model_coloring import PointCloudColoringModel
from .model_utils import set_requires_grad
from .model_diff_data import ConditionalPCDiffusionSeparateSegm
from .model_hoattn import CrossAttenHODiffusionModel
def get_model(cfg: ProjectConfig):
if cfg.model.model_name == 'pc2-diff':
model = ConditionalPointCloudDiffusionModel(**cfg.model)
elif cfg.model.model_name == 'pc2-diff-ho-sepsegm':
model = ConditionalPCDiffusionSeparateSegm(**cfg.model)
print("Using a separate model to predict segmentation label")
elif cfg.model.model_name == 'diff-ho-attn':
model = CrossAttenHODiffusionModel(**cfg.model)
print("Using separate model for human + object with cross attention.")
else:
raise NotImplementedError
if cfg.run.freeze_feature_model:
set_requires_grad(model.feature_model, False)
return model
def get_coloring_model(cfg: ProjectConfig):
model = PointCloudColoringModel(**cfg.model)
if cfg.run.freeze_feature_model:
set_requires_grad(model.feature_model, False)
return model