|
import os
|
|
import torch
|
|
import torchaudio
|
|
|
|
from functools import wraps
|
|
from types import SimpleNamespace
|
|
from torch.nn import SyncBatchNorm
|
|
from hyperpyyaml import load_hyperpyyaml
|
|
|
|
from torch.nn import DataParallel as DP
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
MAIN_PROC_ONLY = 0
|
|
|
|
def fetch(filename, source):
|
|
return os.path.abspath(os.path.join(source, filename))
|
|
|
|
def run_on_main(func, args=None, kwargs=None, post_func=None, post_args=None, post_kwargs=None, run_post_on_main=False):
|
|
if args is None: args = []
|
|
if kwargs is None: kwargs = {}
|
|
if post_args is None: post_args = []
|
|
if post_kwargs is None: post_kwargs = {}
|
|
|
|
main_process_only(func)(*args, **kwargs)
|
|
ddp_barrier()
|
|
|
|
if post_func is not None:
|
|
if run_post_on_main: post_func(*post_args, **post_kwargs)
|
|
else:
|
|
if not if_main_process(): post_func(*post_args, **post_kwargs)
|
|
ddp_barrier()
|
|
|
|
def is_distributed_initialized():
|
|
return (torch.distributed.is_available() and torch.distributed.is_initialized())
|
|
|
|
def if_main_process():
|
|
if is_distributed_initialized(): return torch.distributed.get_rank() == 0
|
|
else: return True
|
|
|
|
class MainProcessContext:
|
|
def __enter__(self):
|
|
global MAIN_PROC_ONLY
|
|
|
|
MAIN_PROC_ONLY += 1
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
global MAIN_PROC_ONLY
|
|
|
|
MAIN_PROC_ONLY -= 1
|
|
|
|
def main_process_only(function):
|
|
@wraps(function)
|
|
def main_proc_wrapped_func(*args, **kwargs):
|
|
with MainProcessContext():
|
|
return function(*args, **kwargs) if if_main_process() else None
|
|
|
|
return main_proc_wrapped_func
|
|
|
|
def ddp_barrier():
|
|
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized(): return
|
|
|
|
if torch.distributed.get_backend() == torch.distributed.Backend.NCCL: torch.distributed.barrier(device_ids=[torch.cuda.current_device()])
|
|
else: torch.distributed.barrier()
|
|
|
|
class Resample(torch.nn.Module):
|
|
def __init__(self, orig_freq=16000, new_freq=16000, *args, **kwargs):
|
|
super().__init__()
|
|
|
|
self.orig_freq = orig_freq
|
|
self.new_freq = new_freq
|
|
self.resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq, *args, **kwargs)
|
|
|
|
def forward(self, waveforms):
|
|
if self.orig_freq == self.new_freq: return waveforms
|
|
|
|
unsqueezed = False
|
|
if len(waveforms.shape) == 2:
|
|
waveforms = waveforms.unsqueeze(1)
|
|
unsqueezed = True
|
|
elif len(waveforms.shape) == 3: waveforms = waveforms.transpose(1, 2)
|
|
else: raise ValueError
|
|
|
|
self.resampler.to(waveforms.device)
|
|
resampled_waveform = self.resampler(waveforms)
|
|
|
|
return resampled_waveform.squeeze(1) if unsqueezed else resampled_waveform.transpose(1, 2)
|
|
|
|
class AudioNormalizer:
|
|
def __init__(self, sample_rate=16000, mix="avg-to-mono"):
|
|
self.sample_rate = sample_rate
|
|
|
|
if mix not in ["avg-to-mono", "keep"]: raise ValueError
|
|
|
|
self.mix = mix
|
|
self._cached_resamplers = {}
|
|
|
|
def __call__(self, audio, sample_rate):
|
|
if sample_rate not in self._cached_resamplers: self._cached_resamplers[sample_rate] = Resample(sample_rate, self.sample_rate)
|
|
return self._mix(self._cached_resamplers[sample_rate](audio.unsqueeze(0)).squeeze(0))
|
|
|
|
def _mix(self, audio):
|
|
flat_input = audio.dim() == 1
|
|
|
|
if self.mix == "avg-to-mono":
|
|
if flat_input: return audio
|
|
return torch.mean(audio, 1)
|
|
|
|
if self.mix == "keep": return audio
|
|
|
|
class Pretrained(torch.nn.Module):
|
|
HPARAMS_NEEDED, MODULES_NEEDED = [], []
|
|
def __init__(self, modules=None, hparams=None, run_opts=None, freeze_params=True):
|
|
super().__init__()
|
|
|
|
for arg, default in {"device": "cpu", "data_parallel_count": -1, "data_parallel_backend": False, "distributed_launch": False, "distributed_backend": "nccl", "jit": False, "jit_module_keys": None, "compile": False, "compile_module_keys": None, "compile_mode": "reduce-overhead", "compile_using_fullgraph": False, "compile_using_dynamic_shape_tracing": False}.items():
|
|
if run_opts is not None and arg in run_opts: setattr(self, arg, run_opts[arg])
|
|
elif hparams is not None and arg in hparams: setattr(self, arg, hparams[arg])
|
|
else: setattr(self, arg, default)
|
|
|
|
self.mods = torch.nn.ModuleDict(modules)
|
|
|
|
for module in self.mods.values():
|
|
if module is not None: module.to(self.device)
|
|
|
|
if self.HPARAMS_NEEDED and hparams is None: raise ValueError
|
|
|
|
if hparams is not None:
|
|
for hp in self.HPARAMS_NEEDED:
|
|
if hp not in hparams: raise ValueError
|
|
|
|
self.hparams = SimpleNamespace(**hparams)
|
|
|
|
self._prepare_modules(freeze_params)
|
|
self.audio_normalizer = hparams.get("audio_normalizer", AudioNormalizer())
|
|
|
|
def _prepare_modules(self, freeze_params):
|
|
self._compile()
|
|
self._wrap_distributed()
|
|
|
|
if freeze_params:
|
|
self.mods.eval()
|
|
for p in self.mods.parameters():
|
|
p.requires_grad = False
|
|
|
|
def _compile(self):
|
|
compile_available = hasattr(torch, "compile")
|
|
if not compile_available and self.compile_module_keys is not None: raise ValueError
|
|
|
|
compile_module_keys = set()
|
|
if self.compile: compile_module_keys = set(self.mods) if self.compile_module_keys is None else set(self.compile_module_keys)
|
|
|
|
jit_module_keys = set()
|
|
if self.jit: jit_module_keys = set(self.mods) if self.jit_module_keys is None else set(self.jit_module_keys)
|
|
|
|
for name in compile_module_keys | jit_module_keys:
|
|
if name not in self.mods: raise ValueError
|
|
|
|
for name in compile_module_keys:
|
|
try:
|
|
module = torch.compile(self.mods[name], mode=self.compile_mode, fullgraph=self.compile_using_fullgraph, dynamic=self.compile_using_dynamic_shape_tracing)
|
|
except Exception:
|
|
continue
|
|
|
|
self.mods[name] = module.to(self.device)
|
|
jit_module_keys.discard(name)
|
|
|
|
for name in jit_module_keys:
|
|
module = torch.jit.script(self.mods[name])
|
|
self.mods[name] = module.to(self.device)
|
|
|
|
def _compile_jit(self):
|
|
self._compile()
|
|
|
|
def _wrap_distributed(self):
|
|
if not self.distributed_launch and not self.data_parallel_backend: return
|
|
elif self.distributed_launch:
|
|
for name, module in self.mods.items():
|
|
if any(p.requires_grad for p in module.parameters()): self.mods[name] = DDP(SyncBatchNorm.convert_sync_batchnorm(module), device_ids=[self.device])
|
|
else:
|
|
for name, module in self.mods.items():
|
|
if any(p.requires_grad for p in module.parameters()): self.mods[name] = DP(module) if self.data_parallel_count == -1 else DP(module, [i for i in range(self.data_parallel_count)])
|
|
|
|
@classmethod
|
|
def from_hparams(cls, source, hparams_file="hyperparams.yaml", overrides={}, download_only=False, overrides_must_match=True, **kwargs):
|
|
with open(fetch(filename=hparams_file, source=source)) as fin:
|
|
hparams = load_hyperpyyaml(fin, overrides, overrides_must_match=overrides_must_match)
|
|
|
|
pretrainer = hparams.get("pretrainer", None)
|
|
|
|
if pretrainer is not None:
|
|
run_on_main(pretrainer.collect_files, kwargs={"default_source": source})
|
|
if not download_only:
|
|
pretrainer.load_collected()
|
|
return cls(hparams["modules"], hparams, **kwargs)
|
|
else: return cls(hparams["modules"], hparams, **kwargs)
|
|
|
|
class EncoderClassifier(Pretrained):
|
|
MODULES_NEEDED = ["compute_features", "mean_var_norm", "embedding_model", "classifier"]
|
|
|
|
def encode_batch(self, wavs, wav_lens=None, normalize=False):
|
|
if len(wavs.shape) == 1: wavs = wavs.unsqueeze(0)
|
|
if wav_lens is None: wav_lens = torch.ones(wavs.shape[0], device=self.device)
|
|
|
|
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
|
|
wavs = wavs.float()
|
|
|
|
embeddings = self.mods.embedding_model(self.mods.mean_var_norm(self.mods.compute_features(wavs), wav_lens), wav_lens)
|
|
|
|
if normalize: embeddings = self.hparams.mean_var_norm_emb(embeddings, torch.ones(embeddings.shape[0], device=self.device))
|
|
return embeddings
|
|
|
|
def classify_batch(self, wavs, wav_lens=None):
|
|
out_prob = self.mods.classifier(self.encode_batch(wavs, wav_lens)).squeeze(1)
|
|
score, index = torch.max(out_prob, dim=-1)
|
|
|
|
return out_prob, score, index, self.hparams.label_encoder.decode_torch(index)
|
|
|
|
def forward(self, wavs, wav_lens=None):
|
|
return self.classify_batch(wavs, wav_lens) |