|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
import contextlib |
|
import numpy as np |
|
import torch |
|
import warnings |
|
import sys |
|
|
|
sys.path.insert(1, os.path.join(sys.path[0], "..")) |
|
import stylegan2_ada_pytorch.dnnlib as dnnlib |
|
|
|
|
|
|
|
|
|
|
|
_constant_cache = dict() |
|
|
|
|
|
def constant(value, shape=None, dtype=None, device=None, memory_format=None): |
|
value = np.asarray(value) |
|
if shape is not None: |
|
shape = tuple(shape) |
|
if dtype is None: |
|
dtype = torch.get_default_dtype() |
|
if device is None: |
|
device = torch.device("cpu") |
|
if memory_format is None: |
|
memory_format = torch.contiguous_format |
|
|
|
key = ( |
|
value.shape, |
|
value.dtype, |
|
value.tobytes(), |
|
shape, |
|
dtype, |
|
device, |
|
memory_format, |
|
) |
|
tensor = _constant_cache.get(key, None) |
|
if tensor is None: |
|
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) |
|
if shape is not None: |
|
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) |
|
tensor = tensor.contiguous(memory_format=memory_format) |
|
_constant_cache[key] = tensor |
|
return tensor |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
nan_to_num = torch.nan_to_num |
|
except AttributeError: |
|
|
|
def nan_to_num( |
|
input, nan=0.0, posinf=None, neginf=None, *, out=None |
|
): |
|
assert isinstance(input, torch.Tensor) |
|
if posinf is None: |
|
posinf = torch.finfo(input.dtype).max |
|
if neginf is None: |
|
neginf = torch.finfo(input.dtype).min |
|
assert nan == 0 |
|
return torch.clamp( |
|
input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out |
|
) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
symbolic_assert = torch._assert |
|
except AttributeError: |
|
symbolic_assert = torch.Assert |
|
|
|
|
|
|
|
|
|
|
|
class suppress_tracer_warnings(warnings.catch_warnings): |
|
def __enter__(self): |
|
super().__enter__() |
|
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) |
|
return self |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_shape(tensor, ref_shape): |
|
if tensor.ndim != len(ref_shape): |
|
raise AssertionError( |
|
f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}" |
|
) |
|
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): |
|
if ref_size is None: |
|
pass |
|
elif isinstance(ref_size, torch.Tensor): |
|
with suppress_tracer_warnings(): |
|
symbolic_assert( |
|
torch.equal(torch.as_tensor(size), ref_size), |
|
f"Wrong size for dimension {idx}", |
|
) |
|
elif isinstance(size, torch.Tensor): |
|
with suppress_tracer_warnings(): |
|
symbolic_assert( |
|
torch.equal(size, torch.as_tensor(ref_size)), |
|
f"Wrong size for dimension {idx}: expected {ref_size}", |
|
) |
|
elif size != ref_size: |
|
raise AssertionError( |
|
f"Wrong size for dimension {idx}: got {size}, expected {ref_size}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def profiled_function(fn): |
|
def decorator(*args, **kwargs): |
|
with torch.autograd.profiler.record_function(fn.__name__): |
|
return fn(*args, **kwargs) |
|
|
|
decorator.__name__ = fn.__name__ |
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InfiniteSampler(torch.utils.data.Sampler): |
|
def __init__( |
|
self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5 |
|
): |
|
assert len(dataset) > 0 |
|
assert num_replicas > 0 |
|
assert 0 <= rank < num_replicas |
|
assert 0 <= window_size <= 1 |
|
super().__init__(dataset) |
|
self.dataset = dataset |
|
self.rank = rank |
|
self.num_replicas = num_replicas |
|
self.shuffle = shuffle |
|
self.seed = seed |
|
self.window_size = window_size |
|
|
|
def __iter__(self): |
|
order = np.arange(len(self.dataset)) |
|
rnd = None |
|
window = 0 |
|
if self.shuffle: |
|
rnd = np.random.RandomState(self.seed) |
|
rnd.shuffle(order) |
|
window = int(np.rint(order.size * self.window_size)) |
|
|
|
idx = 0 |
|
while True: |
|
i = idx % order.size |
|
if idx % self.num_replicas == self.rank: |
|
yield order[i] |
|
if window >= 2: |
|
j = (i - rnd.randint(window)) % order.size |
|
order[i], order[j] = order[j], order[i] |
|
idx += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
def params_and_buffers(module): |
|
assert isinstance(module, torch.nn.Module) |
|
return list(module.parameters()) + list(module.buffers()) |
|
|
|
|
|
def named_params_and_buffers(module): |
|
assert isinstance(module, torch.nn.Module) |
|
return list(module.named_parameters()) + list(module.named_buffers()) |
|
|
|
|
|
def copy_params_and_buffers(src_module, dst_module, require_all=False): |
|
assert isinstance(src_module, torch.nn.Module) |
|
assert isinstance(dst_module, torch.nn.Module) |
|
src_tensors = { |
|
name: tensor for name, tensor in named_params_and_buffers(src_module) |
|
} |
|
for name, tensor in named_params_and_buffers(dst_module): |
|
assert (name in src_tensors) or (not require_all) |
|
if name in src_tensors: |
|
tensor.copy_(src_tensors[name].detach()).requires_grad_( |
|
tensor.requires_grad |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
def ddp_sync(module, sync): |
|
assert isinstance(module, torch.nn.Module) |
|
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): |
|
yield |
|
else: |
|
with module.no_sync(): |
|
yield |
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_ddp_consistency(module, ignore_regex=None): |
|
assert isinstance(module, torch.nn.Module) |
|
for name, tensor in named_params_and_buffers(module): |
|
fullname = type(module).__name__ + "." + name |
|
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): |
|
continue |
|
tensor = tensor.detach() |
|
other = tensor.clone() |
|
torch.distributed.broadcast(tensor=other, src=0) |
|
assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname |
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): |
|
assert isinstance(module, torch.nn.Module) |
|
assert not isinstance(module, torch.jit.ScriptModule) |
|
assert isinstance(inputs, (tuple, list)) |
|
|
|
|
|
entries = [] |
|
nesting = [0] |
|
|
|
def pre_hook(_mod, _inputs): |
|
nesting[0] += 1 |
|
|
|
def post_hook(mod, _inputs, outputs): |
|
nesting[0] -= 1 |
|
if nesting[0] <= max_nesting: |
|
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] |
|
outputs = [t for t in outputs if isinstance(t, torch.Tensor)] |
|
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) |
|
|
|
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] |
|
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] |
|
|
|
|
|
outputs = module(*inputs) |
|
for hook in hooks: |
|
hook.remove() |
|
|
|
|
|
tensors_seen = set() |
|
for e in entries: |
|
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] |
|
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] |
|
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] |
|
tensors_seen |= { |
|
id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs |
|
} |
|
|
|
|
|
if skip_redundant: |
|
entries = [ |
|
e |
|
for e in entries |
|
if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs) |
|
] |
|
|
|
|
|
rows = [ |
|
[type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"] |
|
] |
|
rows += [["---"] * len(rows[0])] |
|
param_total = 0 |
|
buffer_total = 0 |
|
submodule_names = {mod: name for name, mod in module.named_modules()} |
|
for e in entries: |
|
name = "<top-level>" if e.mod is module else submodule_names[e.mod] |
|
param_size = sum(t.numel() for t in e.unique_params) |
|
buffer_size = sum(t.numel() for t in e.unique_buffers) |
|
output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] |
|
output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs] |
|
rows += [ |
|
[ |
|
name + (":0" if len(e.outputs) >= 2 else ""), |
|
str(param_size) if param_size else "-", |
|
str(buffer_size) if buffer_size else "-", |
|
(output_shapes + ["-"])[0], |
|
(output_dtypes + ["-"])[0], |
|
] |
|
] |
|
for idx in range(1, len(e.outputs)): |
|
rows += [ |
|
[name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]] |
|
] |
|
param_total += param_size |
|
buffer_total += buffer_size |
|
rows += [["---"] * len(rows[0])] |
|
rows += [["Total", str(param_total), str(buffer_total), "-", "-"]] |
|
|
|
|
|
widths = [max(len(cell) for cell in column) for column in zip(*rows)] |
|
print() |
|
for row in rows: |
|
print( |
|
" ".join( |
|
cell + " " * (width - len(cell)) for cell, width in zip(row, widths) |
|
) |
|
) |
|
print() |
|
return outputs |
|
|
|
|
|
|
|
|