from dataclasses import dataclass import sys @dataclass class XftConfig: max_seq_len: int = 4096 beam_width: int = 1 eos_token_id: int = -1 pad_token_id: int = -1 num_return_sequences: int = 1 is_encoder_decoder: bool = False padding: bool = True early_stopping: bool = False data_type: str = "bf16_fp16" class XftModel: def __init__(self, xft_model, xft_config): self.model = xft_model self.config = xft_config def load_xft_model(model_path, xft_config: XftConfig): try: import xfastertransformer from transformers import AutoTokenizer except ImportError as e: print(f"Error: Failed to load xFasterTransformer. {e}") sys.exit(-1) if xft_config.data_type is None or xft_config.data_type == "": data_type = "bf16_fp16" else: data_type = xft_config.data_type tokenizer = AutoTokenizer.from_pretrained( model_path, use_fast=False, padding_side="left", trust_remote_code=True ) xft_model = xfastertransformer.AutoModel.from_pretrained( model_path, dtype=data_type ) model = XftModel(xft_model=xft_model, xft_config=xft_config) if model.model.rank > 0: while True: model.model.generate() return model, tokenizer