Jerome2046 commited on
Commit
1108b2e
1 Parent(s): d64db71

Create arguments.py

Browse files
Files changed (1) hide show
  1. arguments.py +223 -0
arguments.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class ModelArguments:
7
+ """
8
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9
+ """
10
+
11
+ model_name_or_path: str = field(
12
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
13
+ )
14
+ ptuning_checkpoint: str = field(
15
+ default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
16
+ )
17
+ config_name: Optional[str] = field(
18
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
19
+ )
20
+ tokenizer_name: Optional[str] = field(
21
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
22
+ )
23
+ cache_dir: Optional[str] = field(
24
+ default=None,
25
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
26
+ )
27
+ use_fast_tokenizer: bool = field(
28
+ default=True,
29
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
30
+ )
31
+ model_revision: str = field(
32
+ default="main",
33
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
34
+ )
35
+ use_auth_token: bool = field(
36
+ default=False,
37
+ metadata={
38
+ "help": (
39
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
40
+ "with private models)."
41
+ )
42
+ },
43
+ )
44
+ resize_position_embeddings: Optional[bool] = field(
45
+ default=None,
46
+ metadata={
47
+ "help": (
48
+ "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
49
+ "the model's position embeddings."
50
+ )
51
+ },
52
+ )
53
+ quantization_bit: Optional[int] = field(
54
+ default=None
55
+ )
56
+ pre_seq_len: Optional[int] = field(
57
+ default=None
58
+ )
59
+ prefix_projection: bool = field(
60
+ default=False
61
+ )
62
+
63
+
64
+ @dataclass
65
+ class DataTrainingArguments:
66
+ """
67
+ Arguments pertaining to what data we are going to input our model for training and eval.
68
+ """
69
+
70
+ lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
71
+
72
+ dataset_name: Optional[str] = field(
73
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
74
+ )
75
+ dataset_config_name: Optional[str] = field(
76
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
77
+ )
78
+ prompt_column: Optional[str] = field(
79
+ default=None,
80
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
81
+ )
82
+ response_column: Optional[str] = field(
83
+ default=None,
84
+ metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
85
+ )
86
+ history_column: Optional[str] = field(
87
+ default=None,
88
+ metadata={"help": "The name of the column in the datasets containing the history of chat."},
89
+ )
90
+ train_file: Optional[str] = field(
91
+ default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
92
+ )
93
+ validation_file: Optional[str] = field(
94
+ default=None,
95
+ metadata={
96
+ "help": (
97
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
98
+ )
99
+ },
100
+ )
101
+ test_file: Optional[str] = field(
102
+ default=None,
103
+ metadata={
104
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
105
+ },
106
+ )
107
+ overwrite_cache: bool = field(
108
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
109
+ )
110
+ preprocessing_num_workers: Optional[int] = field(
111
+ default=None,
112
+ metadata={"help": "The number of processes to use for the preprocessing."},
113
+ )
114
+ max_source_length: Optional[int] = field(
115
+ default=1024,
116
+ metadata={
117
+ "help": (
118
+ "The maximum total input sequence length after tokenization. Sequences longer "
119
+ "than this will be truncated, sequences shorter will be padded."
120
+ )
121
+ },
122
+ )
123
+ max_target_length: Optional[int] = field(
124
+ default=128,
125
+ metadata={
126
+ "help": (
127
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
128
+ "than this will be truncated, sequences shorter will be padded."
129
+ )
130
+ },
131
+ )
132
+ val_max_target_length: Optional[int] = field(
133
+ default=None,
134
+ metadata={
135
+ "help": (
136
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
137
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
138
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
139
+ "during ``evaluate`` and ``predict``."
140
+ )
141
+ },
142
+ )
143
+ pad_to_max_length: bool = field(
144
+ default=False,
145
+ metadata={
146
+ "help": (
147
+ "Whether to pad all samples to model maximum sentence length. "
148
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
149
+ "efficient on GPU but very bad for TPU."
150
+ )
151
+ },
152
+ )
153
+ max_train_samples: Optional[int] = field(
154
+ default=None,
155
+ metadata={
156
+ "help": (
157
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
158
+ "value if set."
159
+ )
160
+ },
161
+ )
162
+ max_eval_samples: Optional[int] = field(
163
+ default=None,
164
+ metadata={
165
+ "help": (
166
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
167
+ "value if set."
168
+ )
169
+ },
170
+ )
171
+ max_predict_samples: Optional[int] = field(
172
+ default=None,
173
+ metadata={
174
+ "help": (
175
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
176
+ "value if set."
177
+ )
178
+ },
179
+ )
180
+ num_beams: Optional[int] = field(
181
+ default=None,
182
+ metadata={
183
+ "help": (
184
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
185
+ "which is used during ``evaluate`` and ``predict``."
186
+ )
187
+ },
188
+ )
189
+ ignore_pad_token_for_loss: bool = field(
190
+ default=True,
191
+ metadata={
192
+ "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
193
+ },
194
+ )
195
+ source_prefix: Optional[str] = field(
196
+ default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
197
+ )
198
+
199
+ forced_bos_token: Optional[str] = field(
200
+ default=None,
201
+ metadata={
202
+ "help": (
203
+ "The token to force as the first generated token after the decoder_start_token_id."
204
+ "Useful for multilingual models like mBART where the first generated token"
205
+ "needs to be the target language token (Usually it is the target language token)"
206
+ )
207
+ },
208
+ )
209
+
210
+
211
+
212
+ def __post_init__(self):
213
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None:
214
+ raise ValueError("Need either a dataset name or a training/validation/test file.")
215
+ else:
216
+ if self.train_file is not None:
217
+ extension = self.train_file.split(".")[-1]
218
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
219
+ if self.validation_file is not None:
220
+ extension = self.validation_file.split(".")[-1]
221
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
222
+ if self.val_max_target_length is None:
223
+ self.val_max_target_length = self.max_target_length