rishiraj commited on
Commit
065a39e
1 Parent(s): 7047a96

Update configs.py

Browse files
Files changed (1) hide show
  1. configs.py +79 -0
configs.py CHANGED
@@ -18,6 +18,85 @@ from dataclasses import dataclass, field
18
  from typing import Any, Dict, List, NewType, Optional, Tuple
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @dataclass
22
  class DataArguments:
23
  """
 
18
  from typing import Any, Dict, List, NewType, Optional, Tuple
19
 
20
 
21
+ @dataclass
22
+ class ModelArguments:
23
+ """
24
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
25
+ """
26
+
27
+ base_model_revision: Optional[str] = field(
28
+ default=None,
29
+ metadata={"help": ("The base model checkpoint for weights initialization with PEFT adatpers.")},
30
+ )
31
+ model_name_or_path: Optional[str] = field(
32
+ default=None,
33
+ metadata={
34
+ "help": (
35
+ "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
36
+ )
37
+ },
38
+ )
39
+ model_revision: str = field(
40
+ default="main",
41
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
42
+ )
43
+ model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"})
44
+ torch_dtype: Optional[str] = field(
45
+ default=None,
46
+ metadata={
47
+ "help": (
48
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
49
+ "dtype will be automatically derived from the model's weights."
50
+ ),
51
+ "choices": ["auto", "bfloat16", "float16", "float32"],
52
+ },
53
+ )
54
+ trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
55
+ use_flash_attention_2: bool = field(
56
+ default=False,
57
+ metadata={
58
+ "help": (
59
+ "Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`"
60
+ )
61
+ },
62
+ )
63
+ use_peft: bool = field(
64
+ default=False,
65
+ metadata={"help": ("Whether to use PEFT or not for training.")},
66
+ )
67
+ lora_r: Optional[int] = field(
68
+ default=16,
69
+ metadata={"help": ("LoRA R value.")},
70
+ )
71
+ lora_alpha: Optional[int] = field(
72
+ default=32,
73
+ metadata={"help": ("LoRA alpha.")},
74
+ )
75
+ lora_dropout: Optional[float] = field(
76
+ default=0.05,
77
+ metadata={"help": ("LoRA dropout.")},
78
+ )
79
+ lora_target_modules: Optional[List[str]] = field(
80
+ default=None,
81
+ metadata={"help": ("LoRA target modules.")},
82
+ )
83
+ lora_modules_to_save: Optional[List[str]] = field(
84
+ default=None,
85
+ metadata={"help": ("Model layers to unfreeze & train")},
86
+ )
87
+ load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"})
88
+ load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"})
89
+
90
+ bnb_4bit_quant_type: Optional[str] = field(
91
+ default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
92
+ )
93
+ use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
94
+
95
+ def __post_init__(self):
96
+ if self.load_in_8bit and self.load_in_4bit:
97
+ raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
98
+
99
+
100
  @dataclass
101
  class DataArguments:
102
  """