|
from typing import Any, Dict |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .sdf import CrossAttentionPointCloudSDFModel |
|
from .transformer import ( |
|
CLIPImageGridPointDiffusionTransformer, |
|
CLIPImageGridUpsamplePointDiffusionTransformer, |
|
CLIPImagePointDiffusionTransformer, |
|
PointDiffusionTransformer, |
|
UpsamplePointDiffusionTransformer, |
|
) |
|
|
|
MODEL_CONFIGS = { |
|
"base40M-imagevec": { |
|
"cond_drop_prob": 0.1, |
|
"heads": 8, |
|
"init_scale": 0.25, |
|
"input_channels": 6, |
|
"layers": 12, |
|
"n_ctx": 1024, |
|
"name": "CLIPImagePointDiffusionTransformer", |
|
"output_channels": 12, |
|
"time_token_cond": True, |
|
"token_cond": True, |
|
"width": 512, |
|
}, |
|
"base40M-textvec": { |
|
"cond_drop_prob": 0.1, |
|
"heads": 8, |
|
"init_scale": 0.25, |
|
"input_channels": 6, |
|
"layers": 12, |
|
"n_ctx": 1024, |
|
"name": "CLIPImagePointDiffusionTransformer", |
|
"output_channels": 12, |
|
"time_token_cond": True, |
|
"token_cond": True, |
|
"width": 512, |
|
}, |
|
"base40M-uncond": { |
|
"heads": 8, |
|
"init_scale": 0.25, |
|
"input_channels": 6, |
|
"layers": 12, |
|
"n_ctx": 1024, |
|
"name": "PointDiffusionTransformer", |
|
"output_channels": 12, |
|
"time_token_cond": True, |
|
"width": 512, |
|
}, |
|
"base40M": { |
|
"cond_drop_prob": 0.1, |
|
"heads": 8, |
|
"init_scale": 0.25, |
|
"input_channels": 6, |
|
"layers": 12, |
|
"n_ctx": 1024, |
|
"name": "CLIPImageGridPointDiffusionTransformer", |
|
"output_channels": 12, |
|
"time_token_cond": True, |
|
"width": 512, |
|
}, |
|
"base300M": { |
|
"cond_drop_prob": 0.1, |
|
"heads": 16, |
|
"init_scale": 0.25, |
|
"input_channels": 6, |
|
"layers": 24, |
|
"n_ctx": 1024, |
|
"name": "CLIPImageGridPointDiffusionTransformer", |
|
"output_channels": 12, |
|
"time_token_cond": True, |
|
"width": 1024, |
|
}, |
|
"base1B": { |
|
"cond_drop_prob": 0.1, |
|
"heads": 32, |
|
"init_scale": 0.25, |
|
"input_channels": 6, |
|
"layers": 24, |
|
"n_ctx": 1024, |
|
"name": "CLIPImageGridPointDiffusionTransformer", |
|
"output_channels": 12, |
|
"time_token_cond": True, |
|
"width": 2048, |
|
}, |
|
"upsample": { |
|
"channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0], |
|
"channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255], |
|
"cond_ctx": 1024, |
|
"cond_drop_prob": 0.1, |
|
"heads": 8, |
|
"init_scale": 0.25, |
|
"input_channels": 6, |
|
"layers": 12, |
|
"n_ctx": 3072, |
|
"name": "CLIPImageGridUpsamplePointDiffusionTransformer", |
|
"output_channels": 12, |
|
"time_token_cond": True, |
|
"width": 512, |
|
}, |
|
"sdf": { |
|
"decoder_heads": 4, |
|
"decoder_layers": 4, |
|
"encoder_heads": 4, |
|
"encoder_layers": 8, |
|
"init_scale": 0.25, |
|
"n_ctx": 4096, |
|
"name": "CrossAttentionPointCloudSDFModel", |
|
"width": 256, |
|
}, |
|
} |
|
|
|
|
|
def model_from_config(config: Dict[str, Any], device: torch.device) -> nn.Module: |
|
config = config.copy() |
|
name = config.pop("name") |
|
if name == "PointDiffusionTransformer": |
|
return PointDiffusionTransformer(device=device, dtype=torch.float32, **config) |
|
elif name == "CLIPImagePointDiffusionTransformer": |
|
return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config) |
|
elif name == "CLIPImageGridPointDiffusionTransformer": |
|
return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config) |
|
elif name == "UpsamplePointDiffusionTransformer": |
|
return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config) |
|
elif name == "CLIPImageGridUpsamplePointDiffusionTransformer": |
|
return CLIPImageGridUpsamplePointDiffusionTransformer( |
|
device=device, dtype=torch.float32, **config |
|
) |
|
elif name == "CrossAttentionPointCloudSDFModel": |
|
return CrossAttentionPointCloudSDFModel(device=device, dtype=torch.float32, **config) |
|
raise ValueError(f"unknown model name: {name}") |
|
|