|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
from typing import Optional |
|
import bitsandbytes |
|
import accelerate |
|
from my_model.config import LLAMA2_config as config |
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers") |
|
|
|
|
|
class Llama2ModelManager: |
|
""" |
|
Manages loading and configuring the LLaMA-2 model and tokenizer. |
|
|
|
Attributes: |
|
device (str): Device to use for the model ('cuda' or 'cpu'). |
|
model_name (str): Name or path of the pre-trained model. |
|
tokenizer_name (str): Name or path of the tokenizer. |
|
quantization (str): Specifies the quantization level ('4bit', '8bit', or None). |
|
from_saved (bool): Flag to load the model from a saved path. |
|
model_path (str or None): Path to the saved model if `from_saved` is True. |
|
trust_remote (bool): Whether to trust remote code when loading the tokenizer. |
|
use_fast (bool): Whether to use the fast version of the tokenizer. |
|
add_eos_token (bool): Whether to add an EOS token to the tokenizer. |
|
access_token (str): Access token for Hugging Face Hub. |
|
model (AutoModelForCausalLM or None): Loaded model, initially None. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
""" |
|
Initializes the Llama2ModelManager class with configuration settings. |
|
""" |
|
|
|
self.device: str = config.DEVICE |
|
self.model_name: str = config.MODEL_NAME |
|
self.tokenizer_name: str = config.TOKENIZER_NAME |
|
self.quantization: str = config.QUANTIZATION |
|
self.from_saved: bool = config.FROM_SAVED |
|
self.model_path: Optional[str] = config.MODEL_PATH |
|
self.trust_remote: bool = config.TRUST_REMOTE |
|
self.use_fast: bool = config.USE_FAST |
|
self.add_eos_token: bool = config.ADD_EOS_TOKEN |
|
self.access_token: str = config.ACCESS_TOKEN |
|
self.model: Optional[AutoModelForCausalLM] = None |
|
|
|
def create_bnb_config(self) -> BitsAndBytesConfig: |
|
""" |
|
Creates a BitsAndBytes configuration based on the quantization setting. |
|
|
|
Returns: |
|
BitsAndBytesConfig: Configuration for BitsAndBytes optimized model. |
|
""" |
|
|
|
if self.quantization == '4bit': |
|
return BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
elif self.quantization == '8bit': |
|
return BitsAndBytesConfig( |
|
load_in_8bit=True, |
|
bnb_8bit_use_double_quant=True, |
|
bnb_8bit_quant_type="nf4", |
|
bnb_8bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
def load_model(self) -> AutoModelForCausalLM: |
|
""" |
|
Loads the LLaMA-2 model based on the specified configuration. |
|
If the model is already loaded, returns the existing model. |
|
|
|
Returns: |
|
AutoModelForCausalLM: Loaded LLaMA-2 model. |
|
""" |
|
|
|
if self.model is not None: |
|
print("Model is already loaded.") |
|
return self.model |
|
|
|
if self.from_saved: |
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto") |
|
else: |
|
bnb_config = None if self.quantization is None else self.create_bnb_config() |
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto", |
|
quantization_config=bnb_config, |
|
torch_dtype=torch.float16, |
|
token=self.access_token) |
|
|
|
if self.model is not None: |
|
print(f"LLAMA2 Model loaded successfully in {self.quantization} quantization.") |
|
else: |
|
print("LLAMA2 Model failed to load.") |
|
return self.model |
|
|
|
def load_tokenizer(self) -> AutoTokenizer: |
|
""" |
|
Loads the tokenizer for the LLaMA-2 model with the specified configuration. |
|
|
|
Returns: |
|
AutoTokenizer: Loaded tokenizer for LLaMA-2 model. |
|
""" |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=self.use_fast, |
|
token=self.access_token, |
|
trust_remote_code=self.trust_remote, |
|
add_eos_token=self.add_eos_token) |
|
|
|
if self.tokenizer is not None: |
|
print(f"LLAMA2 Tokenizer loaded successfully.") |
|
else: |
|
print("LLAMA2 Tokenizer failed to load.") |
|
|
|
return self.tokenizer |
|
|
|
def load_model_and_tokenizer(self, for_fine_tuning: bool) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: |
|
""" |
|
Loads the LLaMA-2 model and tokenizer, and optionally adds special tokens for fine-tuning. |
|
|
|
Args: |
|
for_fine_tuning (bool): Whether to prepare the model and tokenizer for fine-tuning. |
|
|
|
Returns: |
|
Tuple[AutoModelForCausalLM, AutoTokenizer]: The loaded model and tokenizer. |
|
""" |
|
|
|
if for_fine_tuning: |
|
self.tokenizer = self.load_tokenizer() |
|
self.model = self.load_model() |
|
self.add_special_tokens() |
|
else: |
|
self.tokenizer = self.load_tokenizer() |
|
self.model = self.load_model() |
|
|
|
return self.model, self.tokenizer |
|
|
|
|
|
def add_special_tokens(self, tokens: Optional[List[str]] = None) -> None: |
|
""" |
|
Adds special tokens to the tokenizer and updates the model's token embeddings if the model is loaded. |
|
|
|
Args: |
|
tokens (Optional[List[str]]): Special tokens to add. Defaults to a predefined set. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
if self.tokenizer is None: |
|
print("Tokenizer is not loaded. Cannot add special tokens.") |
|
return |
|
|
|
if tokens is None: |
|
tokens = ['[CAP]', '[/CAP]', '[QES]', '[/QES]', '[OBJ]', '[/OBJ]'] |
|
|
|
|
|
print(f"Original vocabulary size: {len(self.tokenizer)}") |
|
print(f"Adding the following tokens: {tokens}") |
|
self.tokenizer.add_tokens(tokens, special_tokens=True) |
|
self.tokenizer.add_special_tokens({'pad_token': '<pad>'}) |
|
print(f"Adding Padding Token {self.tokenizer.pad_token}") |
|
self.tokenizer.padding_side = "right" |
|
print(f'Padding side: {self.tokenizer.padding_side}') |
|
|
|
|
|
if self.model is not None: |
|
self.model.resize_token_embeddings(len(self.tokenizer)) |
|
self.model.config.pad_token_id = self.tokenizer.pad_token_id |
|
|
|
print(f'Updated Vocabulary Size: {len(self.tokenizer)}') |
|
print(f'Padding Token: {self.tokenizer.pad_token}') |
|
print(f'Special Tokens: {self.tokenizer.added_tokens_decoder}') |
|
|
|
|
|
if __name__ == "__main__": |
|
pass |
|
|
|
LLAMA2_manager = Llama2ModelManager() |
|
LLAMA2_model = LLAMA2_manager.load_model() |
|
LLAMA2_tokenizer = LLAMA2_manager.load_tokenizer() |
|
LLAMA2_manager.add_special_tokens(LLAMA2_model, LLAMA2_tokenizer) |
|
|