voidful commited on
Commit
e7affe4
·
verified ·
1 Parent(s): b2a8a9e

Training in progress, step 200

Browse files
model-00003-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:94fd27588ea0cf8a350c88bacfecdf465ab06e718fe81d15c5d103708821ef5c
3
  size 4988522632
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c2110db2c404990f87047aec54bead5bd1b1220df35153c4e554f8cea41c78c
3
  size 4988522632
model-00004-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2c2834790fc01798e9e85481469cc95cb32a7ad752c03f32a7d3129dd4fcefea
3
  size 2795955204
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87db4707734910285045cd2b90da00c932ea7b3879f6e19bee3b2fb131e77575
3
  size 2795955204
train_conv_slurm_full.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import datasets
5
+ from datasets import load_dataset
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from transformers import Trainer, TrainingArguments
9
+ import wandb
10
+
11
+ from mmlm.model_full import MMLMConfig, MMLM
12
+ from mmlm.utility import load_audio_to_tensor
13
+ import numpy as np
14
+
15
+ # ========================
16
+ # Global Configuration
17
+ # ========================
18
+ WANDB_PROJECT_NAME = "mmlm-conv-full"
19
+ WANDB_API_KEY = "0793be66347fa388f401f66cb39fd661452d660d"
20
+ DATASET = load_dataset("voidful/all_conv_data_filtered_small")['train']
21
+ # DATASET = datasets.load_from_disk("/mnt/home/ntuspeechlabtaipei1/anthony/Soundon-TTS-preprocessing/hf_dialogue_chinese_llama31_70B_user_long_2_with_silence")
22
+ LM_MODEL_NAME = "voidful/Llama-3.2-8B-Whisper"
23
+ OUTPUT_DIR = "/mnt/home/ntuspeechlabtaipei1/mmlm-conv-training-full"
24
+ MODEL_SAVE_PATH = "/mnt/home/ntuspeechlabtaipei1/mmlm-conv-model-full"
25
+ TRAIN_TEST_SPLIT_RATIO = 0.1
26
+ EPOCHS = 300
27
+ BATCH_SIZE = 1
28
+ LEARNING_RATE = 8e-4
29
+ GRADIENT_ACCUMULATION_STEPS = 2
30
+ USE_BF16 = True
31
+ USE_FP16 = False
32
+ LOGGING_STEPS = 1
33
+ SAVE_TOTAL_LIMIT = 10
34
+ GRADIENT_CHECKPOINTING = True
35
+ PAD_VALUE = 0.0
36
+ MAX_LENGTH = 8000
37
+
38
+ # Setup logging
39
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ def initialize_wandb():
44
+ """Initialize Weights and Biases for tracking experiments."""
45
+ wandb.login(key=WANDB_API_KEY)
46
+ wandb.init(
47
+ project=WANDB_PROJECT_NAME,
48
+ config={
49
+ "epochs": EPOCHS,
50
+ "batch_size": BATCH_SIZE,
51
+ "learning_rate": LEARNING_RATE,
52
+ },
53
+ group="mmlm",
54
+ )
55
+
56
+ class CustomDataset(Dataset):
57
+ """Custom dataset class for handling audio-text data."""
58
+
59
+ def __init__(self, data, tokenizer):
60
+ self.data = data
61
+ self.tokenizer = tokenizer
62
+
63
+ def __len__(self):
64
+ return len(self.data)
65
+
66
+ def __getitem__(self, idx):
67
+ entry = self.data
68
+ # print(len(entry[idx]["user_audio_path"]['array']),entry[idx]["user_audio_path"]['array'])
69
+ audio_path = torch.tensor(entry[idx]["user_audio_path"]['array'])
70
+ # if not os.path.exists(audio_path):
71
+ # audio_path = os.path.join("/mnt/home/ntuspeechlabtaipei1/anthony/Soundon-TTS-preprocessing/", audio_path)
72
+ audio_tensor = load_audio_to_tensor(audio_path)[0]
73
+ # print("audio_tensor",audio_tensor.shape,)
74
+ x_vector = entry[idx]["x-vector"]
75
+ text_with_pad = entry[idx]["text_with_pad"]
76
+ user_text_with_pad = text_with_pad[0]
77
+ user_text_with_pad = "[PAD]" + user_text_with_pad
78
+ audio_tensor = torch.cat([audio_tensor[0], torch.zeros(int(24000 * 0.08 * 1))], dim=0).unsqueeze(dim=0)
79
+ # machine_text_with_pad = text_with_pad[1]
80
+ machine_text_with_pad = text_with_pad[1][5:] + "[PAD]"
81
+ audio_unit = np.array(entry[idx]["machine_unit"])
82
+
83
+ zero_sequences = [] # To store start and end times
84
+ start = None # Initialize start as None
85
+ for i, value in enumerate(audio_unit[0]): # Iterate through the first element of the audio tensor
86
+ if value != 0 and start is None:
87
+ start = i # Start of a zero sequence
88
+ elif value == 0 and start is not None:
89
+ # End of a zero sequence
90
+ zero_sequences.append((start * 24000 * 0.08, (i - 1) * 24000 * 0.08))
91
+ start = None
92
+
93
+ # Handle sequence ending at the last element
94
+ if start is not None:
95
+ zero_sequences.append((start * 24000 * 0.08, (len(audio_unit[0]) - 1) * 24000 * 0.08))
96
+
97
+ for i in zero_sequences:
98
+ start, end = i
99
+ start, end = int(start), int(end)
100
+ if end > audio_tensor.size(1):
101
+ end = audio_tensor.size(1)
102
+ audio_tensor[0, start:end] = torch.zeros(end - start)
103
+
104
+ padding_token = 0
105
+ bos_token_id = 0
106
+ eos_token_id = 0
107
+
108
+ audio_unit = np.hstack((audio_unit, np.zeros((audio_unit.shape[0], 1), dtype=int)))
109
+ for i in range(1, audio_unit.shape[0]):
110
+ audio_unit[i, 1:] = audio_unit[i, :-1]
111
+ audio_unit[i, 0] = padding_token
112
+
113
+ matrix_with_bos = np.hstack((np.full((audio_unit.shape[0], 1), bos_token_id), audio_unit))
114
+ matrix_with_bos_eos = np.hstack((matrix_with_bos, np.full((matrix_with_bos.shape[0], 1), eos_token_id)))
115
+ input_audio_unit = matrix_with_bos_eos[:, :-1]
116
+ target_audio_unit = matrix_with_bos_eos[:, 1:]
117
+
118
+ return {
119
+ "input_values": torch.tensor(audio_tensor),
120
+ "speaker_codecs": torch.tensor(input_audio_unit),
121
+ "speaker_codec_labels": torch.tensor(target_audio_unit),
122
+ "speaker_embs": torch.tensor(x_vector[1]),
123
+ "speaker_texts": self.tokenizer(machine_text_with_pad, add_special_tokens=False, return_tensors="pt")[
124
+ "input_ids"],
125
+ "listener_texts": self.tokenizer(user_text_with_pad, add_special_tokens=False, return_tensors="pt")[
126
+ "input_ids"],
127
+ }
128
+
129
+
130
+
131
+ class CustomDataCollator:
132
+ """Custom data collator for batching audio and text inputs."""
133
+
134
+ def __init__(self, text_pad_value, audio_pad_value=PAD_VALUE):
135
+ self.text_pad_value = text_pad_value
136
+ self.audio_pad_value = audio_pad_value
137
+
138
+ def __call__(self, batch):
139
+ return {
140
+ "input_values": torch.cat([item["input_values"] for item in batch]),
141
+ "speaker_codecs": torch.cat([item["speaker_codecs"] for item in batch]),
142
+ "speaker_codec_labels": torch.cat([item["speaker_codec_labels"] for item in batch]),
143
+ "speaker_embs": torch.cat([item["speaker_embs"] for item in batch]),
144
+ "speaker_texts": torch.cat([item["speaker_texts"] for item in batch]),
145
+ "listener_texts": torch.cat([item["listener_texts"] for item in batch]),
146
+ }
147
+
148
+
149
+ def compute_metrics(pred):
150
+ """Compute loss as a metric."""
151
+ pred_logits = pred.predictions
152
+ labels = pred.label_ids
153
+ loss_fn = torch.nn.CrossEntropyLoss()
154
+ return {"loss": loss_fn(torch.tensor(pred_logits), torch.tensor(labels)).item()}
155
+
156
+
157
+ def main():
158
+ # Initialize WandB if in main process
159
+ if int(os.environ.get("LOCAL_RANK", "-1")) == 0:
160
+ initialize_wandb()
161
+
162
+ # Load model and tokenizer
163
+ config = MMLMConfig(lm_model_name=LM_MODEL_NAME)
164
+ model = MMLM(config)
165
+ tokenizer = model.tokenizer
166
+ logger.info("Model and tokenizer loaded.")
167
+
168
+ # Load dataset
169
+ data = DATASET
170
+ logger.info(f"Loaded {len(data)} samples from dataset.")
171
+ data = data.filter(lambda x: x["not_aligned_percentage"] < 0.5)
172
+ logger.info(f"Filtered dataset to {len(data)} samples.")
173
+
174
+ # Split dataset
175
+ # data = data.train_test_split(test_size=0.5, seed=42)
176
+ data = data.shuffle(seed=42)
177
+ subset_size = 100
178
+ data = data.select(range(subset_size))
179
+ train_dataset = CustomDataset(data, tokenizer)
180
+ # eval_dataset = CustomDataset(data['test'], tokenizer)
181
+ # train_dataset = CustomDataset(data.select([0, 1, 2, 3, 4]), tokenizer)
182
+ # eval_dataset = CustomDataset(data.select([0, 1, 2, 3, 4]), tokenizer)
183
+
184
+ # Data collator
185
+ data_collator = CustomDataCollator(tokenizer.pad_token_id)
186
+
187
+ # Define training arguments
188
+ training_args = TrainingArguments(
189
+ output_dir=OUTPUT_DIR,
190
+ evaluation_strategy="no",
191
+ logging_strategy="steps",
192
+ logging_steps=LOGGING_STEPS,
193
+ save_strategy="steps",
194
+ save_steps=200,
195
+ save_total_limit=SAVE_TOTAL_LIMIT,
196
+ num_train_epochs=EPOCHS,
197
+ per_device_train_batch_size=BATCH_SIZE,
198
+ per_device_eval_batch_size=BATCH_SIZE,
199
+ learning_rate=LEARNING_RATE,
200
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
201
+ bf16=USE_BF16,
202
+ fp16=USE_FP16,
203
+ do_eval=False,
204
+ max_grad_norm=1,
205
+ report_to="wandb",
206
+ lr_scheduler_type="linear",
207
+ warmup_steps=100,
208
+ eval_accumulation_steps=1,
209
+ run_name=f"{WANDB_PROJECT_NAME}-training",
210
+ load_best_model_at_end=False,
211
+ gradient_checkpointing=GRADIENT_CHECKPOINTING,
212
+ label_names=["listener_text_labels", "speaker_text_labels"],
213
+ prediction_loss_only=True,
214
+ remove_unused_columns=False,
215
+ push_to_hub=True,
216
+ )
217
+
218
+ # Initialize Trainer
219
+ trainer = Trainer(
220
+ model=model,
221
+ processing_class=tokenizer,
222
+ args=training_args,
223
+ train_dataset=train_dataset,
224
+ data_collator=data_collator,
225
+ compute_metrics=compute_metrics,
226
+ )
227
+
228
+ # Train and evaluate model
229
+ # resume_from_checkpoint = '/mnt/home/ntuspeechlabtaipei1/mmlm-conv-training-fixed-10k/checkpoint-2000/'
230
+ trainer.train()
231
+
232
+ # Save model
233
+ trainer.save_model(MODEL_SAVE_PATH)
234
+ logger.info(f"Model and tokenizer saved to '{MODEL_SAVE_PATH}'.")
235
+
236
+ # Finalize WandB
237
+ wandb.finish()
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
train_conv_slurm_full.sh ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -N 13
3
+ #SBATCH -p tp1-user
4
+ #SBATCH --exclusive
5
+ #SBATCH --ntasks-per-node=1
6
+ #SBATCH --cpus-per-task=200
7
+ #SBATCH --mem=200G
8
+ #SBATCH --gres=gpu:8
9
+ #SBATCH --time=30-00:00:00
10
+ #SBATCH --output=/mnt/home/ntuspeechlabtaipei1/eric/result/%j-slurm.out
11
+ #SBATCH --exclude=cnode3-004,cnode3-019
12
+
13
+ module purge
14
+ module load slurm
15
+
16
+ source /mnt/home/ntuspeechlabtaipei1/miniconda3/etc/profile.d/conda.sh
17
+ conda activate base
18
+
19
+ CONTAINER_IMAGE="./eric/trl.sqsh"
20
+ GPUS_PER_NODE=8
21
+ echo "SLURM_NNODES=${SLURM_NNODES}"
22
+ echo "NODELIST="$SLURM_JOB_NODELIST
23
+ echo "SLURM_NODEID=$SLURM_NODEID"
24
+ echo "SLURM_ARRAY_TASK_ID=$SLURM_ARRAY_TASK_ID"
25
+ export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
26
+ export MASTER_PORT=12345
27
+ export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
28
+ export CUDA_LAUNCH_BLOCKING=1
29
+
30
+ export LD_LIBRARY_PATH=/mnt/home/ntuspeechlabtaipei1/miniconda3/lib64:/mnt/home/ntuspeechlabtaipei1/miniconda3/lib64:/mnt/home/ntuspeechlabtaipei1/local/lib:/mnt/home/ntuspeechlabtaipei1/local/lib:/mnt/home/ntuspeechlabtaipei1/miniconda3/envs/whisper/lib:/usr/local/cuda/lib64:/usr/local/cuda/compat/lib.real:/usr/local/lib/python3.10/dist-packages/torch/lib:/usr/local/lib/python3.10/dist-packages/torch_tensorrt/lib:/usr/local/cuda/compat/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
31
+
32
+ SRUN_ARGS=" \
33
+ --wait=60 \
34
+ --kill-on-bad-exit=1 \
35
+ --mpi=pmix \
36
+ --container-image=${CONTAINER_IMAGE} \
37
+ --container-writable \
38
+ --container-mounts=/mnt/home/ntuspeechlabtaipei1/:/mnt/home/ntuspeechlabtaipei1/,/mnt/home/ntuspeechlabtaipei1/.cache:/root/.cache \
39
+ "
40
+
41
+ PRE_LAUNCH="export TORCH_DISTRIBUTED_TIMEOUT=7200; source /mnt/home/ntuspeechlabtaipei1/miniconda3/etc/profile.d/conda.sh; conda activate base;"
42
+
43
+ LAUNCHER="accelerate launch \
44
+ --num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \
45
+ --num_machines $SLURM_NNODES \
46
+ --machine_rank \${SLURM_NODEID} \
47
+ --rdzv_backend c10d \
48
+ --main_process_ip $MASTER_ADDR \
49
+ --main_process_port $MASTER_PORT \
50
+ --deepspeed_config_file /mnt/home/ntuspeechlabtaipei1/ds_config.json \
51
+ --deepspeed_hostfile /mnt/home/ntuspeechlabtaipei1/eric/hostfile \
52
+ --deepspeed_multinode_launcher standard \
53
+ --dynamo_backend no \
54
+ --use_deepspeed \
55
+ --mixed_precision bf16 \
56
+ "
57
+
58
+ CMD="/mnt/home/ntuspeechlabtaipei1/train_conv_slurm_full.py"
59
+
60
+ clear; srun $SRUN_ARGS bash -c "$PRE_LAUNCH$LAUNCHER $CMD"
61
+ echo "END TIME: $(date)"
62
+
63
+
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d9701fce783f2c198703e90b589d200d16c6ffa646d130d05bc0cdc13adb039e
3
  size 7672
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3caa3fd1c46e285f96aeb9d09d1de977f5879a40bd5efb283b0bbc50d1873349
3
  size 7672