|
|
import math |
|
|
import os.path |
|
|
from collections import defaultdict |
|
|
from itertools import chain, combinations |
|
|
from pprint import pprint |
|
|
from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict |
|
|
|
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
import torchaudio as ta |
|
|
import torchmetrics as tm |
|
|
from torch import nn, optim |
|
|
from torch.optim import lr_scheduler |
|
|
from torch.optim.lr_scheduler import LRScheduler |
|
|
from tqdm import tqdm |
|
|
|
|
|
from torch.nn import functional as F |
|
|
|
|
|
from core.types import BatchedInputOutput, OperationMode, RawInputType, SimpleishNamespace |
|
|
from core.types import ( |
|
|
InputType, |
|
|
OutputType, |
|
|
LossOutputType, |
|
|
MetricOutputType, |
|
|
ModelType, |
|
|
OptimizerType, |
|
|
SchedulerType, |
|
|
MetricType, |
|
|
LossType, |
|
|
OptimizationBundle, |
|
|
LossHandler, |
|
|
MetricHandler, |
|
|
AugmentationHandler, |
|
|
InferenceHandler, |
|
|
) |
|
|
|
|
|
|
|
|
class EndToEndLightningSystem(pl.LightningModule): |
|
|
def __init__( |
|
|
self, |
|
|
model: ModelType, |
|
|
loss_handler: LossHandler, |
|
|
metrics: MetricHandler, |
|
|
augmentation_handler: AugmentationHandler, |
|
|
inference_handler: InferenceHandler, |
|
|
optimization_bundle: OptimizationBundle, |
|
|
fast_run: bool = False, |
|
|
commitment_weight: float = 1.0, |
|
|
batch_size: Optional[int] = None, |
|
|
effective_batch_size: Optional[int] = None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.model = model |
|
|
|
|
|
self.loss = loss_handler |
|
|
|
|
|
self.metrics = metrics |
|
|
self.optimization = optimization_bundle |
|
|
self.augmentation = augmentation_handler |
|
|
self.inference = inference_handler |
|
|
|
|
|
self.fast_run = fast_run |
|
|
|
|
|
self.model.fast_run = fast_run |
|
|
|
|
|
self.commitment_weight = commitment_weight |
|
|
|
|
|
self.batch_size = batch_size |
|
|
self.effective_batch_size = effective_batch_size if effective_batch_size is not None else batch_size |
|
|
self.accum_ratio = self.effective_batch_size // self.batch_size if self.effective_batch_size is not None else 1 |
|
|
|
|
|
self.output_dir = None |
|
|
self.split_size = None |
|
|
|
|
|
def configure_optimizers(self) -> Any: |
|
|
optimizer = self.optimization.optimizer.cls( |
|
|
self.model.parameters(), |
|
|
**self.optimization.optimizer.kwargs |
|
|
) |
|
|
|
|
|
ret = { |
|
|
"optimizer": optimizer, |
|
|
} |
|
|
|
|
|
if self.optimization.scheduler is not None: |
|
|
scheduler = self.optimization.scheduler.cls( |
|
|
optimizer, |
|
|
**self.optimization.scheduler.kwargs |
|
|
) |
|
|
ret["lr_scheduler"] = scheduler |
|
|
|
|
|
return ret |
|
|
|
|
|
def compute_loss( |
|
|
self, |
|
|
batch: BatchedInputOutput, |
|
|
mode=OperationMode.TRAIN |
|
|
) -> LossOutputType: |
|
|
loss_dict = self.loss(batch) |
|
|
return loss_dict |
|
|
|
|
|
|
|
|
def update_metrics( |
|
|
self, |
|
|
batch: BatchedInputOutput, |
|
|
mode: OperationMode = OperationMode.TRAIN, |
|
|
) -> None: |
|
|
metrics: MetricType = self.metrics.get_mode(mode) |
|
|
|
|
|
for stem, metric in metrics.items(): |
|
|
if stem not in batch.estimates.keys(): |
|
|
continue |
|
|
metric.update(batch) |
|
|
|
|
|
|
|
|
def compute_metrics(self, mode: OperationMode) -> MetricOutputType: |
|
|
metrics: MetricType = self.metrics.get_mode(mode) |
|
|
|
|
|
metric_dict = {} |
|
|
|
|
|
for stem, metric in metrics.items(): |
|
|
md = metric.compute() |
|
|
metric_dict.update({f"{stem}/{k}": v for k, v in md.items()}) |
|
|
|
|
|
self.log_dict(metric_dict, prog_bar=True, logger=False) |
|
|
|
|
|
return metric_dict |
|
|
|
|
|
|
|
|
def reset_metrics(self, mode: OperationMode) -> None: |
|
|
metrics: MetricType = self.metrics.get_mode(mode) |
|
|
|
|
|
for _, metric in metrics.items(): |
|
|
metric.reset() |
|
|
|
|
|
def forward(self, batch: RawInputType) -> Tuple[InputType, OutputType]: |
|
|
batch = self.model(batch) |
|
|
return batch |
|
|
|
|
|
def common_step( |
|
|
self, batch: RawInputType, mode: OperationMode, batch_idx: int = -1 |
|
|
) -> Tuple[OutputType, LossOutputType]: |
|
|
batch = BatchedInputOutput.from_dict(batch) |
|
|
batch = self.forward(batch) |
|
|
|
|
|
loss_dict = self.compute_loss(batch, mode=mode) |
|
|
|
|
|
if not self.fast_run: |
|
|
with torch.no_grad(): |
|
|
self.update_metrics(batch, mode=mode) |
|
|
|
|
|
return loss_dict |
|
|
|
|
|
def training_step(self, batch: RawInputType, batch_idx: int) -> LossOutputType: |
|
|
|
|
|
|
|
|
self.model.train() |
|
|
|
|
|
loss_dict = self.common_step(batch, mode=OperationMode.TRAIN, batch_idx=batch_idx) |
|
|
|
|
|
self.log_dict_with_prefix(loss_dict, prefix=OperationMode.TRAIN, prog_bar=True) |
|
|
|
|
|
return loss_dict |
|
|
|
|
|
def on_train_batch_end( |
|
|
self, outputs: OutputType, batch: RawInputType, batch_idx: int |
|
|
) -> None: |
|
|
|
|
|
if self.fast_run: |
|
|
return |
|
|
|
|
|
if (batch_idx + 1) % self.accum_ratio == 0: |
|
|
metric_dict = self.compute_metrics(mode=OperationMode.TRAIN) |
|
|
self.log_dict_with_prefix(metric_dict, prefix=OperationMode.TRAIN) |
|
|
self.reset_metrics(mode=OperationMode.TRAIN) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def validation_step( |
|
|
self, batch: RawInputType, batch_idx: int, dataloader_idx: int = 0 |
|
|
) -> Dict[str, Any]: |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
with torch.inference_mode(): |
|
|
loss_dict = self.common_step(batch, mode=OperationMode.VAL) |
|
|
|
|
|
self.log_dict_with_prefix(loss_dict, prefix=OperationMode.VAL) |
|
|
|
|
|
return loss_dict |
|
|
|
|
|
def on_validation_epoch_start(self) -> None: |
|
|
self.reset_metrics(mode=OperationMode.VAL) |
|
|
|
|
|
def on_validation_epoch_end(self) -> None: |
|
|
if self.fast_run: |
|
|
return |
|
|
|
|
|
metric_dict = self.compute_metrics(mode=OperationMode.VAL) |
|
|
self.log_dict_with_prefix( |
|
|
metric_dict, OperationMode.VAL, prog_bar=True, add_dataloader_idx=False |
|
|
) |
|
|
self.reset_metrics(mode=OperationMode.VAL) |
|
|
|
|
|
|
|
|
def save_to_audio(self, batch: BatchedInputOutput, batch_idx: int) -> None: |
|
|
|
|
|
batch_size = batch["mixture"]["audio"].shape[0] |
|
|
|
|
|
assert batch_size == 1, "Batch size must be 1 for inference" |
|
|
|
|
|
metadata = batch.metadata |
|
|
|
|
|
song_id = metadata["mix"][0] |
|
|
stem = metadata["stem"][0] |
|
|
|
|
|
log_dir = os.path.join(self.logger.log_dir, "audio") |
|
|
|
|
|
os.makedirs(os.path.join(log_dir, song_id), exist_ok=True) |
|
|
|
|
|
audio = batch.estimates[stem]["audio"] |
|
|
|
|
|
audio = audio.squeeze(0).cpu().numpy() |
|
|
|
|
|
audio_path = os.path.join(log_dir, song_id, f"{stem}.wav") |
|
|
|
|
|
ta.save(audio_path, torch.tensor(audio), self.inference.fs) |
|
|
|
|
|
def save_vdbo_to_audio(self, batch: BatchedInputOutput, batch_idx: int) -> None: |
|
|
|
|
|
batch_size = batch["mixture"]["audio"].shape[0] |
|
|
|
|
|
assert batch_size == 1, "Batch size must be 1 for inference" |
|
|
|
|
|
metadata = batch.metadata |
|
|
|
|
|
song_id = metadata["song_id"][0] |
|
|
|
|
|
log_dir = os.path.join(self.logger.log_dir, "audio") |
|
|
|
|
|
os.makedirs(os.path.join(log_dir, song_id), exist_ok=True) |
|
|
|
|
|
for stem, audio in batch.estimates.items(): |
|
|
audio = audio["audio"] |
|
|
audio = audio.squeeze(0).cpu().numpy() |
|
|
|
|
|
audio_path = os.path.join(log_dir, song_id, f"{stem}.wav") |
|
|
|
|
|
ta.save(audio_path, torch.tensor(audio), self.inference.fs) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def chunked_inference( |
|
|
self, batch: RawInputType, batch_idx: int = -1, dataloader_idx: int = 0 |
|
|
) -> BatchedInputOutput: |
|
|
batch = BatchedInputOutput.from_dict(batch) |
|
|
|
|
|
audio = batch["mixture"]["audio"] |
|
|
|
|
|
b, c, n_samples = audio.shape |
|
|
|
|
|
assert b == 1 |
|
|
|
|
|
fs = self.inference.fs |
|
|
|
|
|
chunk_size = int(self.inference.chunk_size_seconds * fs) |
|
|
hop_size = int(self.inference.hop_size_seconds * fs) |
|
|
|
|
|
batch_size = self.inference.batch_size |
|
|
|
|
|
overlap = chunk_size - hop_size |
|
|
|
|
|
scaler = chunk_size / (2 * hop_size) |
|
|
|
|
|
n_chunks = int(math.ceil( |
|
|
(n_samples + 4 * overlap - chunk_size) / hop_size |
|
|
)) + 1 |
|
|
|
|
|
pad = (n_chunks - 1) * hop_size + chunk_size - n_samples |
|
|
|
|
|
|
|
|
audio = F.pad( |
|
|
audio, |
|
|
pad=(2 * overlap, 2 * overlap + pad), |
|
|
mode="reflect" |
|
|
) |
|
|
padded_length = audio.shape[-1] |
|
|
audio = audio.reshape(c, 1, -1, 1) |
|
|
|
|
|
chunked_audio = F.unfold( |
|
|
audio, |
|
|
kernel_size=(chunk_size, 1), |
|
|
stride=(hop_size, 1) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
chunked_audio = chunked_audio.permute(2, 0, 1).reshape(-1, c, chunk_size) |
|
|
|
|
|
n_chunks = chunked_audio.shape[0] |
|
|
|
|
|
n_batch = math.ceil(n_chunks / batch_size) |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
for i in tqdm(range(n_batch)): |
|
|
start = i * batch_size |
|
|
end = min((i + 1) * batch_size, n_chunks) |
|
|
|
|
|
chunked_batch = SimpleishNamespace( |
|
|
mixture={ |
|
|
"audio": chunked_audio[start:end] |
|
|
}, |
|
|
query=batch["query"], |
|
|
estimates=batch["estimates"] |
|
|
) |
|
|
|
|
|
output = self.forward(chunked_batch) |
|
|
outputs.append(output.estimates["target"]["audio"]) |
|
|
|
|
|
output = torch.cat(outputs, dim=0) |
|
|
window = torch.hann_window(chunk_size, device=self.device).reshape(1, 1, chunk_size) |
|
|
output = output * window / scaler |
|
|
|
|
|
output = torch.permute(output, (1, 2, 0)) |
|
|
|
|
|
output = F.fold( |
|
|
output, |
|
|
output_size=(padded_length, 1), |
|
|
kernel_size=(chunk_size, 1), |
|
|
stride=(hop_size, 1) |
|
|
) |
|
|
|
|
|
output = output[None, :, 0, 2*overlap: n_samples + 2*overlap, 0] |
|
|
|
|
|
stem = batch.metadata["stem"][0] |
|
|
|
|
|
batch["estimates"][stem] = { |
|
|
"audio": output |
|
|
} |
|
|
|
|
|
return batch |
|
|
|
|
|
def chunked_vdbo_inference( |
|
|
self, batch: RawInputType, batch_idx: int = -1, dataloader_idx: int = 0 |
|
|
) -> BatchedInputOutput: |
|
|
batch = BatchedInputOutput.from_dict(batch) |
|
|
|
|
|
audio = batch["mixture"]["audio"] |
|
|
|
|
|
b, c, n_samples = audio.shape |
|
|
|
|
|
assert b == 1 |
|
|
|
|
|
fs = self.inference.fs |
|
|
|
|
|
chunk_size = int(self.inference.chunk_size_seconds * fs) |
|
|
hop_size = int(self.inference.hop_size_seconds * fs) |
|
|
|
|
|
batch_size = self.inference.batch_size |
|
|
|
|
|
overlap = chunk_size - hop_size |
|
|
|
|
|
scaler = chunk_size / (2 * hop_size) |
|
|
|
|
|
n_chunks = int(math.ceil( |
|
|
(n_samples + 4 * overlap - chunk_size) / hop_size |
|
|
)) + 1 |
|
|
|
|
|
pad = (n_chunks - 1) * hop_size + chunk_size - n_samples |
|
|
|
|
|
|
|
|
audio = F.pad( |
|
|
audio, |
|
|
pad=(2 * overlap, 2 * overlap + pad), |
|
|
mode="reflect" |
|
|
) |
|
|
padded_length = audio.shape[-1] |
|
|
audio = audio.reshape(c, 1, -1, 1) |
|
|
|
|
|
chunked_audio = F.unfold( |
|
|
audio, |
|
|
kernel_size=(chunk_size, 1), |
|
|
stride=(hop_size, 1) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
chunked_audio = chunked_audio.permute(2, 0, 1).reshape(-1, c, chunk_size) |
|
|
|
|
|
n_chunks = chunked_audio.shape[0] |
|
|
|
|
|
n_batch = math.ceil(n_chunks / batch_size) |
|
|
|
|
|
outputs = defaultdict(list) |
|
|
|
|
|
for i in tqdm(range(n_batch)): |
|
|
start = i * batch_size |
|
|
end = min((i + 1) * batch_size, n_chunks) |
|
|
|
|
|
chunked_batch = SimpleishNamespace( |
|
|
mixture={ |
|
|
"audio": chunked_audio[start:end] |
|
|
}, |
|
|
estimates=batch["estimates"] |
|
|
) |
|
|
|
|
|
output = self.forward(chunked_batch) |
|
|
|
|
|
for stem, estimate in output.estimates.items(): |
|
|
outputs[stem].append(estimate["audio"]) |
|
|
|
|
|
for stem, outputs_ in outputs.items(): |
|
|
|
|
|
output = torch.cat(outputs_, dim=0) |
|
|
window = torch.hann_window(chunk_size, device=self.device).reshape(1, 1, chunk_size) |
|
|
output = output * window / scaler |
|
|
|
|
|
output = torch.permute(output, (1, 2, 0)) |
|
|
|
|
|
output = F.fold( |
|
|
output, |
|
|
output_size=(padded_length, 1), |
|
|
kernel_size=(chunk_size, 1), |
|
|
stride=(hop_size, 1) |
|
|
) |
|
|
|
|
|
output = output[None, :, 0, 2*overlap: n_samples + 2*overlap, 0] |
|
|
|
|
|
batch["estimates"][stem] = { |
|
|
"audio": output |
|
|
} |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
def on_test_epoch_start(self) -> None: |
|
|
self.reset_metrics(mode=OperationMode.TEST) |
|
|
|
|
|
def test_step( |
|
|
self, batch: RawInputType, batch_idx: int, dataloader_idx: int = 0 |
|
|
) -> Any: |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
if "query" in batch.keys(): |
|
|
batch = self.chunked_inference(batch, batch_idx, dataloader_idx) |
|
|
else: |
|
|
batch = self.chunked_vdbo_inference(batch, batch_idx, dataloader_idx) |
|
|
|
|
|
self.reset_metrics(mode=OperationMode.TEST) |
|
|
self.update_metrics(batch, mode=OperationMode.TEST) |
|
|
metrics = self.compute_metrics(mode=OperationMode.TEST) |
|
|
|
|
|
self.log_dict_with_prefix(metrics, OperationMode.TEST, |
|
|
on_step=True, on_epoch=False, prog_bar=True) |
|
|
self.reset_metrics(mode=OperationMode.TEST) |
|
|
|
|
|
|
|
|
|
|
|
return batch |
|
|
|
|
|
def on_test_epoch_end(self) -> None: |
|
|
self.reset_metrics(mode=OperationMode.TEST) |
|
|
|
|
|
def set_output_path(self, output_dir: str) -> None: |
|
|
self.output_dir = output_dir |
|
|
|
|
|
def predict_step( |
|
|
self, batch: RawInputType, batch_idx: int, dataloader_idx: int = 0 |
|
|
) -> Any: |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
if "query" in batch.keys(): |
|
|
batch = self.chunked_inference(batch, batch_idx, dataloader_idx) |
|
|
|
|
|
self.save_to_audio(batch, batch_idx) |
|
|
else: |
|
|
batch = self.chunked_vdbo_inference(batch, batch_idx, dataloader_idx) |
|
|
self.save_vdbo_to_audio(batch, batch_idx) |
|
|
|
|
|
def load_state_dict( |
|
|
self, state_dict: Mapping[str, Any], strict: bool = False |
|
|
) -> Any: |
|
|
return super().load_state_dict(state_dict, strict=False) |
|
|
|
|
|
def log_dict_with_prefix( |
|
|
self, |
|
|
dict_: Dict[str, torch.Tensor], |
|
|
prefix: str, |
|
|
batch_size: Optional[int] = None, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
|
|
|
|
|
|
self.log_dict( |
|
|
{f"{prefix}/{k}": v for k, v in dict_.items()}, |
|
|
batch_size=batch_size, |
|
|
logger=True, |
|
|
sync_dist=True, |
|
|
**kwargs, |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
self.logger.save() |
|
|
|