nancyH commited on
Commit
887ef68
·
verified ·
1 Parent(s): abbd565

Upload 2 files

Browse files
Files changed (2) hide show
  1. run_nt.sh +148 -0
  2. train.py +451 -0
run_nt.sh ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ # Usage:
5
+ # nohup bash run_hg38_1024_multi_nt.sh \
6
+ # ft_data \
7
+ # full_output_multi_tune_hg38_1024 \
8
+ # genomic_bench_tune_hg38_1024 \
9
+ # 0 > full_multi_tune_hg38_1024_3e-5.log 2>&1 &
10
+ #
11
+ # Args:
12
+ # 1) data_path (e.g., ft_data)
13
+ # 2) output_path
14
+ # 3) project_name
15
+ # 4) gpu_id (optional, default: 0)
16
+
17
+ source ~/miniconda3/etc/profile.d/conda.sh
18
+ conda activate bpe
19
+
20
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
21
+
22
+ data_path=${1:?"Missing data_path"}
23
+ output_path=${2:?"Missing output_path"}
24
+ project_name=${3:?"Missing project_name"}
25
+ gpu_id=${4:-0}
26
+
27
+ export CUDA_VISIBLE_DEVICES="${gpu_id}"
28
+
29
+ BEST_PARAMS_CSV="/home/n5huang/dna_token/best_params_len2_5120_by_task.csv"
30
+
31
+ MODEL="/home/n5huang/dna_token/pretrain/models/base_5120/checkpoint-100000"
32
+ TOKENIZER="/home/n5huang/dna_token/tokenizer_evaluation/baseline_bpe/vocab_5120/5120_tokenizer.json"
33
+ MODEL_NAME="base_5120"
34
+
35
+ if [[ ! -d "${data_path}" && -d "${SCRIPT_DIR}/${data_path}" ]]; then
36
+ data_path="${SCRIPT_DIR}/${data_path}"
37
+ fi
38
+
39
+ if [[ ! -d "${data_path}" ]]; then
40
+ echo "data_path does not exist: ${data_path}" >&2
41
+ exit 1
42
+ fi
43
+
44
+ declare -A TASK_LR
45
+ declare -A TASK_WD
46
+ declare -A TASK_WR
47
+ declare -A TASK_EP
48
+ declare -A TASK_SEED
49
+
50
+ while IFS=, read -r benchmark task metric best_score lr weight_decay warmup_ratio num_train_epochs selected_epoch seed run_name; do
51
+ [[ "${benchmark}" == "benchmark" ]] && continue
52
+ [[ "${benchmark}" != "NT" ]] && continue
53
+
54
+ TASK_LR["${task}"]="${lr}"
55
+ TASK_WD["${task}"]="${weight_decay}"
56
+ TASK_WR["${task}"]="${warmup_ratio}"
57
+ TASK_EP["${task}"]="${selected_epoch}"
58
+ TASK_SEED["${task}"]="${seed}"
59
+ done < "${BEST_PARAMS_CSV}"
60
+
61
+ run_task() {
62
+ local task="$1"
63
+ local model_max_length="$2"
64
+
65
+ local split_dir="${data_path}/${task}/split"
66
+ local train_csv="${split_dir}/train.csv"
67
+
68
+ if [[ ! -f "${train_csv}" ]]; then
69
+ echo "[WARN] Missing ${train_csv}, skip ${task}"
70
+ return
71
+ fi
72
+
73
+ local best_lr="${TASK_LR[$task]}"
74
+ local best_wd="${TASK_WD[$task]}"
75
+ local best_wr="${TASK_WR[$task]}"
76
+ local best_ep="${TASK_EP[$task]}"
77
+ local best_seed="${TASK_SEED[$task]}"
78
+
79
+ if [[ -z "${best_lr:-}" ]]; then
80
+ echo "[WARN] No best params found in CSV for task ${task}, skip"
81
+ return
82
+ fi
83
+
84
+ hp_tag="lr${best_lr}_wd${best_wd}_wr${best_wr}_ep${best_ep}_seed${best_seed}"
85
+ run_name="base5120_${task}_${hp_tag}"
86
+ run_output_dir="${output_path}/${task}/${MODEL_NAME}/${hp_tag}"
87
+ result_json="${run_output_dir}/results/${run_name}/eval_results.json"
88
+
89
+ if [[ -f "${result_json}" ]]; then
90
+ echo "[SKIP] ${run_name}"
91
+ return
92
+ fi
93
+
94
+ mkdir -p "${run_output_dir}"
95
+ echo "[RUN ] ${run_name}"
96
+
97
+ cmd=(
98
+ python /home/n5huang/dna_token/mario/Finetune-NucleotideTransformerBenchmarks/train.py
99
+ --model_name_or_path "${MODEL}"
100
+ --tokenizer_path "${TOKENIZER}"
101
+ --trust_remote_code True
102
+ --data_path "${split_dir}"
103
+ --kmer -1
104
+ --run_name "${run_name}"
105
+ --model_max_length "${model_max_length}"
106
+ --per_device_train_batch_size 128
107
+ --per_device_eval_batch_size 128
108
+ --gradient_accumulation_steps 1
109
+ --learning_rate "${best_lr}"
110
+ --weight_decay "${best_wd}"
111
+ --num_train_epochs "${best_ep}"
112
+ --lr_scheduler_type linear
113
+ --warmup_steps 0
114
+ --warmup_ratio "${best_wr}"
115
+ --fp16
116
+ --output_dir "${run_output_dir}"
117
+ --evaluation_strategy epoch
118
+ --save_strategy epoch
119
+ --load_best_model_at_end True
120
+ --metric_for_best_model eval_f1
121
+ --greater_is_better True
122
+ --save_total_limit 1
123
+ --save_model True
124
+ --logging_steps 100
125
+ --overwrite_output_dir True
126
+ --log_level info
127
+ --seed "${best_seed}"
128
+ --find_unused_parameters False
129
+ --project_name "${project_name}"
130
+ )
131
+ "${cmd[@]}"
132
+ }
133
+
134
+ for task in enhancers enhancers_types; do
135
+ run_task "${task}" 100
136
+ done
137
+
138
+ for task in promoter_all promoter_no_tata promoter_tata; do
139
+ run_task "${task}" 80
140
+ done
141
+
142
+ for task in splice_sites_acceptors splice_sites_all splice_sites_donors; do
143
+ run_task "${task}" 140
144
+ done
145
+
146
+ for task in H2AFZ H3K27ac H3K27me3 H3K36me3 H3K4me1 H3K4me2 H3K4me3 H3K9ac H3K9me3 H4K20me1; do
147
+ run_task "${task}" 220
148
+ done
train.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ wandb.login(key="293cdcc20c72cb7e8cc5a077eaacf86b254e46ed")
3
+ import os
4
+ import sys
5
+ os.environ["DISABLE_TRITON"] = "1"
6
+ sys.modules['triton'] = None
7
+ sys.modules['flash_attn_triton'] = None
8
+
9
+ import csv
10
+ import copy
11
+ import json
12
+ import logging
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Optional, Dict, Sequence, Tuple, List, Union
15
+
16
+ import torch
17
+ import transformers
18
+ import sklearn
19
+ import numpy as np
20
+ from torch.utils.data import Dataset
21
+ import importlib
22
+ from pathlib import Path
23
+ import itertools
24
+
25
+ from transformers import BertConfig, BertForSequenceClassification
26
+ from transformers import (
27
+ WEIGHTS_NAME,
28
+ AdamW,
29
+ BertConfig,
30
+ BertForMaskedLM,
31
+ BertTokenizer,
32
+ CamembertConfig,
33
+ CamembertForMaskedLM,
34
+ CamembertTokenizer,
35
+ DistilBertConfig,
36
+ DistilBertForMaskedLM,
37
+ DistilBertTokenizer,
38
+ GPT2Config,
39
+ GPT2LMHeadModel,
40
+ GPT2Tokenizer,
41
+ OpenAIGPTConfig,
42
+ OpenAIGPTLMHeadModel,
43
+ OpenAIGPTTokenizer,
44
+ PreTrainedModel,
45
+ PreTrainedTokenizer,
46
+ RobertaConfig,
47
+ RobertaForMaskedLM,
48
+ RobertaTokenizer,
49
+ get_linear_schedule_with_warmup,
50
+ )
51
+
52
+
53
+ @dataclass
54
+ class ModelArguments:
55
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
56
+ trust_remote_code: bool = field(default=False, metadata={"help": "for custom models(has custom code that needs to be executed (e.g., custom architectures, tokenizers, or modeling files)), whether local or from the Hub"})
57
+ use_lora: bool = field(default=False, metadata={"help": "whether to use LoRA"})
58
+ lora_r: int = field(default=8, metadata={"help": "hidden dimension for LoRA"})
59
+ lora_alpha: int = field(default=32, metadata={"help": "alpha for LoRA"})
60
+ lora_dropout: float = field(default=0.05, metadata={"help": "dropout rate for LoRA"})
61
+ lora_target_modules: str = field(default="query,value", metadata={"help": "where to perform LoRA"})
62
+ tokenizer_path: Optional[str] = field(default="facebook/opt-125m")
63
+
64
+
65
+ @dataclass
66
+ class DataArguments:
67
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
68
+ kmer: int = field(default=-1, metadata={"help": "k-mer for input sequence. -1 means not using k-mer."})
69
+ customized_tokenizer: Optional[str] = field(default=None)
70
+
71
+
72
+ @dataclass
73
+ class TrainingArguments(transformers.TrainingArguments):
74
+ vocab_file: Optional[str] = field(
75
+ default=None,
76
+ metadata={"help": "Path to custom vocabulary file (overrides Hugging Face default)"}
77
+ )
78
+ cache_dir: Optional[str] = field(default=None)
79
+ run_name: str = field(default="run")
80
+ optim: str = field(default="adamw_torch")
81
+ model_max_length: int = field(default=512, metadata={"help": "Maximum sequence length."})
82
+ gradient_accumulation_steps: int = field(default=1)
83
+ per_device_train_batch_size: int = field(default=1)
84
+ per_device_eval_batch_size: int = field(default=1)
85
+ num_train_epochs: int = field(default=1)
86
+ fp16: bool = field(default=False)
87
+ logging_steps: int = field(default=100)
88
+ save_steps: int = field(default=100)
89
+ eval_steps: int = field(default=100)
90
+ evaluation_strategy: str = field(default="steps"),
91
+ warmup_steps: int = field(default=50)
92
+ weight_decay: float = field(default=0.01)
93
+ learning_rate: float = field(default=1e-4)
94
+ save_total_limit: int = field(default=3)
95
+ load_best_model_at_end: bool = field(default=False)
96
+ output_dir: str = field(default="output")
97
+ find_unused_parameters: bool = field(default=False)
98
+ checkpointing: bool = field(default=False)
99
+ dataloader_pin_memory: bool = field(default=False)
100
+ eval_and_save_results: bool = field(default=True)
101
+ save_model: bool = field(default=False)
102
+ seed: int = field(default=42)
103
+ project_name: str = field(default=None)
104
+
105
+
106
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
107
+ """Collects the state dict and dump to disk."""
108
+ state_dict = trainer.model.state_dict()
109
+ if trainer.args.should_save:
110
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
111
+ del state_dict
112
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
113
+
114
+
115
+ """
116
+ Get the reversed complement of the original DNA sequence.
117
+ """
118
+ def get_alter_of_dna_sequence(sequence: str):
119
+ MAP = {"A": "T", "T": "A", "C": "G", "G": "C"}
120
+ # return "".join([MAP[c] for c in reversed(sequence)])
121
+ return "".join([MAP[c] for c in sequence])
122
+
123
+ """
124
+ Transform a dna sequence to k-mer string
125
+ """
126
+ def generate_kmer_str(sequence: str, k: int) -> str:
127
+ """Generate k-mer string from DNA sequence."""
128
+ return " ".join([sequence[i:i+k] for i in range(len(sequence) - k + 1)])
129
+
130
+
131
+ """
132
+ Load or generate k-mer string for each DNA sequence. The generated k-mer string will be saved to the same directory as the original data with the same name but with a suffix of "_{k}mer".
133
+ """
134
+ def load_or_generate_kmer(data_path: str, texts: List[str], k: int) -> List[str]:
135
+ """Load or generate k-mer string for each DNA sequence."""
136
+ kmer_path = data_path.replace(".csv", f"_{k}mer.json")
137
+ if os.path.exists(kmer_path):
138
+ logging.warning(f"Loading k-mer from {kmer_path}...")
139
+ with open(kmer_path, "r") as f:
140
+ kmer = json.load(f)
141
+ else:
142
+ logging.warning(f"Generating k-mer...")
143
+ kmer = [generate_kmer_str(text, k) for text in texts]
144
+ with open(kmer_path, "w") as f:
145
+ logging.warning(f"Saving k-mer to {kmer_path}...")
146
+ json.dump(kmer, f)
147
+
148
+ return kmer
149
+
150
+ def load_customized_data(data_path: str, texts: List[str], customized_tokenizer: str) -> List[str]:
151
+ """Load or generate k-mer string for each DNA sequence."""
152
+ customize_path = data_path.replace(".csv", f"_{customized_tokenizer}.json")
153
+ print(customize_path)
154
+ if os.path.exists(customize_path):
155
+ logging.warning(f"Loading data by customized tokenizer from {customize_path}...")
156
+ with open(customize_path, "r") as f:
157
+ data = json.load(f)
158
+
159
+ return data
160
+
161
+
162
+ class SupervisedDataset(Dataset):
163
+ """Dataset for supervised fine-tuning."""
164
+
165
+ def __init__(self,
166
+ data_path: str,
167
+ tokenizer: transformers.PreTrainedTokenizer,
168
+ kmer: int = -1,
169
+ customized_tokenizer = None):
170
+
171
+ super(SupervisedDataset, self).__init__()
172
+
173
+ # load data from the disk
174
+ with open(data_path, "r") as f:
175
+ data = list(csv.reader(f))[1:]
176
+ if len(data[0]) == 2:
177
+ # data is in the format of [text, label]
178
+ logging.warning("Perform single sequence classification...")
179
+ texts = [d[0] for d in data]
180
+ labels = [int(d[1]) for d in data]
181
+ elif len(data[0]) == 3:
182
+ # data is in the format of [text1, text2, label]
183
+ logging.warning("Perform sequence-pair classification...")
184
+ texts = [[d[0], d[1]] for d in data]
185
+ labels = [int(d[2]) for d in data]
186
+ elif len(data[0]) == 5:
187
+ logging.warning("Perform single sequence classification on NucleotideTransformer Benchmarks...")
188
+ texts = [d[4] for d in data]
189
+ labels = [int(d[0]) for d in data]
190
+ else:
191
+ raise ValueError("Data format not supported.")
192
+
193
+ if kmer != -1:
194
+
195
+ logging.warning(f"Using {kmer}-mer as input...")
196
+ texts = load_or_generate_kmer(data_path, texts, kmer)
197
+
198
+ elif kmer == -1 and customized_tokenizer:
199
+ logging.warning(f"Using {customized_tokenizer} as input...")
200
+ texts = load_customized_data(data_path, texts, customized_tokenizer)
201
+
202
+ output = tokenizer(
203
+ texts,
204
+ return_tensors="pt",
205
+ padding="longest",
206
+ max_length=tokenizer.model_max_length,
207
+ truncation=True,
208
+ )
209
+ # print(texts, output["input_ids"])
210
+
211
+ self.input_ids = output["input_ids"]
212
+ self.attention_mask = output["attention_mask"]
213
+ self.labels = labels
214
+ self.num_labels = len(set(labels))
215
+
216
+ def __len__(self):
217
+ return len(self.input_ids)
218
+
219
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
220
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
221
+
222
+
223
+ @dataclass
224
+ class DataCollatorForSupervisedDataset(object):
225
+ """Collate examples for supervised fine-tuning."""
226
+
227
+ tokenizer: transformers.PreTrainedTokenizer
228
+
229
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
230
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
231
+ input_ids = torch.nn.utils.rnn.pad_sequence(
232
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
233
+ )
234
+ labels = torch.Tensor(labels).long()
235
+ return dict(
236
+ input_ids=input_ids,
237
+ labels=labels,
238
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
239
+ )
240
+
241
+ """
242
+ Manually calculate the accuracy, f1, matthews_correlation, precision, recall with sklearn.
243
+ """
244
+ def calculate_metric_with_sklearn(predictions: np.ndarray, labels: np.ndarray):
245
+ valid_mask = labels != -100 # Exclude padding tokens (assuming -100 is the padding token ID)
246
+ valid_predictions = predictions[valid_mask]
247
+ valid_labels = labels[valid_mask]
248
+ return {
249
+ "accuracy": sklearn.metrics.accuracy_score(valid_labels, valid_predictions),
250
+ "f1": sklearn.metrics.f1_score(
251
+ valid_labels, valid_predictions, average="macro", zero_division=0
252
+ ),
253
+ "matthews_correlation": sklearn.metrics.matthews_corrcoef(
254
+ valid_labels, valid_predictions
255
+ ),
256
+ "precision": sklearn.metrics.precision_score(
257
+ valid_labels, valid_predictions, average="macro", zero_division=0
258
+ ),
259
+ "recall": sklearn.metrics.recall_score(
260
+ valid_labels, valid_predictions, average="macro", zero_division=0
261
+ ),
262
+ }
263
+
264
+ # from: https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941/13
265
+ def preprocess_logits_for_metrics(logits:Union[torch.Tensor, Tuple[torch.Tensor, Any]], _):
266
+ if isinstance(logits, tuple): # Unpack logits if it's a tuple
267
+ logits = logits[0]
268
+
269
+ if logits.ndim == 3:
270
+ # Reshape logits to 2D if needed
271
+ logits = logits.reshape(-1, logits.shape[-1])
272
+
273
+ return torch.argmax(logits, dim=-1)
274
+
275
+
276
+ """
277
+ Compute metrics used for huggingface trainer.
278
+ """
279
+ def compute_metrics(eval_pred):
280
+ predictions, labels = eval_pred
281
+ return calculate_metric_with_sklearn(predictions, labels)
282
+
283
+ def load_token_v5_1(tokenizer_kwargs):
284
+ config_class, model_class, tokenizer_class = MODEL_CLASSES['motifBert']
285
+ tokenizer = MotifTokenizer(**tokenizer_kwargs)
286
+
287
+ bases = ['A', 'T', 'C', 'G']
288
+
289
+ token_wc = [
290
+ f"{operator}_POS_{i}_*_{char}"
291
+ for operator, i, char in itertools.product(['WC'], range(12), bases)
292
+ ]
293
+
294
+ motif_wildcarded = []
295
+ with open(os.path.join('/storage2/fs1/btc/Active/yeli/xiaoxiao.zhou/tokenize/tokenizers/tokenizer_v5.1/hg38_NOOP', "motifs_wildcard.txt"), "r") as file:
296
+ for line in file:
297
+ seq, operations = line.strip().split(maxsplit=1) # Split only on the first space
298
+ motif_wildcarded.append(operations.split()[0]) # Store in dictionary
299
+
300
+ tokenizer.add_tokens(token_wc + motif_wildcarded)
301
+ return tokenizer
302
+
303
+ def load_token_v4(tokenizer_kwargs):
304
+ config_class, model_class, tokenizer_class = MODEL_CLASSES['motifBert']
305
+ tokenizer = MotifTokenizer(**tokenizer_kwargs)
306
+
307
+ bases = ['A', 'T', 'C', 'G']
308
+ token_del = [
309
+ f"{operator}_POS_{i}_{char}"
310
+ for operator, i, char in itertools.product(['DEL'], range(12), bases)
311
+ ]
312
+ token_rep = [
313
+ f"{operator}_POS_{i}_{char1}_{char2}"
314
+ for operator, i, char1, char2 in itertools.product(['SUB'], range(12), bases, bases)
315
+ if char1 != char2
316
+ ]
317
+
318
+ token_wc = [
319
+ f"{operator}_POS_{i}_*_{char}"
320
+ for operator, i, char in itertools.product(['WC'], range(12), bases)
321
+ ]
322
+
323
+ token_ins = [
324
+ f"{operator}_POS_{i}_{char}"
325
+ for operator, i, char in itertools.product(['INS'], range(13), bases)
326
+ ]
327
+
328
+ motif_wildcarded = []
329
+ with open(os.path.join('/storage2/fs1/btc/Active/yeli/xiaoxiao.zhou/tokenize/tokenizers/tokenizer_v4/hg38', "motifs_wildcard.txt"), "r") as file:
330
+ for line in file:
331
+ seq, operations = line.strip().split(maxsplit=1) # Split only on the first space
332
+ motif_wildcarded.append(operations.split()[0]) # Store in dictionary
333
+
334
+ tokenizer.add_tokens(token_del + token_rep + token_wc + token_ins + motif_wildcarded)
335
+ return tokenizer
336
+
337
+ def train():
338
+
339
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
340
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
341
+
342
+ wandb.init(
343
+ project=training_args.project_name,
344
+ )
345
+
346
+ tokenizer_kwargs = {
347
+ "cache_dir": training_args.cache_dir,
348
+ "model_max_length": training_args.model_max_length,
349
+ "padding_side": "right",
350
+ "use_fast": True,
351
+ "trust_remote_code": model_args.trust_remote_code # 除非必要否则建议保持False
352
+ }
353
+
354
+ if training_args.vocab_file is not None:
355
+ if not os.path.exists(training_args.vocab_file):
356
+ raise ValueError(f"Vocab file not found at: {training_args.vocab_file}")
357
+ tokenizer_kwargs["vocab_file"] = training_args.vocab_file
358
+
359
+ if data_args.customized_tokenizer == 'token_v4':
360
+ tokenizer = load_token_v4(tokenizer_kwargs)
361
+
362
+ elif data_args.customized_tokenizer == 'token_v5_1':
363
+ tokenizer = load_token_v5_1(tokenizer_kwargs)
364
+
365
+ else:
366
+ tokenizer = transformers.PreTrainedTokenizerFast(
367
+ tokenizer_file=model_args.tokenizer_path,
368
+ **tokenizer_kwargs
369
+ )
370
+
371
+ tokenizer.pad_token = "[PAD]"
372
+ tokenizer.unk_token = "[UNK]"
373
+ tokenizer.cls_token = "[CLS]"
374
+ tokenizer.sep_token = "[SEP]"
375
+ tokenizer.mask_token = "[MASK]"
376
+
377
+ if "InstaDeepAI" in model_args.model_name_or_path:
378
+ tokenizer.eos_token = tokenizer.pad_token
379
+
380
+ # define datasets and data collator
381
+ train_dataset = SupervisedDataset(tokenizer=tokenizer,
382
+ data_path=os.path.join(data_args.data_path, "train.csv"),
383
+ kmer=data_args.kmer,
384
+ customized_tokenizer=data_args.customized_tokenizer)
385
+ val_dataset = SupervisedDataset(tokenizer=tokenizer,
386
+ data_path=os.path.join(data_args.data_path, "dev.csv"),
387
+ kmer=data_args.kmer,
388
+ customized_tokenizer=data_args.customized_tokenizer)
389
+ test_dataset = SupervisedDataset(tokenizer=tokenizer,
390
+ data_path=os.path.join(data_args.data_path, "test.csv"),
391
+ kmer=data_args.kmer,
392
+ customized_tokenizer=data_args.customized_tokenizer)
393
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
394
+
395
+
396
+ config = transformers.AutoConfig.from_pretrained(
397
+ model_args.model_name_or_path,
398
+ num_labels = train_dataset.num_labels,
399
+ trust_remote_code=model_args.trust_remote_code
400
+ )
401
+
402
+ model = transformers.AutoModelForSequenceClassification.from_pretrained(
403
+ model_args.model_name_or_path,
404
+ cache_dir=training_args.cache_dir,
405
+ config=config, # pass the adjusted config
406
+ trust_remote_code=model_args.trust_remote_code
407
+ ).to("cuda")
408
+
409
+ # configure LoRA
410
+ if model_args.use_lora:
411
+ lora_config = LoraConfig(
412
+ r=model_args.lora_r,
413
+ lora_alpha=model_args.lora_alpha,
414
+ target_modules=list(model_args.lora_target_modules.split(",")),
415
+ lora_dropout=model_args.lora_dropout,
416
+ bias="none",
417
+ task_type="SEQ_CLS",
418
+ inference_mode=False,
419
+ )
420
+ model = get_peft_model(model, lora_config)
421
+ model.print_trainable_parameters()
422
+
423
+ # define trainer
424
+ trainer = transformers.Trainer(model=model,
425
+ tokenizer=tokenizer,
426
+ args=training_args,
427
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
428
+ compute_metrics=compute_metrics,
429
+ train_dataset=train_dataset,
430
+ eval_dataset=val_dataset,
431
+ data_collator=data_collator)
432
+ trainer.train()
433
+
434
+ if training_args.save_model:
435
+ trainer.save_state()
436
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
437
+
438
+ # get the evaluation results from trainer
439
+ if training_args.eval_and_save_results:
440
+ results_path = os.path.join(training_args.output_dir, "results", training_args.run_name)
441
+ results = trainer.evaluate(eval_dataset=test_dataset)
442
+ os.makedirs(results_path, exist_ok=True)
443
+ with open(os.path.join(results_path, "eval_results.json"), "w") as f:
444
+ json.dump(results, f)
445
+
446
+
447
+
448
+
449
+ if __name__ == "__main__":
450
+
451
+ train()