| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from logging import getLogger |
| | import math |
| | import os |
| | from typing import Dict, List, Optional, Union, Tuple |
| | from types import MethodType |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from torch.nn.utils import parametrize |
| |
|
| |
|
| | |
| | class DAMP(nn.Identity): |
| | def __init__(self, std: float): |
| | super().__init__() |
| | self.std = std |
| |
|
| |
|
| | def enable_damp(model: nn.Module, std: float): |
| | if isinstance(model, (list, tuple)): |
| | for m in model: |
| | enable_damp(m, std) |
| | return |
| |
|
| | for name, module in model.named_modules(): |
| | if isinstance(module, nn.Linear): |
| | parametrize.register_parametrization(module, 'weight', DAMP(std)) |
| |
|
| |
|
| | def configure_damp_from_args(model: nn.Module, args): |
| | damp = getattr(args, 'damp', None) |
| | if damp: |
| | enable_damp(model, damp) |
| |
|