|
|
|
|
|
|
|
|
import itertools |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional |
|
|
|
|
|
import argbind |
|
|
import torch |
|
|
from tensorboardX import SummaryWriter |
|
|
from torch.optim import AdamW |
|
|
from transformers import get_cosine_schedule_with_warmup |
|
|
|
|
|
from voxcpm.model import VoxCPMModel |
|
|
from voxcpm.model.voxcpm import LoRAConfig |
|
|
from voxcpm.training import ( |
|
|
Accelerator, |
|
|
BatchProcessor, |
|
|
TrainingTracker, |
|
|
build_dataloader, |
|
|
load_audio_text_datasets, |
|
|
) |
|
|
|
|
|
|
|
|
@argbind.bind(without_prefix=True) |
|
|
def train( |
|
|
pretrained_path: str, |
|
|
train_manifest: str, |
|
|
val_manifest: str = "", |
|
|
sample_rate: int = 16_000, |
|
|
batch_size: int = 1, |
|
|
grad_accum_steps: int = 1, |
|
|
num_workers: int = 2, |
|
|
num_iters: int = 100_000, |
|
|
log_interval: int = 100, |
|
|
valid_interval: int = 1_000, |
|
|
save_interval: int = 10_000, |
|
|
learning_rate: float = 1e-4, |
|
|
weight_decay: float = 1e-2, |
|
|
warmup_steps: int = 1_000, |
|
|
max_steps: int = 100_000, |
|
|
max_batch_tokens: int = 0, |
|
|
save_path: str = "checkpoints", |
|
|
tensorboard: str = "", |
|
|
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0}, |
|
|
lora: dict = None, |
|
|
config_path: str = "", |
|
|
): |
|
|
_ = config_path |
|
|
accelerator = Accelerator(amp=True) |
|
|
|
|
|
save_dir = Path(save_path) |
|
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
tb_dir = Path(tensorboard) if tensorboard else save_dir / "logs" |
|
|
tb_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None |
|
|
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank) |
|
|
|
|
|
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None) |
|
|
tokenizer = base_model.text_tokenizer |
|
|
|
|
|
train_ds, val_ds = load_audio_text_datasets( |
|
|
train_manifest=train_manifest, |
|
|
val_manifest=val_manifest, |
|
|
sample_rate=sample_rate, |
|
|
) |
|
|
|
|
|
def tokenize(batch): |
|
|
text_list = batch["text"] |
|
|
text_ids = [tokenizer(text) for text in text_list] |
|
|
return {"text_ids": text_ids} |
|
|
|
|
|
train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"]) |
|
|
if val_ds is not None: |
|
|
val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"]) |
|
|
|
|
|
dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1 |
|
|
num_train_samples = len(train_ds) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if max_batch_tokens and max_batch_tokens > 0: |
|
|
from voxcpm.training.data import compute_sample_lengths |
|
|
|
|
|
est_lengths = compute_sample_lengths( |
|
|
train_ds, |
|
|
audio_vae_fps=25, |
|
|
patch_size=base_model.config.patch_size, |
|
|
) |
|
|
max_sample_len = max_batch_tokens // batch_size if batch_size > 0 else max(est_lengths) |
|
|
keep_indices = [i for i, L in enumerate(est_lengths) if L <= max_sample_len] |
|
|
|
|
|
if len(keep_indices) < len(train_ds) and accelerator.rank == 0: |
|
|
tracker.print( |
|
|
f"Filtering {len(train_ds) - len(keep_indices)} / {len(train_ds)} " |
|
|
f"training samples longer than {max_sample_len} tokens " |
|
|
f"(max_batch_tokens={max_batch_tokens})." |
|
|
) |
|
|
train_ds = train_ds.select(keep_indices) |
|
|
|
|
|
train_loader = build_dataloader( |
|
|
train_ds, |
|
|
accelerator=accelerator, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers, |
|
|
drop_last=True, |
|
|
) |
|
|
val_loader = ( |
|
|
build_dataloader( |
|
|
val_ds, |
|
|
accelerator=accelerator, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers, |
|
|
drop_last=False, |
|
|
) |
|
|
if val_ds is not None |
|
|
else None |
|
|
) |
|
|
|
|
|
model = accelerator.prepare_model(base_model) |
|
|
unwrapped_model = accelerator.unwrap(model) |
|
|
unwrapped_model.train() |
|
|
|
|
|
batch_processor = BatchProcessor( |
|
|
config=unwrapped_model.config, |
|
|
audio_vae=unwrapped_model.audio_vae, |
|
|
dataset_cnt=dataset_cnt, |
|
|
device=accelerator.device, |
|
|
) |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
print(name, param.requires_grad) |
|
|
|
|
|
optimizer = AdamW( |
|
|
(p for p in model.parameters() if p.requires_grad), |
|
|
lr=learning_rate, |
|
|
weight_decay=weight_decay, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_training_steps = max_steps if max_steps > 0 else num_iters |
|
|
scheduler = get_cosine_schedule_with_warmup( |
|
|
optimizer, |
|
|
num_warmup_steps=warmup_steps, |
|
|
num_training_steps=total_training_steps, |
|
|
) |
|
|
|
|
|
train_iter = iter(itertools.cycle(train_loader)) |
|
|
grad_accum_steps = max(int(grad_accum_steps), 1) |
|
|
|
|
|
with tracker.live(): |
|
|
for step in range(num_iters): |
|
|
tracker.step = step |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
|
|
|
loss_dict = {} |
|
|
for micro_step in range(grad_accum_steps): |
|
|
batch = next(train_iter) |
|
|
processed = batch_processor(batch) |
|
|
|
|
|
with accelerator.autocast(dtype=torch.bfloat16): |
|
|
outputs = model( |
|
|
processed["text_tokens"], |
|
|
processed["text_mask"], |
|
|
processed["audio_feats"], |
|
|
processed["audio_mask"], |
|
|
processed["loss_mask"], |
|
|
processed["position_ids"], |
|
|
processed["labels"], |
|
|
progress=step / max(1, num_iters), |
|
|
) |
|
|
|
|
|
total_loss = 0.0 |
|
|
for key, value in outputs.items(): |
|
|
if key.startswith("loss/"): |
|
|
weight = lambdas.get(key, 1.0) |
|
|
loss_value = value * weight / grad_accum_steps |
|
|
total_loss = total_loss + loss_value |
|
|
|
|
|
loss_dict[key] = value.detach() |
|
|
|
|
|
|
|
|
accelerator.backward(total_loss) |
|
|
|
|
|
|
|
|
scaler = getattr(accelerator, "scaler", None) |
|
|
if scaler is not None: |
|
|
scaler.unscale_(optimizer) |
|
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9) |
|
|
|
|
|
accelerator.step(optimizer) |
|
|
accelerator.update() |
|
|
scheduler.step() |
|
|
|
|
|
if step % log_interval == 0: |
|
|
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()} |
|
|
loss_values["lr"] = float(optimizer.param_groups[0]["lr"]) |
|
|
|
|
|
epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples) |
|
|
loss_values["epoch"] = float(epoch) |
|
|
loss_values["grad_norm"] = float(grad_norm) |
|
|
tracker.log_metrics(loss_values, split="train") |
|
|
|
|
|
if val_loader is not None and step % valid_interval == 0 and step != 0: |
|
|
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas) |
|
|
|
|
|
if step % save_interval == 0 and accelerator.rank == 0: |
|
|
save_checkpoint(model, optimizer, scheduler, save_dir, step) |
|
|
|
|
|
if accelerator.rank == 0: |
|
|
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters) |
|
|
if writer: |
|
|
writer.close() |
|
|
|
|
|
|
|
|
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas): |
|
|
model.eval() |
|
|
losses = [] |
|
|
with torch.no_grad(): |
|
|
for batch in itertools.islice(val_loader, 0, 10): |
|
|
processed = batch_processor(batch) |
|
|
with accelerator.autocast(dtype=torch.bfloat16): |
|
|
outputs = model( |
|
|
processed["text_tokens"], |
|
|
processed["text_mask"], |
|
|
processed["audio_feats"], |
|
|
processed["audio_mask"], |
|
|
processed["loss_mask"], |
|
|
processed["position_ids"], |
|
|
processed["labels"], |
|
|
progress=0.0, |
|
|
sample_generate=False, |
|
|
) |
|
|
total = 0.0 |
|
|
for key, value in outputs.items(): |
|
|
if key.startswith("loss/"): |
|
|
total += lambdas.get(key, 1.0) * value |
|
|
losses.append(total.detach()) |
|
|
if losses: |
|
|
mean_loss = torch.stack(losses).mean() |
|
|
tracker.log_metrics({"loss": mean_loss.item()}, split="val") |
|
|
model.train() |
|
|
|
|
|
|
|
|
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int): |
|
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
tag = "latest" if step == 0 else f"step_{step:07d}" |
|
|
folder = save_dir / tag |
|
|
folder.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
unwrapped = model.module if hasattr(model, "module") else model |
|
|
full_state = unwrapped.state_dict() |
|
|
lora_cfg = unwrapped.lora_config |
|
|
if lora_cfg is not None: |
|
|
state_dict = {k: v for k, v in full_state.items() if ("lora_A" in k or "lora_B" in k)} |
|
|
else: |
|
|
state_dict = full_state |
|
|
|
|
|
torch.save({"state_dict": state_dict}, folder / "generator.pth") |
|
|
torch.save(optimizer.state_dict(), folder / "optimizer.pth") |
|
|
torch.save(scheduler.state_dict(), folder / "scheduler.pth") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from voxcpm.training.config import load_yaml_config |
|
|
|
|
|
args = argbind.parse_args() |
|
|
config_file = args.get("config_path") |
|
|
|
|
|
if config_file: |
|
|
yaml_args = load_yaml_config(config_file) |
|
|
train(**yaml_args) |
|
|
else: |
|
|
|
|
|
with argbind.scope(args): |
|
|
train() |
|
|
|
|
|
|