Spaces:
Sleeping
Sleeping
File size: 1,185 Bytes
2fd6166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from configs.structured import ProjectConfig
from .model import ConditionalPointCloudDiffusionModel
from .model_coloring import PointCloudColoringModel
from .model_utils import set_requires_grad
from .model_diff_data import ConditionalPCDiffusionSeparateSegm
from .model_hoattn import CrossAttenHODiffusionModel
def get_model(cfg: ProjectConfig):
if cfg.model.model_name == 'pc2-diff':
model = ConditionalPointCloudDiffusionModel(**cfg.model)
elif cfg.model.model_name == 'pc2-diff-ho-sepsegm':
model = ConditionalPCDiffusionSeparateSegm(**cfg.model)
print("Using a separate model to predict segmentation label")
elif cfg.model.model_name == 'diff-ho-attn':
model = CrossAttenHODiffusionModel(**cfg.model)
print("Using separate model for human + object with cross attention.")
else:
raise NotImplementedError
if cfg.run.freeze_feature_model:
set_requires_grad(model.feature_model, False)
return model
def get_coloring_model(cfg: ProjectConfig):
model = PointCloudColoringModel(**cfg.model)
if cfg.run.freeze_feature_model:
set_requires_grad(model.feature_model, False)
return model
|