mmlm-conv-training-full / train_conv_slurm_full.py
voidful's picture
Training in progress, step 200
e7affe4 verified
raw
history blame
8.89 kB
import os
import logging
import datasets
from datasets import load_dataset
import torch
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments
import wandb
from mmlm.model_full import MMLMConfig, MMLM
from mmlm.utility import load_audio_to_tensor
import numpy as np
# ========================
# Global Configuration
# ========================
WANDB_PROJECT_NAME = "mmlm-conv-full"
WANDB_API_KEY = "0793be66347fa388f401f66cb39fd661452d660d"
DATASET = load_dataset("voidful/all_conv_data_filtered_small")['train']
# DATASET = datasets.load_from_disk("/mnt/home/ntuspeechlabtaipei1/anthony/Soundon-TTS-preprocessing/hf_dialogue_chinese_llama31_70B_user_long_2_with_silence")
LM_MODEL_NAME = "voidful/Llama-3.2-8B-Whisper"
OUTPUT_DIR = "/mnt/home/ntuspeechlabtaipei1/mmlm-conv-training-full"
MODEL_SAVE_PATH = "/mnt/home/ntuspeechlabtaipei1/mmlm-conv-model-full"
TRAIN_TEST_SPLIT_RATIO = 0.1
EPOCHS = 300
BATCH_SIZE = 1
LEARNING_RATE = 8e-4
GRADIENT_ACCUMULATION_STEPS = 2
USE_BF16 = True
USE_FP16 = False
LOGGING_STEPS = 1
SAVE_TOTAL_LIMIT = 10
GRADIENT_CHECKPOINTING = True
PAD_VALUE = 0.0
MAX_LENGTH = 8000
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def initialize_wandb():
"""Initialize Weights and Biases for tracking experiments."""
wandb.login(key=WANDB_API_KEY)
wandb.init(
project=WANDB_PROJECT_NAME,
config={
"epochs": EPOCHS,
"batch_size": BATCH_SIZE,
"learning_rate": LEARNING_RATE,
},
group="mmlm",
)
class CustomDataset(Dataset):
"""Custom dataset class for handling audio-text data."""
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
entry = self.data
# print(len(entry[idx]["user_audio_path"]['array']),entry[idx]["user_audio_path"]['array'])
audio_path = torch.tensor(entry[idx]["user_audio_path"]['array'])
# if not os.path.exists(audio_path):
# audio_path = os.path.join("/mnt/home/ntuspeechlabtaipei1/anthony/Soundon-TTS-preprocessing/", audio_path)
audio_tensor = load_audio_to_tensor(audio_path)[0]
# print("audio_tensor",audio_tensor.shape,)
x_vector = entry[idx]["x-vector"]
text_with_pad = entry[idx]["text_with_pad"]
user_text_with_pad = text_with_pad[0]
user_text_with_pad = "[PAD]" + user_text_with_pad
audio_tensor = torch.cat([audio_tensor[0], torch.zeros(int(24000 * 0.08 * 1))], dim=0).unsqueeze(dim=0)
# machine_text_with_pad = text_with_pad[1]
machine_text_with_pad = text_with_pad[1][5:] + "[PAD]"
audio_unit = np.array(entry[idx]["machine_unit"])
zero_sequences = [] # To store start and end times
start = None # Initialize start as None
for i, value in enumerate(audio_unit[0]): # Iterate through the first element of the audio tensor
if value != 0 and start is None:
start = i # Start of a zero sequence
elif value == 0 and start is not None:
# End of a zero sequence
zero_sequences.append((start * 24000 * 0.08, (i - 1) * 24000 * 0.08))
start = None
# Handle sequence ending at the last element
if start is not None:
zero_sequences.append((start * 24000 * 0.08, (len(audio_unit[0]) - 1) * 24000 * 0.08))
for i in zero_sequences:
start, end = i
start, end = int(start), int(end)
if end > audio_tensor.size(1):
end = audio_tensor.size(1)
audio_tensor[0, start:end] = torch.zeros(end - start)
padding_token = 0
bos_token_id = 0
eos_token_id = 0
audio_unit = np.hstack((audio_unit, np.zeros((audio_unit.shape[0], 1), dtype=int)))
for i in range(1, audio_unit.shape[0]):
audio_unit[i, 1:] = audio_unit[i, :-1]
audio_unit[i, 0] = padding_token
matrix_with_bos = np.hstack((np.full((audio_unit.shape[0], 1), bos_token_id), audio_unit))
matrix_with_bos_eos = np.hstack((matrix_with_bos, np.full((matrix_with_bos.shape[0], 1), eos_token_id)))
input_audio_unit = matrix_with_bos_eos[:, :-1]
target_audio_unit = matrix_with_bos_eos[:, 1:]
return {
"input_values": torch.tensor(audio_tensor),
"speaker_codecs": torch.tensor(input_audio_unit),
"speaker_codec_labels": torch.tensor(target_audio_unit),
"speaker_embs": torch.tensor(x_vector[1]),
"speaker_texts": self.tokenizer(machine_text_with_pad, add_special_tokens=False, return_tensors="pt")[
"input_ids"],
"listener_texts": self.tokenizer(user_text_with_pad, add_special_tokens=False, return_tensors="pt")[
"input_ids"],
}
class CustomDataCollator:
"""Custom data collator for batching audio and text inputs."""
def __init__(self, text_pad_value, audio_pad_value=PAD_VALUE):
self.text_pad_value = text_pad_value
self.audio_pad_value = audio_pad_value
def __call__(self, batch):
return {
"input_values": torch.cat([item["input_values"] for item in batch]),
"speaker_codecs": torch.cat([item["speaker_codecs"] for item in batch]),
"speaker_codec_labels": torch.cat([item["speaker_codec_labels"] for item in batch]),
"speaker_embs": torch.cat([item["speaker_embs"] for item in batch]),
"speaker_texts": torch.cat([item["speaker_texts"] for item in batch]),
"listener_texts": torch.cat([item["listener_texts"] for item in batch]),
}
def compute_metrics(pred):
"""Compute loss as a metric."""
pred_logits = pred.predictions
labels = pred.label_ids
loss_fn = torch.nn.CrossEntropyLoss()
return {"loss": loss_fn(torch.tensor(pred_logits), torch.tensor(labels)).item()}
def main():
# Initialize WandB if in main process
if int(os.environ.get("LOCAL_RANK", "-1")) == 0:
initialize_wandb()
# Load model and tokenizer
config = MMLMConfig(lm_model_name=LM_MODEL_NAME)
model = MMLM(config)
tokenizer = model.tokenizer
logger.info("Model and tokenizer loaded.")
# Load dataset
data = DATASET
logger.info(f"Loaded {len(data)} samples from dataset.")
data = data.filter(lambda x: x["not_aligned_percentage"] < 0.5)
logger.info(f"Filtered dataset to {len(data)} samples.")
# Split dataset
# data = data.train_test_split(test_size=0.5, seed=42)
data = data.shuffle(seed=42)
subset_size = 100
data = data.select(range(subset_size))
train_dataset = CustomDataset(data, tokenizer)
# eval_dataset = CustomDataset(data['test'], tokenizer)
# train_dataset = CustomDataset(data.select([0, 1, 2, 3, 4]), tokenizer)
# eval_dataset = CustomDataset(data.select([0, 1, 2, 3, 4]), tokenizer)
# Data collator
data_collator = CustomDataCollator(tokenizer.pad_token_id)
# Define training arguments
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
evaluation_strategy="no",
logging_strategy="steps",
logging_steps=LOGGING_STEPS,
save_strategy="steps",
save_steps=200,
save_total_limit=SAVE_TOTAL_LIMIT,
num_train_epochs=EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
learning_rate=LEARNING_RATE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
bf16=USE_BF16,
fp16=USE_FP16,
do_eval=False,
max_grad_norm=1,
report_to="wandb",
lr_scheduler_type="linear",
warmup_steps=100,
eval_accumulation_steps=1,
run_name=f"{WANDB_PROJECT_NAME}-training",
load_best_model_at_end=False,
gradient_checkpointing=GRADIENT_CHECKPOINTING,
label_names=["listener_text_labels", "speaker_text_labels"],
prediction_loss_only=True,
remove_unused_columns=False,
push_to_hub=True,
)
# Initialize Trainer
trainer = Trainer(
model=model,
processing_class=tokenizer,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# Train and evaluate model
# resume_from_checkpoint = '/mnt/home/ntuspeechlabtaipei1/mmlm-conv-training-fixed-10k/checkpoint-2000/'
trainer.train()
# Save model
trainer.save_model(MODEL_SAVE_PATH)
logger.info(f"Model and tokenizer saved to '{MODEL_SAVE_PATH}'.")
# Finalize WandB
wandb.finish()
if __name__ == "__main__":
main()