import dataclasses from typing import Optional @dataclasses.dataclass class BelleParam: num_heads: int = 32 size_per_head: int = 128 inter_size: int = 16384 num_layers: int = 30 vocab_size: int = 250880 start_id: Optional[int] = 1 end_id: Optional[int] = 2 tensor_para_size: int = 1 pipeline_para_size: int = 1 remove_padding: bool = True shared_contexts_ratio: float = 1.0 weights_data_type: str = "fp16" def __post_init__(self): if not 0.0 <= self.shared_contexts_ratio <= 1.0: raise ValueError( f'Got an invalid value of shared_context_ratio ' f'{self.shared_contexts_ratio} - range: [0.0, 1.0]') def asdict(self): return dataclasses.asdict(self) BELLE_PARAM = BelleParam() import os current_dir = os.path.dirname(os.path.abspath(__file__)) LIB_SO_PATH = os.path.join(current_dir, 'libth_transformer.so')