import torch _output_ref = None _replicas_ref = None def data_parallel_workaround(model, *input): global _output_ref global _replicas_ref device_ids = list(range(torch.cuda.device_count())) output_device = device_ids[0] replicas = torch.nn.parallel.replicate(model, device_ids) # input.shape = (num_args, batch, ...) inputs = torch.nn.parallel.scatter(input, device_ids) # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...) replicas = replicas[:len(inputs)] outputs = torch.nn.parallel.parallel_apply(replicas, inputs) y_hat = torch.nn.parallel.gather(outputs, output_device) _output_ref = outputs _replicas_ref = replicas return y_hat class ValueWindow(): def __init__(self, window_size=100): self._window_size = window_size self._values = [] def append(self, x): self._values = self._values[-(self._window_size - 1):] + [x] @property def sum(self): return sum(self._values) @property def count(self): return len(self._values) @property def average(self): return self.sum / max(1, self.count) def reset(self): self._values = []