|
import dataclasses |
|
from typing import Optional |
|
|
|
|
|
@dataclasses.dataclass |
|
class LyraBaichuanParam: |
|
num_heads: int = 40 |
|
size_per_head: int = 128 |
|
inter_size: int = 13824 |
|
num_layers: int = 40 |
|
vocab_size: int = 39424 |
|
start_id: Optional[int] = 1 |
|
end_id: Optional[int] = 2 |
|
tensor_para_size: int = 1 |
|
pipeline_para_size: int = 1 |
|
remove_padding: bool = True |
|
shared_contexts_ratio: float = 1.0 |
|
layernorm_eps: float = 1e-6 |
|
weights_data_type: str = "fp16" |
|
rotary_embedding: int = 128 |
|
use_gptj_residual: bool = False |
|
|
|
def __post_init__(self): |
|
if not 0.0 <= self.shared_contexts_ratio <= 1.0: |
|
raise ValueError( |
|
f'Got an invalid value of shared_context_ratio ' |
|
f'{self.shared_contexts_ratio} - range: [0.0, 1.0]') |
|
|
|
def asdict(self): |
|
return dataclasses.asdict(self) |
|
|
|
|
|
LYRA_BAICHUAN_PARAM = LyraBaichuanParam() |
|
LIB_SO_PATH = '/usr/lib/ftlib/libth_transformer.so' |