English
Thomas Male
Upload 98 files
a5407e7
from typing import Any, Dict
import torch
import torch.nn as nn
from .sdf import CrossAttentionPointCloudSDFModel
from .transformer import (
CLIPImageGridPointDiffusionTransformer,
CLIPImageGridUpsamplePointDiffusionTransformer,
CLIPImagePointDiffusionTransformer,
PointDiffusionTransformer,
UpsamplePointDiffusionTransformer,
)
MODEL_CONFIGS = {
"base40M-imagevec": {
"cond_drop_prob": 0.1,
"heads": 8,
"init_scale": 0.25,
"input_channels": 6,
"layers": 12,
"n_ctx": 1024,
"name": "CLIPImagePointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"token_cond": True,
"width": 512,
},
"base40M-textvec": {
"cond_drop_prob": 0.1,
"heads": 8,
"init_scale": 0.25,
"input_channels": 6,
"layers": 12,
"n_ctx": 1024,
"name": "CLIPImagePointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"token_cond": True,
"width": 512,
},
"base40M-uncond": {
"heads": 8,
"init_scale": 0.25,
"input_channels": 6,
"layers": 12,
"n_ctx": 1024,
"name": "PointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"width": 512,
},
"base40M": {
"cond_drop_prob": 0.1,
"heads": 8,
"init_scale": 0.25,
"input_channels": 6,
"layers": 12,
"n_ctx": 1024,
"name": "CLIPImageGridPointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"width": 512,
},
"base300M": {
"cond_drop_prob": 0.1,
"heads": 16,
"init_scale": 0.25,
"input_channels": 6,
"layers": 24,
"n_ctx": 1024,
"name": "CLIPImageGridPointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"width": 1024,
},
"base1B": {
"cond_drop_prob": 0.1,
"heads": 32,
"init_scale": 0.25,
"input_channels": 6,
"layers": 24,
"n_ctx": 1024,
"name": "CLIPImageGridPointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"width": 2048,
},
"upsample": {
"channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0],
"channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255],
"cond_ctx": 1024,
"cond_drop_prob": 0.1,
"heads": 8,
"init_scale": 0.25,
"input_channels": 6,
"layers": 12,
"n_ctx": 3072,
"name": "CLIPImageGridUpsamplePointDiffusionTransformer",
"output_channels": 12,
"time_token_cond": True,
"width": 512,
},
"sdf": {
"decoder_heads": 4,
"decoder_layers": 4,
"encoder_heads": 4,
"encoder_layers": 8,
"init_scale": 0.25,
"n_ctx": 4096,
"name": "CrossAttentionPointCloudSDFModel",
"width": 256,
},
}
def model_from_config(config: Dict[str, Any], device: torch.device) -> nn.Module:
config = config.copy()
name = config.pop("name")
if name == "PointDiffusionTransformer":
return PointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImagePointDiffusionTransformer":
return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImageGridPointDiffusionTransformer":
return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "UpsamplePointDiffusionTransformer":
return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImageGridUpsamplePointDiffusionTransformer":
return CLIPImageGridUpsamplePointDiffusionTransformer(
device=device, dtype=torch.float32, **config
)
elif name == "CrossAttentionPointCloudSDFModel":
return CrossAttentionPointCloudSDFModel(device=device, dtype=torch.float32, **config)
raise ValueError(f"unknown model name: {name}")