| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import random |
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
|
|
| class TensorDiagnosticOptions(object): |
| """Options object for tensor diagnostics: |
| |
| Args: |
| max_eig_dim: |
| The maximum dimension for which we print out eigenvalues |
| (limited for speed reasons). |
| """ |
|
|
| def __init__(self, max_eig_dim: int = 512): |
| self.max_eig_dim = max_eig_dim |
|
|
| def dim_is_summarized(self, size: int): |
| return size > 10 and size != 31 |
|
|
|
|
| def get_tensor_stats( |
| x: Tensor, |
| dim: int, |
| stats_type: str, |
| ) -> Tuple[Tensor, int]: |
| """ |
| Returns the specified transformation of the Tensor (either x or x.abs() |
| or (x > 0), summed over all but the index `dim`. |
| |
| Args: |
| x: |
| Tensor, tensor to be analyzed |
| dim: |
| Dimension with 0 <= dim < x.ndim |
| stats_type: |
| The stats_type includes several types: |
| "abs" -> take abs() before summing |
| "positive" -> take (x > 0) before summing |
| "rms" -> square before summing, we'll take sqrt later |
| "value" -> just sum x itself |
| "max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing |
| "rms-sort" -> this is a bit different than the others, it's based on computing the |
| rms over the specified dim and returning percentiles of the result (11 of them). |
| Returns: |
| stats: a Tensor of shape (x.shape[dim],). |
| count: an integer saying how many items were counted in each element |
| of stats. |
| """ |
|
|
| if stats_type == "rms-sort": |
| rms = (x**2).mean(dim=dim).sqrt() |
| rms = rms.flatten() |
| rms = rms.sort()[0] |
| rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)] |
| count = 1.0 |
| return rms, count |
|
|
| count = x.numel() // x.shape[dim] |
|
|
| if stats_type == "eigs": |
| x = x.transpose(dim, -1) |
| x = x.reshape(-1, x.shape[-1]) |
| |
| |
| return torch.matmul(x.transpose(0, 1), x), count |
| elif stats_type == "abs": |
| x = x.abs() |
| elif stats_type == "rms": |
| x = x**2 |
| elif stats_type == "positive": |
| x = (x > 0).to(dtype=torch.float) |
| else: |
| assert stats_type in ["value", "max", "min"] |
|
|
| sum_dims = [d for d in range(x.ndim) if d != dim] |
| if len(sum_dims) > 0: |
| if stats_type == "max": |
| for dim in reversed(sum_dims): |
| x = torch.max(x, dim=dim)[0] |
| elif stats_type == "min": |
| for dim in reversed(sum_dims): |
| x = torch.min(x, dim=dim)[0] |
| else: |
| x = torch.sum(x, dim=sum_dims) |
| x = x.flatten().clone() |
| return x, count |
|
|
|
|
| @dataclass |
| class TensorAndCount: |
| tensor: Tensor |
| count: int |
|
|
|
|
| class TensorDiagnostic(object): |
| """This class is not directly used by the user, it is responsible for |
| collecting diagnostics for a module or parameter tensor of a torch.nn.Module. |
| |
| Args: |
| opts: |
| Options object. |
| name: |
| The name associated with this diagnostics object, will probably be {module_name}.X |
| where X is "output" or "grad", or {parameter_name}.Y where Y is param_value or param_grad. |
| """ |
|
|
| def __init__(self, opts: TensorDiagnosticOptions, name: str): |
| self.opts = opts |
| self.name = name |
| self.class_name = None |
|
|
| self.stats = None |
| |
| |
| |
| |
|
|
| |
| self.scalar_stats = None |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def accumulate(self, x, class_name: Optional[str] = None): |
| """ |
| Accumulate tensors. |
| """ |
| if class_name is not None: |
| self.class_name = class_name |
| if isinstance(x, Tuple): |
| x = x[0] |
| if not isinstance(x, Tensor): |
| return |
| if x.numel() == 0: |
| return |
| x = x.detach().clone() |
| if x.ndim == 0: |
| x = x.unsqueeze(0) |
| ndim = x.ndim |
| if self.stats is None: |
| self.stats = [dict() for _ in range(ndim)] |
|
|
| for dim in range(ndim): |
| this_dim_stats = self.stats[dim] |
| if ndim > 1: |
| |
| |
| stats_types = [ |
| "abs", |
| "max", |
| "min", |
| "positive", |
| "value", |
| "rms", |
| "rms-sort", |
| ] |
| if x.shape[dim] <= self.opts.max_eig_dim: |
| stats_types.append("eigs") |
| else: |
| stats_types = ["value", "abs", "max", "min"] |
|
|
| for stats_type in stats_types: |
| stats, count = get_tensor_stats(x, dim, stats_type) |
| if stats_type not in this_dim_stats: |
| this_dim_stats[stats_type] = [] |
|
|
| done = False |
| if this_dim_stats[stats_type] is None: |
| |
| |
| |
| continue |
| for s in this_dim_stats[stats_type]: |
| if s.tensor.shape == stats.shape: |
| if stats_type == "max": |
| s.tensor = torch.maximum(s.tensor, stats) |
|
|
| elif stats_type == "min": |
| s.tensor = torch.minimum(s.tensor, stats) |
| else: |
| assert stats_type != "max" |
| s.tensor += stats |
| s.count += count |
| done = True |
| break |
| if not done: |
| if this_dim_stats[stats_type] != [] and stats_type == "eigs": |
| |
| |
| this_dim_stats[stats_type] = None |
| else: |
| this_dim_stats[stats_type].append(TensorAndCount(stats, count)) |
|
|
| def print_diagnostics(self): |
| """Print diagnostics for each dimension of the tensor.""" |
| if self.stats is None: |
| print(f"Warning: the stats of {self.name} is None.") |
| return |
| for dim, this_dim_stats in enumerate(self.stats): |
| if "rms" in this_dim_stats and "value" in this_dim_stats: |
| |
| rms_stats_list = this_dim_stats["rms"] |
| value_stats_list = this_dim_stats["value"] |
| if len(rms_stats_list) == len(value_stats_list): |
| stddev_stats_list = [] |
| for r, v in zip(rms_stats_list, value_stats_list): |
| stddev_stats_list.append( |
| |
| TensorAndCount( |
| r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20), |
| r.count, |
| ) |
| ) |
| this_dim_stats["stddev"] = stddev_stats_list |
|
|
| for stats_type, stats_list in this_dim_stats.items(): |
| |
| |
| |
| if stats_list is None: |
| assert stats_type == "eigs" |
| continue |
|
|
| def get_count(count): |
| return 1 if stats_type in ["max", "min"] else count |
|
|
| if len(stats_list) == 1: |
| stats = stats_list[0].tensor / get_count(stats_list[0].count) |
| else: |
| |
| |
| stats = torch.cat( |
| [x.tensor / get_count(x.count) for x in stats_list], dim=0 |
| ) |
|
|
| if stats_type == "eigs": |
| try: |
| if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): |
| eigs, _ = torch.linalg.eigh(stats) |
| else: |
| eigs, _ = torch.symeig(stats) |
| stats = eigs.abs().sqrt() |
| except: |
| print("Error getting eigenvalues, trying another method.") |
| if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"): |
| eigs, _ = torch.linalg.eig(stats) |
| eigs = eigs.abs() |
| else: |
| eigs, _ = torch.eig(stats) |
| eigs = eigs.norm(dim=1) |
| stats = eigs.sqrt() |
| |
|
|
| if stats_type in ["rms", "stddev"]: |
| |
| stats = stats.sqrt() |
|
|
| |
| |
| summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized( |
| stats.numel() |
| ) |
| if summarize: |
| |
| stats = stats.sort()[0] |
| num_percentiles = 10 |
| size = stats.numel() |
| percentiles = [] |
| for i in range(num_percentiles + 1): |
| index = (i * (size - 1)) // num_percentiles |
| percentiles.append(stats[index].item()) |
| percentiles = ["%.2g" % x for x in percentiles] |
| percentiles = " ".join(percentiles) |
| ans = f"percentiles: [{percentiles}]" |
| else: |
| ans = stats.tolist() |
| ans = ["%.2g" % x for x in ans] |
| ans = "[" + " ".join(ans) + "]" |
| if stats_type in ["value", "rms", "stddev", "eigs"]: |
| |
| |
| |
| |
| norm = (stats**2).sum().sqrt().item() |
| ans += f", norm={norm:.2g}" |
| mean = stats.mean().item() |
| rms = (stats**2).mean().sqrt().item() |
| ans += f", mean={mean:.3g}, rms={rms:.3g}" |
|
|
| |
| |
|
|
| sizes = [x.tensor.shape[0] for x in stats_list] |
| size_str = ( |
| f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" |
| ) |
| maybe_class_name = ( |
| f" type={self.class_name}," if self.class_name is not None else "" |
| ) |
| print( |
| f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" |
| ) |
|
|
|
|
| class ScalarDiagnostic(object): |
| """This class is not directly used by the user, it is responsible for |
| collecting diagnostics for a single module (subclass of torch.nn.Module) that |
| represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc. |
| """ |
|
|
| def __init__(self, opts: TensorDiagnosticOptions, name: str): |
| self.opts = opts |
| self.name = name |
| self.class_name = None |
| self.is_forward_pass = True |
|
|
| self.tick_scale = None |
|
|
| self.saved_inputs = [] |
| self.is_ok = True |
|
|
| self.counts = None |
| self.sum_grad = None |
| self.sum_gradsq = None |
| self.sum_abs_grad = None |
|
|
| def accumulate_input(self, x: Tensor, class_name: Optional[str] = None): |
| """ |
| Called in forward pass. |
| """ |
| if not self.is_forward_pass: |
| |
| self.saved_inputs = [] |
| self.is_forward_pass = True |
|
|
| if class_name is not None: |
| self.class_name = class_name |
| if not self.is_ok: |
| return |
|
|
| limit = 10 |
| if len(self.saved_inputs) > limit: |
| print( |
| f"ERROR: forward pass called for this module over {limit} times with no backward pass. " |
| f" Will not accumulate scalar stats." |
| ) |
| self.is_ok = False |
| return |
| self.saved_inputs.append(x) |
|
|
| def accumulate_output_grad(self, grad: Tensor): |
| if not self.is_ok: |
| return |
| if self.is_forward_pass: |
| self.is_forward_pass = False |
|
|
| last_shape = ( |
| "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape |
| ) |
| if len(self.saved_inputs) == 0 or grad.shape != last_shape: |
| print( |
| f"ERROR: shape mismatch or no forward activation present when backward " |
| f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" |
| f", shape-of-last-saved-input={last_shape}" |
| ) |
| self.is_ok = False |
| return |
|
|
| x = self.saved_inputs.pop() |
| self.process_input_and_grad(x, grad) |
|
|
| def process_input_and_grad(self, x: Tensor, grad: Tensor): |
| assert x.shape == grad.shape |
| x = x.flatten() |
| grad = grad.flatten() |
|
|
| num_ticks_per_side = 256 |
|
|
| if self.tick_scale is None: |
| x_abs_sorted = x.abs().sort()[0] |
| |
| index = int(x.numel() * 0.98) |
| self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side) |
|
|
| |
| self.counts = torch.zeros( |
| 2 * num_ticks_per_side, dtype=torch.long, device=x.device |
| ) |
| self.sum_grad = torch.zeros( |
| 2 * num_ticks_per_side, dtype=torch.double, device=x.device |
| ) |
| |
| self.sum_gradsq = torch.zeros( |
| 2 * num_ticks_per_side, dtype=torch.double, device=x.device |
| ) |
| self.sum_abs_grad = torch.zeros( |
| 2 * num_ticks_per_side, dtype=torch.double, device=x.device |
| ) |
|
|
| |
| x = (x / self.tick_scale).to(torch.long) |
| x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1) |
| x = x + num_ticks_per_side |
|
|
| self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x)) |
| self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double)) |
| self.sum_gradsq.index_add_( |
| dim=0, index=x, source=(grad * grad).to(torch.double) |
| ) |
| self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double)) |
|
|
| def print_diagnostics(self): |
| """Print diagnostics.""" |
| if self.is_ok is False or self.counts is None: |
| print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}") |
| return |
|
|
| counts = self.counts.to("cpu") |
| sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32) |
| sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32) |
| sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32) |
|
|
| counts_cumsum = counts.cumsum(dim=0) |
| counts_tot = counts_cumsum[-1] |
|
|
| |
| |
| num_bins = 20 |
|
|
| |
| counts_per_bin = (counts_tot // num_bins) + 1 |
| bin_indexes = counts_cumsum // counts_per_bin |
| bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long) |
|
|
| bin_counts = torch.zeros(num_bins, dtype=torch.long) |
| bin_counts.index_add_(dim=0, index=bin_indexes, source=counts) |
| bin_grad = torch.zeros(num_bins) |
| bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad) |
| bin_gradsq = torch.zeros(num_bins) |
| bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq) |
| bin_abs_grad = torch.zeros(num_bins) |
| bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad) |
|
|
| avg_grad = bin_grad / bin_counts |
| avg_grad_stddev = (bin_gradsq / bin_counts).sqrt() |
|
|
| bin_boundary_counts = ( |
| torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin |
| ) |
| bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts) |
| |
| |
| num_ticks_per_side = counts.numel() // 2 |
| bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale |
|
|
| bin_grad = bin_grad / (bin_counts + 1) |
| bin_conf_interval = bin_gradsq.sqrt() / ( |
| bin_counts + 1 |
| ) |
| |
| |
| bin_abs_grad = bin_abs_grad / (bin_counts + 1) |
|
|
| bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20) |
| bin_conf = bin_grad / (bin_conf_interval + 1.0e-20) |
|
|
| def tensor_to_str(x: Tensor): |
| x = ["%.2g" % f for f in x] |
| x = "[" + " ".join(x) + "]" |
| return x |
|
|
| maybe_class_name = ( |
| f" type={self.class_name}," if self.class_name is not None else "" |
| ) |
|
|
| print( |
| f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, " |
| f"rel_grad={tensor_to_str(bin_rel_grad)}, grad_conf={tensor_to_str(bin_conf)}" |
| ) |
|
|
|
|
| class ModelDiagnostic(object): |
| """This class stores diagnostics for all tensors in the torch.nn.Module. |
| |
| Args: |
| opts: |
| Options object. |
| """ |
|
|
| def __init__(self, opts: Optional[TensorDiagnosticOptions] = None): |
| |
| |
| if opts is None: |
| self.opts = TensorDiagnosticOptions() |
| else: |
| self.opts = opts |
| self.diagnostics = dict() |
|
|
| def __getitem__(self, name: str): |
| T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic |
| if name not in self.diagnostics: |
| self.diagnostics[name] = T(self.opts, name) |
| return self.diagnostics[name] |
|
|
| def print_diagnostics(self): |
| """Print diagnostics for each tensor.""" |
| for k in sorted(self.diagnostics.keys()): |
| self.diagnostics[k].print_diagnostics() |
|
|
|
|
| def get_class_name(module: nn.Module): |
| ans = type(module).__name__ |
| |
| |
| if ans == "Balancer" or ans == "ActivationBalancer": |
| try: |
| ans += f"[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]" |
| except: |
| pass |
| elif ans == "AbsValuePenalizer": |
| try: |
| ans += f"[{module.limit}]" |
| except: |
| pass |
| return ans |
|
|
|
|
| def attach_diagnostics( |
| model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None |
| ) -> ModelDiagnostic: |
| """Attach a ModelDiagnostic object to the model by |
| 1) registering forward hook and backward hook on each module, to accumulate |
| its output tensors and gradient tensors, respectively; |
| 2) registering backward hook on each module parameter, to accumulate its |
| values and gradients. |
| |
| Args: |
| model: |
| the model to be analyzed. |
| opts: |
| Options object. |
| |
| Returns: |
| The ModelDiagnostic object attached to the model. |
| """ |
|
|
| ans = ModelDiagnostic(opts) |
| for name, module in model.named_modules(): |
| if name == "": |
| name = "<top-level>" |
|
|
| |
| |
| |
| |
| |
| def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): |
| if isinstance(_output, tuple) and len(_output) == 1: |
| _output = _output[0] |
|
|
| if isinstance(_output, Tensor) and _output.dtype in ( |
| torch.float32, |
| torch.float16, |
| torch.float64, |
| ): |
| _model_diagnostic[f"{_name}.output"].accumulate( |
| _output, class_name=get_class_name(_module) |
| ) |
| elif isinstance(_output, tuple): |
| for i, o in enumerate(_output): |
| if isinstance(o, Tensor) and o.dtype in ( |
| torch.float32, |
| torch.float16, |
| torch.float64, |
| ): |
| _model_diagnostic[f"{_name}.output[{i}]"].accumulate( |
| o, class_name=get_class_name(_module) |
| ) |
|
|
| def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): |
| if isinstance(_output, tuple) and len(_output) == 1: |
| _output = _output[0] |
| if isinstance(_output, Tensor) and _output.dtype in ( |
| torch.float32, |
| torch.float16, |
| torch.float64, |
| ): |
| _model_diagnostic[f"{_name}.grad"].accumulate( |
| _output, class_name=get_class_name(_module) |
| ) |
| elif isinstance(_output, tuple): |
| for i, o in enumerate(_output): |
| if isinstance(o, Tensor) and o.dtype in ( |
| torch.float32, |
| torch.float16, |
| torch.float64, |
| ): |
| _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( |
| o, class_name=get_class_name(_module) |
| ) |
|
|
| module.register_forward_hook(forward_hook) |
| if hasattr(module, "register_full_backward_hook"): |
| module.register_full_backward_hook(backward_hook) |
| else: |
| module.register_backward_hook(backward_hook) |
|
|
| if type(module).__name__ in [ |
| "Sigmoid", |
| "Tanh", |
| "ReLU", |
| "TanSwish", |
| "Swish", |
| "DoubleSwish", |
| "Swoosh", |
| ]: |
| |
| |
| |
| |
| def scalar_forward_hook( |
| _module, _input, _output, _model_diagnostic=ans, _name=name |
| ): |
| if isinstance(_input, tuple): |
| (_input,) = _input |
| assert isinstance(_input, Tensor) |
| _model_diagnostic[f"{_name}.scalar"].accumulate_input( |
| _input, class_name=get_class_name(_module) |
| ) |
|
|
| def scalar_backward_hook( |
| _module, _input, _output, _model_diagnostic=ans, _name=name |
| ): |
| if isinstance(_output, tuple): |
| (_output,) = _output |
| assert isinstance(_output, Tensor) |
| _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) |
|
|
| module.register_forward_hook(scalar_forward_hook) |
| if hasattr(module, "register_full_backward_hook"): |
| module.register_full_backward_hook(scalar_backward_hook) |
| else: |
| module.register_backward_hook(scalar_backward_hook) |
|
|
| for name, parameter in model.named_parameters(): |
|
|
| def param_backward_hook( |
| grad, _parameter=parameter, _model_diagnostic=ans, _name=name |
| ): |
| _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) |
| _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) |
|
|
| try: |
| parameter.register_hook(param_backward_hook) |
| except: |
| logging.warning( |
| f"Warning: could not register backward hook for parameter {name}, " |
| f"it might not be differentiable." |
| ) |
|
|
| return ans |
|
|
|
|
| def _test_tensor_diagnostic(): |
| opts = TensorDiagnosticOptions(512) |
|
|
| diagnostic = TensorDiagnostic(opts, "foo") |
|
|
| for _ in range(10): |
| diagnostic.accumulate(torch.randn(50, 100) * 10.0) |
|
|
| diagnostic.print_diagnostics() |
|
|
| model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80)) |
|
|
| diagnostic = attach_diagnostics(model, opts) |
| for _ in range(10): |
| T = random.randint(200, 300) |
| x = torch.randn(T, 100) |
| y = model(x) |
| y.sum().backward() |
|
|
| diagnostic.print_diagnostics() |
|
|
|
|
| if __name__ == "__main__": |
| _test_tensor_diagnostic() |
|
|