emozilla commited on
Commit
2b1c7b3
1 Parent(s): 2010c83

update inference code

Browse files
Files changed (3) hide show
  1. configuration_olmo.py +2 -0
  2. optim.py +769 -0
  3. safetensors_util.py +81 -0
configuration_olmo.py CHANGED
@@ -10,7 +10,9 @@ from .aliases import PathOrStr
10
  from .beam_search import Sampler
11
  from .exceptions import OLMoError
12
  from .initialization import ModuleType
 
13
  from .util import StrEnum
 
14
  from .torch_util import seed_all
15
 
16
  logger = logging.get_logger(__name__)
 
10
  from .beam_search import Sampler
11
  from .exceptions import OLMoError
12
  from .initialization import ModuleType
13
+ from .optim import Optimizer
14
  from .util import StrEnum
15
+ from .safetensors_util import STKey
16
  from .torch_util import seed_all
17
 
18
  logger = logging.get_logger(__name__)
optim.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import ABCMeta, abstractmethod
3
+ from dataclasses import dataclass, replace
4
+ from math import cos, pi, sqrt
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.nn as nn
10
+ from torch.distributed.fsdp import FullyShardedDataParallel
11
+ from torch.optim.optimizer import Optimizer as OptimizerBase
12
+
13
+ from . import LayerNormBase, BitLinear158
14
+ from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig
15
+ from .torch_util import get_default_device, is_distributed
16
+
17
+ __all__ = [
18
+ "Optimizer",
19
+ "LionW",
20
+ "AdamW",
21
+ "Scheduler",
22
+ "CosWithWarmup",
23
+ "LinearWithWarmup",
24
+ "InvSqrtWithWarmup",
25
+ "MaxScheduler",
26
+ "ConstantScheduler",
27
+ "BoltOnWarmupScheduler",
28
+ "build_optimizer",
29
+ "build_scheduler",
30
+ ]
31
+
32
+
33
+ log = logging.getLogger(__name__)
34
+
35
+
36
+ class Optimizer(OptimizerBase):
37
+ def _clean_param_name(self, name: str) -> str:
38
+ return name.replace("_fsdp_wrapped_module.", "")
39
+
40
+ @torch.no_grad()
41
+ def clip_grads_and_collect_metrics(
42
+ self, global_step: int, collect_param_metrics: bool = True
43
+ ) -> Dict[str, torch.Tensor]:
44
+ """
45
+ Clips gradients for every group that has the field `max_grad_norm`.
46
+ At the same time collect metrics for each parameter and its gradient.
47
+ """
48
+ device = get_default_device()
49
+
50
+ # NOTE (epwalsh): during distributed training we're making an assumption that the order of
51
+ # the param groups and the params within each group are the same across all ranks.
52
+ # This is justified since we initialize the parameter groups in every rank by iterating over
53
+ # `module.parameters()` or `module.named_modules()` / `module.named_parameters()`, each of which
54
+ # provides a consistent order.
55
+ # For each parameter (with a gradient) we'll collect:
56
+ # - min, max, avg, norm of the param itself
57
+ # - min, max, avg, norm of the param's gradient
58
+ # - min, max, avg, norm of any additional per-parameter optimizer state metrics returned from
59
+ # `self.get_state_for_param()`.
60
+ # Afterwards we'll reduce these all over all ranks.
61
+ per_param_min_metrics: List[torch.Tensor] = []
62
+ per_param_max_metrics: List[torch.Tensor] = []
63
+ per_param_sum_metrics: List[torch.Tensor] = []
64
+ per_param_norm_metrics: List[torch.Tensor] = []
65
+ per_param_numel_metrics: List[torch.Tensor] = []
66
+
67
+ per_param_min_metric_names: List[str] = []
68
+ per_param_max_metric_names: List[str] = []
69
+ per_param_avg_metric_names: List[str] = []
70
+ per_param_norm_metric_names: List[str] = []
71
+
72
+ # Collect metrics locally.
73
+ for group in self.param_groups:
74
+ if is_distributed():
75
+ # TODO (epwalsh): handle non-sharded params. We don't have any right now but we would
76
+ # with ReLoRa, for example.
77
+ assert group.get("sharded", True) is True
78
+
79
+ for name, p in zip(group["param_names"], group["params"]):
80
+ name = self._clean_param_name(name)
81
+ # Always need to collect the norm of gradients for clipping, even if we're not collecting
82
+ # other metrics.
83
+ tensors: List[Optional[torch.Tensor]] = [p.grad]
84
+ prefixes: List[str] = [f"grad/{name}"]
85
+ if collect_param_metrics:
86
+ state = self.get_state_for_param(p)
87
+ sorted_state_keys = sorted([k for k in state.keys()])
88
+ tensors.extend([p] + [state[key] for key in sorted_state_keys])
89
+ prefixes.extend([f"param/{name}"] + [f"{key}/{name}" for key in sorted_state_keys])
90
+ assert len(tensors) == len(prefixes)
91
+
92
+ # Get min, max, avg, and norm for all `tensors` associated with the parameter.
93
+ for x, prefix in zip(tensors, prefixes):
94
+ # grad or state tensors could be none for params that have their shards completely on
95
+ # other ranks.
96
+ if x is not None and x.numel() > 0:
97
+ if collect_param_metrics:
98
+ x_abs = x.abs()
99
+ per_param_min_metrics.append(x_abs.min().unsqueeze(0).to(dtype=torch.float32))
100
+ per_param_max_metrics.append(x_abs.max().unsqueeze(0).to(dtype=torch.float32))
101
+ per_param_sum_metrics.append(x.sum().unsqueeze(0).to(dtype=torch.float32))
102
+ per_param_numel_metrics.append(
103
+ torch.tensor([x.numel()], device=device, dtype=torch.float32)
104
+ )
105
+ per_param_norm_metrics.append(
106
+ torch.linalg.vector_norm(x, 2.0, dtype=torch.float32).unsqueeze(0)
107
+ )
108
+ else:
109
+ if collect_param_metrics:
110
+ per_param_min_metrics.append(
111
+ torch.tensor([float("inf")], device=device, dtype=torch.float32)
112
+ )
113
+ per_param_max_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
114
+ per_param_sum_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
115
+ per_param_numel_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
116
+ per_param_norm_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
117
+ if collect_param_metrics:
118
+ per_param_min_metric_names.append(f"{prefix}.min")
119
+ per_param_max_metric_names.append(f"{prefix}.max")
120
+ per_param_avg_metric_names.append(f"{prefix}.avg")
121
+ per_param_norm_metric_names.append(f"{prefix}.norm")
122
+
123
+ assert (
124
+ len(per_param_min_metrics)
125
+ == len(per_param_min_metric_names)
126
+ == len(per_param_max_metrics)
127
+ == len(per_param_max_metric_names)
128
+ == len(per_param_sum_metrics)
129
+ == len(per_param_numel_metrics)
130
+ == len(per_param_avg_metric_names)
131
+ )
132
+ assert len(per_param_norm_metrics) == len(per_param_norm_metric_names)
133
+
134
+ def is_grad_norm_metric(metric_name: str) -> bool:
135
+ return metric_name.startswith("grad/") and metric_name.endswith(".norm")
136
+
137
+ # Now reduce metrics over all ranks.
138
+ total_grad_norm: torch.Tensor
139
+ per_param_avg_metrics: List[torch.Tensor] = []
140
+ if is_distributed(): # TODO (epwalsh): skip for non-sharded params
141
+ # Reduce metrics across all ranks. Note that we can use a `reduce` for most cases
142
+ # instead of an `all_reduce`, but we need `all_reduce` for norms so that all ranks
143
+ # get the right value for gradient norms so they can clip correctly.
144
+ # Reduce mins.
145
+ if per_param_min_metrics:
146
+ all_mins = torch.cat(per_param_min_metrics).to(device)
147
+ dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN)
148
+ per_param_min_metrics = all_mins.split(1)
149
+ # Reduce maxs.
150
+ if per_param_max_metrics:
151
+ all_maxs = torch.cat(per_param_max_metrics).to(device)
152
+ dist.reduce(all_maxs, 0, op=dist.ReduceOp.MAX)
153
+ per_param_max_metrics = all_maxs.split(1)
154
+ # Reduce sums or just norms.
155
+ all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0
156
+ if per_param_sum_metrics and per_param_numel_metrics:
157
+ all_sums = torch.cat(per_param_sum_metrics).to(device)
158
+ all_numels = torch.cat(per_param_numel_metrics).to(device)
159
+ all_sums_norms_numels = torch.cat(
160
+ [all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0
161
+ )
162
+ dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM)
163
+ all_sums, all_norms, all_numels = all_sums_norms_numels.split(1)
164
+ # Get averages.
165
+ # NOTE: could get infs for non-rank0 processes but that's okay.
166
+ per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1)
167
+ else:
168
+ dist.all_reduce(all_norms, op=dist.ReduceOp.SUM)
169
+ grad_norm_metric_mask = torch.tensor(
170
+ [float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device
171
+ )
172
+ total_grad_norm = (all_norms * grad_norm_metric_mask).sum() ** 0.5
173
+ per_param_norm_metrics = (all_norms ** (0.5)).squeeze(0).split(1)
174
+ else:
175
+ total_grad_norm = (
176
+ torch.cat(
177
+ [
178
+ m
179
+ for m, n in zip(per_param_norm_metrics, per_param_norm_metric_names)
180
+ if is_grad_norm_metric(n)
181
+ ]
182
+ )
183
+ ** 2.0
184
+ ).sum() ** 0.5
185
+ per_param_avg_metrics = [x / n for x, n in zip(per_param_sum_metrics, per_param_numel_metrics)]
186
+
187
+ assert len(per_param_avg_metrics) == len(per_param_avg_metric_names)
188
+
189
+ # Collect all metrics into a single dict.
190
+ all_metrics: Dict[str, torch.Tensor] = {}
191
+ for metric_name, metric in zip(per_param_min_metric_names, per_param_min_metrics):
192
+ all_metrics[metric_name] = metric.squeeze(0)
193
+ for metric_name, metric in zip(per_param_max_metric_names, per_param_max_metrics):
194
+ all_metrics[metric_name] = metric.squeeze(0)
195
+ for metric_name, metric in zip(per_param_avg_metric_names, per_param_avg_metrics):
196
+ all_metrics[metric_name] = metric.squeeze(0)
197
+ for metric_name, metric in zip(per_param_norm_metric_names, per_param_norm_metrics):
198
+ all_metrics[metric_name] = metric.squeeze(0)
199
+ all_metrics["total_grad_norm"] = total_grad_norm
200
+
201
+ # Clip gradients.
202
+ num_grads_clipped = 0
203
+ num_eligible_grads = 0
204
+ for group in self.param_groups:
205
+ if (max_norm_ratio := group.get("max_grad_norm_ratio")) is not None:
206
+ num_clipped = self._do_adaptive_clipping(
207
+ group, max_norm_ratio, global_step, all_metrics, collect_param_metrics=collect_param_metrics
208
+ )
209
+ elif (max_norm := group.get("max_grad_norm")) is not None:
210
+ num_clipped = self._do_global_fixed_clipping(
211
+ group, max_norm, all_metrics, collect_param_metrics=collect_param_metrics
212
+ )
213
+ else:
214
+ # No clipping needed.
215
+ continue
216
+ num_eligible_grads += len(group["params"])
217
+ if num_clipped is not None:
218
+ num_grads_clipped += num_clipped
219
+
220
+ if collect_param_metrics:
221
+ if num_eligible_grads > 0:
222
+ clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu")
223
+ else:
224
+ clipping_rate = torch.tensor(0.0, device="cpu")
225
+ all_metrics["clipping_rate"] = clipping_rate
226
+ return all_metrics
227
+ else:
228
+ return {}
229
+
230
+ @torch.no_grad()
231
+ def _do_adaptive_clipping(
232
+ self,
233
+ group: Dict[str, Any],
234
+ max_norm_ratio: float,
235
+ global_step: int,
236
+ all_metrics: Dict[str, torch.Tensor],
237
+ collect_param_metrics: bool = True,
238
+ ) -> Optional[int]:
239
+ """
240
+ Do adaptive gradient clipping on a param group.
241
+
242
+ If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
243
+ """
244
+ device = get_default_device()
245
+ num_grads_clipped = 0
246
+ # We'll use the bigger of beta1 and beta2 to update the exponential average of the norm of
247
+ # the gradient (a scalar), not to be confused with the exponential average of the gradient.
248
+ # TODO (epwalsh): handle optimizers that don't have betas.
249
+ beta1, beta2 = group["betas"]
250
+ beta = max(beta1, beta2)
251
+ for name, p in zip(group["param_names"], group["params"]):
252
+ name = self._clean_param_name(name)
253
+ grad_norm = all_metrics.get(f"grad/{name}.norm")
254
+ if grad_norm is None:
255
+ continue
256
+
257
+ # Get or initialize the exponential average of grad norm.
258
+ # TODO: The way we have it right now, every rank tracks the `grad_norm_exp_avg` of every parameter,
259
+ # even parameters for which the corresponding local shard is empty. This has the potential to
260
+ # cause some issues with the optimizer, as we ran into with https://github.com/allenai/LLM/pull/372.
261
+ # So we should consider changing how we do this at some point so that we don't add any state
262
+ # to parameters for which the local shard is empty. That would probably add extra distributed
263
+ # communication, at least on steps where we have to log (i.e. when `collect_param_metrics=True`).
264
+ state = self.state[p]
265
+ grad_norm_exp_avg = state.get("grad_norm_exp_avg")
266
+ if grad_norm_exp_avg is None:
267
+ grad_norm_exp_avg = grad_norm.clone().to(device)
268
+ # We don't want to add anything to `state` until `state` has been initialized, otherwise
269
+ # this will crash some optimizers which rely on checking `len(state)`. The downside here
270
+ # is that we won't start tracking `grad_norm_exp_avg` until the 2nd training step.
271
+ if global_step > 1:
272
+ state["grad_norm_exp_avg"] = grad_norm_exp_avg
273
+
274
+ max_allowed_norm = max_norm_ratio * grad_norm_exp_avg
275
+ clip_coef = max_allowed_norm / (grad_norm + 1e-6)
276
+
277
+ # Clip the gradients and update the exponential average.
278
+ # Note that multiplying by the clamped coefficient is meaningless when it is
279
+ # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
280
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
281
+ if p.grad is not None:
282
+ # p.grad could be none for some ranks when using FSDP.
283
+ p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
284
+
285
+ # Update the exponential average of the norm of the gradient with the clipped norm of the gradient.
286
+ grad_norm_exp_avg.lerp_((grad_norm * clip_coef_clamped).to(grad_norm_exp_avg.device), 1 - beta)
287
+ # Alternative: update with the *unclipped* norm of the gradient.
288
+ # grad_norm_exp_avg.lerp_(grad_norm.to(grad_norm_exp_avg.device), 1 - beta)
289
+
290
+ if collect_param_metrics:
291
+ # Can't avoid host-device sync here.
292
+ if clip_coef_clamped < 1.0:
293
+ num_grads_clipped += 1
294
+ all_metrics[f"grad_norm_exp_avg/{name}"] = grad_norm_exp_avg
295
+ return num_grads_clipped if collect_param_metrics else None
296
+
297
+ @torch.no_grad()
298
+ def _do_global_fixed_clipping(
299
+ self,
300
+ group: Dict[str, Any],
301
+ max_norm: float,
302
+ all_metrics: Dict[str, torch.Tensor],
303
+ collect_param_metrics: bool = True,
304
+ ) -> Optional[int]:
305
+ """
306
+ Do global fixed gradient clipping on a param group.
307
+
308
+ If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
309
+ """
310
+ device = get_default_device()
311
+ total_grad_norm = all_metrics["total_grad_norm"]
312
+ clip_coef = max_norm / (total_grad_norm.to(device) + 1e-6)
313
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
314
+ num_grads_clipped: Optional[int] = None
315
+ if collect_param_metrics:
316
+ # Can't avoid host-device sync here.
317
+ if clip_coef_clamped < 1.0:
318
+ num_grads_clipped = len(group["params"])
319
+ for p in group["params"]:
320
+ # Clip the gradients.
321
+ # Note that multiplying by the clamped coefficient is meaningless when it is
322
+ # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
323
+ if p.grad is not None:
324
+ # p.grad could be none for some ranks when using FSDP.
325
+ p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
326
+ return num_grads_clipped
327
+
328
+ def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
329
+ del module
330
+ return {}
331
+
332
+ def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
333
+ del param
334
+ return {}
335
+
336
+
337
+ class LionW(Optimizer):
338
+ """
339
+ Adapted from https://github.com/google/automl/blob/master/lion/lion_pytorch.py
340
+ """
341
+
342
+ def __init__(
343
+ self,
344
+ params,
345
+ lr: float = 1e-4,
346
+ betas: Tuple[float, float] = (0.9, 0.99),
347
+ weight_decay: float = 0.0,
348
+ ):
349
+ assert lr > 0.0
350
+ assert all([0.0 <= beta <= 1.0 for beta in betas])
351
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
352
+ super().__init__(params, defaults)
353
+ for group in self.param_groups:
354
+ group["initial_lr"] = group["lr"]
355
+ self._update_total_dot_prod: Optional[torch.Tensor] = None
356
+ self._update_total_norm: Optional[torch.Tensor] = None
357
+ self._signed_update_total_norm: Optional[torch.Tensor] = None
358
+
359
+ def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
360
+ update_total_dot_prod = self._update_total_dot_prod
361
+ update_total_norm = self._update_total_norm
362
+ signed_update_total_norm = self._signed_update_total_norm
363
+ if update_total_dot_prod is None or update_total_norm is None or signed_update_total_norm is None:
364
+ return {}
365
+
366
+ if is_distributed() and isinstance(module, FullyShardedDataParallel):
367
+ # Reduce total dot prod and norms across all ranks.
368
+ update_total_norm = update_total_norm**2.0
369
+ signed_update_total_norm = signed_update_total_norm**2.0
370
+ # Reduce all together to avoid multiple communication calls.
371
+ all_together = torch.stack([update_total_dot_prod, update_total_norm, signed_update_total_norm])
372
+ # Only need the final result on rank0, since that's where we log from.
373
+ dist.reduce(all_together, 0)
374
+ update_total_dot_prod, update_total_norm, signed_update_total_norm = all_together
375
+ update_total_norm = update_total_norm**0.5
376
+ signed_update_total_norm = signed_update_total_norm**0.5
377
+
378
+ update_cos_sim = update_total_dot_prod / torch.max(
379
+ update_total_norm * signed_update_total_norm, torch.tensor(1e-8, device=get_default_device())
380
+ )
381
+ return {"update_cos_sim": update_cos_sim}
382
+
383
+ @torch.no_grad()
384
+ def step(self, closure=None) -> None:
385
+ if closure is not None:
386
+ with torch.enable_grad():
387
+ closure()
388
+
389
+ update_total_dot_prod = torch.tensor(0.0, dtype=torch.float32)
390
+ update_norms = []
391
+ signed_update_norms = []
392
+
393
+ for group in self.param_groups:
394
+ for p in group["params"]:
395
+ if p.grad is None:
396
+ continue
397
+
398
+ # Perform step weight decay
399
+ p.data.mul_(1 - group["lr"] * group["weight_decay"])
400
+
401
+ grad = p.grad
402
+ state = self.state[p]
403
+
404
+ # State initialization
405
+ if len(state) == 0:
406
+ # Exponential moving average of gradient values
407
+ state["exp_avg"] = torch.zeros_like(p)
408
+
409
+ exp_avg = state["exp_avg"]
410
+ beta1, beta2 = group["betas"]
411
+
412
+ # Weight update
413
+ update = exp_avg * beta1 + grad * (1 - beta1)
414
+ signed_update = torch.sign(update)
415
+ p.add_(signed_update, alpha=-group["lr"])
416
+
417
+ # Decay the momentum running average coefficient
418
+ exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
419
+
420
+ # Track dot product and norms of update vs signed update in order to calculate
421
+ # their cosine similarity.
422
+ update_total_dot_prod = update_total_dot_prod.to(update.device)
423
+ update_total_dot_prod += torch.tensordot(update, signed_update, dims=len(update.shape))
424
+ update_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32))
425
+ signed_update_norms.append(torch.linalg.vector_norm(signed_update, 2.0, dtype=torch.float32))
426
+
427
+ # Compute cosine similarity between update and signed update.
428
+ self._update_total_dot_prod = update_total_dot_prod.to(get_default_device())
429
+ self._update_total_norm = torch.linalg.vector_norm(
430
+ torch.stack(update_norms),
431
+ 2.0,
432
+ dtype=torch.float32,
433
+ ).to(get_default_device())
434
+ self._signed_update_total_norm = torch.linalg.vector_norm(
435
+ torch.stack(signed_update_norms),
436
+ 2.0,
437
+ dtype=torch.float32,
438
+ ).to(get_default_device())
439
+
440
+
441
+ class AdamW(torch.optim.AdamW, Optimizer):
442
+ def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
443
+ return {key: self.state[param].get(key) for key in ("exp_avg", "exp_avg_sq")} # type: ignore
444
+
445
+
446
+ @dataclass
447
+ class Scheduler(metaclass=ABCMeta):
448
+ # NOTE: these fields are not given default values because otherwise dataclasses complains
449
+ # about how the scheduler subclasses are defined.
450
+ grad_clip_warmup_steps: Optional[int]
451
+ grad_clip_warmup_factor: Optional[float]
452
+
453
+ @abstractmethod
454
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
455
+ raise NotImplementedError
456
+
457
+ def _get_max_grad_norm_coeff(
458
+ self, initial_value: Optional[float], step: int, max_steps: int
459
+ ) -> Optional[float]:
460
+ del max_steps # might need this in the future, but for now I just wanted to match the API of `get_lr()`.
461
+ if initial_value is None:
462
+ return None
463
+ elif (
464
+ self.grad_clip_warmup_steps is None
465
+ or self.grad_clip_warmup_factor is None
466
+ or step > self.grad_clip_warmup_steps
467
+ ):
468
+ return initial_value
469
+ else:
470
+ return self.grad_clip_warmup_factor * initial_value
471
+
472
+ def get_max_grad_norm(
473
+ self, initial_max_grad_norm: Optional[float], step: int, max_steps: int
474
+ ) -> Optional[float]:
475
+ return self._get_max_grad_norm_coeff(initial_max_grad_norm, step, max_steps)
476
+
477
+ def get_max_grad_norm_ratio(
478
+ self, initial_max_grad_norm_ratio: Optional[float], step: int, max_steps: int
479
+ ) -> Optional[float]:
480
+ return self._get_max_grad_norm_coeff(initial_max_grad_norm_ratio, step, max_steps)
481
+
482
+ def _linear_warmup(self, initial_lr: float, step: int, warmup_steps: int = 2000) -> float:
483
+ return initial_lr * (0.1 + 0.9 * min(step, warmup_steps) / warmup_steps)
484
+
485
+
486
+ @dataclass
487
+ class CosWithWarmup(Scheduler):
488
+ warmup_steps: int
489
+ alpha_f: float = 0.1
490
+ t_max: Optional[int] = None
491
+
492
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
493
+ max_steps = max_steps if self.t_max is None else self.t_max
494
+ eta_min = initial_lr * self.alpha_f
495
+ if step < self.warmup_steps:
496
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
497
+ elif step >= max_steps:
498
+ return eta_min
499
+ else:
500
+ step = step - self.warmup_steps
501
+ max_steps = max_steps - self.warmup_steps
502
+ return eta_min + (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2
503
+
504
+
505
+ @dataclass
506
+ class LinearWithWarmup(Scheduler):
507
+ warmup_steps: int
508
+ alpha_f: float = 0.1
509
+ t_max: Optional[int] = None
510
+
511
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
512
+ max_steps = max_steps if self.t_max is None else self.t_max
513
+ eta_min = initial_lr * self.alpha_f
514
+ if step < self.warmup_steps:
515
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
516
+ elif step >= max_steps:
517
+ return eta_min
518
+ else:
519
+ step = step - self.warmup_steps
520
+ max_steps = max_steps - self.warmup_steps
521
+ return initial_lr - (initial_lr - eta_min) * (step / max_steps)
522
+
523
+
524
+ @dataclass
525
+ class InvSqrtWithWarmup(Scheduler):
526
+ warmup_steps: int
527
+
528
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
529
+ if step < self.warmup_steps:
530
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
531
+ del max_steps
532
+ return initial_lr * sqrt(self.warmup_steps / max(self.warmup_steps, step))
533
+
534
+
535
+ @dataclass
536
+ class MaxScheduler(Scheduler):
537
+ sched1: Scheduler
538
+ sched2: Scheduler
539
+
540
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
541
+ return max(
542
+ self.sched1.get_lr(initial_lr, step, max_steps), self.sched2.get_lr(initial_lr, step, max_steps)
543
+ )
544
+
545
+
546
+ @dataclass
547
+ class BoltOnWarmupScheduler(Scheduler):
548
+ inner: Scheduler
549
+ warmup_start: int
550
+ warmup_end: int
551
+
552
+ @classmethod
553
+ def wrap(cls, scheduler: Scheduler, warmup_start: int, warmup_end: int) -> "BoltOnWarmupScheduler":
554
+ return cls(
555
+ grad_clip_warmup_steps=None,
556
+ grad_clip_warmup_factor=None,
557
+ inner=scheduler,
558
+ warmup_start=warmup_start,
559
+ warmup_end=warmup_end,
560
+ )
561
+
562
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
563
+ if step < self.warmup_start:
564
+ return 0.0
565
+ if step < self.warmup_end:
566
+ lr_at_intercept = self.inner.get_lr(initial_lr, self.warmup_end, max_steps)
567
+ return lr_at_intercept * (step - self.warmup_start) / (self.warmup_end - self.warmup_start)
568
+ else:
569
+ return self.inner.get_lr(initial_lr, step, max_steps)
570
+
571
+ def _get_max_grad_norm_coeff(
572
+ self, initial_value: Optional[float], step: int, max_steps: int
573
+ ) -> Optional[float]:
574
+ return self.inner._get_max_grad_norm_coeff(initial_value, step, max_steps)
575
+
576
+
577
+ @dataclass
578
+ class ConstantScheduler(Scheduler):
579
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
580
+ del step, max_steps
581
+ return initial_lr
582
+
583
+
584
+ PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names")
585
+
586
+
587
+ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]:
588
+ """
589
+ Separate parameters into weight decay and non weight decay groups.
590
+ """
591
+ param_groups: List[Dict[str, Any]]
592
+ param_group_defaults = {
593
+ "sharded": isinstance(model, FullyShardedDataParallel),
594
+ "max_grad_norm": cfg.max_grad_norm,
595
+ "max_grad_norm_ratio": cfg.max_grad_norm_ratio,
596
+ }
597
+
598
+ # Separate out parameters that we don't want to apply weight decay to, like norms and biases.
599
+ decay = set()
600
+ no_decay = set()
601
+ all_params = {}
602
+ for mn, m in model.named_modules():
603
+ for pn, p in m.named_parameters():
604
+ # NOTE: because named_modules and named_parameters are recursive
605
+ # we will see the same tensors p many many times, but doing it this way
606
+ # allows us to know which parent module any tensor p belongs to...
607
+ if not p.requires_grad:
608
+ continue
609
+
610
+ fpn = f"{mn}.{pn}" if mn else pn
611
+ all_params[fpn] = p
612
+
613
+ if pn.endswith("bias"):
614
+ if cfg.optimizer.decay_norm_and_bias:
615
+ decay.add(fpn)
616
+ else:
617
+ no_decay.add(fpn)
618
+ elif pn.endswith("weight") and (isinstance(m, nn.Linear) or isinstance(m, BitLinear158)):
619
+ decay.add(fpn)
620
+ elif pn.endswith("weight") and isinstance(m, (LayerNormBase, nn.LayerNorm)):
621
+ if cfg.optimizer.decay_norm_and_bias:
622
+ decay.add(fpn)
623
+ else:
624
+ no_decay.add(fpn)
625
+ elif pn.endswith("weight") and isinstance(m, nn.Embedding):
626
+ if cfg.optimizer.decay_embeddings:
627
+ decay.add(fpn)
628
+ else:
629
+ no_decay.add(fpn)
630
+
631
+ # Validate that we've considered every parameter
632
+ inter_params = decay & no_decay
633
+ union_params = decay | no_decay
634
+ assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!"
635
+ assert (
636
+ len(all_params.keys() - union_params) == 0
637
+ ), f"parameters {all_params.keys() - union_params} were not separated into either decay/no_decay set!"
638
+
639
+ # Create the pytorch optimizer groups.
640
+ decay_sorted = sorted(list(decay))
641
+ no_decay_sorted = sorted(list(no_decay))
642
+ param_groups = []
643
+ if len(decay_sorted) > 0:
644
+ param_groups.append(
645
+ {
646
+ "params": [all_params[pn] for pn in decay_sorted],
647
+ "param_names": decay_sorted,
648
+ **param_group_defaults,
649
+ }
650
+ )
651
+ if len(no_decay_sorted) > 0:
652
+ param_groups.append(
653
+ {
654
+ "params": [all_params[pn] for pn in no_decay_sorted],
655
+ "param_names": no_decay_sorted,
656
+ "weight_decay": 0.0,
657
+ **param_group_defaults,
658
+ }
659
+ )
660
+
661
+ # Validate fields.
662
+ for group in param_groups:
663
+ for key in PARAM_GROUP_FIELDS:
664
+ assert key in group
665
+
666
+ return param_groups
667
+
668
+
669
+ def fix_optim_state_dict(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
670
+ """
671
+ Make sure old optim state dicts are compatible with new versions.
672
+ """
673
+ if len(state_dict["param_groups"]) == 1 and len(optimizer.param_groups) == 2:
674
+ assert optimizer.param_groups[1]["weight_decay"] == 0.0
675
+
676
+ # Decay
677
+ decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"}
678
+ decay_param_group["params"] = optimizer.state_dict()["param_groups"][0]["params"]
679
+
680
+ # No decay.
681
+ no_decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"}
682
+ no_decay_param_group["weight_decay"] = 0.0
683
+ no_decay_param_group["params"] = optimizer.state_dict()["param_groups"][1]["params"]
684
+
685
+ state_dict["param_groups"] = [decay_param_group, no_decay_param_group]
686
+
687
+ assert len(optimizer.param_groups) == len(state_dict["param_groups"])
688
+
689
+ # Make sure:
690
+ # - All required fields are included in the state dict,
691
+ # - And that the values of those fields doesn't change from what's currently set in the optimizer,
692
+ # since we might have changed those fields on purpose after a restart.
693
+ for group, sd_group in zip(optimizer.param_groups, state_dict["param_groups"]):
694
+ for key in PARAM_GROUP_FIELDS:
695
+ sd_group[key] = group[key]
696
+
697
+ return state_dict
698
+
699
+
700
+ def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
701
+ param_groups = get_param_groups(cfg, model)
702
+ log.info(f"Constructing optimizer with {len(param_groups)} param groups")
703
+ if cfg.optimizer.name == OptimizerType.lionw:
704
+ return LionW(
705
+ param_groups,
706
+ lr=cfg.optimizer.learning_rate,
707
+ betas=cfg.optimizer.betas,
708
+ weight_decay=cfg.optimizer.weight_decay,
709
+ )
710
+ elif cfg.optimizer.name == OptimizerType.adamw:
711
+ return AdamW(
712
+ param_groups,
713
+ lr=cfg.optimizer.learning_rate,
714
+ betas=cfg.optimizer.betas,
715
+ weight_decay=cfg.optimizer.weight_decay,
716
+ eps=1e-5,
717
+ )
718
+ else:
719
+ raise NotImplementedError
720
+
721
+
722
+ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = None) -> Scheduler:
723
+ sched_cfg = sched_cfg if sched_cfg is not None else cfg.scheduler
724
+ if sched_cfg.name == SchedulerType.cosine_with_warmup:
725
+ return CosWithWarmup(
726
+ grad_clip_warmup_steps=None
727
+ if sched_cfg.grad_clip_warmup_steps is None
728
+ else int(sched_cfg.grad_clip_warmup_steps),
729
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
730
+ warmup_steps=int(sched_cfg.t_warmup),
731
+ alpha_f=sched_cfg.alpha_f,
732
+ t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
733
+ )
734
+ elif sched_cfg.name == SchedulerType.linear_with_warmup:
735
+ return LinearWithWarmup(
736
+ grad_clip_warmup_steps=None
737
+ if sched_cfg.grad_clip_warmup_steps is None
738
+ else int(sched_cfg.grad_clip_warmup_steps),
739
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
740
+ warmup_steps=int(sched_cfg.t_warmup),
741
+ alpha_f=sched_cfg.alpha_f,
742
+ t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
743
+ )
744
+ elif sched_cfg.name == SchedulerType.inverse_sqrt_with_warmup:
745
+ return InvSqrtWithWarmup(
746
+ grad_clip_warmup_steps=None
747
+ if sched_cfg.grad_clip_warmup_steps is None
748
+ else int(sched_cfg.grad_clip_warmup_steps),
749
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
750
+ warmup_steps=int(sched_cfg.t_warmup),
751
+ )
752
+ elif sched_cfg.name == SchedulerType.max_scheduler:
753
+ return MaxScheduler(
754
+ grad_clip_warmup_steps=None
755
+ if sched_cfg.grad_clip_warmup_steps is None
756
+ else int(sched_cfg.grad_clip_warmup_steps),
757
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
758
+ sched1=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.cosine_with_warmup)),
759
+ sched2=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.inverse_sqrt_with_warmup)),
760
+ )
761
+ elif sched_cfg.name == SchedulerType.constant:
762
+ return ConstantScheduler(
763
+ grad_clip_warmup_steps=None
764
+ if sched_cfg.grad_clip_warmup_steps is None
765
+ else int(sched_cfg.grad_clip_warmup_steps),
766
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
767
+ )
768
+ else:
769
+ raise NotImplementedError
safetensors_util.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import pickle
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Optional, Tuple
5
+
6
+ import safetensors.torch
7
+ import torch
8
+
9
+ from .aliases import PathOrStr
10
+
11
+ __all__ = [
12
+ "state_dict_to_safetensors_file",
13
+ "safetensors_file_to_state_dict",
14
+ ]
15
+
16
+
17
+ @dataclass(eq=True, frozen=True)
18
+ class STKey:
19
+ keys: Tuple
20
+ value_is_pickled: bool
21
+
22
+
23
+ def encode_key(key: STKey) -> str:
24
+ b = pickle.dumps((key.keys, key.value_is_pickled))
25
+ b = base64.urlsafe_b64encode(b)
26
+ return str(b, "ASCII")
27
+
28
+
29
+ def decode_key(key: str) -> STKey:
30
+ b = base64.urlsafe_b64decode(key)
31
+ keys, value_is_pickled = pickle.loads(b)
32
+ return STKey(keys, value_is_pickled)
33
+
34
+
35
+ def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]:
36
+ result = {}
37
+ for key, value in d.items():
38
+ if isinstance(value, torch.Tensor):
39
+ result[STKey((key,), False)] = value
40
+ elif isinstance(value, dict):
41
+ value = flatten_dict(value)
42
+ for inner_key, inner_value in value.items():
43
+ result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value
44
+ else:
45
+ pickled = bytearray(pickle.dumps(value))
46
+ pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8)
47
+ result[STKey((key,), True)] = pickled_tensor
48
+ return result
49
+
50
+
51
+ def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict:
52
+ result: Dict = {}
53
+
54
+ for key, value in d.items():
55
+ if key.value_is_pickled:
56
+ value = pickle.loads(value.numpy().data)
57
+
58
+ target_dict = result
59
+ for k in key.keys[:-1]:
60
+ new_target_dict = target_dict.get(k)
61
+ if new_target_dict is None:
62
+ new_target_dict = {}
63
+ target_dict[k] = new_target_dict
64
+ target_dict = new_target_dict
65
+ target_dict[key.keys[-1]] = value
66
+
67
+ return result
68
+
69
+
70
+ def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr):
71
+ state_dict = flatten_dict(state_dict)
72
+ state_dict = {encode_key(k): v for k, v in state_dict.items()}
73
+ safetensors.torch.save_file(state_dict, filename)
74
+
75
+
76
+ def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict:
77
+ if map_location is None:
78
+ map_location = "cpu"
79
+ state_dict = safetensors.torch.load_file(filename, device=map_location)
80
+ state_dict = {decode_key(k): v for k, v in state_dict.items()}
81
+ return unflatten_dict(state_dict)