import os
import sys
from typing import List, Optional, Any, Dict, Tuple
import torch
from loguru import logger
from peft import PeftModel
from tqdm import tqdm
from transformers import (
AutoModel,
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
PreTrainedTokenizer,
PreTrainedModel,
)
from transformers.utils.versions import require_version
if sys.version_info >= (3, 9):
from functools import cache
else:
from functools import lru_cache as cache
class BaseModelAdapter:
""" The base and default model adapter. """
model_names = []
def match(self, model_name) -> bool:
"""
Check if the given model name matches any of the predefined model names.
Args:
model_name (str): The model name to check.
Returns:
bool: True if the model name matches any of the predefined model names, False otherwise.
"""
return any(m in model_name for m in self.model_names) if self.model_names else True
def load_model(
self,
model_name_or_path: Optional[str] = None,
adapter_model: Optional[str] = None,
**kwargs: Any,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Load a model and tokenizer based on the provided model name or path.
Args:
model_name_or_path (str, optional): The name or path of the model. Defaults to None.
adapter_model (str, optional): The adapter model to load the tokenizer from. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
"""
model_name_or_path = model_name_or_path or self.default_model_name_or_path
tokenizer_kwargs = {"trust_remote_code": True, "use_fast": False}
tokenizer_kwargs.update(self.tokenizer_kwargs)
# load a tokenizer from adapter model if it exists.
if adapter_model is not None:
try:
tokenizer = self.tokenizer_class.from_pretrained(
adapter_model, **tokenizer_kwargs,
)
except OSError:
tokenizer = self.tokenizer_class.from_pretrained(
model_name_or_path, **tokenizer_kwargs,
)
else:
tokenizer = self.tokenizer_class.from_pretrained(
model_name_or_path, **tokenizer_kwargs,
)
config_kwargs = self.model_kwargs
device = kwargs.get("device", "cuda")
num_gpus = kwargs.get("num_gpus", 1)
dtype = kwargs.get("dtype", "half")
if device == "cuda":
if "torch_dtype" not in config_kwargs:
if dtype == "half":
config_kwargs["torch_dtype"] = torch.float16
elif dtype == "bfloat16":
config_kwargs["torch_dtype"] = torch.bfloat16
elif dtype == "float32":
config_kwargs["torch_dtype"] = torch.float32
if num_gpus != 1:
config_kwargs["device_map"] = "auto"
# model_kwargs["device_map"] = "sequential" # This is important for not the same VRAM sizes
# Quantization configurations (using bitsandbytes library).
if kwargs.get("load_in_8bit", False):
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
)
config_kwargs["device_map"] = "auto" if device == "cuda" else None
logger.info("Quantizing model to 8 bit.")
elif kwargs.get("load_in_4bit", False):
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
config_kwargs["device_map"] = "auto" if device == "cuda" else None
logger.info("Quantizing model to 4 bit.")
if kwargs.get("device_map", None) == "auto":
config_kwargs["device_map"] = "auto"
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
# Fix config (for Qwen)
if hasattr(config, "fp16") and hasattr(config, "bf16"):
setattr(config, "fp16", dtype == "half")
setattr(config, "bf16", dtype == "bfloat16")
config_kwargs.pop("torch_dtype", None)
if kwargs.get("using_ptuning_v2", False) and adapter_model:
config.pre_seq_len = kwargs.get("pre_seq_len", 128)
# Load and prepare pretrained models (without valuehead).
model = self.model_class.from_pretrained(
model_name_or_path,
config=config,
trust_remote_code=True,
**config_kwargs
)
if device == "cpu":
model = model.float()
# post process for special tokens
tokenizer = self.post_tokenizer(tokenizer)
is_chatglm = "chatglm" in str(type(model))
if adapter_model is not None:
model = self.load_adapter_model(model, tokenizer, adapter_model, is_chatglm, config_kwargs, **kwargs)
if is_chatglm or "baichuan" in str(type(model)) or "xverse" in str(type(model)):
quantize = kwargs.get("quantize", None)
if quantize and quantize != 16:
logger.info(f"Quantizing model to {quantize} bit.")
model = model.quantize(quantize)
if device == "cuda" and num_gpus == 1 and "device_map" not in config_kwargs:
model.to(device)
# inference mode
model.eval()
return model, tokenizer
def load_lora_model(
self, model: PreTrainedModel, adapter_model: str, model_kwargs: Dict,
) -> PeftModel:
"""
Load a LoRA model.
This function loads a LoRA model using the specified pretrained model and adapter model.
Args:
model (PreTrainedModel): The base pretrained model.
adapter_model (str): The name or path of the adapter model.
model_kwargs (dict): Additional keyword arguments for the model.
Returns:
PeftModel: The loaded LoRA model.
"""
return PeftModel.from_pretrained(
model,
adapter_model,
torch_dtype=model_kwargs.get("torch_dtype", torch.float16),
)
def load_adapter_model(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
adapter_model: str,
is_chatglm: bool,
model_kwargs: Dict,
**kwargs: Any,
) -> PreTrainedModel:
using_ptuning_v2 = kwargs.get("using_ptuning_v2", False)
resize_embeddings = kwargs.get("resize_embeddings", False)
if adapter_model and resize_embeddings and not is_chatglm:
model_vocab_size = model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
logger.info(f"Vocab of the base model: {model_vocab_size}")
logger.info(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
if model_vocab_size != tokenzier_vocab_size:
assert tokenzier_vocab_size > model_vocab_size
logger.info("Resize model embeddings to fit tokenizer")
model.resize_token_embeddings(tokenzier_vocab_size)
if using_ptuning_v2:
prefix_state_dict = torch.load(os.path.join(adapter_model, "pytorch_model.bin"))
new_prefix_state_dict = {
k[len("transformer.prefix_encoder."):]: v
for k, v in prefix_state_dict.items()
if k.startswith("transformer.prefix_encoder.")
}
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
model.transformer.prefix_encoder.float()
else:
model = self.load_lora_model(model, adapter_model, model_kwargs)
return model
def post_tokenizer(self, tokenizer) -> PreTrainedTokenizer:
return tokenizer
@property
def model_class(self):
return AutoModelForCausalLM
@property
def model_kwargs(self):
return {}
@property
def tokenizer_class(self):
return AutoTokenizer
@property
def tokenizer_kwargs(self):
return {}
@property
def default_model_name_or_path(self):
return "zpn/llama-7b"
# A global registry for all model adapters
model_adapters: List[BaseModelAdapter] = []
def register_model_adapter(cls):
""" Register a model adapter. """
model_adapters.append(cls())
@cache
def get_model_adapter(model_name: str) -> BaseModelAdapter:
"""
Get a model adapter for a given model name.
Args:
model_name (str): The name of the model.
Returns:
ModelAdapter: The model adapter that matches the given model name.
"""
for adapter in model_adapters:
if adapter.match(model_name):
return adapter
raise ValueError(f"No valid model adapter for {model_name}")
def load_model(
model_name: str,
model_name_or_path: Optional[str] = None,
adapter_model: Optional[str] = None,
quantize: Optional[int] = 16,
device: Optional[str] = "cuda",
load_in_8bit: Optional[bool] = False,
**kwargs: Any,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Load a pre-trained model and tokenizer.
Args:
model_name (str): The name of the model.
model_name_or_path (Optional[str], optional): The path or name of the pre-trained model. Defaults to None.
adapter_model (Optional[str], optional): The name of the adapter model. Defaults to None.
quantize (Optional[int], optional): The quantization level. Defaults to 16.
device (Optional[str], optional): The device to load the model on. Defaults to "cuda".
load_in_8bit (Optional[bool], optional): Whether to load the model in 8-bit mode. Defaults to False.
**kwargs (Any): Additional keyword arguments.
Returns:
Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
"""
model_name = model_name.lower()
if "tiger" in model_name:
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
# get model adapter
adapter = get_model_adapter(model_name)
model, tokenizer = adapter.load_model(
model_name_or_path,
adapter_model,
device=device,
quantize=quantize,
load_in_8bit=load_in_8bit,
**kwargs
)
return model, tokenizer
class ChatglmModelAdapter(BaseModelAdapter):
""" https://github.com/THUDM/ChatGLM-6B """
model_names = ["chatglm"]
@property
def model_class(self):
return AutoModel
@property
def default_model_name_or_path(self):
return "THUDM/chatglm2-6b"
class Chatglm3ModelAdapter(ChatglmModelAdapter):
""" https://github.com/THUDM/ChatGLM-6B """
model_names = ["chatglm3"]
@property
def tokenizer_kwargs(self):
return {"encode_special_tokens": True}
@property
def default_model_name_or_path(self):
return "THUDM/chatglm3-6b"
class LlamaModelAdapter(BaseModelAdapter):
""" https://github.com/project-baize/baize-chatbot """
model_names = ["alpaca", "baize", "openbuddy-llama", "ziya-llama", "guanaco", "llama2"]
def post_tokenizer(self, tokenizer):
tokenizer.bos_token = ""
tokenizer.eos_token = ""
tokenizer.unk_token = ""
return tokenizer
@property
def model_kwargs(self):
return {"low_cpu_mem_usage": True}
class MossModelAdapter(BaseModelAdapter):
""" https://github.com/OpenLMLab/MOSS """
model_names = ["moss"]
@property
def default_model_name_or_path(self):
return "fnlp/moss-moon-003-sft-int4"
class PhoenixModelAdapter(BaseModelAdapter):
""" https://github.com/FreedomIntelligence/LLMZoo """
model_names = ["phoenix"]
@property
def model_kwargs(self):
return {"low_cpu_mem_usage": True}
@property
def tokenizer_kwargs(self):
return {"use_fast": True}
@property
def default_model_name_or_path(self):
return "FreedomIntelligence/phoenix-inst-chat-7b"
class FireflyModelAdapter(BaseModelAdapter):
""" https://github.com/yangjianxin1/Firefly """
model_names = ["firefly"]
@property
def model_kwargs(self):
return {"torch_dtype": torch.float32}
@property
def tokenizer_kwargs(self):
return {"use_fast": True}
@property
def default_model_name_or_path(self):
return "YeungNLP/firefly-2b6"
class YuLanChatModelAdapter(BaseModelAdapter):
""" https://github.com/RUC-GSAI/YuLan-Chat """
model_names = ["yulan"]
def post_tokenizer(self, tokenizer):
tokenizer.bos_token = ""
tokenizer.eos_token = ""
tokenizer.unk_token = ""
return tokenizer
@property
def model_kwargs(self):
return {"low_cpu_mem_usage": True}
def load_adapter_model(self, model, tokenizer, adapter_model, is_chatglm, model_kwargs, **kwargs):
adapter_model = AutoModelForCausalLM.from_pretrained(
adapter_model, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
if model.model.embed_tokens.weight.size(0) + 1 == adapter_model.model.embed_tokens.weight.size(0):
model.resize_token_embeddings(len(tokenizer))
model.model.embed_tokens.weight.data[-1, :] = 0
logger.info("Applying the delta")
for name, param in tqdm(model.state_dict().items(), desc="Applying delta"):
assert name in model.state_dict()
param.data += model.state_dict()[name]
return model
class TigerBotModelAdapter(BaseModelAdapter):
""" https://github.com/TigerResearch/TigerBot """
model_names = ["tiger"]
@property
def tokenizer_kwargs(self):
return {"use_fast": True}
@property
def default_model_name_or_path(self):
return "TigerResearch/tigerbot-7b-sft"
class OpenBuddyFalconModelAdapter(BaseModelAdapter):
""" https://github.com/OpenBuddy/OpenBuddy """
model_names = ["openbuddy-falcon"]
@property
def default_model_name_or_path(self):
return "OpenBuddy/openbuddy-falcon-7b-v5-fp16"
class AnimaModelAdapter(LlamaModelAdapter):
model_names = ["anima"]
def load_lora_model(self, model, adapter_model, model_kwargs):
return PeftModel.from_pretrained(model, adapter_model)
class BaiChuanModelAdapter(BaseModelAdapter):
""" https://github.com/baichuan-inc/Baichuan-13B """
model_names = ["baichuan"]
def load_lora_model(self, model, adapter_model, model_kwargs):
return PeftModel.from_pretrained(model, adapter_model)
@property
def default_model_name_or_path(self):
return "baichuan-inc/Baichuan-13B-Chat"
class InternLMModelAdapter(BaseModelAdapter):
""" https://github.com/InternLM/InternLM """
model_names = ["internlm"]
@property
def default_model_name_or_path(self):
return "internlm/internlm-chat-7b"
class StarCodeModelAdapter(BaseModelAdapter):
""" https://github.com/bigcode-project/starcoder """
model_names = ["starcode", "starchat"]
@property
def tokenizer_kwargs(self):
return {}
@property
def default_model_name_or_path(self):
return "HuggingFaceH4/starchat-beta"
class AquilaModelAdapter(BaseModelAdapter):
""" https://github.com/FlagAI-Open/FlagAI """
model_names = ["aquila"]
@property
def default_model_name_or_path(self):
return "BAAI/AquilaChat-7B"
class QwenModelAdapter(BaseModelAdapter):
""" https://github.com/QwenLM/Qwen-7B """
model_names = ["qwen"]
@property
def default_model_name_or_path(self):
return "Qwen/Qwen-7B-Chat"
class XverseModelAdapter(BaseModelAdapter):
""" https://github.com/xverse-ai/XVERSE-13B """
model_names = ["xverse"]
@property
def default_model_name_or_path(self):
return "xverse/XVERSE-13B-Chat"
class CodeLlamaModelAdapter(LlamaModelAdapter):
""" https://github.com/project-baize/baize-chatbot """
model_names = ["code-llama"]
@property
def tokenizer_class(self):
require_version("transformers>=4.33.1", "To fix: pip install transformers>=4.33.1")
from transformers import CodeLlamaTokenizer
return CodeLlamaTokenizer
@property
def default_model_name_or_path(self):
return "codellama/CodeLlama-7b-Instruct-hf"
register_model_adapter(ChatglmModelAdapter)
register_model_adapter(Chatglm3ModelAdapter)
register_model_adapter(LlamaModelAdapter)
register_model_adapter(MossModelAdapter)
register_model_adapter(PhoenixModelAdapter)
register_model_adapter(FireflyModelAdapter)
register_model_adapter(YuLanChatModelAdapter)
register_model_adapter(TigerBotModelAdapter)
register_model_adapter(OpenBuddyFalconModelAdapter)
register_model_adapter(AnimaModelAdapter)
register_model_adapter(BaiChuanModelAdapter)
register_model_adapter(InternLMModelAdapter)
register_model_adapter(AquilaModelAdapter)
register_model_adapter(QwenModelAdapter)
register_model_adapter(XverseModelAdapter)
register_model_adapter(CodeLlamaModelAdapter)
# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)