unpairedelectron07
commited on
Commit
•
8788873
1
Parent(s):
652ff96
Upload 7 files
Browse files- audiocraft/optim/cosine_lr_scheduler.py +48 -0
- audiocraft/optim/dadam.py +248 -0
- audiocraft/optim/ema.py +85 -0
- audiocraft/optim/fsdp.py +195 -0
- audiocraft/optim/inverse_sqrt_lr_scheduler.py +38 -0
- audiocraft/optim/linear_warmup_lr_scheduler.py +35 -0
- audiocraft/optim/polynomial_decay_lr_scheduler.py +47 -0
audiocraft/optim/cosine_lr_scheduler.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
from torch.optim import Optimizer
|
10 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
11 |
+
|
12 |
+
|
13 |
+
class CosineLRScheduler(_LRScheduler):
|
14 |
+
"""Cosine LR scheduler.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
optimizer (Optimizer): Torch optimizer.
|
18 |
+
warmup_steps (int): Number of warmup steps.
|
19 |
+
total_steps (int): Total number of steps.
|
20 |
+
lr_min_ratio (float): Minimum learning rate.
|
21 |
+
cycle_length (float): Cycle length.
|
22 |
+
"""
|
23 |
+
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
|
24 |
+
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
|
25 |
+
self.warmup_steps = warmup_steps
|
26 |
+
assert self.warmup_steps >= 0
|
27 |
+
self.total_steps = total_steps
|
28 |
+
assert self.total_steps >= 0
|
29 |
+
self.lr_min_ratio = lr_min_ratio
|
30 |
+
self.cycle_length = cycle_length
|
31 |
+
super().__init__(optimizer)
|
32 |
+
|
33 |
+
def _get_sched_lr(self, lr: float, step: int):
|
34 |
+
if step < self.warmup_steps:
|
35 |
+
lr_ratio = step / self.warmup_steps
|
36 |
+
lr = lr_ratio * lr
|
37 |
+
elif step <= self.total_steps:
|
38 |
+
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
39 |
+
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
|
40 |
+
(1. + math.cos(math.pi * s / self.cycle_length))
|
41 |
+
lr = lr_ratio * lr
|
42 |
+
else:
|
43 |
+
lr_ratio = self.lr_min_ratio
|
44 |
+
lr = lr_ratio * lr
|
45 |
+
return lr
|
46 |
+
|
47 |
+
def get_lr(self):
|
48 |
+
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
|
audiocraft/optim/dadam.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from typing import Any
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.optim
|
12 |
+
import torch.distributed as dist
|
13 |
+
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
_params_t = Any
|
17 |
+
|
18 |
+
|
19 |
+
def to_real(x):
|
20 |
+
if torch.is_complex(x):
|
21 |
+
return x.real
|
22 |
+
else:
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
class DAdaptAdam(torch.optim.Optimizer):
|
27 |
+
"""Adam with D-Adaptation automatic step-sizes.
|
28 |
+
Leave LR set to 1 unless you encounter instability.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
params (iterable):
|
32 |
+
Iterable of parameters to optimize or dicts defining parameter groups.
|
33 |
+
lr (float):
|
34 |
+
Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate.
|
35 |
+
betas (tuple[float, float], optional): coefficients used for computing
|
36 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
37 |
+
momentum (float):
|
38 |
+
Momentum value in the range [0,1) (default: 0.9).
|
39 |
+
eps (float):
|
40 |
+
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8).
|
41 |
+
weight_decay (float):
|
42 |
+
Weight decay, i.e. a L2 penalty (default: 0).
|
43 |
+
log_every (int):
|
44 |
+
Log using print every k steps, default 0 (no logging).
|
45 |
+
decouple (boolean):
|
46 |
+
Use AdamW style decoupled weight decay
|
47 |
+
d0 (float):
|
48 |
+
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
49 |
+
growth_rate (float):
|
50 |
+
prevent the D estimate from growing faster than this multiplicative rate.
|
51 |
+
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
|
52 |
+
rate warmup effect.
|
53 |
+
fsdp_in_use (bool):
|
54 |
+
If you're using sharded parameters, this should be set to True. The optimizer
|
55 |
+
will attempt to auto-detect this, but if you're using an implementation other
|
56 |
+
than PyTorch's builtin version, the auto-detection won't work.
|
57 |
+
"""
|
58 |
+
def __init__(self, params, lr=1.0,
|
59 |
+
betas=(0.9, 0.999),
|
60 |
+
eps=1e-8,
|
61 |
+
weight_decay=0,
|
62 |
+
log_every=0,
|
63 |
+
decouple=True,
|
64 |
+
d0=1e-6,
|
65 |
+
growth_rate=float('inf')):
|
66 |
+
if not 0.0 < d0:
|
67 |
+
raise ValueError("Invalid d0 value: {}".format(d0))
|
68 |
+
if not 0.0 < lr:
|
69 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
70 |
+
if not 0.0 < eps:
|
71 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
72 |
+
if not 0.0 <= betas[0] < 1.0:
|
73 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
74 |
+
if not 0.0 <= betas[1] < 1.0:
|
75 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
76 |
+
|
77 |
+
if decouple:
|
78 |
+
logger.info("Using decoupled weight decay")
|
79 |
+
|
80 |
+
from .fsdp import is_fsdp_used
|
81 |
+
fsdp_in_use = is_fsdp_used()
|
82 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
83 |
+
weight_decay=weight_decay,
|
84 |
+
d=d0,
|
85 |
+
k=0,
|
86 |
+
gsq_weighted=0.0,
|
87 |
+
log_every=log_every,
|
88 |
+
decouple=decouple,
|
89 |
+
growth_rate=growth_rate,
|
90 |
+
fsdp_in_use=fsdp_in_use)
|
91 |
+
|
92 |
+
super().__init__(params, defaults)
|
93 |
+
|
94 |
+
@property
|
95 |
+
def supports_memory_efficient_fp16(self):
|
96 |
+
return False
|
97 |
+
|
98 |
+
@property
|
99 |
+
def supports_flat_params(self):
|
100 |
+
return True
|
101 |
+
|
102 |
+
def step(self, closure=None):
|
103 |
+
"""Performs a single optimization step.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
closure (callable, optional): A closure that reevaluates the model
|
107 |
+
and returns the loss.
|
108 |
+
"""
|
109 |
+
loss = None
|
110 |
+
if closure is not None:
|
111 |
+
loss = closure()
|
112 |
+
|
113 |
+
g_sq = 0.0
|
114 |
+
sksq_weighted = 0.0
|
115 |
+
sk_l1 = 0.0
|
116 |
+
|
117 |
+
lr = max(group['lr'] for group in self.param_groups)
|
118 |
+
|
119 |
+
group = self.param_groups[0]
|
120 |
+
gsq_weighted = group['gsq_weighted']
|
121 |
+
d = group['d']
|
122 |
+
dlr = d*lr
|
123 |
+
|
124 |
+
growth_rate = group['growth_rate']
|
125 |
+
decouple = group['decouple']
|
126 |
+
fsdp_in_use = group['fsdp_in_use']
|
127 |
+
log_every = group['log_every']
|
128 |
+
|
129 |
+
beta1, beta2 = group['betas']
|
130 |
+
|
131 |
+
for group in self.param_groups:
|
132 |
+
group_lr = group['lr']
|
133 |
+
decay = group['weight_decay']
|
134 |
+
k = group['k']
|
135 |
+
eps = group['eps']
|
136 |
+
|
137 |
+
if group_lr not in [lr, 0.0]:
|
138 |
+
raise RuntimeError("Setting different lr values in different parameter "
|
139 |
+
"groups is only supported for values of 0")
|
140 |
+
|
141 |
+
for p in group['params']:
|
142 |
+
if p.grad is None:
|
143 |
+
continue
|
144 |
+
if hasattr(p, "_fsdp_flattened"):
|
145 |
+
fsdp_in_use = True
|
146 |
+
grad = p.grad.data
|
147 |
+
|
148 |
+
# Apply weight decay (coupled variant)
|
149 |
+
if decay != 0 and not decouple:
|
150 |
+
grad.add_(p.data, alpha=decay)
|
151 |
+
|
152 |
+
state = self.state[p]
|
153 |
+
|
154 |
+
# State initialization
|
155 |
+
if 'step' not in state:
|
156 |
+
state['step'] = 0
|
157 |
+
state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
|
158 |
+
# Exponential moving average of gradient values
|
159 |
+
state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
|
160 |
+
# Exponential moving average of squared gradient values
|
161 |
+
state['exp_avg_sq'] = torch.zeros_like(
|
162 |
+
to_real(p.data), memory_format=torch.preserve_format).detach()
|
163 |
+
|
164 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
165 |
+
|
166 |
+
grad_grad = to_real(grad * grad.conj())
|
167 |
+
|
168 |
+
# Adam EMA updates
|
169 |
+
if group_lr > 0:
|
170 |
+
exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1))
|
171 |
+
exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2)
|
172 |
+
|
173 |
+
denom = exp_avg_sq.sqrt().add_(eps)
|
174 |
+
|
175 |
+
g_sq += grad_grad.div_(denom).sum().item()
|
176 |
+
|
177 |
+
s = state['s']
|
178 |
+
s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2))
|
179 |
+
sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item()
|
180 |
+
sk_l1 += s.abs().sum().item()
|
181 |
+
|
182 |
+
######
|
183 |
+
|
184 |
+
gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2)
|
185 |
+
d_hat = d
|
186 |
+
|
187 |
+
# if we have not done any progres, return
|
188 |
+
# if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0)
|
189 |
+
if sk_l1 == 0:
|
190 |
+
return loss
|
191 |
+
|
192 |
+
if lr > 0.0:
|
193 |
+
if fsdp_in_use:
|
194 |
+
dist_tensor = torch.zeros(3, device='cuda')
|
195 |
+
dist_tensor[0] = sksq_weighted
|
196 |
+
dist_tensor[1] = gsq_weighted
|
197 |
+
dist_tensor[2] = sk_l1
|
198 |
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
199 |
+
global_sksq_weighted = dist_tensor[0]
|
200 |
+
global_gsq_weighted = dist_tensor[1]
|
201 |
+
global_sk_l1 = dist_tensor[2]
|
202 |
+
else:
|
203 |
+
global_sksq_weighted = sksq_weighted
|
204 |
+
global_gsq_weighted = gsq_weighted
|
205 |
+
global_sk_l1 = sk_l1
|
206 |
+
|
207 |
+
d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1
|
208 |
+
d = max(d, min(d_hat, d*growth_rate))
|
209 |
+
|
210 |
+
if log_every > 0 and k % log_every == 0:
|
211 |
+
logger.info(
|
212 |
+
f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. "
|
213 |
+
f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} "
|
214 |
+
f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}")
|
215 |
+
|
216 |
+
for group in self.param_groups:
|
217 |
+
group['gsq_weighted'] = gsq_weighted
|
218 |
+
group['d'] = d
|
219 |
+
|
220 |
+
group_lr = group['lr']
|
221 |
+
decay = group['weight_decay']
|
222 |
+
k = group['k']
|
223 |
+
eps = group['eps']
|
224 |
+
|
225 |
+
for p in group['params']:
|
226 |
+
if p.grad is None:
|
227 |
+
continue
|
228 |
+
grad = p.grad.data
|
229 |
+
|
230 |
+
state = self.state[p]
|
231 |
+
|
232 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
233 |
+
|
234 |
+
state['step'] += 1
|
235 |
+
|
236 |
+
denom = exp_avg_sq.sqrt().add_(eps)
|
237 |
+
denom = denom.type(p.type())
|
238 |
+
|
239 |
+
# Apply weight decay (decoupled variant)
|
240 |
+
if decay != 0 and decouple and group_lr > 0:
|
241 |
+
p.data.add_(p.data, alpha=-decay * dlr)
|
242 |
+
|
243 |
+
# Take step
|
244 |
+
p.data.addcdiv_(exp_avg, denom, value=-1)
|
245 |
+
|
246 |
+
group['k'] = k + 1
|
247 |
+
|
248 |
+
return loss
|
audiocraft/optim/ema.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# ModelEMA implementation is taken from
|
8 |
+
# https://github.com/facebookresearch/demucs
|
9 |
+
|
10 |
+
from collections import defaultdict
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
|
17 |
+
def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set:
|
18 |
+
names: set = set()
|
19 |
+
for (name, sub_module) in module.named_modules():
|
20 |
+
if name == '':
|
21 |
+
buffer_names = module._non_persistent_buffers_set
|
22 |
+
buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name
|
23 |
+
for buff_name in buffer_names}
|
24 |
+
names.update(buffer_names)
|
25 |
+
else:
|
26 |
+
sub_name = f"{root}.{name}" if len(root) > 0 else name
|
27 |
+
sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name)
|
28 |
+
names.update(sub_buffer_names)
|
29 |
+
return names
|
30 |
+
|
31 |
+
|
32 |
+
def _get_named_tensors(module: nn.Module):
|
33 |
+
non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module)
|
34 |
+
named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers()
|
35 |
+
if name not in non_persistent_buffers_set]
|
36 |
+
named_parameters = list(module.named_parameters())
|
37 |
+
return named_parameters + named_buffers
|
38 |
+
|
39 |
+
|
40 |
+
class ModuleDictEMA:
|
41 |
+
"""Exponential Moving Average over a nn.ModuleDict.
|
42 |
+
|
43 |
+
You can switch to the EMA weights temporarily.
|
44 |
+
"""
|
45 |
+
def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999,
|
46 |
+
unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'):
|
47 |
+
self.decay = decay
|
48 |
+
self.module_dict = module_dict
|
49 |
+
self.state: dict = defaultdict(dict)
|
50 |
+
self.count = 0
|
51 |
+
self.device = device
|
52 |
+
self.unbias = unbias
|
53 |
+
self._init()
|
54 |
+
|
55 |
+
def _init(self):
|
56 |
+
for module_name, module in self.module_dict.items():
|
57 |
+
for key, val in _get_named_tensors(module):
|
58 |
+
if not val.is_floating_point():
|
59 |
+
continue
|
60 |
+
device = self.device or val.device
|
61 |
+
if key not in self.state[module_name]:
|
62 |
+
self.state[module_name][key] = val.detach().to(device, copy=True)
|
63 |
+
|
64 |
+
def step(self):
|
65 |
+
if self.unbias:
|
66 |
+
self.count = self.count * self.decay + 1
|
67 |
+
w = 1 / self.count
|
68 |
+
else:
|
69 |
+
w = 1 - self.decay
|
70 |
+
for module_name, module in self.module_dict.items():
|
71 |
+
for key, val in _get_named_tensors(module):
|
72 |
+
if not val.is_floating_point():
|
73 |
+
continue
|
74 |
+
device = self.device or val.device
|
75 |
+
self.state[module_name][key].mul_(1 - w)
|
76 |
+
self.state[module_name][key].add_(val.detach().to(device), alpha=w)
|
77 |
+
|
78 |
+
def state_dict(self):
|
79 |
+
return {'state': self.state, 'count': self.count}
|
80 |
+
|
81 |
+
def load_state_dict(self, state):
|
82 |
+
self.count = state['count']
|
83 |
+
for module_name, module in state['state'].items():
|
84 |
+
for key, val in module.items():
|
85 |
+
self.state[module_name][key].copy_(val)
|
audiocraft/optim/fsdp.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Wrapper around FSDP for more convenient use in the training loops.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from contextlib import contextmanager
|
12 |
+
import typing as tp
|
13 |
+
import dora
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
17 |
+
from torch.distributed.fsdp import (
|
18 |
+
MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType)
|
19 |
+
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
20 |
+
|
21 |
+
|
22 |
+
def is_fsdp_used() -> bool:
|
23 |
+
"""Return whether we are using FSDP."""
|
24 |
+
# A bit of a hack but should work from anywhere.
|
25 |
+
if dora.is_xp():
|
26 |
+
cfg = dora.get_xp().cfg
|
27 |
+
if hasattr(cfg, 'fsdp'):
|
28 |
+
return cfg.fsdp.use
|
29 |
+
return False
|
30 |
+
|
31 |
+
|
32 |
+
def is_sharded_tensor(x: tp.Any) -> bool:
|
33 |
+
return isinstance(x, ShardedTensor)
|
34 |
+
|
35 |
+
|
36 |
+
@contextmanager
|
37 |
+
def switch_to_full_state_dict(models: tp.List[FSDP]):
|
38 |
+
# Another bug in FSDP makes it that we cannot use the `state_dict_type` API,
|
39 |
+
# so let's do thing manually.
|
40 |
+
for model in models:
|
41 |
+
FSDP.set_state_dict_type( # type: ignore
|
42 |
+
model, StateDictType.FULL_STATE_DICT,
|
43 |
+
FullStateDictConfig(offload_to_cpu=True, rank0_only=True))
|
44 |
+
try:
|
45 |
+
yield
|
46 |
+
finally:
|
47 |
+
for model in models:
|
48 |
+
FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore
|
49 |
+
|
50 |
+
|
51 |
+
def wrap_with_fsdp(cfg, model: torch.nn.Module,
|
52 |
+
block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP:
|
53 |
+
"""Wraps a model with FSDP."""
|
54 |
+
# Some of the typing is disabled until this gets integrated
|
55 |
+
# into the stable version of PyTorch.
|
56 |
+
from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore
|
57 |
+
|
58 |
+
# we import this here to prevent circular import.
|
59 |
+
from ..modules.transformer import StreamingTransformerLayer
|
60 |
+
from ..modules.conditioners import ConditioningProvider
|
61 |
+
|
62 |
+
_fix_post_backward_hook()
|
63 |
+
|
64 |
+
assert cfg.use
|
65 |
+
sharding_strategy_dict = {
|
66 |
+
"no_shard": ShardingStrategy.NO_SHARD,
|
67 |
+
"shard_grad_op": ShardingStrategy.SHARD_GRAD_OP,
|
68 |
+
"full_shard": ShardingStrategy.FULL_SHARD,
|
69 |
+
}
|
70 |
+
|
71 |
+
dtype_dict = {
|
72 |
+
"float32": torch.float32,
|
73 |
+
"float16": torch.float16,
|
74 |
+
"bfloat16": torch.bfloat16,
|
75 |
+
}
|
76 |
+
|
77 |
+
mixed_precision_config = MixedPrecision(
|
78 |
+
param_dtype=dtype_dict[cfg.param_dtype],
|
79 |
+
reduce_dtype=dtype_dict[cfg.reduce_dtype],
|
80 |
+
buffer_dtype=dtype_dict[cfg.buffer_dtype],
|
81 |
+
)
|
82 |
+
|
83 |
+
sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy]
|
84 |
+
# The following is going to require being a bit smart
|
85 |
+
# when doing LM, because this would flush the weights for every time step
|
86 |
+
# during generation. One possiblity is to use hybrid sharding:
|
87 |
+
# See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy
|
88 |
+
assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \
|
89 |
+
"Not supported at the moment, requires a bit more work."
|
90 |
+
|
91 |
+
local_rank = dora.distrib.get_distrib_spec().local_rank
|
92 |
+
assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!"
|
93 |
+
|
94 |
+
auto_wrap_policy = None
|
95 |
+
if block_classes is None:
|
96 |
+
block_classes = {StreamingTransformerLayer, ConditioningProvider}
|
97 |
+
if cfg.per_block:
|
98 |
+
auto_wrap_policy = ModuleWrapPolicy(block_classes)
|
99 |
+
wrapped = _FSDPFixStateDict(
|
100 |
+
model,
|
101 |
+
sharding_strategy=sharding_strategy_config,
|
102 |
+
mixed_precision=mixed_precision_config,
|
103 |
+
device_id=local_rank,
|
104 |
+
sync_module_states=True,
|
105 |
+
use_orig_params=True,
|
106 |
+
auto_wrap_policy=auto_wrap_policy,
|
107 |
+
) # type: ignore
|
108 |
+
FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore
|
109 |
+
|
110 |
+
# Let the wrapped model know about the wrapping!
|
111 |
+
# We use __dict__ to avoid it going into the state dict.
|
112 |
+
# This is a bit dirty, but needed during generation, as otherwise
|
113 |
+
# the wrapped model would call itself and bypass FSDP.
|
114 |
+
for module in FSDP.fsdp_modules(wrapped):
|
115 |
+
original = module._fsdp_wrapped_module
|
116 |
+
original.__dict__['_fsdp'] = module
|
117 |
+
return wrapped
|
118 |
+
|
119 |
+
|
120 |
+
def purge_fsdp(model: FSDP):
|
121 |
+
"""Purge the FSDP cached shard inside the model. This should
|
122 |
+
allow setting the best state or switching to the EMA.
|
123 |
+
"""
|
124 |
+
from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore
|
125 |
+
for module in FSDP.fsdp_modules(model):
|
126 |
+
handles = module._handles
|
127 |
+
if not handles:
|
128 |
+
continue
|
129 |
+
handle = handles[0]
|
130 |
+
unsharded_flat_param = handle._get_padded_unsharded_flat_param()
|
131 |
+
storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore
|
132 |
+
if storage_size == 0:
|
133 |
+
continue
|
134 |
+
true_list = [True for h in handles]
|
135 |
+
_reshard(module, handles, true_list)
|
136 |
+
|
137 |
+
|
138 |
+
class _FSDPFixStateDict(FSDP):
|
139 |
+
@staticmethod
|
140 |
+
def _name_without_fsdp_prefix(name: str) -> str:
|
141 |
+
from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE # type: ignore
|
142 |
+
parts = name.split('.')
|
143 |
+
new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE]
|
144 |
+
return '.'.join(new_parts)
|
145 |
+
|
146 |
+
def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type: ignore
|
147 |
+
state = dict(super().state_dict(*args, **kwargs))
|
148 |
+
for key, value in list(state.items()):
|
149 |
+
if is_sharded_tensor(value):
|
150 |
+
del state[key]
|
151 |
+
return state
|
152 |
+
|
153 |
+
def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore
|
154 |
+
if self._state_dict_type is StateDictType.FULL_STATE_DICT:
|
155 |
+
super().load_state_dict(state)
|
156 |
+
purge_fsdp(self)
|
157 |
+
return
|
158 |
+
# Fix FSDP load state dict in all situation.
|
159 |
+
# Use this only with LOCAL_STATE_DICT !!!
|
160 |
+
current_state = dict(super().state_dict())
|
161 |
+
for key, value in state.items():
|
162 |
+
key = _FSDPFixStateDict._name_without_fsdp_prefix(key)
|
163 |
+
if key not in current_state:
|
164 |
+
# Emulate strict loading manually.
|
165 |
+
raise RuntimeError(f"Unknown state key {key}")
|
166 |
+
current_state[key].copy_(value)
|
167 |
+
|
168 |
+
# Purging cached weights from previous forward.
|
169 |
+
purge_fsdp(self)
|
170 |
+
|
171 |
+
|
172 |
+
_hook_fixed = False
|
173 |
+
|
174 |
+
|
175 |
+
def _fix_post_backward_hook():
|
176 |
+
global _hook_fixed
|
177 |
+
if _hook_fixed:
|
178 |
+
return
|
179 |
+
_hook_fixed = True
|
180 |
+
|
181 |
+
from torch.distributed.fsdp import _runtime_utils
|
182 |
+
from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState
|
183 |
+
old_hook = _runtime_utils._post_backward_hook
|
184 |
+
|
185 |
+
def _post_backward_hook(state, handle, *args, **kwargs):
|
186 |
+
checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False)
|
187 |
+
if checkpointed:
|
188 |
+
# there will be one more forward in the backward with checkpointing and that will
|
189 |
+
# massively confuse FSDP, so we have to make it think everything
|
190 |
+
# is going according to the plan.
|
191 |
+
state.training_state = TrainingState.FORWARD_BACKWARD
|
192 |
+
handle._training_state = HandleTrainingState.BACKWARD_PRE
|
193 |
+
old_hook(state, handle, *args, **kwargs)
|
194 |
+
|
195 |
+
_runtime_utils._post_backward_hook = _post_backward_hook
|
audiocraft/optim/inverse_sqrt_lr_scheduler.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
from torch.optim import Optimizer
|
10 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
11 |
+
|
12 |
+
|
13 |
+
class InverseSquareRootLRScheduler(_LRScheduler):
|
14 |
+
"""Inverse square root LR scheduler.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
optimizer (Optimizer): Torch optimizer.
|
18 |
+
warmup_steps (int): Number of warmup steps.
|
19 |
+
warmup_init_lr (tp.Optional[float]): Initial learning rate
|
20 |
+
during warmup phase. When not set, use the provided learning rate.
|
21 |
+
"""
|
22 |
+
def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
|
23 |
+
self.warmup_steps = warmup_steps
|
24 |
+
self.warmup_init_lr = warmup_init_lr
|
25 |
+
super().__init__(optimizer)
|
26 |
+
|
27 |
+
def _get_sched_lr(self, lr: float, step: int):
|
28 |
+
if step < self.warmup_steps:
|
29 |
+
warmup_init_lr = self.warmup_init_lr or 0
|
30 |
+
lr_step = (lr - warmup_init_lr) / self.warmup_steps
|
31 |
+
lr = warmup_init_lr + step * lr_step
|
32 |
+
else:
|
33 |
+
decay_factor = lr * self.warmup_steps**0.5
|
34 |
+
lr = decay_factor * step**-0.5
|
35 |
+
return lr
|
36 |
+
|
37 |
+
def get_lr(self):
|
38 |
+
return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs]
|
audiocraft/optim/linear_warmup_lr_scheduler.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
from torch.optim import Optimizer
|
10 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
11 |
+
|
12 |
+
|
13 |
+
class LinearWarmupLRScheduler(_LRScheduler):
|
14 |
+
"""Inverse square root LR scheduler.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
optimizer (Optimizer): Torch optimizer.
|
18 |
+
warmup_steps (int): Number of warmup steps.
|
19 |
+
warmup_init_lr (tp.Optional[float]): Initial learning rate
|
20 |
+
during warmup phase. When not set, use the provided learning rate.
|
21 |
+
"""
|
22 |
+
def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
|
23 |
+
self.warmup_steps = warmup_steps
|
24 |
+
self.warmup_init_lr = warmup_init_lr
|
25 |
+
super().__init__(optimizer)
|
26 |
+
|
27 |
+
def _get_sched_lr(self, lr: float, step: int):
|
28 |
+
if step < self.warmup_steps:
|
29 |
+
warmup_init_lr = self.warmup_init_lr or 0
|
30 |
+
lr_step = (lr - warmup_init_lr) / self.warmup_steps
|
31 |
+
lr = warmup_init_lr + step * lr_step
|
32 |
+
return lr
|
33 |
+
|
34 |
+
def get_lr(self):
|
35 |
+
return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
|
audiocraft/optim/polynomial_decay_lr_scheduler.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from torch.optim import Optimizer
|
8 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
9 |
+
|
10 |
+
|
11 |
+
class PolynomialDecayLRScheduler(_LRScheduler):
|
12 |
+
"""Polynomial decay LR scheduler.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
optimizer (Optimizer): Torch optimizer.
|
16 |
+
warmup_steps (int): Number of warmup steps.
|
17 |
+
total_steps (int): Total number of steps.
|
18 |
+
end_lr (float): Final learning rate to achieve over total number of steps.
|
19 |
+
zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0.
|
20 |
+
power (float): Decay exponent.
|
21 |
+
"""
|
22 |
+
def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int,
|
23 |
+
end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.):
|
24 |
+
self.warmup_steps = warmup_steps
|
25 |
+
self.total_steps = total_steps
|
26 |
+
self.end_lr = end_lr
|
27 |
+
self.zero_lr_warmup_steps = zero_lr_warmup_steps
|
28 |
+
self.power = power
|
29 |
+
super().__init__(optimizer)
|
30 |
+
|
31 |
+
def _get_sched_lr(self, lr: float, step: int):
|
32 |
+
if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps:
|
33 |
+
lr = 0
|
34 |
+
elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps:
|
35 |
+
lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps)
|
36 |
+
lr = lr_ratio * lr
|
37 |
+
elif step >= self.total_steps:
|
38 |
+
lr = self.end_lr
|
39 |
+
else:
|
40 |
+
total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps
|
41 |
+
lr_range = lr - self.end_lr
|
42 |
+
pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps)
|
43 |
+
lr = lr_range * pct_remaining ** self.power + self.end_lr
|
44 |
+
return lr
|
45 |
+
|
46 |
+
def get_lr(self):
|
47 |
+
return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
|