Clemspace's picture
Initial model upload
cb9e677
from dataclasses import dataclass, field
from typing import Optional
from simple_parsing.helpers import Serializable
@dataclass
class LoraArgs(Serializable):
enable: bool = True
rank: int = 16
dropout: float = 0.0
scaling: float = 2.0
def __post_init__(self):
if self.enable:
assert self.rank > 0
assert self.scaling > 0.0
@dataclass
class MoeArgs(Serializable):
num_experts: int = 8
num_experts_per_tok: int = 2
@dataclass
class ModelArgs(Serializable):
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
rope_theta: float = 10000.0
lora: LoraArgs = field(default_factory=LoraArgs)
moe: Optional[MoeArgs] = None