Spaces:
Runtime error
Runtime error
import torch | |
_output_ref = None | |
_replicas_ref = None | |
def data_parallel_workaround(model, *input): | |
global _output_ref | |
global _replicas_ref | |
device_ids = list(range(torch.cuda.device_count())) | |
output_device = device_ids[0] | |
replicas = torch.nn.parallel.replicate(model, device_ids) | |
# input.shape = (num_args, batch, ...) | |
inputs = torch.nn.parallel.scatter(input, device_ids) | |
# inputs.shape = (num_gpus, num_args, batch/num_gpus, ...) | |
replicas = replicas[:len(inputs)] | |
outputs = torch.nn.parallel.parallel_apply(replicas, inputs) | |
y_hat = torch.nn.parallel.gather(outputs, output_device) | |
_output_ref = outputs | |
_replicas_ref = replicas | |
return y_hat | |
class ValueWindow(): | |
def __init__(self, window_size=100): | |
self._window_size = window_size | |
self._values = [] | |
def append(self, x): | |
self._values = self._values[-(self._window_size - 1):] + [x] | |
def sum(self): | |
return sum(self._values) | |
def count(self): | |
return len(self._values) | |
def average(self): | |
return self.sum / max(1, self.count) | |
def reset(self): | |
self._values = [] | |