File size: 7,345 Bytes
17c1e65
 
 
 
 
f9236c8
17c1e65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Optional
import bitsandbytes  # only for using on GPU
import accelerate  # only for using on GPU
from my_model.config import LLAMA2_config as config  # Importing LLAMA2 configuration file
import warnings

# Suppress only FutureWarning from transformers
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")


class Llama2ModelManager:
    """
    Manages loading and configuring the LLaMA-2 model and tokenizer.

    Attributes:
        device (str): Device to use for the model ('cuda' or 'cpu').
        model_name (str): Name or path of the pre-trained model.
        tokenizer_name (str): Name or path of the tokenizer.
        quantization (str): Specifies the quantization level ('4bit', '8bit', or None).
        from_saved (bool): Flag to load the model from a saved path.
        model_path (str or None): Path to the saved model if `from_saved` is True.
        trust_remote (bool): Whether to trust remote code when loading the tokenizer.
        use_fast (bool): Whether to use the fast version of the tokenizer.
        add_eos_token (bool): Whether to add an EOS token to the tokenizer.
        access_token (str): Access token for Hugging Face Hub.
        model (AutoModelForCausalLM or None): Loaded model, initially None.
    """

    def __init__(self) -> None:
        """
        Initializes the Llama2ModelManager class with configuration settings.
        """
        self.device: str = config.DEVICE
        self.model_name: str = config.MODEL_NAME
        self.tokenizer_name: str = config.TOKENIZER_NAME
        self.quantization: str = config.QUANTIZATION
        self.from_saved: bool = config.FROM_SAVED
        self.model_path: Optional[str] = config.MODEL_PATH
        self.trust_remote: bool = config.TRUST_REMOTE
        self.use_fast: bool = config.USE_FAST
        self.add_eos_token: bool = config.ADD_EOS_TOKEN
        self.access_token: str = config.ACCESS_TOKEN
        self.model: Optional[AutoModelForCausalLM] = None

    def create_bnb_config(self) -> BitsAndBytesConfig:
        """
        Creates a BitsAndBytes configuration based on the quantization setting.

        Returns:
            BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
        """
        if self.quantization == '4bit':
            return BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
        elif self.quantization == '8bit':
            return BitsAndBytesConfig(
                load_in_8bit=True,
                bnb_8bit_use_double_quant=True,
                bnb_8bit_quant_type="nf4",
                bnb_8bit_compute_dtype=torch.bfloat16
            )

    def load_model(self) -> AutoModelForCausalLM:
        """
        Loads the LLaMA-2 model based on the specified configuration. If the model is already loaded, returns the existing model.

        Returns:
            AutoModelForCausalLM: Loaded LLaMA-2 model.
        """
        if self.model is not None:
            print("Model is already loaded.")
            return self.model

        if self.from_saved:
            self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto")
        else:
            bnb_config = None if self.quantization is None else self.create_bnb_config()
            self.model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto",
                                                              quantization_config=bnb_config,
                                                              torch_dtype=torch.float16,
                                                              token=self.access_token)

        if self.model is not None:
            print(f"LLAMA2 Model loaded successfully in {self.quantization} quantization.")
        else:
            print("LLAMA2 Model failed to load.")
        return self.model

    def load_tokenizer(self) -> AutoTokenizer:
        """
        Loads the tokenizer for the LLaMA-2 model with the specified configuration.

        Returns:
            AutoTokenizer: Loaded tokenizer for LLaMA-2 model.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=self.use_fast,
                                                       token=self.access_token,
                                                       trust_remote_code=self.trust_remote,
                                                       add_eos_token=self.add_eos_token)

        if self.tokenizer is not None:
            print(f"LLAMA2 Tokenizer loaded successfully.")
        else:
            print("LLAMA2 Tokenizer failed to load.")

        return self.tokenizer

    def load_model_and_tokenizer(self, for_fine_tuning):
        """
        Loads LLAMa2 model and tokenizer in one method and adds special tokens if the purpose if fine tuning.
        :param for_fine_tuning: YES(True) / NO (False)
        :return: LLAMA2 Model and Tokenizer
        """
        if for_fine_tuning:
            self.tokenizer = self.load_tokenizer()
            self.model = self.load_model()
            self.add_special_tokens()
        else:
            self.tokenizer = self.load_tokenizer()
            self.model = self.load_model()

        return self.model, self.tokenizer


    def add_special_tokens(self, tokens: Optional[list[str]] = None) -> None:
        """
        Adds special tokens to the tokenizer and updates the model's token embeddings if the model is loaded,
        only if the tokenizer is loaded.

        Args:
            tokens (list of str, optional): Special tokens to add. Defaults to a predefined set.

        Returns:
            None
        """
        if self.tokenizer is None:
            print("Tokenizer is not loaded. Cannot add special tokens.")
            return

        if tokens is None:
            tokens = ['[CAP]', '[/CAP]', '[QES]', '[/QES]', '[OBJ]', '[/OBJ]']

        # Update the tokenizer with new tokens
        print(f"Original vocabulary size: {len(self.tokenizer)}")
        print(f"Adding the following tokens: {tokens}")
        self.tokenizer.add_tokens(tokens, special_tokens=True)
        self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
        print(f"Adding Padding Token {self.tokenizer.pad_token}")
        self.tokenizer.padding_side = "right"
        print(f'Padding side: {self.tokenizer.padding_side}')

        # Resize the model token embeddings if the model is loaded
        if self.model is not None:
            self.model.resize_token_embeddings(len(self.tokenizer))
            self.model.config.pad_token_id = self.tokenizer.pad_token_id

        print(f'Updated Vocabulary Size: {len(self.tokenizer)}')
        print(f'Padding Token: {self.tokenizer.pad_token}')
        print(f'Special Tokens: {self.tokenizer.added_tokens_decoder}')


if __name__ == "__main__":
    pass
    LLAMA2_manager = Llama2ModelManager()
    LLAMA2_model = LLAMA2_manager.load_model()  # First time loading the model
    LLAMA2_tokenizer = LLAMA2_manager.load_tokenizer()
    LLAMA2_manager.add_special_tokens(LLAMA2_model, LLAMA2_tokenizer)