|
from dataclasses import dataclass, field |
|
from typing import Literal |
|
|
|
|
|
@dataclass |
|
class BackboneConfig: |
|
d_model: int = 1024 |
|
d_intermediate: int = 0 |
|
attn_mlp_d_intermediate: int = 0 |
|
n_layer: int = 16 |
|
ssm_cfg: dict = field(default_factory=dict) |
|
attn_layer_idx: list = field(default_factory=list) |
|
attn_cfg: dict = field(default_factory=dict) |
|
rms_norm: bool = False |
|
residual_in_fp32: bool = False |
|
norm_epsilon: float = 1e-5 |
|
|
|
|
|
@dataclass |
|
class PrefixConditionerConfig: |
|
conditioners: list[dict] |
|
projection: Literal["none", "linear", "mlp"] |
|
|
|
|
|
@dataclass |
|
class ZonosConfig: |
|
backbone: BackboneConfig |
|
prefix_conditioner: PrefixConditionerConfig |
|
eos_token_id: int = 1024 |
|
masked_token_id: int = 1025 |
|
|
|
@classmethod |
|
def from_dict(cls, d: dict) -> "ZonosConfig": |
|
d = d.copy() |
|
backbone_config = BackboneConfig(**d.pop("backbone")) |
|
prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner")) |
|
config = cls(backbone_config, prefix_conditioner_config, **d) |
|
return config |
|
|