| | import torch |
| | import torch.nn |
| | from typing import Dict, Any |
| |
|
| |
|
| | class LayerWithVisualization(torch.nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.visualization_enabled = False |
| |
|
| | def prepare(self): |
| | |
| | pass |
| |
|
| | def plot(self, options: Dict[str, Any]) -> Dict[str, Any]: |
| | raise NotImplementedError() |
| |
|
| |
|
| | class LayerVisualizer: |
| | def __init__(self, module: torch.nn.Module, options: Dict[str, Any] = {}): |
| | self.modules = [] |
| | self.options = options |
| | self.curr_options = None |
| | for n, m in module.named_modules(): |
| | if isinstance(m, LayerWithVisualization): |
| | self.modules.append((n, m)) |
| |
|
| | def plot(self) -> Dict[str, Any]: |
| | res = {} |
| | for n, m in self.modules: |
| | res.update({f"{n}/{k}": v for k, v in m.plot(self.curr_options).items()}) |
| | m.visualization_enabled = False |
| | |
| | self.curr_options = None |
| | return res |
| |
|
| | def prepare(self, options: Dict[str, Any] = {}): |
| | self.curr_options = self.options.copy() |
| | self.curr_options.update(options) |
| |
|
| | for _, m in self.modules: |
| | m.prepare() |
| | m.visualization_enabled = True |
| |
|