unpairedelectron07 commited on
Commit
8788873
1 Parent(s): 652ff96

Upload 7 files

Browse files
audiocraft/optim/cosine_lr_scheduler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ from torch.optim import Optimizer
10
+ from torch.optim.lr_scheduler import _LRScheduler
11
+
12
+
13
+ class CosineLRScheduler(_LRScheduler):
14
+ """Cosine LR scheduler.
15
+
16
+ Args:
17
+ optimizer (Optimizer): Torch optimizer.
18
+ warmup_steps (int): Number of warmup steps.
19
+ total_steps (int): Total number of steps.
20
+ lr_min_ratio (float): Minimum learning rate.
21
+ cycle_length (float): Cycle length.
22
+ """
23
+ def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
24
+ lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
25
+ self.warmup_steps = warmup_steps
26
+ assert self.warmup_steps >= 0
27
+ self.total_steps = total_steps
28
+ assert self.total_steps >= 0
29
+ self.lr_min_ratio = lr_min_ratio
30
+ self.cycle_length = cycle_length
31
+ super().__init__(optimizer)
32
+
33
+ def _get_sched_lr(self, lr: float, step: int):
34
+ if step < self.warmup_steps:
35
+ lr_ratio = step / self.warmup_steps
36
+ lr = lr_ratio * lr
37
+ elif step <= self.total_steps:
38
+ s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
39
+ lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
40
+ (1. + math.cos(math.pi * s / self.cycle_length))
41
+ lr = lr_ratio * lr
42
+ else:
43
+ lr_ratio = self.lr_min_ratio
44
+ lr = lr_ratio * lr
45
+ return lr
46
+
47
+ def get_lr(self):
48
+ return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
audiocraft/optim/dadam.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from typing import Any
9
+
10
+ import torch
11
+ import torch.optim
12
+ import torch.distributed as dist
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+ _params_t = Any
17
+
18
+
19
+ def to_real(x):
20
+ if torch.is_complex(x):
21
+ return x.real
22
+ else:
23
+ return x
24
+
25
+
26
+ class DAdaptAdam(torch.optim.Optimizer):
27
+ """Adam with D-Adaptation automatic step-sizes.
28
+ Leave LR set to 1 unless you encounter instability.
29
+
30
+ Args:
31
+ params (iterable):
32
+ Iterable of parameters to optimize or dicts defining parameter groups.
33
+ lr (float):
34
+ Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate.
35
+ betas (tuple[float, float], optional): coefficients used for computing
36
+ running averages of gradient and its square (default: (0.9, 0.999))
37
+ momentum (float):
38
+ Momentum value in the range [0,1) (default: 0.9).
39
+ eps (float):
40
+ Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8).
41
+ weight_decay (float):
42
+ Weight decay, i.e. a L2 penalty (default: 0).
43
+ log_every (int):
44
+ Log using print every k steps, default 0 (no logging).
45
+ decouple (boolean):
46
+ Use AdamW style decoupled weight decay
47
+ d0 (float):
48
+ Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
49
+ growth_rate (float):
50
+ prevent the D estimate from growing faster than this multiplicative rate.
51
+ Default is inf, for unrestricted. Values like 1.02 give a kind of learning
52
+ rate warmup effect.
53
+ fsdp_in_use (bool):
54
+ If you're using sharded parameters, this should be set to True. The optimizer
55
+ will attempt to auto-detect this, but if you're using an implementation other
56
+ than PyTorch's builtin version, the auto-detection won't work.
57
+ """
58
+ def __init__(self, params, lr=1.0,
59
+ betas=(0.9, 0.999),
60
+ eps=1e-8,
61
+ weight_decay=0,
62
+ log_every=0,
63
+ decouple=True,
64
+ d0=1e-6,
65
+ growth_rate=float('inf')):
66
+ if not 0.0 < d0:
67
+ raise ValueError("Invalid d0 value: {}".format(d0))
68
+ if not 0.0 < lr:
69
+ raise ValueError("Invalid learning rate: {}".format(lr))
70
+ if not 0.0 < eps:
71
+ raise ValueError("Invalid epsilon value: {}".format(eps))
72
+ if not 0.0 <= betas[0] < 1.0:
73
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
74
+ if not 0.0 <= betas[1] < 1.0:
75
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
76
+
77
+ if decouple:
78
+ logger.info("Using decoupled weight decay")
79
+
80
+ from .fsdp import is_fsdp_used
81
+ fsdp_in_use = is_fsdp_used()
82
+ defaults = dict(lr=lr, betas=betas, eps=eps,
83
+ weight_decay=weight_decay,
84
+ d=d0,
85
+ k=0,
86
+ gsq_weighted=0.0,
87
+ log_every=log_every,
88
+ decouple=decouple,
89
+ growth_rate=growth_rate,
90
+ fsdp_in_use=fsdp_in_use)
91
+
92
+ super().__init__(params, defaults)
93
+
94
+ @property
95
+ def supports_memory_efficient_fp16(self):
96
+ return False
97
+
98
+ @property
99
+ def supports_flat_params(self):
100
+ return True
101
+
102
+ def step(self, closure=None):
103
+ """Performs a single optimization step.
104
+
105
+ Args:
106
+ closure (callable, optional): A closure that reevaluates the model
107
+ and returns the loss.
108
+ """
109
+ loss = None
110
+ if closure is not None:
111
+ loss = closure()
112
+
113
+ g_sq = 0.0
114
+ sksq_weighted = 0.0
115
+ sk_l1 = 0.0
116
+
117
+ lr = max(group['lr'] for group in self.param_groups)
118
+
119
+ group = self.param_groups[0]
120
+ gsq_weighted = group['gsq_weighted']
121
+ d = group['d']
122
+ dlr = d*lr
123
+
124
+ growth_rate = group['growth_rate']
125
+ decouple = group['decouple']
126
+ fsdp_in_use = group['fsdp_in_use']
127
+ log_every = group['log_every']
128
+
129
+ beta1, beta2 = group['betas']
130
+
131
+ for group in self.param_groups:
132
+ group_lr = group['lr']
133
+ decay = group['weight_decay']
134
+ k = group['k']
135
+ eps = group['eps']
136
+
137
+ if group_lr not in [lr, 0.0]:
138
+ raise RuntimeError("Setting different lr values in different parameter "
139
+ "groups is only supported for values of 0")
140
+
141
+ for p in group['params']:
142
+ if p.grad is None:
143
+ continue
144
+ if hasattr(p, "_fsdp_flattened"):
145
+ fsdp_in_use = True
146
+ grad = p.grad.data
147
+
148
+ # Apply weight decay (coupled variant)
149
+ if decay != 0 and not decouple:
150
+ grad.add_(p.data, alpha=decay)
151
+
152
+ state = self.state[p]
153
+
154
+ # State initialization
155
+ if 'step' not in state:
156
+ state['step'] = 0
157
+ state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
158
+ # Exponential moving average of gradient values
159
+ state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
160
+ # Exponential moving average of squared gradient values
161
+ state['exp_avg_sq'] = torch.zeros_like(
162
+ to_real(p.data), memory_format=torch.preserve_format).detach()
163
+
164
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
165
+
166
+ grad_grad = to_real(grad * grad.conj())
167
+
168
+ # Adam EMA updates
169
+ if group_lr > 0:
170
+ exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1))
171
+ exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2)
172
+
173
+ denom = exp_avg_sq.sqrt().add_(eps)
174
+
175
+ g_sq += grad_grad.div_(denom).sum().item()
176
+
177
+ s = state['s']
178
+ s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2))
179
+ sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item()
180
+ sk_l1 += s.abs().sum().item()
181
+
182
+ ######
183
+
184
+ gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2)
185
+ d_hat = d
186
+
187
+ # if we have not done any progres, return
188
+ # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0)
189
+ if sk_l1 == 0:
190
+ return loss
191
+
192
+ if lr > 0.0:
193
+ if fsdp_in_use:
194
+ dist_tensor = torch.zeros(3, device='cuda')
195
+ dist_tensor[0] = sksq_weighted
196
+ dist_tensor[1] = gsq_weighted
197
+ dist_tensor[2] = sk_l1
198
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
199
+ global_sksq_weighted = dist_tensor[0]
200
+ global_gsq_weighted = dist_tensor[1]
201
+ global_sk_l1 = dist_tensor[2]
202
+ else:
203
+ global_sksq_weighted = sksq_weighted
204
+ global_gsq_weighted = gsq_weighted
205
+ global_sk_l1 = sk_l1
206
+
207
+ d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1
208
+ d = max(d, min(d_hat, d*growth_rate))
209
+
210
+ if log_every > 0 and k % log_every == 0:
211
+ logger.info(
212
+ f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. "
213
+ f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} "
214
+ f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}")
215
+
216
+ for group in self.param_groups:
217
+ group['gsq_weighted'] = gsq_weighted
218
+ group['d'] = d
219
+
220
+ group_lr = group['lr']
221
+ decay = group['weight_decay']
222
+ k = group['k']
223
+ eps = group['eps']
224
+
225
+ for p in group['params']:
226
+ if p.grad is None:
227
+ continue
228
+ grad = p.grad.data
229
+
230
+ state = self.state[p]
231
+
232
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
233
+
234
+ state['step'] += 1
235
+
236
+ denom = exp_avg_sq.sqrt().add_(eps)
237
+ denom = denom.type(p.type())
238
+
239
+ # Apply weight decay (decoupled variant)
240
+ if decay != 0 and decouple and group_lr > 0:
241
+ p.data.add_(p.data, alpha=-decay * dlr)
242
+
243
+ # Take step
244
+ p.data.addcdiv_(exp_avg, denom, value=-1)
245
+
246
+ group['k'] = k + 1
247
+
248
+ return loss
audiocraft/optim/ema.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # ModelEMA implementation is taken from
8
+ # https://github.com/facebookresearch/demucs
9
+
10
+ from collections import defaultdict
11
+ import typing as tp
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+
17
+ def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set:
18
+ names: set = set()
19
+ for (name, sub_module) in module.named_modules():
20
+ if name == '':
21
+ buffer_names = module._non_persistent_buffers_set
22
+ buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name
23
+ for buff_name in buffer_names}
24
+ names.update(buffer_names)
25
+ else:
26
+ sub_name = f"{root}.{name}" if len(root) > 0 else name
27
+ sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name)
28
+ names.update(sub_buffer_names)
29
+ return names
30
+
31
+
32
+ def _get_named_tensors(module: nn.Module):
33
+ non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module)
34
+ named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers()
35
+ if name not in non_persistent_buffers_set]
36
+ named_parameters = list(module.named_parameters())
37
+ return named_parameters + named_buffers
38
+
39
+
40
+ class ModuleDictEMA:
41
+ """Exponential Moving Average over a nn.ModuleDict.
42
+
43
+ You can switch to the EMA weights temporarily.
44
+ """
45
+ def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999,
46
+ unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'):
47
+ self.decay = decay
48
+ self.module_dict = module_dict
49
+ self.state: dict = defaultdict(dict)
50
+ self.count = 0
51
+ self.device = device
52
+ self.unbias = unbias
53
+ self._init()
54
+
55
+ def _init(self):
56
+ for module_name, module in self.module_dict.items():
57
+ for key, val in _get_named_tensors(module):
58
+ if not val.is_floating_point():
59
+ continue
60
+ device = self.device or val.device
61
+ if key not in self.state[module_name]:
62
+ self.state[module_name][key] = val.detach().to(device, copy=True)
63
+
64
+ def step(self):
65
+ if self.unbias:
66
+ self.count = self.count * self.decay + 1
67
+ w = 1 / self.count
68
+ else:
69
+ w = 1 - self.decay
70
+ for module_name, module in self.module_dict.items():
71
+ for key, val in _get_named_tensors(module):
72
+ if not val.is_floating_point():
73
+ continue
74
+ device = self.device or val.device
75
+ self.state[module_name][key].mul_(1 - w)
76
+ self.state[module_name][key].add_(val.detach().to(device), alpha=w)
77
+
78
+ def state_dict(self):
79
+ return {'state': self.state, 'count': self.count}
80
+
81
+ def load_state_dict(self, state):
82
+ self.count = state['count']
83
+ for module_name, module in state['state'].items():
84
+ for key, val in module.items():
85
+ self.state[module_name][key].copy_(val)
audiocraft/optim/fsdp.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Wrapper around FSDP for more convenient use in the training loops.
9
+ """
10
+
11
+ from contextlib import contextmanager
12
+ import typing as tp
13
+ import dora
14
+ import torch
15
+
16
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
17
+ from torch.distributed.fsdp import (
18
+ MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType)
19
+ from torch.distributed._shard.sharded_tensor.api import ShardedTensor
20
+
21
+
22
+ def is_fsdp_used() -> bool:
23
+ """Return whether we are using FSDP."""
24
+ # A bit of a hack but should work from anywhere.
25
+ if dora.is_xp():
26
+ cfg = dora.get_xp().cfg
27
+ if hasattr(cfg, 'fsdp'):
28
+ return cfg.fsdp.use
29
+ return False
30
+
31
+
32
+ def is_sharded_tensor(x: tp.Any) -> bool:
33
+ return isinstance(x, ShardedTensor)
34
+
35
+
36
+ @contextmanager
37
+ def switch_to_full_state_dict(models: tp.List[FSDP]):
38
+ # Another bug in FSDP makes it that we cannot use the `state_dict_type` API,
39
+ # so let's do thing manually.
40
+ for model in models:
41
+ FSDP.set_state_dict_type( # type: ignore
42
+ model, StateDictType.FULL_STATE_DICT,
43
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True))
44
+ try:
45
+ yield
46
+ finally:
47
+ for model in models:
48
+ FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore
49
+
50
+
51
+ def wrap_with_fsdp(cfg, model: torch.nn.Module,
52
+ block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP:
53
+ """Wraps a model with FSDP."""
54
+ # Some of the typing is disabled until this gets integrated
55
+ # into the stable version of PyTorch.
56
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore
57
+
58
+ # we import this here to prevent circular import.
59
+ from ..modules.transformer import StreamingTransformerLayer
60
+ from ..modules.conditioners import ConditioningProvider
61
+
62
+ _fix_post_backward_hook()
63
+
64
+ assert cfg.use
65
+ sharding_strategy_dict = {
66
+ "no_shard": ShardingStrategy.NO_SHARD,
67
+ "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP,
68
+ "full_shard": ShardingStrategy.FULL_SHARD,
69
+ }
70
+
71
+ dtype_dict = {
72
+ "float32": torch.float32,
73
+ "float16": torch.float16,
74
+ "bfloat16": torch.bfloat16,
75
+ }
76
+
77
+ mixed_precision_config = MixedPrecision(
78
+ param_dtype=dtype_dict[cfg.param_dtype],
79
+ reduce_dtype=dtype_dict[cfg.reduce_dtype],
80
+ buffer_dtype=dtype_dict[cfg.buffer_dtype],
81
+ )
82
+
83
+ sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy]
84
+ # The following is going to require being a bit smart
85
+ # when doing LM, because this would flush the weights for every time step
86
+ # during generation. One possiblity is to use hybrid sharding:
87
+ # See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy
88
+ assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \
89
+ "Not supported at the moment, requires a bit more work."
90
+
91
+ local_rank = dora.distrib.get_distrib_spec().local_rank
92
+ assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!"
93
+
94
+ auto_wrap_policy = None
95
+ if block_classes is None:
96
+ block_classes = {StreamingTransformerLayer, ConditioningProvider}
97
+ if cfg.per_block:
98
+ auto_wrap_policy = ModuleWrapPolicy(block_classes)
99
+ wrapped = _FSDPFixStateDict(
100
+ model,
101
+ sharding_strategy=sharding_strategy_config,
102
+ mixed_precision=mixed_precision_config,
103
+ device_id=local_rank,
104
+ sync_module_states=True,
105
+ use_orig_params=True,
106
+ auto_wrap_policy=auto_wrap_policy,
107
+ ) # type: ignore
108
+ FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore
109
+
110
+ # Let the wrapped model know about the wrapping!
111
+ # We use __dict__ to avoid it going into the state dict.
112
+ # This is a bit dirty, but needed during generation, as otherwise
113
+ # the wrapped model would call itself and bypass FSDP.
114
+ for module in FSDP.fsdp_modules(wrapped):
115
+ original = module._fsdp_wrapped_module
116
+ original.__dict__['_fsdp'] = module
117
+ return wrapped
118
+
119
+
120
+ def purge_fsdp(model: FSDP):
121
+ """Purge the FSDP cached shard inside the model. This should
122
+ allow setting the best state or switching to the EMA.
123
+ """
124
+ from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore
125
+ for module in FSDP.fsdp_modules(model):
126
+ handles = module._handles
127
+ if not handles:
128
+ continue
129
+ handle = handles[0]
130
+ unsharded_flat_param = handle._get_padded_unsharded_flat_param()
131
+ storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore
132
+ if storage_size == 0:
133
+ continue
134
+ true_list = [True for h in handles]
135
+ _reshard(module, handles, true_list)
136
+
137
+
138
+ class _FSDPFixStateDict(FSDP):
139
+ @staticmethod
140
+ def _name_without_fsdp_prefix(name: str) -> str:
141
+ from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE # type: ignore
142
+ parts = name.split('.')
143
+ new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE]
144
+ return '.'.join(new_parts)
145
+
146
+ def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type: ignore
147
+ state = dict(super().state_dict(*args, **kwargs))
148
+ for key, value in list(state.items()):
149
+ if is_sharded_tensor(value):
150
+ del state[key]
151
+ return state
152
+
153
+ def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore
154
+ if self._state_dict_type is StateDictType.FULL_STATE_DICT:
155
+ super().load_state_dict(state)
156
+ purge_fsdp(self)
157
+ return
158
+ # Fix FSDP load state dict in all situation.
159
+ # Use this only with LOCAL_STATE_DICT !!!
160
+ current_state = dict(super().state_dict())
161
+ for key, value in state.items():
162
+ key = _FSDPFixStateDict._name_without_fsdp_prefix(key)
163
+ if key not in current_state:
164
+ # Emulate strict loading manually.
165
+ raise RuntimeError(f"Unknown state key {key}")
166
+ current_state[key].copy_(value)
167
+
168
+ # Purging cached weights from previous forward.
169
+ purge_fsdp(self)
170
+
171
+
172
+ _hook_fixed = False
173
+
174
+
175
+ def _fix_post_backward_hook():
176
+ global _hook_fixed
177
+ if _hook_fixed:
178
+ return
179
+ _hook_fixed = True
180
+
181
+ from torch.distributed.fsdp import _runtime_utils
182
+ from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState
183
+ old_hook = _runtime_utils._post_backward_hook
184
+
185
+ def _post_backward_hook(state, handle, *args, **kwargs):
186
+ checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False)
187
+ if checkpointed:
188
+ # there will be one more forward in the backward with checkpointing and that will
189
+ # massively confuse FSDP, so we have to make it think everything
190
+ # is going according to the plan.
191
+ state.training_state = TrainingState.FORWARD_BACKWARD
192
+ handle._training_state = HandleTrainingState.BACKWARD_PRE
193
+ old_hook(state, handle, *args, **kwargs)
194
+
195
+ _runtime_utils._post_backward_hook = _post_backward_hook
audiocraft/optim/inverse_sqrt_lr_scheduler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ from torch.optim import Optimizer
10
+ from torch.optim.lr_scheduler import _LRScheduler
11
+
12
+
13
+ class InverseSquareRootLRScheduler(_LRScheduler):
14
+ """Inverse square root LR scheduler.
15
+
16
+ Args:
17
+ optimizer (Optimizer): Torch optimizer.
18
+ warmup_steps (int): Number of warmup steps.
19
+ warmup_init_lr (tp.Optional[float]): Initial learning rate
20
+ during warmup phase. When not set, use the provided learning rate.
21
+ """
22
+ def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
23
+ self.warmup_steps = warmup_steps
24
+ self.warmup_init_lr = warmup_init_lr
25
+ super().__init__(optimizer)
26
+
27
+ def _get_sched_lr(self, lr: float, step: int):
28
+ if step < self.warmup_steps:
29
+ warmup_init_lr = self.warmup_init_lr or 0
30
+ lr_step = (lr - warmup_init_lr) / self.warmup_steps
31
+ lr = warmup_init_lr + step * lr_step
32
+ else:
33
+ decay_factor = lr * self.warmup_steps**0.5
34
+ lr = decay_factor * step**-0.5
35
+ return lr
36
+
37
+ def get_lr(self):
38
+ return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs]
audiocraft/optim/linear_warmup_lr_scheduler.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ from torch.optim import Optimizer
10
+ from torch.optim.lr_scheduler import _LRScheduler
11
+
12
+
13
+ class LinearWarmupLRScheduler(_LRScheduler):
14
+ """Inverse square root LR scheduler.
15
+
16
+ Args:
17
+ optimizer (Optimizer): Torch optimizer.
18
+ warmup_steps (int): Number of warmup steps.
19
+ warmup_init_lr (tp.Optional[float]): Initial learning rate
20
+ during warmup phase. When not set, use the provided learning rate.
21
+ """
22
+ def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
23
+ self.warmup_steps = warmup_steps
24
+ self.warmup_init_lr = warmup_init_lr
25
+ super().__init__(optimizer)
26
+
27
+ def _get_sched_lr(self, lr: float, step: int):
28
+ if step < self.warmup_steps:
29
+ warmup_init_lr = self.warmup_init_lr or 0
30
+ lr_step = (lr - warmup_init_lr) / self.warmup_steps
31
+ lr = warmup_init_lr + step * lr_step
32
+ return lr
33
+
34
+ def get_lr(self):
35
+ return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
audiocraft/optim/polynomial_decay_lr_scheduler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch.optim import Optimizer
8
+ from torch.optim.lr_scheduler import _LRScheduler
9
+
10
+
11
+ class PolynomialDecayLRScheduler(_LRScheduler):
12
+ """Polynomial decay LR scheduler.
13
+
14
+ Args:
15
+ optimizer (Optimizer): Torch optimizer.
16
+ warmup_steps (int): Number of warmup steps.
17
+ total_steps (int): Total number of steps.
18
+ end_lr (float): Final learning rate to achieve over total number of steps.
19
+ zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0.
20
+ power (float): Decay exponent.
21
+ """
22
+ def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int,
23
+ end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.):
24
+ self.warmup_steps = warmup_steps
25
+ self.total_steps = total_steps
26
+ self.end_lr = end_lr
27
+ self.zero_lr_warmup_steps = zero_lr_warmup_steps
28
+ self.power = power
29
+ super().__init__(optimizer)
30
+
31
+ def _get_sched_lr(self, lr: float, step: int):
32
+ if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps:
33
+ lr = 0
34
+ elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps:
35
+ lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps)
36
+ lr = lr_ratio * lr
37
+ elif step >= self.total_steps:
38
+ lr = self.end_lr
39
+ else:
40
+ total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps
41
+ lr_range = lr - self.end_lr
42
+ pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps)
43
+ lr = lr_range * pct_remaining ** self.power + self.end_lr
44
+ return lr
45
+
46
+ def get_lr(self):
47
+ return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]