Spaces:
Sleeping
Sleeping
from configs.structured import ProjectConfig | |
from .model import ConditionalPointCloudDiffusionModel | |
from .model_coloring import PointCloudColoringModel | |
from .model_utils import set_requires_grad | |
from .model_diff_data import ConditionalPCDiffusionSeparateSegm | |
from .model_hoattn import CrossAttenHODiffusionModel | |
def get_model(cfg: ProjectConfig): | |
if cfg.model.model_name == 'pc2-diff': | |
model = ConditionalPointCloudDiffusionModel(**cfg.model) | |
elif cfg.model.model_name == 'pc2-diff-ho-sepsegm': | |
model = ConditionalPCDiffusionSeparateSegm(**cfg.model) | |
print("Using a separate model to predict segmentation label") | |
elif cfg.model.model_name == 'diff-ho-attn': | |
model = CrossAttenHODiffusionModel(**cfg.model) | |
print("Using separate model for human + object with cross attention.") | |
else: | |
raise NotImplementedError | |
if cfg.run.freeze_feature_model: | |
set_requires_grad(model.feature_model, False) | |
return model | |
def get_coloring_model(cfg: ProjectConfig): | |
model = PointCloudColoringModel(**cfg.model) | |
if cfg.run.freeze_feature_model: | |
set_requires_grad(model.feature_model, False) | |
return model | |