import torch from torch.utils.data import DataLoader from transformers import AdamW, get_linear_schedule_with_warmup # from utils.dataset import BDTtsDataset from inference import tts # reuse your model training_config = { "learning_rate": 1e-4, "batch_size": 16, "warmup_steps": 1000, "gradient_accumulation_steps": 4, "mixed_precision": True, "save_strategy": "steps", "save_steps": 500, "eval_steps": 100, "num_epochs": 5 } def train(): dataset = BDTtsDataset("./data/train") dataloader = DataLoader(dataset, batch_size=training_config["batch_size"], shuffle=True) optimizer = AdamW(tts.model.parameters(), lr=training_config["learning_rate"]) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=training_config["warmup_steps"], num_training_steps=len(dataloader) * training_config["num_epochs"] ) scaler = torch.cuda.amp.GradScaler() if training_config["mixed_precision"] else None step = 0 for epoch in range(training_config["num_epochs"]): for batch in dataloader: inputs, targets = batch optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=scaler is not None): outputs = tts.model(inputs) loss = outputs.loss if hasattr(outputs, "loss") else torch.nn.functional.mse_loss(outputs, targets) if scaler: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss.backward() optimizer.step() scheduler.step() step += 1 if step % training_config["save_steps"] == 0: torch.save(tts.model.state_dict(), f"checkpoints/model_step{step}.pth") print(f"Saved checkpoint at step {step}") if __name__ == "__main__": train()