import os import sys from typing import List, Optional, Any, Dict, Tuple import torch from loguru import logger from peft import PeftModel from tqdm import tqdm from transformers import ( AutoModel, AutoConfig, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedTokenizer, PreTrainedModel, ) from transformers.utils.versions import require_version if sys.version_info >= (3, 9): from functools import cache else: from functools import lru_cache as cache class BaseModelAdapter: """ The base and default model adapter. """ model_names = [] def match(self, model_name) -> bool: """ Check if the given model name matches any of the predefined model names. Args: model_name (str): The model name to check. Returns: bool: True if the model name matches any of the predefined model names, False otherwise. """ return any(m in model_name for m in self.model_names) if self.model_names else True def load_model( self, model_name_or_path: Optional[str] = None, adapter_model: Optional[str] = None, **kwargs: Any, ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: """ Load a model and tokenizer based on the provided model name or path. Args: model_name_or_path (str, optional): The name or path of the model. Defaults to None. adapter_model (str, optional): The adapter model to load the tokenizer from. Defaults to None. **kwargs: Additional keyword arguments. Returns: Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer. """ model_name_or_path = model_name_or_path or self.default_model_name_or_path tokenizer_kwargs = {"trust_remote_code": True, "use_fast": False} tokenizer_kwargs.update(self.tokenizer_kwargs) # load a tokenizer from adapter model if it exists. if adapter_model is not None: try: tokenizer = self.tokenizer_class.from_pretrained( adapter_model, **tokenizer_kwargs, ) except OSError: tokenizer = self.tokenizer_class.from_pretrained( model_name_or_path, **tokenizer_kwargs, ) else: tokenizer = self.tokenizer_class.from_pretrained( model_name_or_path, **tokenizer_kwargs, ) config_kwargs = self.model_kwargs device = kwargs.get("device", "cuda") num_gpus = kwargs.get("num_gpus", 1) dtype = kwargs.get("dtype", "half") if device == "cuda": if "torch_dtype" not in config_kwargs: if dtype == "half": config_kwargs["torch_dtype"] = torch.float16 elif dtype == "bfloat16": config_kwargs["torch_dtype"] = torch.bfloat16 elif dtype == "float32": config_kwargs["torch_dtype"] = torch.float32 if num_gpus != 1: config_kwargs["device_map"] = "auto" # model_kwargs["device_map"] = "sequential" # This is important for not the same VRAM sizes # Quantization configurations (using bitsandbytes library). if kwargs.get("load_in_8bit", False): require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") config_kwargs["load_in_8bit"] = True config_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_8bit=True, llm_int8_threshold=6.0, ) config_kwargs["device_map"] = "auto" if device == "cuda" else None logger.info("Quantizing model to 8 bit.") elif kwargs.get("load_in_4bit", False): require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") config_kwargs["load_in_4bit"] = True config_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) config_kwargs["device_map"] = "auto" if device == "cuda" else None logger.info("Quantizing model to 4 bit.") if kwargs.get("device_map", None) == "auto": config_kwargs["device_map"] = "auto" config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) # Fix config (for Qwen) if hasattr(config, "fp16") and hasattr(config, "bf16"): setattr(config, "fp16", dtype == "half") setattr(config, "bf16", dtype == "bfloat16") config_kwargs.pop("torch_dtype", None) if kwargs.get("using_ptuning_v2", False) and adapter_model: config.pre_seq_len = kwargs.get("pre_seq_len", 128) # Load and prepare pretrained models (without valuehead). model = self.model_class.from_pretrained( model_name_or_path, config=config, trust_remote_code=True, **config_kwargs ) if device == "cpu": model = model.float() # post process for special tokens tokenizer = self.post_tokenizer(tokenizer) is_chatglm = "chatglm" in str(type(model)) if adapter_model is not None: model = self.load_adapter_model(model, tokenizer, adapter_model, is_chatglm, config_kwargs, **kwargs) if is_chatglm or "baichuan" in str(type(model)) or "xverse" in str(type(model)): quantize = kwargs.get("quantize", None) if quantize and quantize != 16: logger.info(f"Quantizing model to {quantize} bit.") model = model.quantize(quantize) if device == "cuda" and num_gpus == 1 and "device_map" not in config_kwargs: model.to(device) # inference mode model.eval() return model, tokenizer def load_lora_model( self, model: PreTrainedModel, adapter_model: str, model_kwargs: Dict, ) -> PeftModel: """ Load a LoRA model. This function loads a LoRA model using the specified pretrained model and adapter model. Args: model (PreTrainedModel): The base pretrained model. adapter_model (str): The name or path of the adapter model. model_kwargs (dict): Additional keyword arguments for the model. Returns: PeftModel: The loaded LoRA model. """ return PeftModel.from_pretrained( model, adapter_model, torch_dtype=model_kwargs.get("torch_dtype", torch.float16), ) def load_adapter_model( self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, adapter_model: str, is_chatglm: bool, model_kwargs: Dict, **kwargs: Any, ) -> PreTrainedModel: using_ptuning_v2 = kwargs.get("using_ptuning_v2", False) resize_embeddings = kwargs.get("resize_embeddings", False) if adapter_model and resize_embeddings and not is_chatglm: model_vocab_size = model.get_input_embeddings().weight.size(0) tokenzier_vocab_size = len(tokenizer) logger.info(f"Vocab of the base model: {model_vocab_size}") logger.info(f"Vocab of the tokenizer: {tokenzier_vocab_size}") if model_vocab_size != tokenzier_vocab_size: assert tokenzier_vocab_size > model_vocab_size logger.info("Resize model embeddings to fit tokenizer") model.resize_token_embeddings(tokenzier_vocab_size) if using_ptuning_v2: prefix_state_dict = torch.load(os.path.join(adapter_model, "pytorch_model.bin")) new_prefix_state_dict = { k[len("transformer.prefix_encoder."):]: v for k, v in prefix_state_dict.items() if k.startswith("transformer.prefix_encoder.") } model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) model.transformer.prefix_encoder.float() else: model = self.load_lora_model(model, adapter_model, model_kwargs) return model def post_tokenizer(self, tokenizer) -> PreTrainedTokenizer: return tokenizer @property def model_class(self): return AutoModelForCausalLM @property def model_kwargs(self): return {} @property def tokenizer_class(self): return AutoTokenizer @property def tokenizer_kwargs(self): return {} @property def default_model_name_or_path(self): return "zpn/llama-7b" # A global registry for all model adapters model_adapters: List[BaseModelAdapter] = [] def register_model_adapter(cls): """ Register a model adapter. """ model_adapters.append(cls()) @cache def get_model_adapter(model_name: str) -> BaseModelAdapter: """ Get a model adapter for a given model name. Args: model_name (str): The name of the model. Returns: ModelAdapter: The model adapter that matches the given model name. """ for adapter in model_adapters: if adapter.match(model_name): return adapter raise ValueError(f"No valid model adapter for {model_name}") def load_model( model_name: str, model_name_or_path: Optional[str] = None, adapter_model: Optional[str] = None, quantize: Optional[int] = 16, device: Optional[str] = "cuda", load_in_8bit: Optional[bool] = False, **kwargs: Any, ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: """ Load a pre-trained model and tokenizer. Args: model_name (str): The name of the model. model_name_or_path (Optional[str], optional): The path or name of the pre-trained model. Defaults to None. adapter_model (Optional[str], optional): The name of the adapter model. Defaults to None. quantize (Optional[int], optional): The quantization level. Defaults to 16. device (Optional[str], optional): The device to load the model on. Defaults to "cuda". load_in_8bit (Optional[bool], optional): Whether to load the model in 8-bit mode. Defaults to False. **kwargs (Any): Additional keyword arguments. Returns: Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer. """ model_name = model_name.lower() if "tiger" in model_name: def skip(*args, **kwargs): pass torch.nn.init.kaiming_uniform_ = skip torch.nn.init.uniform_ = skip torch.nn.init.normal_ = skip # get model adapter adapter = get_model_adapter(model_name) model, tokenizer = adapter.load_model( model_name_or_path, adapter_model, device=device, quantize=quantize, load_in_8bit=load_in_8bit, **kwargs ) return model, tokenizer class ChatglmModelAdapter(BaseModelAdapter): """ https://github.com/THUDM/ChatGLM-6B """ model_names = ["chatglm"] @property def model_class(self): return AutoModel @property def default_model_name_or_path(self): return "THUDM/chatglm2-6b" class Chatglm3ModelAdapter(ChatglmModelAdapter): """ https://github.com/THUDM/ChatGLM-6B """ model_names = ["chatglm3"] @property def tokenizer_kwargs(self): return {"encode_special_tokens": True} @property def default_model_name_or_path(self): return "THUDM/chatglm3-6b" class LlamaModelAdapter(BaseModelAdapter): """ https://github.com/project-baize/baize-chatbot """ model_names = ["alpaca", "baize", "openbuddy-llama", "ziya-llama", "guanaco", "llama2"] def post_tokenizer(self, tokenizer): tokenizer.bos_token = "" tokenizer.eos_token = "" tokenizer.unk_token = "" return tokenizer @property def model_kwargs(self): return {"low_cpu_mem_usage": True} class MossModelAdapter(BaseModelAdapter): """ https://github.com/OpenLMLab/MOSS """ model_names = ["moss"] @property def default_model_name_or_path(self): return "fnlp/moss-moon-003-sft-int4" class PhoenixModelAdapter(BaseModelAdapter): """ https://github.com/FreedomIntelligence/LLMZoo """ model_names = ["phoenix"] @property def model_kwargs(self): return {"low_cpu_mem_usage": True} @property def tokenizer_kwargs(self): return {"use_fast": True} @property def default_model_name_or_path(self): return "FreedomIntelligence/phoenix-inst-chat-7b" class FireflyModelAdapter(BaseModelAdapter): """ https://github.com/yangjianxin1/Firefly """ model_names = ["firefly"] @property def model_kwargs(self): return {"torch_dtype": torch.float32} @property def tokenizer_kwargs(self): return {"use_fast": True} @property def default_model_name_or_path(self): return "YeungNLP/firefly-2b6" class YuLanChatModelAdapter(BaseModelAdapter): """ https://github.com/RUC-GSAI/YuLan-Chat """ model_names = ["yulan"] def post_tokenizer(self, tokenizer): tokenizer.bos_token = "" tokenizer.eos_token = "" tokenizer.unk_token = "" return tokenizer @property def model_kwargs(self): return {"low_cpu_mem_usage": True} def load_adapter_model(self, model, tokenizer, adapter_model, is_chatglm, model_kwargs, **kwargs): adapter_model = AutoModelForCausalLM.from_pretrained( adapter_model, torch_dtype=torch.float16, low_cpu_mem_usage=True ) if model.model.embed_tokens.weight.size(0) + 1 == adapter_model.model.embed_tokens.weight.size(0): model.resize_token_embeddings(len(tokenizer)) model.model.embed_tokens.weight.data[-1, :] = 0 logger.info("Applying the delta") for name, param in tqdm(model.state_dict().items(), desc="Applying delta"): assert name in model.state_dict() param.data += model.state_dict()[name] return model class TigerBotModelAdapter(BaseModelAdapter): """ https://github.com/TigerResearch/TigerBot """ model_names = ["tiger"] @property def tokenizer_kwargs(self): return {"use_fast": True} @property def default_model_name_or_path(self): return "TigerResearch/tigerbot-7b-sft" class OpenBuddyFalconModelAdapter(BaseModelAdapter): """ https://github.com/OpenBuddy/OpenBuddy """ model_names = ["openbuddy-falcon"] @property def default_model_name_or_path(self): return "OpenBuddy/openbuddy-falcon-7b-v5-fp16" class AnimaModelAdapter(LlamaModelAdapter): model_names = ["anima"] def load_lora_model(self, model, adapter_model, model_kwargs): return PeftModel.from_pretrained(model, adapter_model) class BaiChuanModelAdapter(BaseModelAdapter): """ https://github.com/baichuan-inc/Baichuan-13B """ model_names = ["baichuan"] def load_lora_model(self, model, adapter_model, model_kwargs): return PeftModel.from_pretrained(model, adapter_model) @property def default_model_name_or_path(self): return "baichuan-inc/Baichuan-13B-Chat" class InternLMModelAdapter(BaseModelAdapter): """ https://github.com/InternLM/InternLM """ model_names = ["internlm"] @property def default_model_name_or_path(self): return "internlm/internlm-chat-7b" class StarCodeModelAdapter(BaseModelAdapter): """ https://github.com/bigcode-project/starcoder """ model_names = ["starcode", "starchat"] @property def tokenizer_kwargs(self): return {} @property def default_model_name_or_path(self): return "HuggingFaceH4/starchat-beta" class AquilaModelAdapter(BaseModelAdapter): """ https://github.com/FlagAI-Open/FlagAI """ model_names = ["aquila"] @property def default_model_name_or_path(self): return "BAAI/AquilaChat-7B" class QwenModelAdapter(BaseModelAdapter): """ https://github.com/QwenLM/Qwen-7B """ model_names = ["qwen"] @property def default_model_name_or_path(self): return "Qwen/Qwen-7B-Chat" class XverseModelAdapter(BaseModelAdapter): """ https://github.com/xverse-ai/XVERSE-13B """ model_names = ["xverse"] @property def default_model_name_or_path(self): return "xverse/XVERSE-13B-Chat" class CodeLlamaModelAdapter(LlamaModelAdapter): """ https://github.com/project-baize/baize-chatbot """ model_names = ["code-llama"] @property def tokenizer_class(self): require_version("transformers>=4.33.1", "To fix: pip install transformers>=4.33.1") from transformers import CodeLlamaTokenizer return CodeLlamaTokenizer @property def default_model_name_or_path(self): return "codellama/CodeLlama-7b-Instruct-hf" register_model_adapter(ChatglmModelAdapter) register_model_adapter(Chatglm3ModelAdapter) register_model_adapter(LlamaModelAdapter) register_model_adapter(MossModelAdapter) register_model_adapter(PhoenixModelAdapter) register_model_adapter(FireflyModelAdapter) register_model_adapter(YuLanChatModelAdapter) register_model_adapter(TigerBotModelAdapter) register_model_adapter(OpenBuddyFalconModelAdapter) register_model_adapter(AnimaModelAdapter) register_model_adapter(BaiChuanModelAdapter) register_model_adapter(InternLMModelAdapter) register_model_adapter(AquilaModelAdapter) register_model_adapter(QwenModelAdapter) register_model_adapter(XverseModelAdapter) register_model_adapter(CodeLlamaModelAdapter) # After all adapters, try the default base adapter. register_model_adapter(BaseModelAdapter)