Spaces:
Runtime error
Runtime error
import os | |
from datetime import datetime | |
from pathlib import Path | |
import torch | |
import typer | |
from accelerate import Accelerator | |
from accelerate.utils import LoggerType | |
from torch import Tensor | |
from torch.optim import AdamW | |
# from torch.optim.lr_scheduler import ReduceLROnPlateau | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from data import MusdbDataset | |
from splitter import Splitter | |
DISABLE_TQDM = os.environ.get("DISABLE_TQDM", False) | |
app = typer.Typer(pretty_exceptions_show_locals=False) | |
def spectrogram_loss(masked_target: Tensor, original: Tensor) -> Tensor: | |
""" | |
masked_target (Tensor): a masked STFT generated by applying a net's | |
estimated mask for source S to the ground truth STFT for source S | |
original (Tensor): an original input mixture | |
""" | |
square_difference = torch.square(masked_target - original) | |
loss_value = torch.mean(square_difference) | |
return loss_value | |
def train( | |
dataset: str = "data/musdb18-wav", | |
output_dir: str = None, | |
fp16: bool = False, | |
cpu: bool = True, | |
max_steps: int = 100, | |
num_train_epochs: int = 1, | |
per_device_train_batch_size: int = 1, | |
effective_batch_size: int = 4, | |
max_grad_norm: float = 0.0, | |
) -> None: | |
if not output_dir: | |
now_str = datetime.now().strftime("%Y%m%d-%H%M%S") | |
output_dir = f"experiments/{now_str}" | |
output_dir = Path(output_dir) | |
logging_dir = output_dir / "tracker_logs" | |
accelerator = Accelerator( | |
fp16=fp16, | |
cpu=cpu, | |
logging_dir=logging_dir, | |
log_with=[LoggerType.TENSORBOARD], | |
) | |
accelerator.init_trackers(logging_dir / "run") | |
train_dataset = MusdbDataset(root=dataset, is_train=True) | |
train_dataloader = DataLoader( | |
train_dataset, | |
shuffle=True, | |
batch_size=per_device_train_batch_size, | |
) | |
model = Splitter(stem_names=[s for s in train_dataset.targets]) | |
optimizer = AdamW( | |
model.parameters(), | |
lr=1e-3, | |
eps=1e-8, | |
) | |
model, optimizer, train_dataloader = accelerator.prepare( | |
model, optimizer, train_dataloader | |
) | |
num_train_steps = ( | |
max_steps if max_steps > 0 else len(train_dataloader) * num_train_epochs | |
) | |
accelerator.print(f"Num train steps: {num_train_steps}") | |
step_batch_size = per_device_train_batch_size * accelerator.num_processes | |
gradient_accumulation_steps = max( | |
1, | |
effective_batch_size // step_batch_size, | |
) | |
accelerator.print( | |
f"Gradient Accumulation Steps: {gradient_accumulation_steps}\nEffective Batch Size: {gradient_accumulation_steps * step_batch_size}" | |
) | |
global_step = 0 | |
while global_step < num_train_steps: | |
accelerator.wait_for_everyone() | |
# accelerator.print(f"global step: {global_step}") | |
# accelerator.print("running train...") | |
model.train() | |
batch_iterator = tqdm( | |
train_dataloader, | |
desc="Batch", | |
disable=((not accelerator.is_local_main_process) or DISABLE_TQDM), | |
) | |
for batch_idx, batch in enumerate(batch_iterator): | |
assert per_device_train_batch_size == 1, "For now limit to 1." | |
x_wav, y_target_wavs = batch | |
predictions = model(x_wav) | |
stem_losses = [] | |
for name, masked_stft in predictions.items(): | |
target_stft, _ = model.compute_stft(y_target_wavs[name].squeeze()) | |
loss = spectrogram_loss( | |
masked_target=masked_stft, | |
original=target_stft, | |
) | |
stem_losses.append(loss) | |
accelerator.log({f"train-loss-{name}": 1.0 * loss}, step=global_step) | |
total_loss = ( | |
torch.sum(torch.stack(stem_losses)) / gradient_accumulation_steps | |
) | |
accelerator.print(f"global step: {global_step}\tloss: {total_loss:.4f}") | |
accelerator.log({f"train-loss": 1.0 * total_loss}, step=global_step) | |
accelerator.backward(total_loss) | |
if (batch_idx + 1) % gradient_accumulation_steps == 0: | |
if max_grad_norm > 0: | |
accelerator.clip_grad_norm_(model.parameters(), max_grad_norm) | |
optimizer.step() | |
optimizer.zero_grad() | |
global_step += 1 | |
accelerator.wait_for_everyone() | |
accelerator.end_training() | |
accelerator.print(f"Saving model to {output_dir}...") | |
unwrapped_model = accelerator.unwrap_model(model) | |
unwrapped_model.save_pretrained( | |
output_dir, | |
save_function=accelerator.save, | |
state_dict=accelerator.get_state_dict(model), | |
) | |
accelerator.wait_for_everyone() | |
accelerator.print("DONE!") | |
if __name__ == "__main__": | |
app() | |