# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import errno import functools import hashlib import inspect import io import os import random import socket import tempfile import warnings import zlib from contextlib import contextmanager from diffq import UniformQuantizer, DiffQuantizer import torch as th import tqdm from torch import distributed from torch.nn import functional as F def center_trim(tensor, reference): """ Center trim `tensor` with respect to `reference`, along the last dimension. `reference` can also be a number, representing the length to trim to. If the size difference != 0 mod 2, the extra sample is removed on the right side. """ if hasattr(reference, "size"): reference = reference.size(-1) delta = tensor.size(-1) - reference if delta < 0: raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") if delta: tensor = tensor[..., delta // 2:-(delta - delta // 2)] return tensor def average_metric(metric, count=1.): """ Average `metric` which should be a float across all hosts. `count` should be the weight for this particular host (i.e. number of examples). """ metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) return metric[1].item() / metric[0].item() def free_port(host='', low=20000, high=40000): """ Return a port number that is most likely free. This could suffer from a race condition although it should be quite rare. """ sock = socket.socket() while True: port = random.randint(low, high) try: sock.bind((host, port)) except OSError as error: if error.errno == errno.EADDRINUSE: continue raise return port def sizeof_fmt(num, suffix='B'): """ Given `num` bytes, return human readable size. Taken from https://stackoverflow.com/a/1094933 """ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: if abs(num) < 1024.0: return "%3.1f%s%s" % (num, unit, suffix) num /= 1024.0 return "%.1f%s%s" % (num, 'Yi', suffix) def human_seconds(seconds, display='.2f'): """ Given `seconds` seconds, return human readable duration. """ value = seconds * 1e6 ratios = [1e3, 1e3, 60, 60, 24] names = ['us', 'ms', 's', 'min', 'hrs', 'days'] last = names.pop(0) for name, ratio in zip(names, ratios): if value / ratio < 0.3: break value /= ratio last = name return f"{format(value, display)} {last}" class TensorChunk: def __init__(self, tensor, offset=0, length=None): total_length = tensor.shape[-1] assert offset >= 0 assert offset < total_length if length is None: length = total_length - offset else: length = min(total_length - offset, length) self.tensor = tensor self.offset = offset self.length = length self.device = tensor.device @property def shape(self): shape = list(self.tensor.shape) shape[-1] = self.length return shape def padded(self, target_length): delta = target_length - self.length total_length = self.tensor.shape[-1] assert delta >= 0 start = self.offset - delta // 2 end = start + target_length correct_start = max(0, start) correct_end = min(total_length, end) pad_left = correct_start - start pad_right = end - correct_end out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) assert out.shape[-1] == target_length return out def tensor_chunk(tensor_or_chunk): if isinstance(tensor_or_chunk, TensorChunk): return tensor_or_chunk else: assert isinstance(tensor_or_chunk, th.Tensor) return TensorChunk(tensor_or_chunk) def apply_model(model, mix, shifts=None, split=False, overlap=0.25, transition_power=1., progress=False): """ Apply model to a given mixture. Args: shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec and apply the oppositve shift to the output. This is repeated `shifts` time and all predictions are averaged. This effectively makes the model time equivariant and improves SDR by up to 0.2 points. split (bool): if True, the input will be broken down in 8 seconds extracts and predictions will be performed individually on each and concatenated. Useful for model with large memory footprint like Tasnet. progress (bool): if True, show a progress bar (requires split=True) """ assert transition_power >= 1, "transition_power < 1 leads to weird behavior." device = mix.device channels, length = mix.shape if split: out = th.zeros(len(model.sources), channels, length, device=device) sum_weight = th.zeros(length, device=device) segment = model.segment_length stride = int((1 - overlap) * segment) offsets = range(0, length, stride) scale = stride / model.samplerate if progress: offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') # We start from a triangle shaped weight, with maximal weight in the middle # of the segment. Then we normalize and take to the power `transition_power`. # Large values of transition power will lead to sharper transitions. weight = th.cat([th.arange(1, segment // 2 + 1), th.arange(segment - segment // 2, 0, -1)]).to(device) assert len(weight) == segment # If the overlap < 50%, this will translate to linear transition when # transition_power is 1. weight = (weight / weight.max())**transition_power for offset in offsets: chunk = TensorChunk(mix, offset, segment) chunk_out = apply_model(model, chunk, shifts=shifts) chunk_length = chunk_out.shape[-1] out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out sum_weight[offset:offset + segment] += weight[:chunk_length] offset += segment assert sum_weight.min() > 0 out /= sum_weight return out elif shifts: max_shift = int(0.5 * model.samplerate) mix = tensor_chunk(mix) padded_mix = mix.padded(length + 2 * max_shift) out = 0 for _ in range(shifts): offset = random.randint(0, max_shift) shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) shifted_out = apply_model(model, shifted) out += shifted_out[..., max_shift - offset:] out /= shifts return out else: valid_length = model.valid_length(length) mix = tensor_chunk(mix) padded_mix = mix.padded(valid_length) with th.no_grad(): out = model(padded_mix.unsqueeze(0))[0] return center_trim(out, length) @contextmanager def temp_filenames(count, delete=True): names = [] try: for _ in range(count): names.append(tempfile.NamedTemporaryFile(delete=False).name) yield names finally: if delete: for name in names: os.unlink(name) def get_quantizer(model, args, optimizer=None): quantizer = None if args.diffq: quantizer = DiffQuantizer( model, min_size=args.q_min_size, group_size=8) if optimizer is not None: quantizer.setup_optimizer(optimizer) elif args.qat: quantizer = UniformQuantizer( model, bits=args.qat, min_size=args.q_min_size) return quantizer def load_model(path, strict=False): with warnings.catch_warnings(): warnings.simplefilter("ignore") load_from = path package = th.load(load_from, 'cpu') klass = package["klass"] args = package["args"] kwargs = package["kwargs"] if strict: model = klass(*args, **kwargs) else: sig = inspect.signature(klass) for key in list(kwargs): if key not in sig.parameters: warnings.warn("Dropping inexistant parameter " + key) del kwargs[key] model = klass(*args, **kwargs) state = package["state"] training_args = package["training_args"] quantizer = get_quantizer(model, training_args) set_state(model, quantizer, state) return model def get_state(model, quantizer): if quantizer is None: state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} else: state = quantizer.get_quantized_state() buf = io.BytesIO() th.save(state, buf) state = {'compressed': zlib.compress(buf.getvalue())} return state def set_state(model, quantizer, state): if quantizer is None: model.load_state_dict(state) else: buf = io.BytesIO(zlib.decompress(state["compressed"])) state = th.load(buf, "cpu") quantizer.restore_quantized_state(state) return state def save_state(state, path): buf = io.BytesIO() th.save(state, buf) sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] path = path.parent / (path.stem + "-" + sig + path.suffix) path.write_bytes(buf.getvalue()) def save_model(model, quantizer, training_args, path): args, kwargs = model._init_args_kwargs klass = model.__class__ state = get_state(model, quantizer) save_to = path package = { 'klass': klass, 'args': args, 'kwargs': kwargs, 'state': state, 'training_args': training_args, } th.save(package, save_to) def capture_init(init): @functools.wraps(init) def __init__(self, *args, **kwargs): self._init_args_kwargs = (args, kwargs) init(self, *args, **kwargs) return __init__