rishiraj commited on
Commit
9f309ad
1 Parent(s): 9f718e0

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +101 -0
model_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from typing import Dict
18
+
19
+ import torch
20
+ from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
21
+
22
+ from accelerate import Accelerator
23
+ from huggingface_hub import list_repo_files
24
+ from peft import LoraConfig, PeftConfig
25
+
26
+ from .configs import DataArguments, ModelArguments
27
+ from .data import DEFAULT_CHAT_TEMPLATE
28
+
29
+
30
+ def get_current_device() -> int:
31
+ """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
32
+ return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
33
+
34
+
35
+ def get_kbit_device_map() -> Dict[str, int] | None:
36
+ """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
37
+ return {"": get_current_device()} if torch.cuda.is_available() else None
38
+
39
+
40
+ def get_quantization_config(model_args) -> BitsAndBytesConfig | None:
41
+ if model_args.load_in_4bit:
42
+ quantization_config = BitsAndBytesConfig(
43
+ load_in_4bit=True,
44
+ bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models
45
+ bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
46
+ bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
47
+ )
48
+ elif model_args.load_in_8bit:
49
+ quantization_config = BitsAndBytesConfig(
50
+ load_in_8bit=True,
51
+ )
52
+ else:
53
+ quantization_config = None
54
+
55
+ return quantization_config
56
+
57
+
58
+ def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
59
+ """Get the tokenizer for the model."""
60
+ tokenizer = AutoTokenizer.from_pretrained(
61
+ model_args.model_name_or_path,
62
+ revision=model_args.model_revision,
63
+ )
64
+ if tokenizer.pad_token_id is None:
65
+ tokenizer.pad_token_id = tokenizer.eos_token_id
66
+
67
+ if data_args.truncation_side is not None:
68
+ tokenizer.truncation_side = data_args.truncation_side
69
+
70
+ # Set reasonable default for models without max length
71
+ if tokenizer.model_max_length > 100_000:
72
+ tokenizer.model_max_length = 2048
73
+
74
+ if data_args.chat_template is not None:
75
+ tokenizer.chat_template = data_args.chat_template
76
+ elif tokenizer.chat_template is None:
77
+ tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
78
+
79
+ return tokenizer
80
+
81
+
82
+ def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
83
+ if model_args.use_peft is False:
84
+ return None
85
+
86
+ peft_config = LoraConfig(
87
+ r=model_args.lora_r,
88
+ lora_alpha=model_args.lora_alpha,
89
+ lora_dropout=model_args.lora_dropout,
90
+ bias="none",
91
+ task_type="CAUSAL_LM",
92
+ target_modules=model_args.lora_target_modules,
93
+ modules_to_save=model_args.lora_modules_to_save,
94
+ )
95
+
96
+ return peft_config
97
+
98
+
99
+ def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
100
+ repo_files = list_repo_files(model_name_or_path, revision=revision)
101
+ return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files