pt-sk commited on
Commit
18fc02b
1 Parent(s): b4bd23e

Upload 4 files

Browse files
ppo/core_commented.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import gc
15
+ import random
16
+ import warnings
17
+ from contextlib import contextmanager
18
+ from typing import Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.nn.utils.rnn import pad_sequence
25
+ from transformers import top_k_top_p_filtering
26
+
27
+ from .import_utils import is_npu_available, is_xpu_available
28
+
29
+
30
+ try:
31
+ from collections.abc import Mapping
32
+ except ImportError:
33
+ from collections.abc import Mapping
34
+
35
+
36
+ WANDB_PADDING = -1
37
+
38
+
39
+ def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
40
+ """Flatten dictionary and concatenate nested keys with separator."""
41
+
42
+ def recurse(nest: Dict, prefix: str, into: Dict) -> None:
43
+ for k, v in nest.items():
44
+ if sep in k:
45
+ raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
46
+ if isinstance(v, Mapping):
47
+ recurse(v, prefix + k + sep, into)
48
+ else:
49
+ into[prefix + k] = v
50
+
51
+ flat = {}
52
+ recurse(nested, "", flat)
53
+ return flat
54
+
55
+
56
+ def convert_to_scalar(stats: Dict) -> Dict:
57
+ """
58
+ Converts the stats from a flattened dict to single scalar dicts
59
+ """
60
+ tensorboard_stats = {}
61
+ for k, v in stats.items():
62
+ # for tensorboard compatibility - arrays and tensors are ignored with tensorboard
63
+ # therefore we convert single element tensors to scalars
64
+ if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and (
65
+ len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1)
66
+ ):
67
+ v = v.item()
68
+ tensorboard_stats[k] = v
69
+ return tensorboard_stats
70
+
71
+
72
+ def stack_dicts(stats_dicts: List[Dict]) -> Dict:
73
+ """Stack the values of a dict."""
74
+ results = dict()
75
+ for k in stats_dicts[0]:
76
+ stats_list = [torch.flatten(d[k]) for d in stats_dicts]
77
+ results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)
78
+ return results
79
+
80
+
81
+ def add_suffix(input_dict: Dict, suffix: str) -> Dict:
82
+ """Add suffix to dict keys."""
83
+ return {k + suffix: v for k, v in input_dict.items()}
84
+
85
+
86
+ def pad_to_size(tensor: torch.Tensor, size: int, dim: int = 1, padding: int = 50256) -> torch.Tensor:
87
+ """Pad tensor to size."""
88
+ t_size = tensor.size()[dim]
89
+ if t_size == size:
90
+ return tensor
91
+ else:
92
+ return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding)
93
+
94
+
95
+ def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
96
+ """
97
+ See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
98
+ """
99
+ logp = F.log_softmax(logits, dim=2)
100
+
101
+ if not gather:
102
+ return logp
103
+ logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
104
+ return logpy
105
+
106
+
107
+ def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
108
+ """Whiten values."""
109
+ mean, var = torch.mean(values), torch.var(values)
110
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
111
+ if not shift_mean:
112
+ whitened += mean
113
+ return whitened
114
+
115
+
116
+ def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
117
+ """Compute mean of tensor with a masked values."""
118
+ if axis is not None:
119
+ return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
120
+ else:
121
+ return (values * mask).sum() / mask.sum()
122
+
123
+
124
+ def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
125
+ """Compute variance of tensor with masked values."""
126
+ mean = masked_mean(values, mask)
127
+ centered_values = values - mean
128
+ variance = masked_mean(centered_values**2, mask)
129
+ if unbiased:
130
+ mask_sum = mask.sum()
131
+ if mask_sum == 0:
132
+ raise ValueError(
133
+ "The sum of the mask is zero, which can happen when `mini_batch_size=1`;"
134
+ "try increase the `mini_batch_size` or `gradient_accumulation_steps`"
135
+ )
136
+ # note that if mask_sum == 1, then there is a division by zero issue
137
+ # to avoid it you just need to use a larger minibatch_size
138
+ bessel_correction = mask_sum / (mask_sum - 1)
139
+ variance = variance * bessel_correction
140
+ return variance
141
+
142
+
143
+ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
144
+ """Whiten values with masked values."""
145
+ mean, var = masked_mean(values, mask), masked_var(values, mask)
146
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
147
+ if not shift_mean:
148
+ whitened += mean
149
+ return whitened
150
+
151
+
152
+ def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor:
153
+ """
154
+ Tensor extension to torch.clamp
155
+ https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
156
+ """
157
+ clipped = torch.max(torch.min(x, tensor_max), tensor_min)
158
+ return clipped
159
+
160
+
161
+ def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
162
+ """Calculate entropy from logits."""
163
+ # More info here:
164
+ # 1) Wikipedia: "The convex conjugate of LogSumExp is the negative entropy." - https://en.wikipedia.org/wiki/LogSumExp
165
+ # 2) https://math.stackexchange.com/questions/2614316/conjugate-function-of-log-sum-exp
166
+ # 3) The Log-Sum-Exp Trick - https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
167
+ pd = torch.nn.functional.softmax(logits, dim=-1)
168
+ entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
169
+ return entropy
170
+
171
+
172
+ def average_torch_dicts(list_of_dicts: List[Dict]) -> Dict:
173
+ """Average values of a list of dicts with torch tensors."""
174
+ average_dict = dict()
175
+ for key in list_of_dicts[0].keys():
176
+ average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0)
177
+ return average_dict
178
+
179
+
180
+ def stats_to_np(stats_dict: Dict) -> Dict:
181
+ """Cast all torch.tensors in dict to numpy arrays."""
182
+ new_dict = dict()
183
+ for k, v in stats_dict.items():
184
+ if isinstance(v, torch.Tensor):
185
+ new_dict[k] = v.detach().cpu()
186
+ if new_dict[k].dtype == torch.bfloat16:
187
+ new_dict[k] = new_dict[k].float()
188
+ new_dict[k] = new_dict[k].numpy()
189
+ else:
190
+ new_dict[k] = v
191
+ if np.isscalar(new_dict[k]):
192
+ new_dict[k] = float(new_dict[k])
193
+ return new_dict
194
+
195
+
196
+ def respond_to_batch(
197
+ model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
198
+ ) -> torch.LongTensor:
199
+ """Sample text from language model."""
200
+ input_ids = queries
201
+ for _i in range(txt_len):
202
+ # Get Logits
203
+ outputs = model(input_ids)
204
+ next_token_logits = outputs[0][:, -1, :]
205
+ next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
206
+ # Sample
207
+ probs = F.softmax(next_token_logits, dim=-1)
208
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
209
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
210
+ return input_ids[:, -txt_len:]
211
+
212
+
213
+ def set_seed(seed: int) -> None:
214
+ """
215
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.
216
+
217
+ Args:
218
+ seed (`int`): The seed to set.
219
+ """
220
+ random.seed(seed)
221
+ np.random.seed(seed)
222
+ torch.manual_seed(seed)
223
+ if is_xpu_available():
224
+ torch.xpu.manual_seed_all(seed)
225
+ elif is_npu_available():
226
+ torch.npu.manual_seed_all(seed)
227
+ else:
228
+ torch.cuda.manual_seed_all(seed)
229
+
230
+
231
+ class LengthSampler:
232
+ """
233
+ Samples a length
234
+ """
235
+
236
+ def __init__(self, min_value: int, max_value: int):
237
+ self.values = list(range(min_value, max_value))
238
+
239
+ def __call__(self) -> int:
240
+ return np.random.choice(self.values)
241
+
242
+
243
+ class PPODecorators:
244
+ optimize_device_cache = False
245
+
246
+ @classmethod
247
+ @contextmanager
248
+ def empty_device_cache(cls):
249
+ yield
250
+ if cls.optimize_device_cache:
251
+ if is_xpu_available():
252
+ gc.collect()
253
+ torch.xpu.empty_cache()
254
+ gc.collect()
255
+ elif is_npu_available():
256
+ gc.collect()
257
+ torch.npu.empty_cache()
258
+ gc.collect()
259
+ elif torch.cuda.is_available():
260
+ gc.collect()
261
+ torch.cuda.empty_cache()
262
+ gc.collect()
263
+
264
+
265
+ def randn_tensor(
266
+ shape: Union[Tuple, List],
267
+ generator: Optional[Union[List[torch.Generator], torch.Generator]] = None,
268
+ device: Optional[torch.device] = None,
269
+ dtype: Optional[torch.dtype] = None,
270
+ layout: Optional[torch.layout] = None,
271
+ ) -> torch.Tensor:
272
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
273
+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
274
+ is always created on the CPU.
275
+ """
276
+ # device on which tensor is created defaults to device
277
+ rand_device = device
278
+ batch_size = shape[0]
279
+
280
+ layout = layout or torch.strided
281
+ device = device or torch.device("cpu")
282
+
283
+ if generator is not None:
284
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
285
+ if gen_device_type != device.type and gen_device_type == "cpu":
286
+ rand_device = "cpu"
287
+ if device != "mps":
288
+ warnings.warn(
289
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
290
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
291
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
292
+ )
293
+ elif gen_device_type != device.type and gen_device_type == "cuda":
294
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
295
+
296
+ # make sure generator list of length 1 is treated like a non-list
297
+ if isinstance(generator, list) and len(generator) == 1:
298
+ generator = generator[0]
299
+
300
+ if isinstance(generator, list):
301
+ shape = (1,) + shape[1:]
302
+ latents = [
303
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
304
+ for i in range(batch_size)
305
+ ]
306
+ latents = torch.cat(latents, dim=0).to(device)
307
+ else:
308
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
309
+
310
+ return latents
ppo/core_original.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import gc
15
+ import random
16
+ import warnings
17
+ from contextlib import contextmanager
18
+ from typing import Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.nn.utils.rnn import pad_sequence
25
+ from transformers import top_k_top_p_filtering
26
+
27
+ from .import_utils import is_npu_available, is_xpu_available
28
+
29
+
30
+ try:
31
+ from collections.abc import Mapping
32
+ except ImportError:
33
+ from collections.abc import Mapping
34
+
35
+
36
+ WANDB_PADDING = -1
37
+
38
+
39
+ def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
40
+ """Flatten dictionary and concatenate nested keys with separator."""
41
+
42
+ def recurse(nest: Dict, prefix: str, into: Dict) -> None:
43
+ for k, v in nest.items():
44
+ if sep in k:
45
+ raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
46
+ if isinstance(v, Mapping):
47
+ recurse(v, prefix + k + sep, into)
48
+ else:
49
+ into[prefix + k] = v
50
+
51
+ flat = {}
52
+ recurse(nested, "", flat)
53
+ return flat
54
+
55
+
56
+ def convert_to_scalar(stats: Dict) -> Dict:
57
+ """
58
+ Converts the stats from a flattened dict to single scalar dicts
59
+ """
60
+ tensorboard_stats = {}
61
+ for k, v in stats.items():
62
+ # for tensorboard compatibility - arrays and tensors are ignored with tensorboard
63
+ # therefore we convert single element tensors to scalars
64
+ if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and (
65
+ len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1)
66
+ ):
67
+ v = v.item()
68
+ tensorboard_stats[k] = v
69
+ return tensorboard_stats
70
+
71
+
72
+ def stack_dicts(stats_dicts: List[Dict]) -> Dict:
73
+ """Stack the values of a dict."""
74
+ results = dict()
75
+ for k in stats_dicts[0]:
76
+ stats_list = [torch.flatten(d[k]) for d in stats_dicts]
77
+ results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)
78
+ return results
79
+
80
+
81
+ def add_suffix(input_dict: Dict, suffix: str) -> Dict:
82
+ """Add suffix to dict keys."""
83
+ return {k + suffix: v for k, v in input_dict.items()}
84
+
85
+
86
+ def pad_to_size(tensor: torch.Tensor, size: int, dim: int = 1, padding: int = 50256) -> torch.Tensor:
87
+ """Pad tensor to size."""
88
+ t_size = tensor.size()[dim]
89
+ if t_size == size:
90
+ return tensor
91
+ else:
92
+ return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding)
93
+
94
+
95
+ def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
96
+ """
97
+ See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
98
+ """
99
+ logp = F.log_softmax(logits, dim=2)
100
+
101
+ if not gather:
102
+ return logp
103
+ logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
104
+ return logpy
105
+
106
+
107
+ def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
108
+ """Whiten values."""
109
+ mean, var = torch.mean(values), torch.var(values)
110
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
111
+ if not shift_mean:
112
+ whitened += mean
113
+ return whitened
114
+
115
+
116
+ def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
117
+ """Compute mean of tensor with a masked values."""
118
+ if axis is not None:
119
+ return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
120
+ else:
121
+ return (values * mask).sum() / mask.sum()
122
+
123
+
124
+ def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
125
+ """Compute variance of tensor with masked values."""
126
+ mean = masked_mean(values, mask)
127
+ centered_values = values - mean
128
+ variance = masked_mean(centered_values**2, mask)
129
+ if unbiased:
130
+ mask_sum = mask.sum()
131
+ if mask_sum == 0:
132
+ raise ValueError(
133
+ "The sum of the mask is zero, which can happen when `mini_batch_size=1`;"
134
+ "try increase the `mini_batch_size` or `gradient_accumulation_steps`"
135
+ )
136
+ # note that if mask_sum == 1, then there is a division by zero issue
137
+ # to avoid it you just need to use a larger minibatch_size
138
+ bessel_correction = mask_sum / (mask_sum - 1)
139
+ variance = variance * bessel_correction
140
+ return variance
141
+
142
+
143
+ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
144
+ """Whiten values with masked values."""
145
+ mean, var = masked_mean(values, mask), masked_var(values, mask)
146
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
147
+ if not shift_mean:
148
+ whitened += mean
149
+ return whitened
150
+
151
+
152
+ def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor:
153
+ """
154
+ Tensor extension to torch.clamp
155
+ https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
156
+ """
157
+ clipped = torch.max(torch.min(x, tensor_max), tensor_min)
158
+ return clipped
159
+
160
+
161
+ def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
162
+ """Calculate entropy from logits."""
163
+ pd = torch.nn.functional.softmax(logits, dim=-1)
164
+ entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
165
+ return entropy
166
+
167
+
168
+ def average_torch_dicts(list_of_dicts: List[Dict]) -> Dict:
169
+ """Average values of a list of dicts with torch tensors."""
170
+ average_dict = dict()
171
+ for key in list_of_dicts[0].keys():
172
+ average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0)
173
+ return average_dict
174
+
175
+
176
+ def stats_to_np(stats_dict: Dict) -> Dict:
177
+ """Cast all torch.tensors in dict to numpy arrays."""
178
+ new_dict = dict()
179
+ for k, v in stats_dict.items():
180
+ if isinstance(v, torch.Tensor):
181
+ new_dict[k] = v.detach().cpu()
182
+ if new_dict[k].dtype == torch.bfloat16:
183
+ new_dict[k] = new_dict[k].float()
184
+ new_dict[k] = new_dict[k].numpy()
185
+ else:
186
+ new_dict[k] = v
187
+ if np.isscalar(new_dict[k]):
188
+ new_dict[k] = float(new_dict[k])
189
+ return new_dict
190
+
191
+
192
+ def respond_to_batch(
193
+ model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
194
+ ) -> torch.LongTensor:
195
+ """Sample text from language model."""
196
+ input_ids = queries
197
+ for _i in range(txt_len):
198
+ # Get Logits
199
+ outputs = model(input_ids)
200
+ next_token_logits = outputs[0][:, -1, :]
201
+ next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
202
+ # Sample
203
+ probs = F.softmax(next_token_logits, dim=-1)
204
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
205
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
206
+ return input_ids[:, -txt_len:]
207
+
208
+
209
+ def set_seed(seed: int) -> None:
210
+ """
211
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.
212
+
213
+ Args:
214
+ seed (`int`): The seed to set.
215
+ """
216
+ random.seed(seed)
217
+ np.random.seed(seed)
218
+ torch.manual_seed(seed)
219
+ if is_xpu_available():
220
+ torch.xpu.manual_seed_all(seed)
221
+ elif is_npu_available():
222
+ torch.npu.manual_seed_all(seed)
223
+ else:
224
+ torch.cuda.manual_seed_all(seed)
225
+
226
+
227
+ class LengthSampler:
228
+ """
229
+ Samples a length
230
+ """
231
+
232
+ def __init__(self, min_value: int, max_value: int):
233
+ self.values = list(range(min_value, max_value))
234
+
235
+ def __call__(self) -> int:
236
+ return np.random.choice(self.values)
237
+
238
+
239
+ class PPODecorators:
240
+ optimize_device_cache = False
241
+
242
+ @classmethod
243
+ @contextmanager
244
+ def empty_device_cache(cls):
245
+ yield
246
+ if cls.optimize_device_cache:
247
+ if is_xpu_available():
248
+ gc.collect()
249
+ torch.xpu.empty_cache()
250
+ gc.collect()
251
+ elif is_npu_available():
252
+ gc.collect()
253
+ torch.npu.empty_cache()
254
+ gc.collect()
255
+ elif torch.cuda.is_available():
256
+ gc.collect()
257
+ torch.cuda.empty_cache()
258
+ gc.collect()
259
+
260
+
261
+ def randn_tensor(
262
+ shape: Union[Tuple, List],
263
+ generator: Optional[Union[List[torch.Generator], torch.Generator]] = None,
264
+ device: Optional[torch.device] = None,
265
+ dtype: Optional[torch.dtype] = None,
266
+ layout: Optional[torch.layout] = None,
267
+ ) -> torch.Tensor:
268
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
269
+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
270
+ is always created on the CPU.
271
+ """
272
+ # device on which tensor is created defaults to device
273
+ rand_device = device
274
+ batch_size = shape[0]
275
+
276
+ layout = layout or torch.strided
277
+ device = device or torch.device("cpu")
278
+
279
+ if generator is not None:
280
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
281
+ if gen_device_type != device.type and gen_device_type == "cpu":
282
+ rand_device = "cpu"
283
+ if device != "mps":
284
+ warnings.warn(
285
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
286
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
287
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
288
+ )
289
+ elif gen_device_type != device.type and gen_device_type == "cuda":
290
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
291
+
292
+ # make sure generator list of length 1 is treated like a non-list
293
+ if isinstance(generator, list) and len(generator) == 1:
294
+ generator = generator[0]
295
+
296
+ if isinstance(generator, list):
297
+ shape = (1,) + shape[1:]
298
+ latents = [
299
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
300
+ for i in range(batch_size)
301
+ ]
302
+ latents = torch.cat(latents, dim=0).to(device)
303
+ else:
304
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
305
+
306
+ return latents
ppo/ppo_trainer_commented.py ADDED
@@ -0,0 +1,1523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ import math
16
+ import os
17
+ import time
18
+ import typing
19
+ import warnings
20
+ from contextlib import nullcontext
21
+ from typing import Callable, List, Optional, Union
22
+
23
+ import datasets
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from accelerate import Accelerator
28
+ from accelerate.utils import ProjectConfiguration, gather_object, is_deepspeed_available
29
+ from datasets import Dataset
30
+ from huggingface_hub import whoami
31
+ from packaging import version
32
+ from torch.optim import Adam
33
+ from transformers import (
34
+ DataCollatorForLanguageModeling,
35
+ PreTrainedTokenizer,
36
+ PreTrainedTokenizerBase,
37
+ PreTrainedTokenizerFast,
38
+ )
39
+
40
+ from ..core import (
41
+ WANDB_PADDING,
42
+ PPODecorators,
43
+ clip_by_value,
44
+ convert_to_scalar,
45
+ entropy_from_logits,
46
+ flatten_dict,
47
+ logprobs_from_logits,
48
+ masked_mean,
49
+ masked_var,
50
+ masked_whiten,
51
+ set_seed,
52
+ stack_dicts,
53
+ stats_to_np,
54
+ )
55
+ from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
56
+ from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
57
+ from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
58
+
59
+
60
+ if is_deepspeed_available():
61
+ import deepspeed
62
+
63
+ MODEL_CARD_TEMPLATE = """---
64
+ license: apache-2.0
65
+ tags:
66
+ - trl
67
+ - ppo
68
+ - transformers
69
+ - reinforcement-learning
70
+ ---
71
+
72
+ # {model_name}
73
+
74
+ This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to
75
+ guide the model outputs according to a value, function, or human feedback. The model can be used for text generation.
76
+
77
+ ## Usage
78
+
79
+ To use this model for inference, first install the TRL library:
80
+
81
+ ```bash
82
+ python -m pip install trl
83
+ ```
84
+
85
+ You can then generate text as follows:
86
+
87
+ ```python
88
+ from transformers import pipeline
89
+
90
+ generator = pipeline("text-generation", model="{model_id}")
91
+ outputs = generator("Hello, my llama is cute")
92
+ ```
93
+
94
+ If you want to use the model for training or to obtain the outputs from the value head, load the model as follows:
95
+
96
+ ```python
97
+ from transformers import AutoTokenizer
98
+ from trl import AutoModelForCausalLMWithValueHead
99
+
100
+ tokenizer = AutoTokenizer.from_pretrained("{model_id}")
101
+ model = AutoModelForCausalLMWithValueHead.from_pretrained("{model_id}")
102
+
103
+ inputs = tokenizer("Hello, my llama is cute", return_tensors="pt")
104
+ outputs = model(**inputs, labels=inputs["input_ids"])
105
+ ```
106
+ """
107
+
108
+
109
+ class PPOTrainer(BaseTrainer):
110
+ """
111
+ The PPOTrainer uses Proximal Policy Optimization to optimise language models.
112
+ Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here:
113
+ https://github.com/openai/summarize-from-feedback
114
+
115
+ Attributes:
116
+ **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more
117
+ details.
118
+ **model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head.
119
+ Check the documentation of `PreTrainedModelWrapper` for more details.
120
+ **ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face
121
+ transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper`
122
+ for more details. If no reference model is provided, the trainer will create a reference model with the same
123
+ architecture as the model to be optimized with shared layers.
124
+ **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the
125
+ data. Check the documentation of `transformers.PreTrainedTokenizer` and
126
+ `transformers.PreTrainedTokenizerFast` for more details.
127
+ **dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging
128
+ Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be
129
+ created outside the trainer users needs to design their own dataloader and make sure the batch
130
+ size that is used is the same as the one specified in the configuration object.
131
+ **optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is
132
+ provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration
133
+ object.
134
+ **data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and
135
+ passed along the dataloader
136
+ **num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference
137
+ model, if no reference model is passed. If no number is provided, all the layers will be shared.
138
+ **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
139
+ """
140
+
141
+ _tag_names = ["trl", "ppo"]
142
+
143
+ def __init__(
144
+ self,
145
+ config: Optional[PPOConfig] = None,
146
+ model: Optional[PreTrainedModelWrapper] = None,
147
+ ref_model: Optional[PreTrainedModelWrapper] = None,
148
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
149
+ dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
150
+ optimizer: Optional[torch.optim.Optimizer] = None,
151
+ data_collator: Optional[typing.Callable] = None,
152
+ num_shared_layers: Optional[int] = None,
153
+ lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
154
+ ):
155
+ """
156
+ Initialize PPOTrainer.
157
+
158
+ Args:
159
+ config (`PPOConfig`):
160
+ Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details.
161
+ model (`PreTrainedModelWrapper`):
162
+ Hugging Face transformer model with a value head.
163
+ ref_model (`PreTrainedModelWrapper`):
164
+ Hugging Face transformer model with a casual language modelling head. Used for KL penalty
165
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
166
+ Hugging Face tokenizer
167
+ dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]):
168
+ PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
169
+ will be preprocessed by removing the columns that are not used by the model. If none is passed,
170
+ a warning will be raised in a multi-GPU setting.
171
+ optimizer (Optional[`torch.optim.Optimizer`]):
172
+ Optimizer used for training. If `None`, the `Adam` is used as default.
173
+ data_collator (Optional[function]):
174
+ Data collator function.
175
+ num_shared_layers (Optional[int]):
176
+ Number of shared layers between the model and the reference model. If `None`, all layers are shared.
177
+ used only if `ref_model` is `None`.
178
+ lr_scheduler (Optional[`torch.optim.lr_scheduler`]):
179
+ Learning rate scheduler used for training.
180
+ """
181
+ super().__init__(config)
182
+
183
+ # initial seed for reproducible experiments
184
+ set_seed(config.seed)
185
+
186
+ # Step 0: check positional arguments validity
187
+ if not isinstance(config, PPOConfig):
188
+ raise ValueError(f"config must be a PPOConfig, got {type(config)}")
189
+ if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
190
+ raise ValueError(
191
+ f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}"
192
+ )
193
+ if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
194
+ raise ValueError(
195
+ f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
196
+ )
197
+ # Step 1: Initialize Accelerator
198
+ self.accelerator = Accelerator(
199
+ log_with=config.log_with,
200
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
201
+ project_config=ProjectConfiguration(**config.project_kwargs),
202
+ **config.accelerator_kwargs,
203
+ )
204
+
205
+ # Step 1.1 Runtime variables filled by the accelerator
206
+ config.world_size = self.accelerator.num_processes
207
+ config.global_backward_batch_size = config.backward_batch_size * config.world_size
208
+ config.global_batch_size = config.batch_size * config.world_size
209
+
210
+ self.model = model
211
+ self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
212
+ self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
213
+ self.is_peft_model = getattr(self.model, "is_peft_model", False)
214
+ config.is_encoder_decoder = self.is_encoder_decoder
215
+ config.is_peft_model = self.is_peft_model
216
+
217
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
218
+ self.accelerator.init_trackers(
219
+ config.tracker_project_name,
220
+ config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
221
+ init_kwargs=config.tracker_kwargs,
222
+ )
223
+ self.is_using_text_environment = getattr(config, "use_text_environment", False)
224
+
225
+ if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
226
+ self.ref_model = ref_model
227
+ if num_shared_layers is not None:
228
+ warnings.warn(
229
+ "num_shared_layers is ignored when ref_model is provided. Two different models are used for the "
230
+ "model and the reference model and no layers are shared.",
231
+ UserWarning,
232
+ )
233
+ elif ref_model is None and not self.is_peft_model:
234
+ self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers)
235
+ elif self.is_peft_model:
236
+ self.ref_model = None
237
+ else:
238
+ raise ValueError(
239
+ f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported "
240
+ f"architectures are: {SUPPORTED_ARCHITECTURES} "
241
+ )
242
+ self.optional_peft_ctx = (
243
+ self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter
244
+ if self.is_peft_model
245
+ else nullcontext
246
+ )
247
+
248
+ if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)):
249
+ raise ValueError(
250
+ "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast"
251
+ )
252
+ self.tokenizer = tokenizer
253
+
254
+ if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)):
255
+ raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset")
256
+ elif dataset is None:
257
+ warnings.warn(
258
+ "No dataset is provided. Make sure to set config.batch_size to the correct value before training.",
259
+ UserWarning,
260
+ )
261
+ self.dataset = dataset
262
+ self._signature_columns = None
263
+ if self.dataset is not None:
264
+ self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
265
+ elif self.dataset is None and self.accelerator.num_processes > 1:
266
+ warnings.warn(
267
+ "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should"
268
+ " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`"
269
+ " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please "
270
+ " refer to the documentation for more details.",
271
+ UserWarning,
272
+ )
273
+ self.dataloader = None
274
+ else:
275
+ self.dataloader = None
276
+
277
+ # Step 3: Initialize optimizer and data collator
278
+ self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
279
+ if optimizer is None:
280
+ self.optimizer = Adam(
281
+ filter(lambda p: p.requires_grad, self.model.parameters()),
282
+ lr=self.config.learning_rate,
283
+ )
284
+ else:
285
+ self.optimizer = optimizer
286
+
287
+ self.lr_scheduler = lr_scheduler
288
+ if self.lr_scheduler is not None:
289
+ lr_scheduler_class = (
290
+ torch.optim.lr_scheduler._LRScheduler
291
+ if not is_torch_greater_2_0()
292
+ else torch.optim.lr_scheduler.LRScheduler
293
+ )
294
+
295
+ if not isinstance(self.lr_scheduler, lr_scheduler_class):
296
+ raise ValueError(
297
+ "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)"
298
+ )
299
+
300
+ if self.config.adap_kl_ctrl:
301
+ self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
302
+ else:
303
+ self.kl_ctl = FixedKLController(self.config.init_kl_coef)
304
+
305
+ # Safety checkers for DS integration
306
+ is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
307
+ self.accelerator.state, "deepspeed_plugin"
308
+ )
309
+
310
+ (
311
+ self.model,
312
+ self.optimizer,
313
+ self.data_collator,
314
+ self.dataloader,
315
+ self.lr_scheduler,
316
+ ) = self.accelerator.prepare(
317
+ self.model,
318
+ self.optimizer,
319
+ self.data_collator,
320
+ self.dataloader,
321
+ self.lr_scheduler,
322
+ )
323
+ if is_deepspeed_used:
324
+ # Quantized models are already set on the correct device
325
+ if not self.is_peft_model and not (
326
+ getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False)
327
+ or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False)
328
+ ):
329
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
330
+ else:
331
+ self.ref_model = self.accelerator.prepare(self.ref_model)
332
+
333
+ # In a distributed setup, only logging needs to be performed on the main process
334
+ # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
335
+ # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11
336
+ self.is_distributed = self.accelerator.num_processes > 1
337
+
338
+ # init the current step
339
+ self.current_step = 0
340
+
341
+ # init variables for pushing model to hub
342
+ if config.push_to_hub_if_best_kwargs:
343
+ if "repo_id" not in config.push_to_hub_if_best_kwargs:
344
+ raise ValueError("You have to specify repo_id in order to push the model to the hub!")
345
+ self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
346
+ self.compare_step = 0
347
+ self.highest_reward = torch.tensor(-float("inf"))
348
+
349
+ # post process for PP
350
+ if not getattr(self.model, "is_sequential_parallel", False):
351
+ self.current_device = self.accelerator.device
352
+ else:
353
+ if is_xpu_available():
354
+ self.current_device = torch.device("xpu:0")
355
+ elif is_npu_available():
356
+ self.current_device = torch.device("npu:0")
357
+ else:
358
+ self.current_device = torch.device("cuda:0")
359
+
360
+ PPODecorators.optimize_device_cache = self.config.optimize_device_cache
361
+
362
+ self.running = RunningMoments(self.accelerator)
363
+
364
+ def _filter_kwargs(self, kwargs, target_func):
365
+ """
366
+ filter the keyword arguments that are supported by the target function.
367
+
368
+ Args:
369
+ kwargs (dict):
370
+ Keyword arguments
371
+ target_func (function):
372
+ Target function
373
+ """
374
+ return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()}
375
+
376
+ def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None):
377
+ """
378
+ Prepare the dataloader for training.
379
+
380
+ Args:
381
+ dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]):
382
+ PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
383
+ will be preprocessed by removing the columns that are not used by the model.
384
+ data_collator (Optional[function]):
385
+ Data collator function.
386
+
387
+ Returns:
388
+ `torch.utils.data.DataLoader`: PyTorch dataloader
389
+ """
390
+ if isinstance(dataset, Dataset):
391
+ dataset = self._remove_unused_columns(dataset)
392
+ dataloader = torch.utils.data.DataLoader(
393
+ dataset,
394
+ batch_size=self.config.batch_size,
395
+ collate_fn=data_collator,
396
+ shuffle=True,
397
+ drop_last=True,
398
+ )
399
+ return dataloader
400
+
401
+ # Adapted from transformers.Trainer._set_signature_columns_if_needed
402
+ def _set_signature_columns_if_needed(self):
403
+ if self._signature_columns is None:
404
+ # Inspect model forward signature to keep only the arguments it accepts.
405
+ signature = inspect.signature(self.model.forward)
406
+ self._signature_columns = list(signature.parameters.keys())
407
+ # label => sentiment | we need query and response for logging purpose
408
+ self._signature_columns += ["label", "query", "response"]
409
+
410
+ # Adapted from transformers.Trainer._remove_unused_columns
411
+ def _remove_unused_columns(self, dataset: "Dataset"):
412
+ if not self.config.remove_unused_columns:
413
+ return dataset
414
+ self._set_signature_columns_if_needed()
415
+ signature_columns = self._signature_columns
416
+
417
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
418
+
419
+ columns = [k for k in signature_columns if k in dataset.column_names]
420
+
421
+ if version.parse(datasets.__version__) < version.parse("1.4.0"):
422
+ dataset.set_format(
423
+ type=dataset.format["type"],
424
+ columns=columns,
425
+ format_kwargs=dataset.format["format_kwargs"],
426
+ )
427
+ return dataset
428
+ else:
429
+ return dataset.remove_columns(ignored_columns)
430
+
431
+ def generate(
432
+ self,
433
+ query_tensor: Union[torch.Tensor, List[torch.Tensor]],
434
+ length_sampler: Optional[Callable] = None,
435
+ batch_size: int = 4,
436
+ return_prompt: bool = True,
437
+ generate_ref_response: bool = False,
438
+ **generation_kwargs,
439
+ ):
440
+ """
441
+ Generate response with the model given the query tensor.
442
+ call the `generate` method of the model.
443
+
444
+ Args:
445
+ query_tensor (`torch.LongTensor`):
446
+ A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`).
447
+ length_sampler (`Callable`, *optional*):
448
+ Callable that returns the number of newly generated tokens.
449
+ batch_size (`int`, *optional):
450
+ Batch size used for generation, defaults to `4`.
451
+ return_prompt (`bool`, *optional*):
452
+ If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`.
453
+ generate_ref_response (`bool`, *optional*):
454
+ If set to `True` the reference response is also generated, defaults to `False`.
455
+ generation_kwargs (dict[str, Any]):
456
+ Keyword arguments for generation.
457
+
458
+ Returns:
459
+ `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens.
460
+ """
461
+ if generate_ref_response:
462
+ ref_model = self.model if self.is_peft_model else self.ref_model
463
+ if isinstance(query_tensor, List):
464
+ response = self._generate_batched(
465
+ self.model,
466
+ query_tensor,
467
+ length_sampler=length_sampler,
468
+ batch_size=batch_size,
469
+ return_prompt=return_prompt,
470
+ **generation_kwargs,
471
+ )
472
+ if generate_ref_response:
473
+ with self.optional_peft_ctx():
474
+ ref_response = self._generate_batched(
475
+ ref_model,
476
+ query_tensor,
477
+ length_sampler=length_sampler,
478
+ batch_size=batch_size,
479
+ return_prompt=return_prompt,
480
+ **generation_kwargs,
481
+ )
482
+
483
+ else:
484
+ if len(query_tensor.shape) == 2:
485
+ raise ValueError(
486
+ "query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)"
487
+ )
488
+
489
+ if length_sampler is not None:
490
+ generation_kwargs["max_new_tokens"] = length_sampler()
491
+ response = self.accelerator.unwrap_model(self.model).generate(
492
+ input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
493
+ )
494
+ if generate_ref_response:
495
+ with self.optional_peft_ctx():
496
+ ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
497
+
498
+ if not return_prompt and not self.is_encoder_decoder:
499
+ response = response[:, query_tensor.shape[0] :]
500
+ if generate_ref_response:
501
+ ref_response = ref_response[:, query_tensor.shape[0] :]
502
+
503
+ if generate_ref_response:
504
+ return response, ref_response
505
+ return response
506
+
507
+ def _generate_batched(
508
+ self,
509
+ model: PreTrainedModelWrapper,
510
+ query_tensors: List[torch.Tensor],
511
+ length_sampler: Optional[Callable] = None,
512
+ batch_size: int = 4,
513
+ return_prompt: bool = True,
514
+ pad_to_multiple_of: Optional[int] = None,
515
+ remove_padding: bool = True,
516
+ **generation_kwargs,
517
+ ):
518
+ outputs = []
519
+
520
+ padding_side_default = self.tokenizer.padding_side
521
+ if not self.is_encoder_decoder:
522
+ self.tokenizer.padding_side = "left"
523
+
524
+ # in case we have fewer examples than bs
525
+ batch_size = min(len(query_tensors), batch_size)
526
+
527
+ for i in range(0, len(query_tensors), batch_size):
528
+ if length_sampler is not None:
529
+ generation_kwargs["max_new_tokens"] = length_sampler()
530
+
531
+ # prevent overflow if query tensors are not even multiple of bs
532
+ end_index = min(len(query_tensors), i + batch_size)
533
+
534
+ batch = query_tensors[i:end_index]
535
+ batch_mask = [torch.ones_like(element) for element in batch]
536
+ inputs = {"input_ids": batch, "attention_mask": batch_mask}
537
+
538
+ padded_inputs = self.tokenizer.pad(
539
+ inputs,
540
+ padding=True,
541
+ max_length=None,
542
+ pad_to_multiple_of=pad_to_multiple_of,
543
+ return_tensors="pt",
544
+ ).to(self.current_device)
545
+
546
+ generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs)
547
+
548
+ for generation, mask in zip(generations, padded_inputs["attention_mask"]):
549
+ if not self.is_encoder_decoder:
550
+ output = generation[(1 - mask).sum() :] # remove padding
551
+ else:
552
+ output = generation
553
+
554
+ if not return_prompt and not self.is_encoder_decoder:
555
+ output = output[(mask).sum() :] # remove prompt
556
+
557
+ if remove_padding and self.tokenizer.eos_token_id in output:
558
+ pad_mask = output == self.tokenizer.eos_token_id
559
+ pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()
560
+ output = output[: pad_start + 1] # keep the eos token at the end
561
+
562
+ outputs.append(output)
563
+
564
+ self.tokenizer.padding_side = padding_side_default
565
+ return outputs
566
+
567
+ def _step_safety_checker(
568
+ self,
569
+ batch_size: int,
570
+ queries: List[torch.LongTensor],
571
+ responses: List[torch.LongTensor],
572
+ scores: List[torch.FloatTensor],
573
+ masks: Optional[List[torch.LongTensor]] = None,
574
+ ):
575
+ """
576
+ Check if the input data is valid for training.
577
+
578
+ Args:
579
+ batch_size (int):
580
+ Batch size from the config file.
581
+ queries (List[`torch.LongTensor`]):
582
+ List of tensors containing the encoded queries of shape (`query_length`)
583
+ responses (List[`torch.LongTensor`]):
584
+ List of tensors containing the encoded responses of shape (`response_length`)
585
+ scores (List[`torch.FloatTensor`]):
586
+ List of tensors containing the scores.
587
+ masks (List[`torch.LongTensor`], *optional*):
588
+ list of optional tensors containing the masks of shape (`query_length` + `response_length`)
589
+ Returns:
590
+ `tuple`: The input processed data.
591
+ """
592
+ for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]):
593
+ if not isinstance(tensor_list, list):
594
+ raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
595
+ if not isinstance(tensor_list[0], torch.Tensor):
596
+ raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
597
+ if batch_size is not None and len(tensor_list) != batch_size:
598
+ raise ValueError(
599
+ f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}"
600
+ )
601
+
602
+ # add queries, scores and responses on the correct device
603
+ queries = [tensor.to(self.current_device) for tensor in queries]
604
+ responses = [tensor.to(self.current_device) for tensor in responses]
605
+ scores = [tensor.to(self.current_device) for tensor in scores]
606
+ masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None
607
+
608
+ # squeeze scores if needed
609
+ for i, score in enumerate(scores):
610
+ if score.dim() > 1:
611
+ raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}")
612
+ elif score.dim() == 1:
613
+ scores[i] = score.squeeze()
614
+
615
+ return queries, responses, scores, masks
616
+
617
+ @PPODecorators.empty_device_cache()
618
+ def step(
619
+ self,
620
+ queries: List[torch.LongTensor], # The list of prompts used to generate responses from the old model (offline policy)
621
+ responses: List[torch.LongTensor], # A list of resnponses generated by the old model (offline policy)
622
+ scores: List[torch.FloatTensor], # A list of reward associated with each response. One reward for each response (NOT for each token of the response)
623
+ response_masks: Optional[List[torch.LongTensor]] = None,
624
+ ):
625
+ """
626
+ Run a PPO optimisation step given a list of queries, model responses, and rewards.
627
+
628
+ Args:
629
+ queries (List[`torch.LongTensor`]):
630
+ List of tensors containing the encoded queries of shape (`query_length`)
631
+ responses (List[`torch.LongTensor`]):
632
+ List of tensors containing the encoded responses of shape (`response_length`)
633
+ scores (List[`torch.FloatTensor`]):
634
+ List of tensors containing the scores.
635
+ response_masks (List[`torch.FloatTensor`], *optional*)):
636
+ List of tensors containing masks of the response tokens.
637
+
638
+ Returns:
639
+ `dict[str, Any]`: A summary of the training statistics
640
+ """
641
+ bs = self.config.batch_size
642
+
643
+ # queries: input_ids of the prompts;
644
+ # responses: input_ids of the responses;
645
+ # scores: score from reward model (one per response)
646
+ # Verify input tensors (check types, shapes, etc.)
647
+ queries, responses, scores, response_masks = self._step_safety_checker(
648
+ bs, queries, responses, scores, response_masks
649
+ )
650
+
651
+ # Indicates the rewards given to the responses. One scalar for each response.
652
+ # shape: (batch_size)
653
+ scores = torch.tensor(scores, device=self.current_device)
654
+
655
+ # if self.config.use_score_scaling:
656
+ # # Score scaling
657
+ # scores_mean, scores_std = self.running.update(scores)
658
+ # tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device)
659
+ # score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps
660
+ # if self.config.use_score_norm:
661
+ # scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor
662
+ # else:
663
+ # scores /= score_scaling_factor
664
+
665
+ # if self.config.score_clip is not None:
666
+ # # Score clipping
667
+ # scores_dtype = scores.dtype
668
+ # scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
669
+
670
+ # # if we want to push best model to the hub
671
+ # if hasattr(self, "highest_reward"):
672
+ # if self.compare_step % self.config.compare_steps == 0:
673
+ # curr_mean_reward = scores.mean()
674
+ # # if the best reward ever seen
675
+ # if curr_mean_reward > self.highest_reward:
676
+ # self.highest_reward = curr_mean_reward
677
+ # # push model to hub
678
+ # self.push_to_hub(**self.push_to_hub_kwargs)
679
+ # self.compare_step += 1
680
+
681
+ timing = dict()
682
+ t0 = time.time()
683
+
684
+ t = time.time()
685
+
686
+ # Join the query and the response to create a input_ids tensor
687
+ # Also generate the attention masks (for padding). Padding is added so that all the query+response can be joined in the same tensor
688
+ # Dictionary with input_ids and attention_mask.
689
+ # Shape of input_ids: (batch_size, seq_len)
690
+ # Shape of attention_mask: (batch_size, seq_len). The attention mask just masks out the padding token.
691
+ model_inputs = self.prepare_model_inputs(queries, responses)
692
+
693
+ # if self.is_distributed:
694
+ # pad_first = self.tokenizer.padding_side == "left"
695
+
696
+ # model_inputs["input_ids"] = self.accelerator.pad_across_processes(
697
+ # model_inputs["input_ids"],
698
+ # dim=1,
699
+ # pad_index=self.tokenizer.pad_token_id,
700
+ # pad_first=pad_first,
701
+ # )
702
+ # model_inputs["attention_mask"] = self.accelerator.pad_across_processes(
703
+ # model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first
704
+ # )
705
+ # if self.is_encoder_decoder:
706
+ # model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes(
707
+ # model_inputs["decoder_input_ids"],
708
+ # dim=1,
709
+ # pad_index=self.tokenizer.pad_token_id,
710
+ # pad_first=pad_first,
711
+ # )
712
+ # model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
713
+ # model_inputs["decoder_attention_mask"],
714
+ # dim=1,
715
+ # pad_index=0,
716
+ # pad_first=pad_first,
717
+ # )
718
+
719
+ model_inputs_names = list(model_inputs.keys())
720
+
721
+ full_kl_penalty = self.config.kl_penalty == "full" # It is going to be False in our case.
722
+
723
+ # Since the given trajectories from the offline model do not have the logprobs and value estimations for each position (action), we need to calculate them.
724
+
725
+ with torch.no_grad():
726
+ # Calculate the log probabilities of all tokens of each sentence
727
+ # The masks indicate which log probabilities to use (exclude query tokens and padding tokens)
728
+ # all_logprobs: (Batch_Size, Seq_Len - 1) where Seq_Len is the maximum length of a query+response
729
+ # values: (Batch_Size, Seq_Len - 1), masks: (Batch_Size, Seq_Len - 1)
730
+ all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
731
+ self.model,
732
+ queries,
733
+ responses,
734
+ model_inputs,
735
+ response_masks=response_masks,
736
+ return_logits=full_kl_penalty,
737
+ )
738
+
739
+ with self.optional_peft_ctx():
740
+ # Get the log probabilities also w.r.t the reference model (frozen model)
741
+ ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
742
+ self.model if self.is_peft_model else self.ref_model,
743
+ queries,
744
+ responses,
745
+ model_inputs,
746
+ return_logits=full_kl_penalty,
747
+ )
748
+
749
+ timing["time/ppo/forward_pass"] = time.time() - t
750
+
751
+ with torch.no_grad():
752
+ t = time.time()
753
+ if full_kl_penalty:
754
+ # === NOT USED === #
755
+ active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
756
+ ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)
757
+
758
+ rewards, non_score_reward, kls = self.compute_rewards(
759
+ scores, active_full_logprobs, ref_full_logprobs, masks
760
+ )
761
+ else:
762
+ # Use the scores (from reward model) and the log probabilities to generate the rewards.
763
+ # rewards: (Batch_Size, Seq_Len - 1)
764
+ rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
765
+ timing["time/ppo/compute_rewards"] = time.time() - t
766
+
767
+ t = time.time()
768
+ # Use the rewards and the values to compute the advantage using GAE.
769
+ # values: (Batch_Size, Seq_Len - 1)
770
+ # rewards: (Batch_Size, Seq_Len-1)
771
+ # returns (Q-values): (Batch_Size, Seq_Len-1)
772
+ values, advantages, returns = self.compute_advantages(values, rewards, masks)
773
+ timing["time/ppo/compute_advantages"] = time.time() - t
774
+
775
+ # This represents all the trajectories sampled (our storage of trajectories) using the old policy (offline).
776
+ # upcast to float32 to avoid dataset issues
777
+ batch_dict = {
778
+ "queries": queries,
779
+ "responses": responses,
780
+ "logprobs": all_logprobs.to(torch.float32),
781
+ "values": values.to(torch.float32),
782
+ "masks": masks,
783
+ "advantages": advantages,
784
+ "returns": returns,
785
+ }
786
+ batch_dict.update(model_inputs)
787
+
788
+ # ======================================
789
+ # PHASE 2: Optimize the model using PPO
790
+ # ======================================
791
+
792
+ t = time.time()
793
+ all_stats = []
794
+ early_stop = False
795
+ for _ in range(self.config.ppo_epochs):
796
+ if early_stop:
797
+ break
798
+ b_inds = np.random.permutation(bs) # Shuffle the trajectories
799
+ for backward_batch_start in range(0, bs, self.config.backward_batch_size):
800
+ backward_batch_end = backward_batch_start + self.config.backward_batch_size
801
+ # Get the items to retrieve from the trajectories storage
802
+ backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
803
+
804
+ # Extract a mini-batch from the macro-batch extracted from the trajectories
805
+ for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
806
+ mini_batch_end = mini_batch_start + self.config.mini_batch_size
807
+ mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
808
+
809
+
810
+ # This is the sampled mini-batch that will be used to optimize the model
811
+ mini_batch_dict = {
812
+ "logprobs": batch_dict["logprobs"][mini_batch_inds],
813
+ "values": batch_dict["values"][mini_batch_inds],
814
+ "masks": batch_dict["masks"][mini_batch_inds],
815
+ # hacks: the queries and responses are ragged.
816
+ "queries": [batch_dict["queries"][i] for i in mini_batch_inds],
817
+ "responses": [batch_dict["responses"][i] for i in mini_batch_inds],
818
+ "advantages": batch_dict["advantages"][mini_batch_inds],
819
+ "returns": batch_dict["returns"][mini_batch_inds],
820
+ }
821
+
822
+ for k in model_inputs_names:
823
+ mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
824
+ with self.accelerator.accumulate(self.model):
825
+ model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
826
+
827
+ # Calculate the logprobs, logits and values of the online model (new policy)
828
+ logprobs, logits, vpreds, _ = self.batched_forward_pass(
829
+ self.model,
830
+ mini_batch_dict["queries"],
831
+ mini_batch_dict["responses"],
832
+ model_inputs,
833
+ return_logits=True,
834
+ )
835
+
836
+ # Perform a training step using the logprobs from the old policy and the logprobs from the new policy
837
+ train_stats = self.train_minibatch(
838
+ mini_batch_dict["logprobs"],
839
+ mini_batch_dict["values"],
840
+ logprobs,
841
+ logits,
842
+ vpreds,
843
+ mini_batch_dict["masks"],
844
+ mini_batch_dict["advantages"],
845
+ mini_batch_dict["returns"],
846
+ )
847
+ all_stats.append(train_stats)
848
+
849
+ # typically, early stopping is done at the epoch level
850
+ if self.config.early_stopping:
851
+ policykl = train_stats["policy/policykl"]
852
+ early_stop = self._early_stop(policykl)
853
+ if early_stop:
854
+ break
855
+
856
+ timing["time/ppo/optimize_step"] = time.time() - t
857
+
858
+ t = time.time()
859
+ train_stats = stack_dicts(all_stats)
860
+
861
+ # reshape advantages/ratios such that they are not averaged.
862
+ train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0)
863
+ train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING)
864
+ train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0)
865
+
866
+ stats = self.record_step_stats(
867
+ scores=scores,
868
+ logprobs=all_logprobs,
869
+ ref_logprobs=ref_logprobs,
870
+ non_score_reward=non_score_reward,
871
+ train_stats=train_stats,
872
+ kl_coef=self.kl_ctl.value,
873
+ masks=masks,
874
+ queries=queries,
875
+ responses=responses,
876
+ kls=kls,
877
+ )
878
+ # Gather/Reduce stats from all processes
879
+ if self.is_distributed:
880
+ stats = self.gather_stats(stats)
881
+ stats = stats_to_np(stats)
882
+ timing["time/ppo/calc_stats"] = time.time() - t
883
+ stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]
884
+
885
+ # Update the KL control - multiply the batch_size by the number of processes
886
+ self.kl_ctl.update(
887
+ stats["objective/kl"],
888
+ self.config.batch_size * self.accelerator.num_processes,
889
+ )
890
+
891
+ # Log the total ppo time
892
+ timing["time/ppo/total"] = time.time() - t0
893
+ stats.update(timing)
894
+
895
+ # post-process stats for tensorboard and other loggers
896
+ if self.config.log_with != "wandb":
897
+ stats = convert_to_scalar(stats)
898
+
899
+ if self.lr_scheduler is not None:
900
+ self.lr_scheduler.step()
901
+
902
+ return stats
903
+
904
+ def _early_stop(self, policykl):
905
+ r"""
906
+ Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and
907
+ the optimization step is skipped.
908
+ This also handles the multi-gpu case where the policy KL is averaged across all processes.
909
+
910
+ Args:
911
+ policy_kl (torch.Tensor):
912
+ the policy KL
913
+
914
+ Returns:
915
+ `bool`: whether to early stop or not
916
+ """
917
+ early_stop = False
918
+ if not self.config.early_stopping:
919
+ return early_stop
920
+
921
+ if not self.is_distributed and policykl > 1.5 * self.config.target_kl:
922
+ self.optimizer.zero_grad()
923
+ early_stop = True
924
+ elif self.is_distributed:
925
+ import torch.distributed as dist
926
+
927
+ # Wait for all processes to finish
928
+ dist.barrier()
929
+
930
+ # all gather the policykl
931
+ dist.all_reduce(policykl, dist.ReduceOp.SUM)
932
+ policykl /= self.accelerator.num_processes
933
+
934
+ if policykl > 1.5 * self.config.target_kl:
935
+ self.optimizer.zero_grad()
936
+ early_stop = True
937
+ return early_stop
938
+
939
+ def gather_stats(self, stats):
940
+ """
941
+ Gather stats from all processes. Useful in the context of distributed training.
942
+
943
+ Args:
944
+ stats (dict[str, Any]):
945
+ a dictionary of stats to be gathered. The stats should contain torch tensors.
946
+
947
+ Returns:
948
+ `dict[str, Any]`: A dictionary of stats with the tensors gathered.
949
+ """
950
+ import torch.distributed as dist
951
+
952
+ # Wait for all processes to finish
953
+ dist.barrier()
954
+
955
+ for k, v in stats.items():
956
+ if isinstance(v, torch.Tensor):
957
+ dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM)
958
+ v /= self.accelerator.num_processes
959
+ stats[k] = v
960
+ return stats
961
+
962
+ def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
963
+ if self.is_encoder_decoder:
964
+ input_data = self.data_collator(
965
+ [{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]
966
+ ).to(self.current_device)
967
+
968
+ decoder_inputs = self.data_collator(
969
+ [{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]
970
+ ).to(self.current_device)
971
+
972
+ input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
973
+ input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
974
+ else:
975
+ input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
976
+ input_data = self.data_collator(
977
+ [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]
978
+ ).to(self.current_device)
979
+
980
+ input_data.pop("labels", None) # we don't want to compute LM losses
981
+ return input_data
982
+
983
+ @PPODecorators.empty_device_cache()
984
+ def batched_forward_pass(
985
+ self,
986
+ model: PreTrainedModelWrapper,
987
+ queries: torch.Tensor,
988
+ responses: torch.Tensor,
989
+ model_inputs: dict,
990
+ return_logits: bool = False,
991
+ response_masks: Optional[torch.Tensor] = None,
992
+ ):
993
+ """
994
+ Calculate model outputs in multiple batches.
995
+
996
+ Args:
997
+ queries (`torch.LongTensor`):
998
+ List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
999
+ responses (`torch.LongTensor`):
1000
+ List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
1001
+ return_logits (`bool`, *optional*, defaults to `False`):
1002
+ Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
1003
+ Returns:
1004
+ (tuple):
1005
+ - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
1006
+ shape (`batch_size`, `response_length`)
1007
+ - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
1008
+ shape (`batch_size`, `response_length`)
1009
+ - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
1010
+ """
1011
+ bs = len(queries)
1012
+ fbs = self.config.mini_batch_size
1013
+ all_logprobs = []
1014
+ all_logits = []
1015
+ all_masks = []
1016
+ all_values = []
1017
+
1018
+ model.eval()
1019
+
1020
+ # Since each batch can be big and may not fit in memory, we calculate the logits and log probabilities by splitting the batch into smaller batches of size `fbs`
1021
+
1022
+ for i in range(math.ceil(bs / fbs)):
1023
+ # Get the input tensors for the current mini batch (of size `fbs`)
1024
+
1025
+ input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
1026
+ query_batch = queries[i * fbs : (i + 1) * fbs]
1027
+ response_batch = responses[i * fbs : (i + 1) * fbs]
1028
+ if response_masks is not None:
1029
+ response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
1030
+
1031
+ # Obtain the logits corresponding to each token in the input and the corresponding value from the ValueHead.
1032
+ # The input is the concatenation of the query and the response.
1033
+ # logits: (Batch, Seq_Length, Vocab_Size),
1034
+ # values: (Batch, Seq_Length)
1035
+ logits, _, values = model(**input_kwargs)
1036
+
1037
+ if self.is_encoder_decoder:
1038
+ input_ids = input_kwargs["decoder_input_ids"]
1039
+ attention_mask = input_kwargs["decoder_attention_mask"]
1040
+ else:
1041
+ input_ids = input_kwargs["input_ids"]
1042
+ attention_mask = input_kwargs["attention_mask"]
1043
+
1044
+ # Calculate the log probabilities for each token.
1045
+ # This can be obtained by the logits output by the token for each token (and by applying softmax).
1046
+ # logits: (Batch_Size, Seq_Length - 1)
1047
+ logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
1048
+
1049
+ masks = torch.zeros_like(attention_mask)
1050
+ masks[:, :-1] = attention_mask[:, 1:] # Indicates for which tokens we have the logprobs
1051
+
1052
+ for j in range(len(query_batch)):
1053
+ if self.is_encoder_decoder:
1054
+ # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
1055
+ start = 1
1056
+ end = attention_mask[j, :].sum() - 1
1057
+ else:
1058
+ # logprobs starts from the first response token
1059
+ start = len(query_batch[j]) - 1
1060
+ if attention_mask[j, 0] == 0: # offset left padding
1061
+ start += attention_mask[j, :].nonzero()[0]
1062
+ # The index corresponding to the end position in the entire (query+response) sequence
1063
+ end = start + len(response_batch[j])
1064
+ if response_masks is not None:
1065
+ response_masks_batch[j] = torch.cat(
1066
+ (torch.zeros_like(query_batch[j]), response_masks_batch[j])
1067
+ )[1:]
1068
+
1069
+ # All the tokens for which we don't have logprobs are masked out
1070
+ # Mask out any token before the first response token (so mask out the prompt tokens)
1071
+ masks[j, :start] = 0
1072
+ # Mask out any token that comes after the response tokens (so mask out any padding tokens)
1073
+ masks[j, end:] = 0
1074
+
1075
+ if response_masks is not None:
1076
+ masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
1077
+
1078
+ if return_logits:
1079
+ all_logits.append(logits)
1080
+ else:
1081
+ del logits
1082
+ all_values.append(values)
1083
+ all_logprobs.append(logprobs)
1084
+ all_masks.append(masks)
1085
+
1086
+ return (
1087
+ torch.cat(all_logprobs),
1088
+ torch.cat(all_logits)[:, :-1] if return_logits else None,
1089
+ torch.cat(all_values)[:, :-1],
1090
+ torch.cat(all_masks)[:, :-1],
1091
+ )
1092
+
1093
+ @PPODecorators.empty_device_cache()
1094
+ def train_minibatch(
1095
+ self,
1096
+ old_logprobs: torch.FloatTensor, # log probabilities under the OLD policy (offline)
1097
+ values: torch.FloatTensor, # values under the OLD policy (offline)
1098
+ logprobs: torch.FloatTensor, # log probabilities under the new policy (online)
1099
+ logits: torch.FloatTensor, # logits under the new policy (online)
1100
+ vpreds: torch.FloatTensor, # values under the new policy (online)
1101
+ mask: torch.LongTensor, # indicates for which tokens the log probabilities correspond to
1102
+ advantages: torch.FloatTensor, # advantages calculated under the OLD policy (offline)
1103
+ returns: torch.FloatTensor, # returns calculated under the OLD policy (offline)
1104
+ ):
1105
+ """
1106
+ Train one PPO minibatch
1107
+
1108
+ Args:
1109
+ logprobs (`torch.FloatTensor`):
1110
+ Log probabilities of the model, shape [mini_batch_size, response_length]
1111
+ values (`torch.FloatTensor`):
1112
+ Values of the value head, shape [mini_batch_size, response_length]
1113
+ query (`torch.LongTensor`):
1114
+ Encoded queries, shape [mini_batch_size, query_length]
1115
+ response (`torch.LongTensor`):
1116
+ Encoded responses, shape [mini_batch_size, response_length]
1117
+ model_input (`torch.LongTensor`):
1118
+ Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
1119
+
1120
+ Returns:
1121
+ train_stats (dict[str, `torch.Tensor`]):
1122
+ Dictionary of training statistics
1123
+ """
1124
+ self.model.train()
1125
+ loss_p, loss_v, train_stats = self.loss(
1126
+ old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns
1127
+ )
1128
+ loss = loss_p + loss_v # the loss is the sum of the policy_gradient loss and the values loss
1129
+ self.accelerator.backward(loss)
1130
+ if self.config.max_grad_norm is not None:
1131
+ if self.accelerator.sync_gradients:
1132
+ self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)
1133
+ self.optimizer.step()
1134
+ # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
1135
+ # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code
1136
+ self.optimizer.zero_grad()
1137
+ return train_stats
1138
+
1139
+ def compute_rewards(
1140
+ self,
1141
+ scores: torch.FloatTensor,
1142
+ logprobs: torch.FloatTensor,
1143
+ ref_logprobs: torch.FloatTensor,
1144
+ masks: torch.LongTensor,
1145
+ ):
1146
+ """
1147
+ Compute per token rewards from scores and KL-penalty.
1148
+
1149
+ Args:
1150
+ scores (`torch.FloatTensor`):
1151
+ Scores from the reward model, shape (`batch_size`)
1152
+ logprobs (`torch.FloatTensor`):
1153
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
1154
+ ref_logprobs (`torch.FloatTensor`):
1155
+ Log probabilities of the reference model, shape (`batch_size`, `response_length`)
1156
+
1157
+ Returns:
1158
+ `torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`)
1159
+ `torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`)
1160
+ `torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`)
1161
+ """
1162
+ rewards, non_score_rewards, kls = [], [], []
1163
+ for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
1164
+ # compute KL penalty (from difference in logprobs)
1165
+ # shape: (Seq_Len) - represents the differece in logprobs for each token (frozen model vs fine-tuned model)
1166
+ kl = self._kl_penalty(logprob, ref_logprob)
1167
+ kls.append(kl)
1168
+ non_score_reward = -self.kl_ctl.value * kl
1169
+ non_score_rewards.append(non_score_reward)
1170
+ reward = non_score_reward.clone()
1171
+ last_non_masked_index = mask.nonzero()[-1]
1172
+
1173
+ # The reward is initially initialized with -KL penalty. Then we add the score given by the reward model only to the last generated token of the response
1174
+ # Basically we are penalizing the reward given by the reward model by the KL penalty (how much the response differs from the frozen model)
1175
+ # shape: (Seq_Len)
1176
+ reward[last_non_masked_index] += score
1177
+ rewards.append(reward)
1178
+ return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)
1179
+
1180
+ def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor:
1181
+ if self.config.kl_penalty == "kl":
1182
+ return logprob - ref_logprob
1183
+
1184
+ if self.config.kl_penalty == "abs":
1185
+ return (logprob - ref_logprob).abs()
1186
+
1187
+ if self.config.kl_penalty == "mse":
1188
+ return 0.5 * (logprob - ref_logprob).square()
1189
+
1190
+ if self.config.kl_penalty == "full":
1191
+ # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
1192
+ return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1)
1193
+
1194
+ raise NotImplementedError
1195
+
1196
+ def compute_advantages(
1197
+ self,
1198
+ values: torch.FloatTensor,
1199
+ rewards: torch.FloatTensor,
1200
+ mask: torch.FloatTensor,
1201
+ ):
1202
+ lastgaelam = 0
1203
+ advantages_reversed = []
1204
+ gen_len = rewards.shape[-1]
1205
+
1206
+ values = values * mask
1207
+ rewards = rewards * mask
1208
+
1209
+ # if self.config.whiten_rewards:
1210
+ # rewards = masked_whiten(rewards, mask, shift_mean=False)
1211
+
1212
+ for t in reversed(range(gen_len)):
1213
+ nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 # Value function evaluated at time (t+1)
1214
+ delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t] # From the formula of GAE: delta_t = r_t + gamma * V(s_{t+1}) - V(s_t)
1215
+ lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam # Save the GAE for the next iteration
1216
+ advantages_reversed.append(lastgaelam)
1217
+ advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) # Reverse the advantages and stack them
1218
+
1219
+ returns = advantages + values # Since Advantage = Q - V, we can calculate Q = Advantage + V. The Q values are necessary for training the value function estimation.
1220
+ advantages = masked_whiten(advantages, mask)
1221
+ advantages = advantages.detach()
1222
+ return values, advantages, returns
1223
+
1224
+ def loss(
1225
+ self,
1226
+ old_logprobs: torch.FloatTensor, # log probabilities under the OLD policy (offline)
1227
+ values: torch.FloatTensor, # values under the OLD policy (offline)
1228
+ logits: torch.FloatTensor, # logits under the NEW policy (online)
1229
+ vpreds: torch.FloatTensor, # values under the NEW policy (online)
1230
+ logprobs: torch.FloatTensor, # log probabilities under the NEW policy (online)
1231
+ mask: torch.LongTensor, # which tokens the log probabilities correspond to
1232
+ advantages: torch.FloatTensor, # advantages calculated using the OLD policy (offline)
1233
+ returns: torch.FloatTensor, # state-actions (Q-values) calculated using the OLD policy (offline)
1234
+ ):
1235
+ """
1236
+ Calculate policy and value losses.
1237
+
1238
+ Args:
1239
+ old_logprobs (`torch.FloatTensor`):
1240
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
1241
+ values (`torch.FloatTensor`):
1242
+ Values of the value head, shape (`batch_size`, `response_length`)
1243
+ rewards (`torch.FloatTensor`):
1244
+ Rewards from the reward model, shape (`batch_size`, `response_length`)
1245
+ logits (`torch.FloatTensor`):
1246
+ Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`)
1247
+ v_pred (`torch.FloatTensor`):
1248
+ Values of the value head, shape (`batch_size`, `response_length`)
1249
+ logprobs (`torch.FloatTensor`):
1250
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
1251
+ """
1252
+
1253
+ vpredclipped = clip_by_value(
1254
+ vpreds,
1255
+ values - self.config.cliprange_value,
1256
+ values + self.config.cliprange_value,
1257
+ )
1258
+
1259
+ # Loss for the value head
1260
+ vf_losses1 = (vpreds - returns) ** 2 # This is the loss according to the formula in the slides. (V(s) - Q(s, a))^2
1261
+ vf_losses2 = (vpredclipped - returns) ** 2
1262
+ vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
1263
+ vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
1264
+
1265
+ # Ratio between the log probability of the new policy and the old policy
1266
+ ratio = torch.exp(logprobs - old_logprobs)
1267
+
1268
+ # The "minus" sign is because we want to maximize the objective function, but the optimizer minimizes the loss
1269
+ pg_losses = -advantages * ratio # as per formula, ratio of the log probs multiplied by the advantage
1270
+ pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
1271
+
1272
+ # "max" instead of "min" because we want to maximize the objective function, but the optimizer minimizes the loss
1273
+ pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask) # policy gradient loss
1274
+ pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
1275
+
1276
+ loss = pg_loss + self.config.vf_coef * vf_loss
1277
+
1278
+ avg_ratio = masked_mean(ratio, mask).item()
1279
+ if avg_ratio > self.config.ratio_threshold:
1280
+ warnings.warn(
1281
+ f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch."
1282
+ )
1283
+ pg_loss = pg_loss * 0.0
1284
+ vf_loss = vf_loss * 0.0
1285
+ loss = loss * 0.0
1286
+ # The entropy to force the model to explore
1287
+ entropy = masked_mean(entropy_from_logits(logits), mask)
1288
+
1289
+ approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
1290
+ policykl = masked_mean(old_logprobs - logprobs, mask)
1291
+
1292
+ return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
1293
+ value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)
1294
+
1295
+ stats = dict(
1296
+ loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()),
1297
+ policy=dict(
1298
+ entropy=entropy.detach(),
1299
+ approxkl=approxkl.detach(),
1300
+ policykl=policykl.detach(),
1301
+ clipfrac=pg_clipfrac.detach(),
1302
+ advantages=advantages.detach(),
1303
+ advantages_mean=masked_mean(advantages, mask).detach(),
1304
+ ratio=ratio.detach(),
1305
+ ),
1306
+ returns=dict(mean=return_mean.detach(), var=return_var.detach()),
1307
+ val=dict(
1308
+ vpred=masked_mean(vpreds, mask).detach(),
1309
+ error=masked_mean((vpreds - returns) ** 2, mask).detach(),
1310
+ clipfrac=vf_clipfrac.detach(),
1311
+ mean=value_mean.detach(),
1312
+ var=value_var.detach(),
1313
+ ),
1314
+ )
1315
+ return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)
1316
+
1317
+ def record_step_stats(self, kl_coef: float, **data):
1318
+ """
1319
+ Record training step statistics.
1320
+
1321
+
1322
+ Args:
1323
+ kl_coef (`float`):
1324
+ KL coefficient
1325
+ data (`dict`):
1326
+ Dictionary of training step data
1327
+
1328
+ Returns:
1329
+ stats (`dict`):
1330
+ Dictionary of training step statistics
1331
+ """
1332
+ mask = data.pop("masks")
1333
+
1334
+ kls = data.pop("kls")
1335
+ kl_list = ((kls) * mask).sum(axis=-1)
1336
+ mean_kl = kl_list.mean()
1337
+ mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean()
1338
+
1339
+ mean_non_score_reward = masked_mean(
1340
+ data["non_score_reward"], mask
1341
+ ) # non_score_reward is size `batch_size`, `response_length`
1342
+ mean_scores = data["scores"].mean() # scores is size `batch_size`
1343
+ std_scores = data["scores"].std()
1344
+
1345
+ if mean_kl.item() < -1.0:
1346
+ # warn users
1347
+ warnings.warn(
1348
+ f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
1349
+ " sometimes this happens because the generation kwargs are not correctly set. Please make sure"
1350
+ " that the generation kwargs are set correctly, or review your training hyperparameters."
1351
+ )
1352
+
1353
+ stats = {
1354
+ "objective/kl": mean_kl,
1355
+ "objective/kl_dist": kl_list,
1356
+ "objective/logprobs": data["logprobs"],
1357
+ "objective/ref_logprobs": data["ref_logprobs"],
1358
+ "objective/kl_coef": kl_coef,
1359
+ "objective/entropy": mean_entropy,
1360
+ "ppo/mean_non_score_reward": mean_non_score_reward,
1361
+ "ppo/mean_scores": mean_scores,
1362
+ "ppo/std_scores": std_scores,
1363
+ }
1364
+
1365
+ # Log text properties
1366
+ query_lens = torch.tensor([len(query) for query in data["queries"]], dtype=torch.float)
1367
+ response_lens = torch.tensor([len(response) for response in data["responses"]], dtype=torch.float)
1368
+
1369
+ stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item()
1370
+ stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item()
1371
+ stats["tokens/queries_dist"] = query_lens.cpu().numpy()
1372
+ stats["tokens/responses_len_mean"] = torch.mean(response_lens).cpu().numpy().item()
1373
+ stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item()
1374
+ stats["tokens/responses_dist"] = response_lens.cpu().numpy()
1375
+
1376
+ for k, v in data["train_stats"].items():
1377
+ stats[f"ppo/{k}"] = torch.mean(v, axis=0)
1378
+ stats["ppo/val/var_explained"] = 1 - stats["ppo/val/error"] / stats["ppo/returns/var"]
1379
+ return stats
1380
+
1381
+ def log_stats(
1382
+ self,
1383
+ stats: dict,
1384
+ batch: dict,
1385
+ rewards: List[torch.FloatTensor],
1386
+ columns_to_log: typing.Iterable[str] = ("query", "response"),
1387
+ ):
1388
+ """
1389
+ A function that logs all the training stats. Call it at the end of each epoch.
1390
+
1391
+ Args:
1392
+ stats (dict[str, Any]):
1393
+ A dictionary of training stats.
1394
+ batch (dict[str, Any]):
1395
+ A dictionary of batch data, this contains the queries and responses.
1396
+ rewards (`List[torch.FloatTensor]`):
1397
+ A tensor of rewards.
1398
+ """
1399
+
1400
+ # all gather stats
1401
+ if not isinstance(rewards, torch.Tensor):
1402
+ rewards = torch.tensor(rewards).to(self.current_device)
1403
+ rewards = self.accelerator.gather(rewards).flatten()
1404
+
1405
+ if self.config.log_with == "wandb":
1406
+ import wandb
1407
+
1408
+ if any(column_to_log not in batch.keys() for column_to_log in columns_to_log):
1409
+ raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.")
1410
+
1411
+ batch_list = [batch[column_to_log] for column_to_log in columns_to_log]
1412
+ if self.is_distributed:
1413
+ gathered_batch_list = []
1414
+ for b in batch_list:
1415
+ flattened = gather_object(b)
1416
+ gathered_batch_list.append(flattened)
1417
+ batch_list = gathered_batch_list
1418
+
1419
+ # Log only if we are in the main process
1420
+ if self.accelerator.is_main_process:
1421
+ logs = {}
1422
+
1423
+ # Log stats
1424
+ if "query" not in batch.keys() and "response" not in batch.keys():
1425
+ # warn the user that the game logs will not be logged
1426
+ warnings.warn(
1427
+ "The game logs will not be logged because the batch does not contain the keys 'query' and "
1428
+ "'response'. "
1429
+ )
1430
+ elif self.config.log_with == "wandb":
1431
+ table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())]
1432
+ logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)})
1433
+
1434
+ logs.update(stats)
1435
+
1436
+ # manually cast in fp32 for bf16 torch tensors
1437
+ for k, v in logs.items():
1438
+ if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16:
1439
+ logs[k] = v.float()
1440
+
1441
+ logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
1442
+ logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
1443
+ logs["env/reward_dist"] = rewards.cpu().numpy()
1444
+
1445
+ if self.config.log_with == "tensorboard":
1446
+ # update the current step
1447
+ self.current_step += 1
1448
+
1449
+ self.accelerator.log(
1450
+ logs,
1451
+ step=self.current_step if self.config.log_with == "tensorboard" else None,
1452
+ )
1453
+
1454
+ def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None:
1455
+ """Creates and saves a model card for a TRL model.
1456
+
1457
+ Args:
1458
+ path (`str`): The path to save the model card to.
1459
+ model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`.
1460
+ """
1461
+ try:
1462
+ user = whoami()["name"]
1463
+ # handle the offline case
1464
+ except Exception:
1465
+ warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
1466
+ return
1467
+
1468
+ if not os.path.exists(path):
1469
+ os.makedirs(path)
1470
+
1471
+ model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
1472
+ with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
1473
+ f.write(model_card_content)
1474
+
1475
+ def _save_pretrained(self, save_directory: str) -> None:
1476
+ self.accelerator.unwrap_model(self.model).save_pretrained(save_directory)
1477
+ self.tokenizer.save_pretrained(save_directory)
1478
+ self.create_model_card(save_directory)
1479
+
1480
+ def _show_tokens(self, tokens, masks):
1481
+ from rich import print
1482
+ from rich.text import Text
1483
+
1484
+ text = Text()
1485
+
1486
+ for _i, (token, mask) in enumerate(zip(tokens, masks)):
1487
+ if mask == 1:
1488
+ text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1")
1489
+ text.append(" ")
1490
+ else:
1491
+ text.append(self.tokenizer.decode(token.item()), style="black on cyan3")
1492
+ text.append(" ")
1493
+ print(text)
1494
+
1495
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
1496
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
1497
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
1498
+ config_kwargs = deepspeed_plugin.deepspeed_config
1499
+ if model is not None:
1500
+ if hasattr(model, "config"):
1501
+ hidden_size = (
1502
+ max(model.config.hidden_sizes)
1503
+ if getattr(model.config, "hidden_sizes", None)
1504
+ else getattr(model.config, "hidden_size", None)
1505
+ )
1506
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
1507
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
1508
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
1509
+ config_kwargs.update(
1510
+ {
1511
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
1512
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
1513
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
1514
+ }
1515
+ )
1516
+
1517
+ # If ZeRO-3 is used, we shard both the active and reference model.
1518
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
1519
+ if config_kwargs["zero_optimization"]["stage"] != 3:
1520
+ config_kwargs["zero_optimization"]["stage"] = 0
1521
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
1522
+ model.eval()
1523
+ return model
ppo/ppo_trainer_original.py ADDED
@@ -0,0 +1,1455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ import math
16
+ import os
17
+ import time
18
+ import typing
19
+ import warnings
20
+ from contextlib import nullcontext
21
+ from typing import Callable, List, Optional, Union
22
+
23
+ import datasets
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from accelerate import Accelerator
28
+ from accelerate.utils import ProjectConfiguration, gather_object, is_deepspeed_available
29
+ from datasets import Dataset
30
+ from huggingface_hub import whoami
31
+ from packaging import version
32
+ from torch.optim import Adam
33
+ from transformers import (
34
+ DataCollatorForLanguageModeling,
35
+ PreTrainedTokenizer,
36
+ PreTrainedTokenizerBase,
37
+ PreTrainedTokenizerFast,
38
+ )
39
+
40
+ from ..core import (
41
+ WANDB_PADDING,
42
+ PPODecorators,
43
+ clip_by_value,
44
+ convert_to_scalar,
45
+ entropy_from_logits,
46
+ flatten_dict,
47
+ logprobs_from_logits,
48
+ masked_mean,
49
+ masked_var,
50
+ masked_whiten,
51
+ set_seed,
52
+ stack_dicts,
53
+ stats_to_np,
54
+ )
55
+ from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
56
+ from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
57
+ from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
58
+
59
+
60
+ if is_deepspeed_available():
61
+ import deepspeed
62
+
63
+ MODEL_CARD_TEMPLATE = """---
64
+ license: apache-2.0
65
+ tags:
66
+ - trl
67
+ - ppo
68
+ - transformers
69
+ - reinforcement-learning
70
+ ---
71
+
72
+ # {model_name}
73
+
74
+ This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to
75
+ guide the model outputs according to a value, function, or human feedback. The model can be used for text generation.
76
+
77
+ ## Usage
78
+
79
+ To use this model for inference, first install the TRL library:
80
+
81
+ ```bash
82
+ python -m pip install trl
83
+ ```
84
+
85
+ You can then generate text as follows:
86
+
87
+ ```python
88
+ from transformers import pipeline
89
+
90
+ generator = pipeline("text-generation", model="{model_id}")
91
+ outputs = generator("Hello, my llama is cute")
92
+ ```
93
+
94
+ If you want to use the model for training or to obtain the outputs from the value head, load the model as follows:
95
+
96
+ ```python
97
+ from transformers import AutoTokenizer
98
+ from trl import AutoModelForCausalLMWithValueHead
99
+
100
+ tokenizer = AutoTokenizer.from_pretrained("{model_id}")
101
+ model = AutoModelForCausalLMWithValueHead.from_pretrained("{model_id}")
102
+
103
+ inputs = tokenizer("Hello, my llama is cute", return_tensors="pt")
104
+ outputs = model(**inputs, labels=inputs["input_ids"])
105
+ ```
106
+ """
107
+
108
+
109
+ class PPOTrainer(BaseTrainer):
110
+ """
111
+ The PPOTrainer uses Proximal Policy Optimization to optimise language models.
112
+ Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here:
113
+ https://github.com/openai/summarize-from-feedback
114
+
115
+ Attributes:
116
+ **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more
117
+ details.
118
+ **model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head.
119
+ Check the documentation of `PreTrainedModelWrapper` for more details.
120
+ **ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face
121
+ transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper`
122
+ for more details. If no reference model is provided, the trainer will create a reference model with the same
123
+ architecture as the model to be optimized with shared layers.
124
+ **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the
125
+ data. Check the documentation of `transformers.PreTrainedTokenizer` and
126
+ `transformers.PreTrainedTokenizerFast` for more details.
127
+ **dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging
128
+ Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be
129
+ created outside the trainer users needs to design their own dataloader and make sure the batch
130
+ size that is used is the same as the one specified in the configuration object.
131
+ **optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is
132
+ provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration
133
+ object.
134
+ **data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and
135
+ passed along the dataloader
136
+ **num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference
137
+ model, if no reference model is passed. If no number is provided, all the layers will be shared.
138
+ **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
139
+ """
140
+
141
+ _tag_names = ["trl", "ppo"]
142
+
143
+ def __init__(
144
+ self,
145
+ config: Optional[PPOConfig] = None,
146
+ model: Optional[PreTrainedModelWrapper] = None,
147
+ ref_model: Optional[PreTrainedModelWrapper] = None,
148
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
149
+ dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
150
+ optimizer: Optional[torch.optim.Optimizer] = None,
151
+ data_collator: Optional[typing.Callable] = None,
152
+ num_shared_layers: Optional[int] = None,
153
+ lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
154
+ ):
155
+ """
156
+ Initialize PPOTrainer.
157
+
158
+ Args:
159
+ config (`PPOConfig`):
160
+ Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details.
161
+ model (`PreTrainedModelWrapper`):
162
+ Hugging Face transformer model with a value head.
163
+ ref_model (`PreTrainedModelWrapper`):
164
+ Hugging Face transformer model with a casual language modelling head. Used for KL penalty
165
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
166
+ Hugging Face tokenizer
167
+ dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]):
168
+ PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
169
+ will be preprocessed by removing the columns that are not used by the model. If none is passed,
170
+ a warning will be raised in a multi-GPU setting.
171
+ optimizer (Optional[`torch.optim.Optimizer`]):
172
+ Optimizer used for training. If `None`, the `Adam` is used as default.
173
+ data_collator (Optional[function]):
174
+ Data collator function.
175
+ num_shared_layers (Optional[int]):
176
+ Number of shared layers between the model and the reference model. If `None`, all layers are shared.
177
+ used only if `ref_model` is `None`.
178
+ lr_scheduler (Optional[`torch.optim.lr_scheduler`]):
179
+ Learning rate scheduler used for training.
180
+ """
181
+ super().__init__(config)
182
+
183
+ # initial seed for reproducible experiments
184
+ set_seed(config.seed)
185
+
186
+ # Step 0: check positional arguments validity
187
+ if not isinstance(config, PPOConfig):
188
+ raise ValueError(f"config must be a PPOConfig, got {type(config)}")
189
+ if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
190
+ raise ValueError(
191
+ f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}"
192
+ )
193
+ if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
194
+ raise ValueError(
195
+ f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
196
+ )
197
+ # Step 1: Initialize Accelerator
198
+ self.accelerator = Accelerator(
199
+ log_with=config.log_with,
200
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
201
+ project_config=ProjectConfiguration(**config.project_kwargs),
202
+ **config.accelerator_kwargs,
203
+ )
204
+
205
+ # Step 1.1 Runtime variables filled by the accelerator
206
+ config.world_size = self.accelerator.num_processes
207
+ config.global_backward_batch_size = config.backward_batch_size * config.world_size
208
+ config.global_batch_size = config.batch_size * config.world_size
209
+
210
+ self.model = model
211
+ self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
212
+ self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
213
+ self.is_peft_model = getattr(self.model, "is_peft_model", False)
214
+ config.is_encoder_decoder = self.is_encoder_decoder
215
+ config.is_peft_model = self.is_peft_model
216
+
217
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
218
+ self.accelerator.init_trackers(
219
+ config.tracker_project_name,
220
+ config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
221
+ init_kwargs=config.tracker_kwargs,
222
+ )
223
+ self.is_using_text_environment = getattr(config, "use_text_environment", False)
224
+
225
+ if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
226
+ self.ref_model = ref_model
227
+ if num_shared_layers is not None:
228
+ warnings.warn(
229
+ "num_shared_layers is ignored when ref_model is provided. Two different models are used for the "
230
+ "model and the reference model and no layers are shared.",
231
+ UserWarning,
232
+ )
233
+ elif ref_model is None and not self.is_peft_model:
234
+ self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers)
235
+ elif self.is_peft_model:
236
+ self.ref_model = None
237
+ else:
238
+ raise ValueError(
239
+ f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported "
240
+ f"architectures are: {SUPPORTED_ARCHITECTURES} "
241
+ )
242
+ self.optional_peft_ctx = (
243
+ self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter
244
+ if self.is_peft_model
245
+ else nullcontext
246
+ )
247
+
248
+ if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)):
249
+ raise ValueError(
250
+ "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast"
251
+ )
252
+ self.tokenizer = tokenizer
253
+
254
+ if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)):
255
+ raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset")
256
+ elif dataset is None:
257
+ warnings.warn(
258
+ "No dataset is provided. Make sure to set config.batch_size to the correct value before training.",
259
+ UserWarning,
260
+ )
261
+ self.dataset = dataset
262
+ self._signature_columns = None
263
+ if self.dataset is not None:
264
+ self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
265
+ elif self.dataset is None and self.accelerator.num_processes > 1:
266
+ warnings.warn(
267
+ "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should"
268
+ " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`"
269
+ " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please "
270
+ " refer to the documentation for more details.",
271
+ UserWarning,
272
+ )
273
+ self.dataloader = None
274
+ else:
275
+ self.dataloader = None
276
+
277
+ # Step 3: Initialize optimizer and data collator
278
+ self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
279
+ if optimizer is None:
280
+ self.optimizer = Adam(
281
+ filter(lambda p: p.requires_grad, self.model.parameters()),
282
+ lr=self.config.learning_rate,
283
+ )
284
+ else:
285
+ self.optimizer = optimizer
286
+
287
+ self.lr_scheduler = lr_scheduler
288
+ if self.lr_scheduler is not None:
289
+ lr_scheduler_class = (
290
+ torch.optim.lr_scheduler._LRScheduler
291
+ if not is_torch_greater_2_0()
292
+ else torch.optim.lr_scheduler.LRScheduler
293
+ )
294
+
295
+ if not isinstance(self.lr_scheduler, lr_scheduler_class):
296
+ raise ValueError(
297
+ "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)"
298
+ )
299
+
300
+ if self.config.adap_kl_ctrl:
301
+ self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
302
+ else:
303
+ self.kl_ctl = FixedKLController(self.config.init_kl_coef)
304
+
305
+ # Safety checkers for DS integration
306
+ is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
307
+ self.accelerator.state, "deepspeed_plugin"
308
+ )
309
+
310
+ (
311
+ self.model,
312
+ self.optimizer,
313
+ self.data_collator,
314
+ self.dataloader,
315
+ self.lr_scheduler,
316
+ ) = self.accelerator.prepare(
317
+ self.model,
318
+ self.optimizer,
319
+ self.data_collator,
320
+ self.dataloader,
321
+ self.lr_scheduler,
322
+ )
323
+ if is_deepspeed_used:
324
+ # Quantized models are already set on the correct device
325
+ if not self.is_peft_model and not (
326
+ getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False)
327
+ or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False)
328
+ ):
329
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
330
+ else:
331
+ self.ref_model = self.accelerator.prepare(self.ref_model)
332
+
333
+ # In a distributed setup, only logging needs to be performed on the main process
334
+ # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
335
+ # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11
336
+ self.is_distributed = self.accelerator.num_processes > 1
337
+
338
+ # init the current step
339
+ self.current_step = 0
340
+
341
+ # init variables for pushing model to hub
342
+ if config.push_to_hub_if_best_kwargs:
343
+ if "repo_id" not in config.push_to_hub_if_best_kwargs:
344
+ raise ValueError("You have to specify repo_id in order to push the model to the hub!")
345
+ self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
346
+ self.compare_step = 0
347
+ self.highest_reward = torch.tensor(-float("inf"))
348
+
349
+ # post process for PP
350
+ if not getattr(self.model, "is_sequential_parallel", False):
351
+ self.current_device = self.accelerator.device
352
+ else:
353
+ if is_xpu_available():
354
+ self.current_device = torch.device("xpu:0")
355
+ elif is_npu_available():
356
+ self.current_device = torch.device("npu:0")
357
+ else:
358
+ self.current_device = torch.device("cuda:0")
359
+
360
+ PPODecorators.optimize_device_cache = self.config.optimize_device_cache
361
+
362
+ self.running = RunningMoments(self.accelerator)
363
+
364
+ def _filter_kwargs(self, kwargs, target_func):
365
+ """
366
+ filter the keyword arguments that are supported by the target function.
367
+
368
+ Args:
369
+ kwargs (dict):
370
+ Keyword arguments
371
+ target_func (function):
372
+ Target function
373
+ """
374
+ return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()}
375
+
376
+ def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None):
377
+ """
378
+ Prepare the dataloader for training.
379
+
380
+ Args:
381
+ dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]):
382
+ PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
383
+ will be preprocessed by removing the columns that are not used by the model.
384
+ data_collator (Optional[function]):
385
+ Data collator function.
386
+
387
+ Returns:
388
+ `torch.utils.data.DataLoader`: PyTorch dataloader
389
+ """
390
+ if isinstance(dataset, Dataset):
391
+ dataset = self._remove_unused_columns(dataset)
392
+ dataloader = torch.utils.data.DataLoader(
393
+ dataset,
394
+ batch_size=self.config.batch_size,
395
+ collate_fn=data_collator,
396
+ shuffle=True,
397
+ drop_last=True,
398
+ )
399
+ return dataloader
400
+
401
+ # Adapted from transformers.Trainer._set_signature_columns_if_needed
402
+ def _set_signature_columns_if_needed(self):
403
+ if self._signature_columns is None:
404
+ # Inspect model forward signature to keep only the arguments it accepts.
405
+ signature = inspect.signature(self.model.forward)
406
+ self._signature_columns = list(signature.parameters.keys())
407
+ # label => sentiment | we need query and response for logging purpose
408
+ self._signature_columns += ["label", "query", "response"]
409
+
410
+ # Adapted from transformers.Trainer._remove_unused_columns
411
+ def _remove_unused_columns(self, dataset: "Dataset"):
412
+ if not self.config.remove_unused_columns:
413
+ return dataset
414
+ self._set_signature_columns_if_needed()
415
+ signature_columns = self._signature_columns
416
+
417
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
418
+
419
+ columns = [k for k in signature_columns if k in dataset.column_names]
420
+
421
+ if version.parse(datasets.__version__) < version.parse("1.4.0"):
422
+ dataset.set_format(
423
+ type=dataset.format["type"],
424
+ columns=columns,
425
+ format_kwargs=dataset.format["format_kwargs"],
426
+ )
427
+ return dataset
428
+ else:
429
+ return dataset.remove_columns(ignored_columns)
430
+
431
+ def generate(
432
+ self,
433
+ query_tensor: Union[torch.Tensor, List[torch.Tensor]],
434
+ length_sampler: Optional[Callable] = None,
435
+ batch_size: int = 4,
436
+ return_prompt: bool = True,
437
+ generate_ref_response: bool = False,
438
+ **generation_kwargs,
439
+ ):
440
+ """
441
+ Generate response with the model given the query tensor.
442
+ call the `generate` method of the model.
443
+
444
+ Args:
445
+ query_tensor (`torch.LongTensor`):
446
+ A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`).
447
+ length_sampler (`Callable`, *optional*):
448
+ Callable that returns the number of newly generated tokens.
449
+ batch_size (`int`, *optional):
450
+ Batch size used for generation, defaults to `4`.
451
+ return_prompt (`bool`, *optional*):
452
+ If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`.
453
+ generate_ref_response (`bool`, *optional*):
454
+ If set to `True` the reference response is also generated, defaults to `False`.
455
+ generation_kwargs (dict[str, Any]):
456
+ Keyword arguments for generation.
457
+
458
+ Returns:
459
+ `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens.
460
+ """
461
+ if generate_ref_response:
462
+ ref_model = self.model if self.is_peft_model else self.ref_model
463
+ if isinstance(query_tensor, List):
464
+ response = self._generate_batched(
465
+ self.model,
466
+ query_tensor,
467
+ length_sampler=length_sampler,
468
+ batch_size=batch_size,
469
+ return_prompt=return_prompt,
470
+ **generation_kwargs,
471
+ )
472
+ if generate_ref_response:
473
+ with self.optional_peft_ctx():
474
+ ref_response = self._generate_batched(
475
+ ref_model,
476
+ query_tensor,
477
+ length_sampler=length_sampler,
478
+ batch_size=batch_size,
479
+ return_prompt=return_prompt,
480
+ **generation_kwargs,
481
+ )
482
+
483
+ else:
484
+ if len(query_tensor.shape) == 2:
485
+ raise ValueError(
486
+ "query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)"
487
+ )
488
+
489
+ if length_sampler is not None:
490
+ generation_kwargs["max_new_tokens"] = length_sampler()
491
+ response = self.accelerator.unwrap_model(self.model).generate(
492
+ input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
493
+ )
494
+ if generate_ref_response:
495
+ with self.optional_peft_ctx():
496
+ ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
497
+
498
+ if not return_prompt and not self.is_encoder_decoder:
499
+ response = response[:, query_tensor.shape[0] :]
500
+ if generate_ref_response:
501
+ ref_response = ref_response[:, query_tensor.shape[0] :]
502
+
503
+ if generate_ref_response:
504
+ return response, ref_response
505
+ return response
506
+
507
+ def _generate_batched(
508
+ self,
509
+ model: PreTrainedModelWrapper,
510
+ query_tensors: List[torch.Tensor],
511
+ length_sampler: Optional[Callable] = None,
512
+ batch_size: int = 4,
513
+ return_prompt: bool = True,
514
+ pad_to_multiple_of: Optional[int] = None,
515
+ remove_padding: bool = True,
516
+ **generation_kwargs,
517
+ ):
518
+ outputs = []
519
+
520
+ padding_side_default = self.tokenizer.padding_side
521
+ if not self.is_encoder_decoder:
522
+ self.tokenizer.padding_side = "left"
523
+
524
+ # in case we have fewer examples than bs
525
+ batch_size = min(len(query_tensors), batch_size)
526
+
527
+ for i in range(0, len(query_tensors), batch_size):
528
+ if length_sampler is not None:
529
+ generation_kwargs["max_new_tokens"] = length_sampler()
530
+
531
+ # prevent overflow if query tensors are not even multiple of bs
532
+ end_index = min(len(query_tensors), i + batch_size)
533
+
534
+ batch = query_tensors[i:end_index]
535
+ batch_mask = [torch.ones_like(element) for element in batch]
536
+ inputs = {"input_ids": batch, "attention_mask": batch_mask}
537
+
538
+ padded_inputs = self.tokenizer.pad(
539
+ inputs,
540
+ padding=True,
541
+ max_length=None,
542
+ pad_to_multiple_of=pad_to_multiple_of,
543
+ return_tensors="pt",
544
+ ).to(self.current_device)
545
+
546
+ generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs)
547
+
548
+ for generation, mask in zip(generations, padded_inputs["attention_mask"]):
549
+ if not self.is_encoder_decoder:
550
+ output = generation[(1 - mask).sum() :] # remove padding
551
+ else:
552
+ output = generation
553
+
554
+ if not return_prompt and not self.is_encoder_decoder:
555
+ output = output[(mask).sum() :] # remove prompt
556
+
557
+ if remove_padding and self.tokenizer.eos_token_id in output:
558
+ pad_mask = output == self.tokenizer.eos_token_id
559
+ pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()
560
+ output = output[: pad_start + 1] # keep the eos token at the end
561
+
562
+ outputs.append(output)
563
+
564
+ self.tokenizer.padding_side = padding_side_default
565
+ return outputs
566
+
567
+ def _step_safety_checker(
568
+ self,
569
+ batch_size: int,
570
+ queries: List[torch.LongTensor],
571
+ responses: List[torch.LongTensor],
572
+ scores: List[torch.FloatTensor],
573
+ masks: Optional[List[torch.LongTensor]] = None,
574
+ ):
575
+ """
576
+ Check if the input data is valid for training.
577
+
578
+ Args:
579
+ batch_size (int):
580
+ Batch size from the config file.
581
+ queries (List[`torch.LongTensor`]):
582
+ List of tensors containing the encoded queries of shape (`query_length`)
583
+ responses (List[`torch.LongTensor`]):
584
+ List of tensors containing the encoded responses of shape (`response_length`)
585
+ scores (List[`torch.FloatTensor`]):
586
+ List of tensors containing the scores.
587
+ masks (List[`torch.LongTensor`], *optional*):
588
+ list of optional tensors containing the masks of shape (`query_length` + `response_length`)
589
+ Returns:
590
+ `tuple`: The input processed data.
591
+ """
592
+ for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]):
593
+ if not isinstance(tensor_list, list):
594
+ raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
595
+ if not isinstance(tensor_list[0], torch.Tensor):
596
+ raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
597
+ if batch_size is not None and len(tensor_list) != batch_size:
598
+ raise ValueError(
599
+ f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}"
600
+ )
601
+
602
+ # add queries, scores and responses on the correct device
603
+ queries = [tensor.to(self.current_device) for tensor in queries]
604
+ responses = [tensor.to(self.current_device) for tensor in responses]
605
+ scores = [tensor.to(self.current_device) for tensor in scores]
606
+ masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None
607
+
608
+ # squeeze scores if needed
609
+ for i, score in enumerate(scores):
610
+ if score.dim() > 1:
611
+ raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}")
612
+ elif score.dim() == 1:
613
+ scores[i] = score.squeeze()
614
+
615
+ return queries, responses, scores, masks
616
+
617
+ @PPODecorators.empty_device_cache()
618
+ def step(
619
+ self,
620
+ queries: List[torch.LongTensor],
621
+ responses: List[torch.LongTensor],
622
+ scores: List[torch.FloatTensor],
623
+ response_masks: Optional[List[torch.LongTensor]] = None,
624
+ ):
625
+ """
626
+ Run a PPO optimisation step given a list of queries, model responses, and rewards.
627
+
628
+ Args:
629
+ queries (List[`torch.LongTensor`]):
630
+ List of tensors containing the encoded queries of shape (`query_length`)
631
+ responses (List[`torch.LongTensor`]):
632
+ List of tensors containing the encoded responses of shape (`response_length`)
633
+ scores (List[`torch.FloatTensor`]):
634
+ List of tensors containing the scores.
635
+ response_masks (List[`torch.FloatTensor`], *optional*)):
636
+ List of tensors containing masks of the response tokens.
637
+
638
+ Returns:
639
+ `dict[str, Any]`: A summary of the training statistics
640
+ """
641
+ bs = self.config.batch_size
642
+
643
+ queries, responses, scores, response_masks = self._step_safety_checker(
644
+ bs, queries, responses, scores, response_masks
645
+ )
646
+ scores = torch.tensor(scores, device=self.current_device)
647
+ if self.config.use_score_scaling:
648
+ # Score scaling
649
+ scores_mean, scores_std = self.running.update(scores)
650
+ tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device)
651
+ score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps
652
+ if self.config.use_score_norm:
653
+ scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor
654
+ else:
655
+ scores /= score_scaling_factor
656
+
657
+ if self.config.score_clip is not None:
658
+ # Score clipping
659
+ scores_dtype = scores.dtype
660
+ scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
661
+
662
+ # if we want to push best model to the hub
663
+ if hasattr(self, "highest_reward"):
664
+ if self.compare_step % self.config.compare_steps == 0:
665
+ curr_mean_reward = scores.mean()
666
+ # if the best reward ever seen
667
+ if curr_mean_reward > self.highest_reward:
668
+ self.highest_reward = curr_mean_reward
669
+ # push model to hub
670
+ self.push_to_hub(**self.push_to_hub_kwargs)
671
+ self.compare_step += 1
672
+
673
+ timing = dict()
674
+ t0 = time.time()
675
+
676
+ t = time.time()
677
+
678
+ model_inputs = self.prepare_model_inputs(queries, responses)
679
+
680
+ if self.is_distributed:
681
+ pad_first = self.tokenizer.padding_side == "left"
682
+
683
+ model_inputs["input_ids"] = self.accelerator.pad_across_processes(
684
+ model_inputs["input_ids"],
685
+ dim=1,
686
+ pad_index=self.tokenizer.pad_token_id,
687
+ pad_first=pad_first,
688
+ )
689
+ model_inputs["attention_mask"] = self.accelerator.pad_across_processes(
690
+ model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first
691
+ )
692
+ if self.is_encoder_decoder:
693
+ model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes(
694
+ model_inputs["decoder_input_ids"],
695
+ dim=1,
696
+ pad_index=self.tokenizer.pad_token_id,
697
+ pad_first=pad_first,
698
+ )
699
+ model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
700
+ model_inputs["decoder_attention_mask"],
701
+ dim=1,
702
+ pad_index=0,
703
+ pad_first=pad_first,
704
+ )
705
+
706
+ model_inputs_names = list(model_inputs.keys())
707
+
708
+ full_kl_penalty = self.config.kl_penalty == "full"
709
+
710
+ with torch.no_grad():
711
+ all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
712
+ self.model,
713
+ queries,
714
+ responses,
715
+ model_inputs,
716
+ response_masks=response_masks,
717
+ return_logits=full_kl_penalty,
718
+ )
719
+ with self.optional_peft_ctx():
720
+ ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
721
+ self.model if self.is_peft_model else self.ref_model,
722
+ queries,
723
+ responses,
724
+ model_inputs,
725
+ return_logits=full_kl_penalty,
726
+ )
727
+
728
+ timing["time/ppo/forward_pass"] = time.time() - t
729
+
730
+ with torch.no_grad():
731
+ t = time.time()
732
+ if full_kl_penalty:
733
+ active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
734
+ ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)
735
+
736
+ rewards, non_score_reward, kls = self.compute_rewards(
737
+ scores, active_full_logprobs, ref_full_logprobs, masks
738
+ )
739
+ else:
740
+ rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
741
+ timing["time/ppo/compute_rewards"] = time.time() - t
742
+
743
+ t = time.time()
744
+ values, advantages, returns = self.compute_advantages(values, rewards, masks)
745
+ timing["time/ppo/compute_advantages"] = time.time() - t
746
+
747
+ # upcast to float32 to avoid dataset issues
748
+ batch_dict = {
749
+ "queries": queries,
750
+ "responses": responses,
751
+ "logprobs": all_logprobs.to(torch.float32),
752
+ "values": values.to(torch.float32),
753
+ "masks": masks,
754
+ "advantages": advantages,
755
+ "returns": returns,
756
+ }
757
+ batch_dict.update(model_inputs)
758
+
759
+ t = time.time()
760
+ all_stats = []
761
+ early_stop = False
762
+ for _ in range(self.config.ppo_epochs):
763
+ if early_stop:
764
+ break
765
+ b_inds = np.random.permutation(bs)
766
+ for backward_batch_start in range(0, bs, self.config.backward_batch_size):
767
+ backward_batch_end = backward_batch_start + self.config.backward_batch_size
768
+ backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
769
+
770
+ for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
771
+ mini_batch_end = mini_batch_start + self.config.mini_batch_size
772
+ mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
773
+ mini_batch_dict = {
774
+ "logprobs": batch_dict["logprobs"][mini_batch_inds],
775
+ "values": batch_dict["values"][mini_batch_inds],
776
+ "masks": batch_dict["masks"][mini_batch_inds],
777
+ # hacks: the queries and responses are ragged.
778
+ "queries": [batch_dict["queries"][i] for i in mini_batch_inds],
779
+ "responses": [batch_dict["responses"][i] for i in mini_batch_inds],
780
+ "advantages": batch_dict["advantages"][mini_batch_inds],
781
+ "returns": batch_dict["returns"][mini_batch_inds],
782
+ }
783
+ for k in model_inputs_names:
784
+ mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
785
+ with self.accelerator.accumulate(self.model):
786
+ model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
787
+
788
+ logprobs, logits, vpreds, _ = self.batched_forward_pass(
789
+ self.model,
790
+ mini_batch_dict["queries"],
791
+ mini_batch_dict["responses"],
792
+ model_inputs,
793
+ return_logits=True,
794
+ )
795
+ train_stats = self.train_minibatch(
796
+ mini_batch_dict["logprobs"],
797
+ mini_batch_dict["values"],
798
+ logprobs,
799
+ logits,
800
+ vpreds,
801
+ mini_batch_dict["masks"],
802
+ mini_batch_dict["advantages"],
803
+ mini_batch_dict["returns"],
804
+ )
805
+ all_stats.append(train_stats)
806
+
807
+ # typically, early stopping is done at the epoch level
808
+ if self.config.early_stopping:
809
+ policykl = train_stats["policy/policykl"]
810
+ early_stop = self._early_stop(policykl)
811
+ if early_stop:
812
+ break
813
+
814
+ timing["time/ppo/optimize_step"] = time.time() - t
815
+
816
+ t = time.time()
817
+ train_stats = stack_dicts(all_stats)
818
+
819
+ # reshape advantages/ratios such that they are not averaged.
820
+ train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0)
821
+ train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING)
822
+ train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0)
823
+
824
+ stats = self.record_step_stats(
825
+ scores=scores,
826
+ logprobs=all_logprobs,
827
+ ref_logprobs=ref_logprobs,
828
+ non_score_reward=non_score_reward,
829
+ train_stats=train_stats,
830
+ kl_coef=self.kl_ctl.value,
831
+ masks=masks,
832
+ queries=queries,
833
+ responses=responses,
834
+ kls=kls,
835
+ )
836
+ # Gather/Reduce stats from all processes
837
+ if self.is_distributed:
838
+ stats = self.gather_stats(stats)
839
+ stats = stats_to_np(stats)
840
+ timing["time/ppo/calc_stats"] = time.time() - t
841
+ stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]
842
+
843
+ # Update the KL control - multiply the batch_size by the number of processes
844
+ self.kl_ctl.update(
845
+ stats["objective/kl"],
846
+ self.config.batch_size * self.accelerator.num_processes,
847
+ )
848
+
849
+ # Log the total ppo time
850
+ timing["time/ppo/total"] = time.time() - t0
851
+ stats.update(timing)
852
+
853
+ # post-process stats for tensorboard and other loggers
854
+ if self.config.log_with != "wandb":
855
+ stats = convert_to_scalar(stats)
856
+
857
+ if self.lr_scheduler is not None:
858
+ self.lr_scheduler.step()
859
+
860
+ return stats
861
+
862
+ def _early_stop(self, policykl):
863
+ r"""
864
+ Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and
865
+ the optimization step is skipped.
866
+ This also handles the multi-gpu case where the policy KL is averaged across all processes.
867
+
868
+ Args:
869
+ policy_kl (torch.Tensor):
870
+ the policy KL
871
+
872
+ Returns:
873
+ `bool`: whether to early stop or not
874
+ """
875
+ early_stop = False
876
+ if not self.config.early_stopping:
877
+ return early_stop
878
+
879
+ if not self.is_distributed and policykl > 1.5 * self.config.target_kl:
880
+ self.optimizer.zero_grad()
881
+ early_stop = True
882
+ elif self.is_distributed:
883
+ import torch.distributed as dist
884
+
885
+ # Wait for all processes to finish
886
+ dist.barrier()
887
+
888
+ # all gather the policykl
889
+ dist.all_reduce(policykl, dist.ReduceOp.SUM)
890
+ policykl /= self.accelerator.num_processes
891
+
892
+ if policykl > 1.5 * self.config.target_kl:
893
+ self.optimizer.zero_grad()
894
+ early_stop = True
895
+ return early_stop
896
+
897
+ def gather_stats(self, stats):
898
+ """
899
+ Gather stats from all processes. Useful in the context of distributed training.
900
+
901
+ Args:
902
+ stats (dict[str, Any]):
903
+ a dictionary of stats to be gathered. The stats should contain torch tensors.
904
+
905
+ Returns:
906
+ `dict[str, Any]`: A dictionary of stats with the tensors gathered.
907
+ """
908
+ import torch.distributed as dist
909
+
910
+ # Wait for all processes to finish
911
+ dist.barrier()
912
+
913
+ for k, v in stats.items():
914
+ if isinstance(v, torch.Tensor):
915
+ dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM)
916
+ v /= self.accelerator.num_processes
917
+ stats[k] = v
918
+ return stats
919
+
920
+ def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
921
+ if self.is_encoder_decoder:
922
+ input_data = self.data_collator(
923
+ [{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]
924
+ ).to(self.current_device)
925
+
926
+ decoder_inputs = self.data_collator(
927
+ [{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]
928
+ ).to(self.current_device)
929
+
930
+ input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
931
+ input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
932
+ else:
933
+ input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
934
+ input_data = self.data_collator(
935
+ [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]
936
+ ).to(self.current_device)
937
+
938
+ input_data.pop("labels", None) # we don't want to compute LM losses
939
+ return input_data
940
+
941
+ @PPODecorators.empty_device_cache()
942
+ def batched_forward_pass(
943
+ self,
944
+ model: PreTrainedModelWrapper,
945
+ queries: torch.Tensor,
946
+ responses: torch.Tensor,
947
+ model_inputs: dict,
948
+ return_logits: bool = False,
949
+ response_masks: Optional[torch.Tensor] = None,
950
+ ):
951
+ """
952
+ Calculate model outputs in multiple batches.
953
+
954
+ Args:
955
+ queries (`torch.LongTensor`):
956
+ List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
957
+ responses (`torch.LongTensor`):
958
+ List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
959
+ return_logits (`bool`, *optional*, defaults to `False`):
960
+ Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
961
+ Returns:
962
+ (tuple):
963
+ - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
964
+ shape (`batch_size`, `response_length`)
965
+ - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
966
+ shape (`batch_size`, `response_length`)
967
+ - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
968
+ """
969
+ bs = len(queries)
970
+ fbs = self.config.mini_batch_size
971
+ all_logprobs = []
972
+ all_logits = []
973
+ all_masks = []
974
+ all_values = []
975
+
976
+ model.eval()
977
+
978
+ for i in range(math.ceil(bs / fbs)):
979
+ input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
980
+ query_batch = queries[i * fbs : (i + 1) * fbs]
981
+ response_batch = responses[i * fbs : (i + 1) * fbs]
982
+ if response_masks is not None:
983
+ response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
984
+ logits, _, values = model(**input_kwargs)
985
+
986
+ if self.is_encoder_decoder:
987
+ input_ids = input_kwargs["decoder_input_ids"]
988
+ attention_mask = input_kwargs["decoder_attention_mask"]
989
+ else:
990
+ input_ids = input_kwargs["input_ids"]
991
+ attention_mask = input_kwargs["attention_mask"]
992
+
993
+ logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
994
+ masks = torch.zeros_like(attention_mask)
995
+ masks[:, :-1] = attention_mask[:, 1:]
996
+
997
+ for j in range(len(query_batch)):
998
+ if self.is_encoder_decoder:
999
+ # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
1000
+ start = 1
1001
+ end = attention_mask[j, :].sum() - 1
1002
+ else:
1003
+ start = len(query_batch[j]) - 1 # logprobs starts from the second query token
1004
+ if attention_mask[j, 0] == 0: # offset left padding
1005
+ start += attention_mask[j, :].nonzero()[0]
1006
+ end = start + len(response_batch[j])
1007
+ if response_masks is not None:
1008
+ response_masks_batch[j] = torch.cat(
1009
+ (torch.zeros_like(query_batch[j]), response_masks_batch[j])
1010
+ )[1:]
1011
+
1012
+ masks[j, :start] = 0
1013
+ masks[j, end:] = 0
1014
+ if response_masks is not None:
1015
+ masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
1016
+
1017
+ if return_logits:
1018
+ all_logits.append(logits)
1019
+ else:
1020
+ del logits
1021
+ all_values.append(values)
1022
+ all_logprobs.append(logprobs)
1023
+ all_masks.append(masks)
1024
+
1025
+ return (
1026
+ torch.cat(all_logprobs),
1027
+ torch.cat(all_logits)[:, :-1] if return_logits else None,
1028
+ torch.cat(all_values)[:, :-1],
1029
+ torch.cat(all_masks)[:, :-1],
1030
+ )
1031
+
1032
+ @PPODecorators.empty_device_cache()
1033
+ def train_minibatch(
1034
+ self,
1035
+ old_logprobs: torch.FloatTensor,
1036
+ values: torch.FloatTensor,
1037
+ logprobs: torch.FloatTensor,
1038
+ logits: torch.FloatTensor,
1039
+ vpreds: torch.FloatTensor,
1040
+ mask: torch.LongTensor,
1041
+ advantages: torch.FloatTensor,
1042
+ returns: torch.FloatTensor,
1043
+ ):
1044
+ """
1045
+ Train one PPO minibatch
1046
+
1047
+ Args:
1048
+ logprobs (`torch.FloatTensor`):
1049
+ Log probabilities of the model, shape [mini_batch_size, response_length]
1050
+ values (`torch.FloatTensor`):
1051
+ Values of the value head, shape [mini_batch_size, response_length]
1052
+ query (`torch.LongTensor`):
1053
+ Encoded queries, shape [mini_batch_size, query_length]
1054
+ response (`torch.LongTensor`):
1055
+ Encoded responses, shape [mini_batch_size, response_length]
1056
+ model_input (`torch.LongTensor`):
1057
+ Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
1058
+
1059
+ Returns:
1060
+ train_stats (dict[str, `torch.Tensor`]):
1061
+ Dictionary of training statistics
1062
+ """
1063
+ self.model.train()
1064
+ loss_p, loss_v, train_stats = self.loss(
1065
+ old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns
1066
+ )
1067
+ loss = loss_p + loss_v
1068
+ self.accelerator.backward(loss)
1069
+ if self.config.max_grad_norm is not None:
1070
+ if self.accelerator.sync_gradients:
1071
+ self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)
1072
+ self.optimizer.step()
1073
+ # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
1074
+ # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code
1075
+ self.optimizer.zero_grad()
1076
+ return train_stats
1077
+
1078
+ def compute_rewards(
1079
+ self,
1080
+ scores: torch.FloatTensor,
1081
+ logprobs: torch.FloatTensor,
1082
+ ref_logprobs: torch.FloatTensor,
1083
+ masks: torch.LongTensor,
1084
+ ):
1085
+ """
1086
+ Compute per token rewards from scores and KL-penalty.
1087
+
1088
+ Args:
1089
+ scores (`torch.FloatTensor`):
1090
+ Scores from the reward model, shape (`batch_size`)
1091
+ logprobs (`torch.FloatTensor`):
1092
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
1093
+ ref_logprobs (`torch.FloatTensor`):
1094
+ Log probabilities of the reference model, shape (`batch_size`, `response_length`)
1095
+
1096
+ Returns:
1097
+ `torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`)
1098
+ `torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`)
1099
+ `torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`)
1100
+ """
1101
+ rewards, non_score_rewards, kls = [], [], []
1102
+ for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
1103
+ # compute KL penalty (from difference in logprobs)
1104
+ kl = self._kl_penalty(logprob, ref_logprob)
1105
+ kls.append(kl)
1106
+ non_score_reward = -self.kl_ctl.value * kl
1107
+ non_score_rewards.append(non_score_reward)
1108
+ reward = non_score_reward.clone()
1109
+ last_non_masked_index = mask.nonzero()[-1]
1110
+
1111
+ # reward is preference model score + KL penalty
1112
+ reward[last_non_masked_index] += score
1113
+ rewards.append(reward)
1114
+ return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)
1115
+
1116
+ def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor:
1117
+ if self.config.kl_penalty == "kl":
1118
+ return logprob - ref_logprob
1119
+
1120
+ if self.config.kl_penalty == "abs":
1121
+ return (logprob - ref_logprob).abs()
1122
+
1123
+ if self.config.kl_penalty == "mse":
1124
+ return 0.5 * (logprob - ref_logprob).square()
1125
+
1126
+ if self.config.kl_penalty == "full":
1127
+ # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
1128
+ return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1)
1129
+
1130
+ raise NotImplementedError
1131
+
1132
+ def compute_advantages(
1133
+ self,
1134
+ values: torch.FloatTensor,
1135
+ rewards: torch.FloatTensor,
1136
+ mask: torch.FloatTensor,
1137
+ ):
1138
+ lastgaelam = 0
1139
+ advantages_reversed = []
1140
+ gen_len = rewards.shape[-1]
1141
+
1142
+ values = values * mask
1143
+ rewards = rewards * mask
1144
+
1145
+ if self.config.whiten_rewards:
1146
+ rewards = masked_whiten(rewards, mask, shift_mean=False)
1147
+
1148
+ for t in reversed(range(gen_len)):
1149
+ nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
1150
+ delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
1151
+ lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
1152
+ advantages_reversed.append(lastgaelam)
1153
+ advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
1154
+
1155
+ returns = advantages + values
1156
+ advantages = masked_whiten(advantages, mask)
1157
+ advantages = advantages.detach()
1158
+ return values, advantages, returns
1159
+
1160
+ def loss(
1161
+ self,
1162
+ old_logprobs: torch.FloatTensor,
1163
+ values: torch.FloatTensor,
1164
+ logits: torch.FloatTensor,
1165
+ vpreds: torch.FloatTensor,
1166
+ logprobs: torch.FloatTensor,
1167
+ mask: torch.LongTensor,
1168
+ advantages: torch.FloatTensor,
1169
+ returns: torch.FloatTensor,
1170
+ ):
1171
+ """
1172
+ Calculate policy and value losses.
1173
+
1174
+ Args:
1175
+ old_logprobs (`torch.FloatTensor`):
1176
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
1177
+ values (`torch.FloatTensor`):
1178
+ Values of the value head, shape (`batch_size`, `response_length`)
1179
+ rewards (`torch.FloatTensor`):
1180
+ Rewards from the reward model, shape (`batch_size`, `response_length`)
1181
+ logits (`torch.FloatTensor`):
1182
+ Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`)
1183
+ v_pred (`torch.FloatTensor`):
1184
+ Values of the value head, shape (`batch_size`, `response_length`)
1185
+ logprobs (`torch.FloatTensor`):
1186
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
1187
+ """
1188
+
1189
+ vpredclipped = clip_by_value(
1190
+ vpreds,
1191
+ values - self.config.cliprange_value,
1192
+ values + self.config.cliprange_value,
1193
+ )
1194
+
1195
+ vf_losses1 = (vpreds - returns) ** 2
1196
+ vf_losses2 = (vpredclipped - returns) ** 2
1197
+ vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
1198
+ vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
1199
+
1200
+ ratio = torch.exp(logprobs - old_logprobs)
1201
+
1202
+ pg_losses = -advantages * ratio
1203
+ pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
1204
+
1205
+ pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
1206
+ pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
1207
+
1208
+ loss = pg_loss + self.config.vf_coef * vf_loss
1209
+
1210
+ avg_ratio = masked_mean(ratio, mask).item()
1211
+ if avg_ratio > self.config.ratio_threshold:
1212
+ warnings.warn(
1213
+ f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch."
1214
+ )
1215
+ pg_loss = pg_loss * 0.0
1216
+ vf_loss = vf_loss * 0.0
1217
+ loss = loss * 0.0
1218
+
1219
+ entropy = masked_mean(entropy_from_logits(logits), mask)
1220
+
1221
+ approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
1222
+ policykl = masked_mean(old_logprobs - logprobs, mask)
1223
+
1224
+ return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
1225
+ value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)
1226
+
1227
+ stats = dict(
1228
+ loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()),
1229
+ policy=dict(
1230
+ entropy=entropy.detach(),
1231
+ approxkl=approxkl.detach(),
1232
+ policykl=policykl.detach(),
1233
+ clipfrac=pg_clipfrac.detach(),
1234
+ advantages=advantages.detach(),
1235
+ advantages_mean=masked_mean(advantages, mask).detach(),
1236
+ ratio=ratio.detach(),
1237
+ ),
1238
+ returns=dict(mean=return_mean.detach(), var=return_var.detach()),
1239
+ val=dict(
1240
+ vpred=masked_mean(vpreds, mask).detach(),
1241
+ error=masked_mean((vpreds - returns) ** 2, mask).detach(),
1242
+ clipfrac=vf_clipfrac.detach(),
1243
+ mean=value_mean.detach(),
1244
+ var=value_var.detach(),
1245
+ ),
1246
+ )
1247
+ return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)
1248
+
1249
+ def record_step_stats(self, kl_coef: float, **data):
1250
+ """
1251
+ Record training step statistics.
1252
+
1253
+
1254
+ Args:
1255
+ kl_coef (`float`):
1256
+ KL coefficient
1257
+ data (`dict`):
1258
+ Dictionary of training step data
1259
+
1260
+ Returns:
1261
+ stats (`dict`):
1262
+ Dictionary of training step statistics
1263
+ """
1264
+ mask = data.pop("masks")
1265
+
1266
+ kls = data.pop("kls")
1267
+ kl_list = ((kls) * mask).sum(axis=-1)
1268
+ mean_kl = kl_list.mean()
1269
+ mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean()
1270
+
1271
+ mean_non_score_reward = masked_mean(
1272
+ data["non_score_reward"], mask
1273
+ ) # non_score_reward is size `batch_size`, `response_length`
1274
+ mean_scores = data["scores"].mean() # scores is size `batch_size`
1275
+ std_scores = data["scores"].std()
1276
+
1277
+ if mean_kl.item() < -1.0:
1278
+ # warn users
1279
+ warnings.warn(
1280
+ f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
1281
+ " sometimes this happens because the generation kwargs are not correctly set. Please make sure"
1282
+ " that the generation kwargs are set correctly, or review your training hyperparameters."
1283
+ )
1284
+
1285
+ stats = {
1286
+ "objective/kl": mean_kl,
1287
+ "objective/kl_dist": kl_list,
1288
+ "objective/logprobs": data["logprobs"],
1289
+ "objective/ref_logprobs": data["ref_logprobs"],
1290
+ "objective/kl_coef": kl_coef,
1291
+ "objective/entropy": mean_entropy,
1292
+ "ppo/mean_non_score_reward": mean_non_score_reward,
1293
+ "ppo/mean_scores": mean_scores,
1294
+ "ppo/std_scores": std_scores,
1295
+ }
1296
+
1297
+ # Log text properties
1298
+ query_lens = torch.tensor([len(query) for query in data["queries"]], dtype=torch.float)
1299
+ response_lens = torch.tensor([len(response) for response in data["responses"]], dtype=torch.float)
1300
+
1301
+ stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item()
1302
+ stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item()
1303
+ stats["tokens/queries_dist"] = query_lens.cpu().numpy()
1304
+ stats["tokens/responses_len_mean"] = torch.mean(response_lens).cpu().numpy().item()
1305
+ stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item()
1306
+ stats["tokens/responses_dist"] = response_lens.cpu().numpy()
1307
+
1308
+ for k, v in data["train_stats"].items():
1309
+ stats[f"ppo/{k}"] = torch.mean(v, axis=0)
1310
+ stats["ppo/val/var_explained"] = 1 - stats["ppo/val/error"] / stats["ppo/returns/var"]
1311
+ return stats
1312
+
1313
+ def log_stats(
1314
+ self,
1315
+ stats: dict,
1316
+ batch: dict,
1317
+ rewards: List[torch.FloatTensor],
1318
+ columns_to_log: typing.Iterable[str] = ("query", "response"),
1319
+ ):
1320
+ """
1321
+ A function that logs all the training stats. Call it at the end of each epoch.
1322
+
1323
+ Args:
1324
+ stats (dict[str, Any]):
1325
+ A dictionary of training stats.
1326
+ batch (dict[str, Any]):
1327
+ A dictionary of batch data, this contains the queries and responses.
1328
+ rewards (`List[torch.FloatTensor]`):
1329
+ A tensor of rewards.
1330
+ """
1331
+
1332
+ # all gather stats
1333
+ if not isinstance(rewards, torch.Tensor):
1334
+ rewards = torch.tensor(rewards).to(self.current_device)
1335
+ rewards = self.accelerator.gather(rewards).flatten()
1336
+
1337
+ if self.config.log_with == "wandb":
1338
+ import wandb
1339
+
1340
+ if any(column_to_log not in batch.keys() for column_to_log in columns_to_log):
1341
+ raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.")
1342
+
1343
+ batch_list = [batch[column_to_log] for column_to_log in columns_to_log]
1344
+ if self.is_distributed:
1345
+ gathered_batch_list = []
1346
+ for b in batch_list:
1347
+ flattened = gather_object(b)
1348
+ gathered_batch_list.append(flattened)
1349
+ batch_list = gathered_batch_list
1350
+
1351
+ # Log only if we are in the main process
1352
+ if self.accelerator.is_main_process:
1353
+ logs = {}
1354
+
1355
+ # Log stats
1356
+ if "query" not in batch.keys() and "response" not in batch.keys():
1357
+ # warn the user that the game logs will not be logged
1358
+ warnings.warn(
1359
+ "The game logs will not be logged because the batch does not contain the keys 'query' and "
1360
+ "'response'. "
1361
+ )
1362
+ elif self.config.log_with == "wandb":
1363
+ table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())]
1364
+ logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)})
1365
+
1366
+ logs.update(stats)
1367
+
1368
+ # manually cast in fp32 for bf16 torch tensors
1369
+ for k, v in logs.items():
1370
+ if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16:
1371
+ logs[k] = v.float()
1372
+
1373
+ logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
1374
+ logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
1375
+ logs["env/reward_dist"] = rewards.cpu().numpy()
1376
+
1377
+ if self.config.log_with == "tensorboard":
1378
+ # update the current step
1379
+ self.current_step += 1
1380
+
1381
+ self.accelerator.log(
1382
+ logs,
1383
+ step=self.current_step if self.config.log_with == "tensorboard" else None,
1384
+ )
1385
+
1386
+ def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None:
1387
+ """Creates and saves a model card for a TRL model.
1388
+
1389
+ Args:
1390
+ path (`str`): The path to save the model card to.
1391
+ model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`.
1392
+ """
1393
+ try:
1394
+ user = whoami()["name"]
1395
+ # handle the offline case
1396
+ except Exception:
1397
+ warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
1398
+ return
1399
+
1400
+ if not os.path.exists(path):
1401
+ os.makedirs(path)
1402
+
1403
+ model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
1404
+ with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
1405
+ f.write(model_card_content)
1406
+
1407
+ def _save_pretrained(self, save_directory: str) -> None:
1408
+ self.accelerator.unwrap_model(self.model).save_pretrained(save_directory)
1409
+ self.tokenizer.save_pretrained(save_directory)
1410
+ self.create_model_card(save_directory)
1411
+
1412
+ def _show_tokens(self, tokens, masks):
1413
+ from rich import print
1414
+ from rich.text import Text
1415
+
1416
+ text = Text()
1417
+
1418
+ for _i, (token, mask) in enumerate(zip(tokens, masks)):
1419
+ if mask == 1:
1420
+ text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1")
1421
+ text.append(" ")
1422
+ else:
1423
+ text.append(self.tokenizer.decode(token.item()), style="black on cyan3")
1424
+ text.append(" ")
1425
+ print(text)
1426
+
1427
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
1428
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
1429
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
1430
+ config_kwargs = deepspeed_plugin.deepspeed_config
1431
+ if model is not None:
1432
+ if hasattr(model, "config"):
1433
+ hidden_size = (
1434
+ max(model.config.hidden_sizes)
1435
+ if getattr(model.config, "hidden_sizes", None)
1436
+ else getattr(model.config, "hidden_size", None)
1437
+ )
1438
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
1439
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
1440
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
1441
+ config_kwargs.update(
1442
+ {
1443
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
1444
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
1445
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
1446
+ }
1447
+ )
1448
+
1449
+ # If ZeRO-3 is used, we shard both the active and reference model.
1450
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
1451
+ if config_kwargs["zero_optimization"]["stage"] != 3:
1452
+ config_kwargs["zero_optimization"]["stage"] = 0
1453
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
1454
+ model.eval()
1455
+ return model