RemFx / remfx /utils.py
mattricesound's picture
Remove unneeded scripts. Change eval to use table 4 datasets. Clean
64a6fed
raw
history blame
5.94 kB
import logging
from typing import List, Tuple
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning.utilities import rank_zero_only
import torch
import torchaudio
from torch import nn
import collections.abc
def get_logger(name=__name__) -> logging.Logger:
"""Initializes multi-GPU-friendly python command line logger."""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in (
"debug",
"info",
"warning",
"error",
"exception",
"fatal",
"critical",
):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
log = get_logger(__name__)
@rank_zero_only
def log_hyperparameters(
config: DictConfig,
model: pl.LightningModule,
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
callbacks: List[pl.Callback],
logger: pl.loggers.logger.Logger,
) -> None:
"""Controls which config parts are saved by Lightning loggers.
Additionaly saves:
- number of model parameters
"""
if not trainer.logger:
return
hparams = {}
# choose which parts of hydra config will be saved to loggers
hparams["model"] = config["model"]
# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
hparams["datamodule"] = config["datamodule"]
hparams["trainer"] = config["trainer"]
if "seed" in config:
hparams["seed"] = config["seed"]
if "callbacks" in config:
hparams["callbacks"] = config["callbacks"]
if type(trainer.logger) == pl.loggers.CSVLogger:
logger.log_hyperparams(hparams)
else:
logger.experiment.config.update(hparams)
def create_random_chunks(
audio_file: str, chunk_size: int, num_chunks: int
) -> Tuple[List[Tuple[int, int]], int]:
"""Create num_chunks random chunks of size chunk_size (seconds)
from an audio file.
Return sample_index of start of each chunk and original sr
"""
audio, sr = torchaudio.load(audio_file)
chunk_size_in_samples = chunk_size * sr
if chunk_size_in_samples >= audio.shape[-1]:
chunk_size_in_samples = audio.shape[-1] - 1
chunks = []
for i in range(num_chunks):
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
chunks.append(start)
return chunks, sr
def create_sequential_chunks(
audio_file: str, chunk_size: int, sample_rate: int
) -> List[torch.Tensor]:
"""Create sequential chunks of size chunk_size from an audio file.
Return each chunk
"""
chunks = []
audio, sr = torchaudio.load(audio_file)
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size)
for start in chunk_starts:
if start + chunk_size > audio.shape[-1]:
break
chunk = audio[:, start : start + chunk_size]
resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
# Skip chunks that are too short
if resampled_chunk.shape[-1] < chunk_size:
continue
chunks.append(chunk)
return chunks
def select_random_chunk(
audio_file: str, chunk_size: int, sample_rate: int
) -> List[torch.Tensor]:
"""Select random chunk of size chunk_size (samples) from an audio file."""
audio, sr = torchaudio.load(audio_file)
new_chunk_size = int(chunk_size * (sr / sample_rate))
if new_chunk_size >= audio.shape[-1]:
return None
max_len = audio.shape[-1] - new_chunk_size
random_start = torch.randint(0, max_len, (1,)).item()
chunk = audio[:, random_start : random_start + new_chunk_size]
# Skip if energy too low
if torch.mean(torch.abs(chunk)) < 1e-4:
return None
resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
return resampled_chunk
def spectrogram(
x: torch.Tensor,
window: torch.Tensor,
n_fft: int,
hop_length: int,
alpha: float,
) -> torch.Tensor:
bs, chs, samp = x.size()
x = x.view(bs * chs, -1) # move channels onto batch dim
X = torch.stft(
x,
n_fft=n_fft,
hop_length=hop_length,
window=window,
return_complex=True,
)
# move channels back
X = X.view(bs, chs, X.shape[-2], X.shape[-1])
return torch.pow(X.abs() + 1e-8, alpha)
def init_layer(layer):
"""Initialize a Linear or Convolutional layer."""
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, "bias"):
if layer.bias is not None:
layer.bias.data.fill_(0.0)
def init_bn(bn):
"""Initialize a Batchnorm layer."""
bn.bias.data.fill_(0.0)
bn.weight.data.fill_(1.0)
def _ntuple(n: int):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple([x] * n)
return parse
single = _ntuple(1)
def concat_complex(a: torch.tensor, b: torch.tensor, dim: int = 1) -> torch.tensor:
"""
Concatenate two complex tensors in same dimension concept
:param a: complex tensor
:param b: another complex tensor
:param dim: target dimension
:return: concatenated tensor
"""
a_real, a_img = a.chunk(2, dim)
b_real, b_img = b.chunk(2, dim)
return torch.cat([a_real, b_real, a_img, b_img], dim=dim)
def center_crop(x, length: int):
start = (x.shape[-1] - length) // 2
stop = start + length
return x[..., start:stop]
def causal_crop(x, length: int):
stop = x.shape[-1] - 1
start = stop - length
return x[..., start:stop]