from dataclasses import dataclass, field import sys @dataclass class ExllamaConfig: max_seq_len: int gpu_split: str = None cache_8bit: bool = False class ExllamaModel: def __init__(self, exllama_model, exllama_cache): self.model = exllama_model self.cache = exllama_cache self.config = self.model.config def load_exllama_model(model_path, exllama_config: ExllamaConfig): try: from exllamav2 import ( ExLlamaV2Config, ExLlamaV2Tokenizer, ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Cache_8bit, ) except ImportError as e: print(f"Error: Failed to load Exllamav2. {e}") sys.exit(-1) exllamav2_config = ExLlamaV2Config() exllamav2_config.model_dir = model_path exllamav2_config.prepare() exllamav2_config.max_seq_len = exllama_config.max_seq_len exllamav2_config.cache_8bit = exllama_config.cache_8bit exllama_model = ExLlamaV2(exllamav2_config) tokenizer = ExLlamaV2Tokenizer(exllamav2_config) split = None if exllama_config.gpu_split: split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] exllama_model.load(split) cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache exllama_cache = cache_class(exllama_model) model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) return model, tokenizer