|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from logging import getLogger |
|
import math |
|
import os |
|
from typing import Dict, List, Optional, Union, Tuple |
|
from types import MethodType |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torch.nn.utils import parametrize |
|
|
|
|
|
|
|
class DAMP(nn.Identity): |
|
def __init__(self, std: float): |
|
super().__init__() |
|
self.std = std |
|
|
|
|
|
def enable_damp(model: nn.Module, std: float): |
|
if isinstance(model, (list, tuple)): |
|
for m in model: |
|
enable_damp(m, std) |
|
return |
|
|
|
for name, module in model.named_modules(): |
|
if isinstance(module, nn.Linear): |
|
parametrize.register_parametrization(module, 'weight', DAMP(std)) |
|
|
|
|
|
def configure_damp_from_args(model: nn.Module, args): |
|
damp = getattr(args, 'damp', None) |
|
if damp: |
|
enable_damp(model, damp) |
|
|