import dataclasses from typing import Optional @dataclasses.dataclass class LyraBaichuanParam: num_heads: int = 40 size_per_head: int = 128 inter_size: int = 13824 num_layers: int = 40 vocab_size: int = 39424 start_id: Optional[int] = 1 end_id: Optional[int] = 2 tensor_para_size: int = 1 pipeline_para_size: int = 1 remove_padding: bool = True shared_contexts_ratio: float = 1.0 layernorm_eps: float = 1e-6 weights_data_type: str = "fp16" rotary_embedding: int = 128 use_gptj_residual: bool = False def __post_init__(self): if not 0.0 <= self.shared_contexts_ratio <= 1.0: raise ValueError( f'Got an invalid value of shared_context_ratio ' f'{self.shared_contexts_ratio} - range: [0.0, 1.0]') def asdict(self): return dataclasses.asdict(self) LYRA_BAICHUAN_PARAM = LyraBaichuanParam() LIB_SO_PATH = '/usr/lib/ftlib/libth_transformer.so'