gordonchan's picture
Upload 41 files
ca56e6a verified
import os
import sys
from typing import List, Optional, Any, Dict, Tuple
import torch
from loguru import logger
from peft import PeftModel
from tqdm import tqdm
from transformers import (
AutoModel,
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
PreTrainedTokenizer,
PreTrainedModel,
)
from transformers.utils.versions import require_version
if sys.version_info >= (3, 9):
from functools import cache
else:
from functools import lru_cache as cache
class BaseModelAdapter:
""" The base and default model adapter. """
model_names = []
def match(self, model_name) -> bool:
"""
Check if the given model name matches any of the predefined model names.
Args:
model_name (str): The model name to check.
Returns:
bool: True if the model name matches any of the predefined model names, False otherwise.
"""
return any(m in model_name for m in self.model_names) if self.model_names else True
def load_model(
self,
model_name_or_path: Optional[str] = None,
adapter_model: Optional[str] = None,
**kwargs: Any,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Load a model and tokenizer based on the provided model name or path.
Args:
model_name_or_path (str, optional): The name or path of the model. Defaults to None.
adapter_model (str, optional): The adapter model to load the tokenizer from. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
"""
model_name_or_path = model_name_or_path or self.default_model_name_or_path
tokenizer_kwargs = {"trust_remote_code": True, "use_fast": False}
tokenizer_kwargs.update(self.tokenizer_kwargs)
# load a tokenizer from adapter model if it exists.
if adapter_model is not None:
try:
tokenizer = self.tokenizer_class.from_pretrained(
adapter_model, **tokenizer_kwargs,
)
except OSError:
tokenizer = self.tokenizer_class.from_pretrained(
model_name_or_path, **tokenizer_kwargs,
)
else:
tokenizer = self.tokenizer_class.from_pretrained(
model_name_or_path, **tokenizer_kwargs,
)
config_kwargs = self.model_kwargs
device = kwargs.get("device", "cuda")
num_gpus = kwargs.get("num_gpus", 1)
dtype = kwargs.get("dtype", "half")
if device == "cuda":
if "torch_dtype" not in config_kwargs:
if dtype == "half":
config_kwargs["torch_dtype"] = torch.float16
elif dtype == "bfloat16":
config_kwargs["torch_dtype"] = torch.bfloat16
elif dtype == "float32":
config_kwargs["torch_dtype"] = torch.float32
if num_gpus != 1:
config_kwargs["device_map"] = "auto"
# model_kwargs["device_map"] = "sequential" # This is important for not the same VRAM sizes
# Quantization configurations (using bitsandbytes library).
if kwargs.get("load_in_8bit", False):
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
)
config_kwargs["device_map"] = "auto" if device == "cuda" else None
logger.info("Quantizing model to 8 bit.")
elif kwargs.get("load_in_4bit", False):
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
config_kwargs["device_map"] = "auto" if device == "cuda" else None
logger.info("Quantizing model to 4 bit.")
if kwargs.get("device_map", None) == "auto":
config_kwargs["device_map"] = "auto"
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
# Fix config (for Qwen)
if hasattr(config, "fp16") and hasattr(config, "bf16"):
setattr(config, "fp16", dtype == "half")
setattr(config, "bf16", dtype == "bfloat16")
config_kwargs.pop("torch_dtype", None)
if kwargs.get("using_ptuning_v2", False) and adapter_model:
config.pre_seq_len = kwargs.get("pre_seq_len", 128)
# Load and prepare pretrained models (without valuehead).
model = self.model_class.from_pretrained(
model_name_or_path,
config=config,
trust_remote_code=True,
**config_kwargs
)
if device == "cpu":
model = model.float()
# post process for special tokens
tokenizer = self.post_tokenizer(tokenizer)
is_chatglm = "chatglm" in str(type(model))
if adapter_model is not None:
model = self.load_adapter_model(model, tokenizer, adapter_model, is_chatglm, config_kwargs, **kwargs)
if is_chatglm or "baichuan" in str(type(model)) or "xverse" in str(type(model)):
quantize = kwargs.get("quantize", None)
if quantize and quantize != 16:
logger.info(f"Quantizing model to {quantize} bit.")
model = model.quantize(quantize)
if device == "cuda" and num_gpus == 1 and "device_map" not in config_kwargs:
model.to(device)
# inference mode
model.eval()
return model, tokenizer
def load_lora_model(
self, model: PreTrainedModel, adapter_model: str, model_kwargs: Dict,
) -> PeftModel:
"""
Load a LoRA model.
This function loads a LoRA model using the specified pretrained model and adapter model.
Args:
model (PreTrainedModel): The base pretrained model.
adapter_model (str): The name or path of the adapter model.
model_kwargs (dict): Additional keyword arguments for the model.
Returns:
PeftModel: The loaded LoRA model.
"""
return PeftModel.from_pretrained(
model,
adapter_model,
torch_dtype=model_kwargs.get("torch_dtype", torch.float16),
)
def load_adapter_model(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
adapter_model: str,
is_chatglm: bool,
model_kwargs: Dict,
**kwargs: Any,
) -> PreTrainedModel:
using_ptuning_v2 = kwargs.get("using_ptuning_v2", False)
resize_embeddings = kwargs.get("resize_embeddings", False)
if adapter_model and resize_embeddings and not is_chatglm:
model_vocab_size = model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
logger.info(f"Vocab of the base model: {model_vocab_size}")
logger.info(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
if model_vocab_size != tokenzier_vocab_size:
assert tokenzier_vocab_size > model_vocab_size
logger.info("Resize model embeddings to fit tokenizer")
model.resize_token_embeddings(tokenzier_vocab_size)
if using_ptuning_v2:
prefix_state_dict = torch.load(os.path.join(adapter_model, "pytorch_model.bin"))
new_prefix_state_dict = {
k[len("transformer.prefix_encoder."):]: v
for k, v in prefix_state_dict.items()
if k.startswith("transformer.prefix_encoder.")
}
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
model.transformer.prefix_encoder.float()
else:
model = self.load_lora_model(model, adapter_model, model_kwargs)
return model
def post_tokenizer(self, tokenizer) -> PreTrainedTokenizer:
return tokenizer
@property
def model_class(self):
return AutoModelForCausalLM
@property
def model_kwargs(self):
return {}
@property
def tokenizer_class(self):
return AutoTokenizer
@property
def tokenizer_kwargs(self):
return {}
@property
def default_model_name_or_path(self):
return "zpn/llama-7b"
# A global registry for all model adapters
model_adapters: List[BaseModelAdapter] = []
def register_model_adapter(cls):
""" Register a model adapter. """
model_adapters.append(cls())
@cache
def get_model_adapter(model_name: str) -> BaseModelAdapter:
"""
Get a model adapter for a given model name.
Args:
model_name (str): The name of the model.
Returns:
ModelAdapter: The model adapter that matches the given model name.
"""
for adapter in model_adapters:
if adapter.match(model_name):
return adapter
raise ValueError(f"No valid model adapter for {model_name}")
def load_model(
model_name: str,
model_name_or_path: Optional[str] = None,
adapter_model: Optional[str] = None,
quantize: Optional[int] = 16,
device: Optional[str] = "cuda",
load_in_8bit: Optional[bool] = False,
**kwargs: Any,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Load a pre-trained model and tokenizer.
Args:
model_name (str): The name of the model.
model_name_or_path (Optional[str], optional): The path or name of the pre-trained model. Defaults to None.
adapter_model (Optional[str], optional): The name of the adapter model. Defaults to None.
quantize (Optional[int], optional): The quantization level. Defaults to 16.
device (Optional[str], optional): The device to load the model on. Defaults to "cuda".
load_in_8bit (Optional[bool], optional): Whether to load the model in 8-bit mode. Defaults to False.
**kwargs (Any): Additional keyword arguments.
Returns:
Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
"""
model_name = model_name.lower()
if "tiger" in model_name:
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
# get model adapter
adapter = get_model_adapter(model_name)
model, tokenizer = adapter.load_model(
model_name_or_path,
adapter_model,
device=device,
quantize=quantize,
load_in_8bit=load_in_8bit,
**kwargs
)
return model, tokenizer
class ChatglmModelAdapter(BaseModelAdapter):
""" https://github.com/THUDM/ChatGLM-6B """
model_names = ["chatglm"]
@property
def model_class(self):
return AutoModel
@property
def default_model_name_or_path(self):
return "THUDM/chatglm2-6b"
class Chatglm3ModelAdapter(ChatglmModelAdapter):
""" https://github.com/THUDM/ChatGLM-6B """
model_names = ["chatglm3"]
@property
def tokenizer_kwargs(self):
return {"encode_special_tokens": True}
@property
def default_model_name_or_path(self):
return "THUDM/chatglm3-6b"
class LlamaModelAdapter(BaseModelAdapter):
""" https://github.com/project-baize/baize-chatbot """
model_names = ["alpaca", "baize", "openbuddy-llama", "ziya-llama", "guanaco", "llama2"]
def post_tokenizer(self, tokenizer):
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.unk_token = "<unk>"
return tokenizer
@property
def model_kwargs(self):
return {"low_cpu_mem_usage": True}
class MossModelAdapter(BaseModelAdapter):
""" https://github.com/OpenLMLab/MOSS """
model_names = ["moss"]
@property
def default_model_name_or_path(self):
return "fnlp/moss-moon-003-sft-int4"
class PhoenixModelAdapter(BaseModelAdapter):
""" https://github.com/FreedomIntelligence/LLMZoo """
model_names = ["phoenix"]
@property
def model_kwargs(self):
return {"low_cpu_mem_usage": True}
@property
def tokenizer_kwargs(self):
return {"use_fast": True}
@property
def default_model_name_or_path(self):
return "FreedomIntelligence/phoenix-inst-chat-7b"
class FireflyModelAdapter(BaseModelAdapter):
""" https://github.com/yangjianxin1/Firefly """
model_names = ["firefly"]
@property
def model_kwargs(self):
return {"torch_dtype": torch.float32}
@property
def tokenizer_kwargs(self):
return {"use_fast": True}
@property
def default_model_name_or_path(self):
return "YeungNLP/firefly-2b6"
class YuLanChatModelAdapter(BaseModelAdapter):
""" https://github.com/RUC-GSAI/YuLan-Chat """
model_names = ["yulan"]
def post_tokenizer(self, tokenizer):
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.unk_token = "<unk>"
return tokenizer
@property
def model_kwargs(self):
return {"low_cpu_mem_usage": True}
def load_adapter_model(self, model, tokenizer, adapter_model, is_chatglm, model_kwargs, **kwargs):
adapter_model = AutoModelForCausalLM.from_pretrained(
adapter_model, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
if model.model.embed_tokens.weight.size(0) + 1 == adapter_model.model.embed_tokens.weight.size(0):
model.resize_token_embeddings(len(tokenizer))
model.model.embed_tokens.weight.data[-1, :] = 0
logger.info("Applying the delta")
for name, param in tqdm(model.state_dict().items(), desc="Applying delta"):
assert name in model.state_dict()
param.data += model.state_dict()[name]
return model
class TigerBotModelAdapter(BaseModelAdapter):
""" https://github.com/TigerResearch/TigerBot """
model_names = ["tiger"]
@property
def tokenizer_kwargs(self):
return {"use_fast": True}
@property
def default_model_name_or_path(self):
return "TigerResearch/tigerbot-7b-sft"
class OpenBuddyFalconModelAdapter(BaseModelAdapter):
""" https://github.com/OpenBuddy/OpenBuddy """
model_names = ["openbuddy-falcon"]
@property
def default_model_name_or_path(self):
return "OpenBuddy/openbuddy-falcon-7b-v5-fp16"
class AnimaModelAdapter(LlamaModelAdapter):
model_names = ["anima"]
def load_lora_model(self, model, adapter_model, model_kwargs):
return PeftModel.from_pretrained(model, adapter_model)
class BaiChuanModelAdapter(BaseModelAdapter):
""" https://github.com/baichuan-inc/Baichuan-13B """
model_names = ["baichuan"]
def load_lora_model(self, model, adapter_model, model_kwargs):
return PeftModel.from_pretrained(model, adapter_model)
@property
def default_model_name_or_path(self):
return "baichuan-inc/Baichuan-13B-Chat"
class InternLMModelAdapter(BaseModelAdapter):
""" https://github.com/InternLM/InternLM """
model_names = ["internlm"]
@property
def default_model_name_or_path(self):
return "internlm/internlm-chat-7b"
class StarCodeModelAdapter(BaseModelAdapter):
""" https://github.com/bigcode-project/starcoder """
model_names = ["starcode", "starchat"]
@property
def tokenizer_kwargs(self):
return {}
@property
def default_model_name_or_path(self):
return "HuggingFaceH4/starchat-beta"
class AquilaModelAdapter(BaseModelAdapter):
""" https://github.com/FlagAI-Open/FlagAI """
model_names = ["aquila"]
@property
def default_model_name_or_path(self):
return "BAAI/AquilaChat-7B"
class QwenModelAdapter(BaseModelAdapter):
""" https://github.com/QwenLM/Qwen-7B """
model_names = ["qwen"]
@property
def default_model_name_or_path(self):
return "Qwen/Qwen-7B-Chat"
class XverseModelAdapter(BaseModelAdapter):
""" https://github.com/xverse-ai/XVERSE-13B """
model_names = ["xverse"]
@property
def default_model_name_or_path(self):
return "xverse/XVERSE-13B-Chat"
class CodeLlamaModelAdapter(LlamaModelAdapter):
""" https://github.com/project-baize/baize-chatbot """
model_names = ["code-llama"]
@property
def tokenizer_class(self):
require_version("transformers>=4.33.1", "To fix: pip install transformers>=4.33.1")
from transformers import CodeLlamaTokenizer
return CodeLlamaTokenizer
@property
def default_model_name_or_path(self):
return "codellama/CodeLlama-7b-Instruct-hf"
register_model_adapter(ChatglmModelAdapter)
register_model_adapter(Chatglm3ModelAdapter)
register_model_adapter(LlamaModelAdapter)
register_model_adapter(MossModelAdapter)
register_model_adapter(PhoenixModelAdapter)
register_model_adapter(FireflyModelAdapter)
register_model_adapter(YuLanChatModelAdapter)
register_model_adapter(TigerBotModelAdapter)
register_model_adapter(OpenBuddyFalconModelAdapter)
register_model_adapter(AnimaModelAdapter)
register_model_adapter(BaiChuanModelAdapter)
register_model_adapter(InternLMModelAdapter)
register_model_adapter(AquilaModelAdapter)
register_model_adapter(QwenModelAdapter)
register_model_adapter(XverseModelAdapter)
register_model_adapter(CodeLlamaModelAdapter)
# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)