mini-omni / litgpt /config.py
gpt-omni's picture
init
5e4b316 verified
raw
history blame
7.65 kB
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional, Type, Union
import torch
import yaml
from typing_extensions import Self
import litgpt.model
from litgpt.utils import find_multiple
@dataclass
class Config:
name: str = ""
hf_config: dict = field(default_factory=dict)
scale_embeddings: bool = False
block_size: int = 4096
vocab_size: int = 50254
padding_multiple: int = 512
padded_vocab_size: Optional[int] = None
n_layer: int = 16
n_head: int = 32
head_size: Optional[int] = None
n_embd: int = 4096
rotary_percentage: float = 0.25
parallel_residual: bool = True
bias: bool = True
lm_head_bias: bool = False
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
# to use grouped-query attention (GQA), set this to a value in between
# Example with `n_head=4`
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
# │ │ │ │ │ │ │
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
# MHA GQA MQA
# n_query_groups=4 n_query_groups=2 n_query_groups=1
#
# credit https://arxiv.org/pdf/2305.13245.pdf
n_query_groups: Optional[int] = None
shared_attention_norm: bool = False
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
"GptNeoxMLP"
)
gelu_approximate: str = "none"
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
rope_base: int = 10000
n_expert: int = 0
n_expert_per_token: int = 0
add_qkv_bias: Optional[bool] = None
prompt_vocab_size: Optional[int] = None
attn_dropout: float = 0.0
pos_type: str = "rope"
force_align: bool = False
use_pretrain_phoneme_emb: bool = False
tie_word_embeddings: bool = False
# setting for mini-omni
text_vocab_size:int = 152000
cat_audio_vocab_size: int = 29120
audio_vocab_size: int = 4160
whisper_adapter_dim: int = 768
post_adapter: bool = False
post_adapter_layers: int = 6
asr_adapter: str = "llamamlp"
def __post_init__(self):
if not self.name:
self.name = self.hf_config.get("name", self.name)
if self.head_size is None:
assert self.n_embd % self.n_head == 0
self.head_size = self.n_embd // self.n_head
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
if self.padded_vocab_size is None:
self.padded_vocab_size = find_multiple(
self.vocab_size, self.padding_multiple
)
else:
# vocab size shouldn't be larger than padded vocab size
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
# compute the number of query groups
if self.n_query_groups is not None:
assert self.n_head % self.n_query_groups == 0
else:
self.n_query_groups = self.n_head
# compute the intermediate size for MLP if not set
if self.intermediate_size is None:
if self.mlp_class_name == "LLaMAMLP":
raise ValueError(
f"The config {self.name!r}, needs to set the `intermediate_size`"
)
self.intermediate_size = 4 * self.n_embd
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
if self.add_qkv_bias is None:
self.add_qkv_bias = self.bias
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
if name not in name_to_config:
# search through all `config['hf_config']['name']`
try:
conf_dict = next(
config
for config in configs
if name == config["hf_config"]["name"]
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
== name
)
except StopIteration:
raise ValueError(f"{name!r} is not a supported config name")
else:
conf_dict = name_to_config[name]
conf_dict = conf_dict.copy()
conf_dict.update(kwargs)
return cls(**conf_dict)
@classmethod
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
with open(path, encoding="utf-8") as fp:
file_kwargs = yaml.safe_load(fp)
if file_kwargs is None:
raise ValueError(f"{path} is empty which is likely unexpected.")
file_kwargs.update(kwargs)
return cls(**file_kwargs)
@classmethod
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
"""Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
if (config_path := path / "model_config.yaml").is_file():
return cls.from_file(config_path, **kwargs)
if (model_name := path.name) in name_to_config:
return cls.from_name(model_name, **kwargs)
raise FileNotFoundError(
f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
)
@property
def mlp_class(self) -> Type:
# `self.mlp_class_name` cannot be the type to keep the config serializable
return getattr(litgpt.model, self.mlp_class_name)
@property
def norm_class(self) -> Type:
# `self.norm_class_name` cannot be the type to keep the config serializable
if self.norm_class_name == "RMSNorm":
from functools import partial
from litgpt.model import RMSNorm
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
return getattr(torch.nn, self.norm_class_name)
configs = []
name_to_config = {config["name"]: config for config in configs}