postitive666 commited on
Commit
3c7b14a
1 Parent(s): a8f2b16

orpo chinese phi3 4K

Browse files
README.md CHANGED
@@ -1,3 +1,62 @@
1
  ---
2
- license: mit
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: other
3
+ base_model: /data/user/chengrui/project/mergekit/Phi-3-mini-128k-instruct
4
+ tags:
5
+ - llama-factory
6
+ - full
7
+ - generated_from_trainer
8
+ model-index:
9
+ - name: phi3-chinese-orpo
10
+ results: []
11
  ---
12
+
13
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
14
+ should probably proofread and complete it, then remove this comment. -->
15
+
16
+ # phi3-chinese-orpo
17
+
18
+ This model is a fine-tuned version of [/data/user/chengrui/project/mergekit/Phi-3-mini-128k-instruct](https://huggingface.co//data/user/chengrui/project/mergekit/Phi-3-mini-128k-instruct) on the dpo_mix_en and the dpo_mix_zh datasets.
19
+
20
+ ## Model description
21
+
22
+ More information needed
23
+
24
+ ## Intended uses & limitations
25
+
26
+ More information needed
27
+
28
+ ## Training and evaluation data
29
+
30
+ More information needed
31
+
32
+ ## Training procedure
33
+
34
+ ### Training hyperparameters
35
+
36
+ The following hyperparameters were used during training:
37
+ - learning_rate: 5e-06
38
+ - train_batch_size: 1
39
+ - eval_batch_size: 1
40
+ - seed: 42
41
+ - distributed_type: multi-GPU
42
+ - num_devices: 6
43
+ - gradient_accumulation_steps: 8
44
+ - total_train_batch_size: 48
45
+ - total_eval_batch_size: 6
46
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
47
+ - lr_scheduler_type: cosine
48
+ - lr_scheduler_warmup_ratio: 0.1
49
+ - lr_scheduler_warmup_steps: 20
50
+ - num_epochs: 3.0
51
+ - mixed_precision_training: Native AMP
52
+
53
+ ### Training results
54
+
55
+
56
+
57
+ ### Framework versions
58
+
59
+ - Transformers 4.40.0
60
+ - Pytorch 2.1.0+cu121
61
+ - Datasets 2.15.0
62
+ - Tokenizers 0.19.1
added_tokens.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|/code|>": 32014,
3
+ "<|/data|>": 32033,
4
+ "<|/inst|>": 32037,
5
+ "<|/query|>": 32031,
6
+ "<|/sys|>": 32035,
7
+ "<|assistant_mask|>": 32017,
8
+ "<|assistant|>": 32001,
9
+ "<|calc|>": 32012,
10
+ "<|code|>": 32013,
11
+ "<|continue|>": 32009,
12
+ "<|data|>": 32032,
13
+ "<|diff_marker|>": 32025,
14
+ "<|disc_sep|>": 32029,
15
+ "<|disc_start|>": 32028,
16
+ "<|disc_thread|><|query|>": 32030,
17
+ "<|endoftext|>": 32000,
18
+ "<|end|>": 32007,
19
+ "<|fim_middle|>": 32021,
20
+ "<|fim_prefix|>": 32020,
21
+ "<|fim_suffix|>": 32022,
22
+ "<|function_call|>": 32005,
23
+ "<|function_list|>": 32011,
24
+ "<|function_output|>": 32003,
25
+ "<|ghissue|>": 32026,
26
+ "<|ghreview|>": 32027,
27
+ "<|inst|>": 32036,
28
+ "<|ipynb_marker|>": 32024,
29
+ "<|message|>": 32019,
30
+ "<|meta_start|>": 32023,
31
+ "<|raw|>": 32008,
32
+ "<|resource|>": 32016,
33
+ "<|start|>": 32018,
34
+ "<|step|>": 32002,
35
+ "<|summary|>": 32015,
36
+ "<|system|>": 32006,
37
+ "<|sys|>": 32034,
38
+ "<|tag|>": 32004,
39
+ "<|user|>": 32010
40
+ }
all_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 2.994601079784043,
3
+ "total_flos": 132590267662336.0,
4
+ "train_loss": 0.7937506708579186,
5
+ "train_runtime": 49781.9259,
6
+ "train_samples_per_second": 1.205,
7
+ "train_steps_per_second": 0.025
8
+ }
config.json ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/data/user/chengrui/project/mergekit/Phi-3-mini-128k-instruct",
3
+ "architectures": [
4
+ "Phi3ForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_phi3.Phi3Config",
9
+ "AutoModel": "modeling_phi3.Phi3ForCausalLM",
10
+ "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM"
11
+ },
12
+ "bos_token_id": 1,
13
+ "embd_pdrop": 0.0,
14
+ "eos_token_id": 32000,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 3072,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 8192,
19
+ "max_position_embeddings": 131072,
20
+ "model_type": "phi3",
21
+ "num_attention_heads": 32,
22
+ "num_hidden_layers": 32,
23
+ "num_key_value_heads": 32,
24
+ "original_max_position_embeddings": 4096,
25
+ "pad_token_id": 32000,
26
+ "resid_pdrop": 0.0,
27
+ "rms_norm_eps": 1e-05,
28
+ "rope_scaling": {
29
+ "long_factor": [
30
+ 1.0299999713897705,
31
+ 1.0499999523162842,
32
+ 1.0499999523162842,
33
+ 1.0799999237060547,
34
+ 1.2299998998641968,
35
+ 1.2299998998641968,
36
+ 1.2999999523162842,
37
+ 1.4499999284744263,
38
+ 1.5999999046325684,
39
+ 1.6499998569488525,
40
+ 1.8999998569488525,
41
+ 2.859999895095825,
42
+ 3.68999981880188,
43
+ 5.419999599456787,
44
+ 5.489999771118164,
45
+ 5.489999771118164,
46
+ 9.09000015258789,
47
+ 11.579999923706055,
48
+ 15.65999984741211,
49
+ 15.769999504089355,
50
+ 15.789999961853027,
51
+ 18.360000610351562,
52
+ 21.989999771118164,
53
+ 23.079999923706055,
54
+ 30.009998321533203,
55
+ 32.35000228881836,
56
+ 32.590003967285156,
57
+ 35.56000518798828,
58
+ 39.95000457763672,
59
+ 53.840003967285156,
60
+ 56.20000457763672,
61
+ 57.95000457763672,
62
+ 59.29000473022461,
63
+ 59.77000427246094,
64
+ 59.920005798339844,
65
+ 61.190006256103516,
66
+ 61.96000671386719,
67
+ 62.50000762939453,
68
+ 63.3700065612793,
69
+ 63.48000717163086,
70
+ 63.48000717163086,
71
+ 63.66000747680664,
72
+ 63.850006103515625,
73
+ 64.08000946044922,
74
+ 64.760009765625,
75
+ 64.80001068115234,
76
+ 64.81001281738281,
77
+ 64.81001281738281
78
+ ],
79
+ "short_factor": [
80
+ 1.05,
81
+ 1.05,
82
+ 1.05,
83
+ 1.1,
84
+ 1.1,
85
+ 1.1500000000000001,
86
+ 1.2000000000000002,
87
+ 1.2500000000000002,
88
+ 1.3000000000000003,
89
+ 1.3500000000000003,
90
+ 1.5000000000000004,
91
+ 2.000000000000001,
92
+ 2.000000000000001,
93
+ 2.000000000000001,
94
+ 2.000000000000001,
95
+ 2.000000000000001,
96
+ 2.000000000000001,
97
+ 2.000000000000001,
98
+ 2.000000000000001,
99
+ 2.000000000000001,
100
+ 2.000000000000001,
101
+ 2.000000000000001,
102
+ 2.000000000000001,
103
+ 2.000000000000001,
104
+ 2.000000000000001,
105
+ 2.000000000000001,
106
+ 2.000000000000001,
107
+ 2.000000000000001,
108
+ 2.000000000000001,
109
+ 2.000000000000001,
110
+ 2.000000000000001,
111
+ 2.000000000000001,
112
+ 2.0500000000000007,
113
+ 2.0500000000000007,
114
+ 2.0500000000000007,
115
+ 2.1000000000000005,
116
+ 2.1000000000000005,
117
+ 2.1000000000000005,
118
+ 2.1500000000000004,
119
+ 2.1500000000000004,
120
+ 2.3499999999999996,
121
+ 2.549999999999999,
122
+ 2.5999999999999988,
123
+ 2.5999999999999988,
124
+ 2.7499999999999982,
125
+ 2.849999999999998,
126
+ 2.849999999999998,
127
+ 2.9499999999999975
128
+ ],
129
+ "type": "su"
130
+ },
131
+ "rope_theta": 10000.0,
132
+ "sliding_window": 262144,
133
+ "tie_word_embeddings": false,
134
+ "torch_dtype": "float16",
135
+ "transformers_version": "4.40.0",
136
+ "use_cache": false,
137
+ "vocab_size": 32064
138
+ }
configuration_phi3.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Phi-3 model configuration"""
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json",
27
+ "microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json",
28
+ }
29
+
30
+
31
+ class Phi3Config(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
34
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35
+ defaults will yield a similar configuration to that of the
36
+ [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 32064):
43
+ Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`Phi3Model`].
45
+ hidden_size (`int`, *optional*, defaults to 3072):
46
+ Dimension of the hidden representations.
47
+ intermediate_size (`int`, *optional*, defaults to 8192):
48
+ Dimension of the MLP representations.
49
+ num_hidden_layers (`int`, *optional*, defaults to 32):
50
+ Number of hidden layers in the Transformer decoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 32):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ num_key_value_heads (`int`, *optional*):
54
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58
+ by meanpooling all the original heads within that group. For more details checkout [this
59
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
60
+ `num_attention_heads`.
61
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
62
+ Dropout probability for mlp outputs.
63
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
64
+ The dropout ratio for the embeddings.
65
+ attention_dropout (`float`, *optional*, defaults to 0.0):
66
+ The dropout ratio after computing the attention scores.
67
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
68
+ The non-linear activation function (function or string) in the decoder.
69
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
70
+ The maximum sequence length that this model might ever be used with.
71
+ original_max_position_embeddings (`int`, *optional*, defaults to 4096):
72
+ The maximum sequence length that this model was trained with. This is used to determine the size of the
73
+ original RoPE embeddings when using long scaling.
74
+ initializer_range (`float`, *optional*, defaults to 0.02):
75
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
76
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
77
+ The epsilon value used for the RMSNorm.
78
+ use_cache (`bool`, *optional*, defaults to `True`):
79
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
80
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
81
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
82
+ Whether to tie weight embeddings
83
+ rope_theta (`float`, *optional*, defaults to 10000.0):
84
+ The base period of the RoPE embeddings.
85
+ rope_scaling (`dict`, *optional*):
86
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
87
+ contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
88
+ the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
89
+ divided by the number of attention heads divided by 2.
90
+ bos_token_id (`int`, *optional*, defaults to 1):
91
+ The id of the "beginning-of-sequence" token.
92
+ eos_token_id (`int`, *optional*, defaults to 32000):
93
+ The id of the "end-of-sequence" token.
94
+ pad_token_id (`int`, *optional*, defaults to 32000):
95
+ The id of the padding token.
96
+ sliding_window (`int`, *optional*):
97
+ Sliding window attention window size. If `None`, no sliding window is applied.
98
+
99
+ Example:
100
+
101
+ ```python
102
+ >>> from transformers import Phi3Model, Phi3Config
103
+
104
+ >>> # Initializing a Phi-3 style configuration
105
+ >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
106
+
107
+ >>> # Initializing a model from the configuration
108
+ >>> model = Phi3Model(configuration)
109
+
110
+ >>> # Accessing the model configuration
111
+ >>> configuration = model.config
112
+ ```"""
113
+
114
+ model_type = "phi3"
115
+ keys_to_ignore_at_inference = ["past_key_values"]
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_size=32064,
120
+ hidden_size=3072,
121
+ intermediate_size=8192,
122
+ num_hidden_layers=32,
123
+ num_attention_heads=32,
124
+ num_key_value_heads=None,
125
+ resid_pdrop=0.0,
126
+ embd_pdrop=0.0,
127
+ attention_dropout=0.0,
128
+ hidden_act="silu",
129
+ max_position_embeddings=4096,
130
+ original_max_position_embeddings=4096,
131
+ initializer_range=0.02,
132
+ rms_norm_eps=1e-5,
133
+ use_cache=True,
134
+ tie_word_embeddings=False,
135
+ rope_theta=10000.0,
136
+ rope_scaling=None,
137
+ bos_token_id=1,
138
+ eos_token_id=32000,
139
+ pad_token_id=32000,
140
+ sliding_window=None,
141
+ **kwargs,
142
+ ):
143
+ self.vocab_size = vocab_size
144
+ self.hidden_size = hidden_size
145
+ self.intermediate_size = intermediate_size
146
+ self.num_hidden_layers = num_hidden_layers
147
+ self.num_attention_heads = num_attention_heads
148
+
149
+ if num_key_value_heads is None:
150
+ num_key_value_heads = num_attention_heads
151
+
152
+ self.num_key_value_heads = num_key_value_heads
153
+ self.resid_pdrop = resid_pdrop
154
+ self.embd_pdrop = embd_pdrop
155
+ self.attention_dropout = attention_dropout
156
+ self.hidden_act = hidden_act
157
+ self.max_position_embeddings = max_position_embeddings
158
+ self.original_max_position_embeddings = original_max_position_embeddings
159
+ self.initializer_range = initializer_range
160
+ self.rms_norm_eps = rms_norm_eps
161
+ self.use_cache = use_cache
162
+ self.rope_theta = rope_theta
163
+ self.rope_scaling = rope_scaling
164
+ self._rope_scaling_validation()
165
+ self.sliding_window = sliding_window
166
+
167
+ super().__init__(
168
+ bos_token_id=bos_token_id,
169
+ eos_token_id=eos_token_id,
170
+ pad_token_id=pad_token_id,
171
+ tie_word_embeddings=tie_word_embeddings,
172
+ **kwargs,
173
+ )
174
+
175
+ def _rope_scaling_validation(self):
176
+ """
177
+ Validate the `rope_scaling` configuration.
178
+ """
179
+ if self.rope_scaling is None:
180
+ return
181
+
182
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
183
+ raise ValueError(
184
+ "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
185
+ f"got {self.rope_scaling}"
186
+ )
187
+ rope_scaling_type = self.rope_scaling.get("type", None)
188
+ rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
189
+ rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
190
+ if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]:
191
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
192
+ if not (
193
+ isinstance(rope_scaling_short_factor, list)
194
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
195
+ ):
196
+ raise ValueError(
197
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
198
+ )
199
+ if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
200
+ raise ValueError(
201
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
202
+ )
203
+ if not (
204
+ isinstance(rope_scaling_long_factor, list)
205
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
206
+ ):
207
+ raise ValueError(
208
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
209
+ )
210
+ if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
211
+ raise ValueError(
212
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
213
+ )
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": [
5
+ 32000,
6
+ 32001,
7
+ 32007
8
+ ],
9
+ "pad_token_id": 32000,
10
+ "transformers_version": "4.40.0"
11
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a5de46e0a3c8c853a9fb520b403c406cb1254fb754022ed56ef38ada6613888
3
+ size 4972489200
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d42e717d30a6f95a5312fb8d8ebe283253417fac308597c8673bcb1678ac959
3
+ size 2669692488
model.safetensors.index.json ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 7642159104
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00002-of-00002.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
10
+ "model.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
18
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.1.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
20
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.10.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
24
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
25
+ "model.layers.10.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
26
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
27
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.layers.11.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
29
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
30
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.11.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
32
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
34
+ "model.layers.12.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.12.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
38
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
40
+ "model.layers.13.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
41
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
42
+ "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
43
+ "model.layers.13.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.14.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
49
+ "model.layers.14.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.15.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
54
+ "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
55
+ "model.layers.15.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
56
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
57
+ "model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
58
+ "model.layers.16.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
59
+ "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
60
+ "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.16.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
62
+ "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
63
+ "model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
64
+ "model.layers.17.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
65
+ "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
66
+ "model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.17.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
68
+ "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
70
+ "model.layers.18.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
71
+ "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
72
+ "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
73
+ "model.layers.18.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
74
+ "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
75
+ "model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
76
+ "model.layers.19.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
77
+ "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
78
+ "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
79
+ "model.layers.19.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
80
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
82
+ "model.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
84
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
85
+ "model.layers.2.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
86
+ "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
87
+ "model.layers.20.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
88
+ "model.layers.20.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
89
+ "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
90
+ "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
91
+ "model.layers.20.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
92
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00002.safetensors",
93
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
94
+ "model.layers.21.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
95
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
96
+ "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
97
+ "model.layers.21.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
98
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
99
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
100
+ "model.layers.22.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
101
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
102
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
103
+ "model.layers.22.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
104
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
105
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
106
+ "model.layers.23.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
107
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
108
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
109
+ "model.layers.23.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
110
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
111
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
112
+ "model.layers.24.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
113
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
114
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
115
+ "model.layers.24.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
116
+ "model.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors",
117
+ "model.layers.25.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
118
+ "model.layers.25.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
119
+ "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
120
+ "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
121
+ "model.layers.25.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
122
+ "model.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors",
123
+ "model.layers.26.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
124
+ "model.layers.26.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
125
+ "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
126
+ "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
127
+ "model.layers.26.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
128
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
129
+ "model.layers.27.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
130
+ "model.layers.27.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
131
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
132
+ "model.layers.27.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
133
+ "model.layers.27.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
134
+ "model.layers.28.input_layernorm.weight": "model-00002-of-00002.safetensors",
135
+ "model.layers.28.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
136
+ "model.layers.28.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
137
+ "model.layers.28.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
138
+ "model.layers.28.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
139
+ "model.layers.28.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
140
+ "model.layers.29.input_layernorm.weight": "model-00002-of-00002.safetensors",
141
+ "model.layers.29.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
142
+ "model.layers.29.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
143
+ "model.layers.29.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
144
+ "model.layers.29.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
145
+ "model.layers.29.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
146
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
147
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
148
+ "model.layers.3.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
149
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
150
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
151
+ "model.layers.3.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
152
+ "model.layers.30.input_layernorm.weight": "model-00002-of-00002.safetensors",
153
+ "model.layers.30.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
154
+ "model.layers.30.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
155
+ "model.layers.30.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
156
+ "model.layers.30.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
157
+ "model.layers.30.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
158
+ "model.layers.31.input_layernorm.weight": "model-00002-of-00002.safetensors",
159
+ "model.layers.31.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
160
+ "model.layers.31.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
161
+ "model.layers.31.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
162
+ "model.layers.31.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
163
+ "model.layers.31.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
164
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
165
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
166
+ "model.layers.4.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
167
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
168
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
169
+ "model.layers.4.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
170
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
171
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
172
+ "model.layers.5.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
173
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
174
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
175
+ "model.layers.5.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
176
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
177
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
178
+ "model.layers.6.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
179
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
180
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
181
+ "model.layers.6.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
182
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
183
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
184
+ "model.layers.7.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
185
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
186
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
187
+ "model.layers.7.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
188
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
189
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
190
+ "model.layers.8.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
191
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
192
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
193
+ "model.layers.8.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
194
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
195
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
196
+ "model.layers.9.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
197
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
198
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
199
+ "model.layers.9.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
200
+ "model.norm.weight": "model-00002-of-00002.safetensors"
201
+ }
202
+ }
modeling_phi3.py ADDED
@@ -0,0 +1,1606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch Phi-3 model."""
17
+
18
+ import inspect
19
+ import math
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32
+ from transformers.modeling_outputs import (
33
+ BaseModelOutputWithPast,
34
+ CausalLMOutputWithPast,
35
+ SequenceClassifierOutputWithPast,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.utils import (
40
+ add_code_sample_docstrings,
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ is_flash_attn_2_available,
44
+ is_flash_attn_greater_or_equal_2_10,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from .configuration_phi3 import Phi3Config
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ # Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
54
+ # if is_flash_attn_2_available():
55
+ _flash_supports_window_size = False
56
+ try:
57
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
58
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
59
+
60
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
61
+ except ImportError as error:
62
+ logger.warning(
63
+ f"`flash-attention` package not found, consider installing for better performance: {error}."
64
+ )
65
+ if not _flash_supports_window_size:
66
+ logger.warning(
67
+ "Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
68
+ )
69
+
70
+ _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
71
+ _CONFIG_FOR_DOC = "Phi3Config"
72
+
73
+ PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
74
+ "microsoft/Phi-3-mini-4k-instruct",
75
+ "microsoft/Phi-3-mini-128k-instruct",
76
+ # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
77
+ ]
78
+
79
+
80
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
81
+ class Phi3RMSNorm(nn.Module):
82
+ def __init__(self, hidden_size, eps=1e-6):
83
+ """
84
+ Phi3RMSNorm is equivalent to T5LayerNorm
85
+ """
86
+ super().__init__()
87
+ self.weight = nn.Parameter(torch.ones(hidden_size))
88
+ self.variance_epsilon = eps
89
+
90
+ def forward(self, hidden_states):
91
+ input_dtype = hidden_states.dtype
92
+ hidden_states = hidden_states.to(torch.float32)
93
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
94
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
95
+ return self.weight * hidden_states.to(input_dtype)
96
+
97
+
98
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
99
+ def _get_unpad_data(attention_mask):
100
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
101
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
102
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
103
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
104
+ return (
105
+ indices,
106
+ cu_seqlens,
107
+ max_seqlen_in_batch,
108
+ )
109
+
110
+
111
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
112
+ class Phi3RotaryEmbedding(nn.Module):
113
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
114
+ super().__init__()
115
+
116
+ self.dim = dim
117
+ self.max_position_embeddings = max_position_embeddings
118
+ self.base = base
119
+ self.register_buffer("inv_freq", None, persistent=False)
120
+
121
+ @torch.no_grad()
122
+ def forward(self, x, position_ids, seq_len=None):
123
+ # x: [bs, num_attention_heads, seq_len, head_size]
124
+ if self.inv_freq is None:
125
+ self.inv_freq = 1.0 / (
126
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
127
+ )
128
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
129
+ position_ids_expanded = position_ids[:, None, :].float()
130
+ # Force float32 since bfloat16 loses precision on long contexts
131
+ # See https://github.com/huggingface/transformers/pull/29285
132
+ device_type = x.device.type
133
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
134
+ with torch.autocast(device_type=device_type, enabled=False):
135
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
136
+ emb = torch.cat((freqs, freqs), dim=-1)
137
+ cos = emb.cos()
138
+ sin = emb.sin()
139
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
140
+
141
+
142
+ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
143
+ def __init__(self, dim, config, device=None):
144
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
145
+
146
+ self.short_factor = config.rope_scaling["short_factor"]
147
+ self.long_factor = config.rope_scaling["long_factor"]
148
+ self.original_max_position_embeddings = config.original_max_position_embeddings
149
+
150
+ @torch.no_grad()
151
+ def forward(self, x, position_ids, seq_len=None):
152
+ seq_len = torch.max(position_ids) + 1
153
+ if seq_len > self.original_max_position_embeddings:
154
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
155
+ else:
156
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
157
+
158
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
159
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
160
+
161
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
162
+ position_ids_expanded = position_ids[:, None, :].float()
163
+
164
+ # Force float32 since bfloat16 loses precision on long contexts
165
+ # See https://github.com/huggingface/transformers/pull/29285
166
+ device_type = x.device.type
167
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
168
+ with torch.autocast(device_type=device_type, enabled=False):
169
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
170
+ emb = torch.cat((freqs, freqs), dim=-1)
171
+
172
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
173
+ if scale <= 1.0:
174
+ scaling_factor = 1.0
175
+ else:
176
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
177
+
178
+ cos = emb.cos() * scaling_factor
179
+ sin = emb.sin() * scaling_factor
180
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
181
+
182
+
183
+ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
184
+ def __init__(self, dim, config, device=None):
185
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
186
+
187
+ self.short_factor = config.rope_scaling["short_factor"]
188
+ self.long_factor = config.rope_scaling["long_factor"]
189
+ self.original_max_position_embeddings = config.original_max_position_embeddings
190
+
191
+ @torch.no_grad()
192
+ def forward(self, x, position_ids, seq_len=None):
193
+ seq_len = torch.max(position_ids) + 1
194
+ if seq_len > self.original_max_position_embeddings:
195
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
196
+ else:
197
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
198
+
199
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
200
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
201
+
202
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
203
+ position_ids_expanded = position_ids[:, None, :].float()
204
+
205
+ # Force float32 since bfloat16 loses precision on long contexts
206
+ # See https://github.com/huggingface/transformers/pull/29285
207
+ device_type = x.device.type
208
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
209
+ with torch.autocast(device_type=device_type, enabled=False):
210
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
211
+ emb = torch.cat((freqs, freqs), dim=-1)
212
+
213
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
214
+ if scale <= 1.0:
215
+ scaling_factor = 1.0
216
+ else:
217
+ scaling_factor = 0.1 * math.log(scale) + 1.0
218
+
219
+ cos = emb.cos() * scaling_factor
220
+ sin = emb.sin() * scaling_factor
221
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
222
+
223
+
224
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
225
+ def rotate_half(x):
226
+ """Rotates half the hidden dims of the input."""
227
+ x1 = x[..., : x.shape[-1] // 2]
228
+ x2 = x[..., x.shape[-1] // 2 :]
229
+ return torch.cat((-x2, x1), dim=-1)
230
+
231
+
232
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
233
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
234
+ """Applies Rotary Position Embedding to the query and key tensors.
235
+
236
+ Args:
237
+ q (`torch.Tensor`): The query tensor.
238
+ k (`torch.Tensor`): The key tensor.
239
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
240
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
241
+ position_ids (`torch.Tensor`, *optional*):
242
+ Deprecated and unused.
243
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
244
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
245
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
246
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
247
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
248
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
249
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
250
+ Returns:
251
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
252
+ """
253
+ cos = cos.unsqueeze(unsqueeze_dim)
254
+ sin = sin.unsqueeze(unsqueeze_dim)
255
+ q_embed = (q * cos) + (rotate_half(q) * sin)
256
+ k_embed = (k * cos) + (rotate_half(k) * sin)
257
+ return q_embed, k_embed
258
+
259
+
260
+ class Phi3MLP(nn.Module):
261
+ def __init__(self, config):
262
+ super().__init__()
263
+
264
+ self.config = config
265
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
266
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
267
+
268
+ self.activation_fn = ACT2FN[config.hidden_act]
269
+
270
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
271
+ up_states = self.gate_up_proj(hidden_states)
272
+
273
+ gate, up_states = up_states.chunk(2, dim=-1)
274
+ up_states = up_states * self.activation_fn(gate)
275
+
276
+ return self.down_proj(up_states)
277
+
278
+
279
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
280
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
281
+ """
282
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
283
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
284
+ """
285
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
286
+ if n_rep == 1:
287
+ return hidden_states
288
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
289
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
290
+
291
+
292
+ class Phi3Attention(nn.Module):
293
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
294
+
295
+ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
296
+ super().__init__()
297
+ self.config = config
298
+ self.layer_idx = layer_idx
299
+ if layer_idx is None:
300
+ logger.warning_once(
301
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
302
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
303
+ "when creating this class."
304
+ )
305
+
306
+ self.attention_dropout = config.attention_dropout
307
+ self.hidden_size = config.hidden_size
308
+ self.num_heads = config.num_attention_heads
309
+ self.head_dim = self.hidden_size // self.num_heads
310
+ self.num_key_value_heads = config.num_key_value_heads
311
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
312
+ self.max_position_embeddings = config.max_position_embeddings
313
+ self.original_max_position_embeddings = config.original_max_position_embeddings
314
+ self.rope_theta = config.rope_theta
315
+ self.rope_scaling = config.rope_scaling
316
+ self.is_causal = True
317
+
318
+ if (self.head_dim * self.num_heads) != self.hidden_size:
319
+ raise ValueError(
320
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
321
+ f" and `num_heads`: {self.num_heads})."
322
+ )
323
+
324
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
325
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
326
+ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
327
+ self._init_rope()
328
+
329
+ def _init_rope(self):
330
+ if self.rope_scaling is None:
331
+ self.rotary_emb = Phi3RotaryEmbedding(
332
+ self.head_dim,
333
+ max_position_embeddings=self.max_position_embeddings,
334
+ base=self.rope_theta,
335
+ )
336
+ else:
337
+ scaling_type = self.config.rope_scaling["type"]
338
+ if scaling_type == "su":
339
+ self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
340
+ elif scaling_type == "yarn":
341
+ self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
342
+ else:
343
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
344
+
345
+ def forward(
346
+ self,
347
+ hidden_states: torch.Tensor,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ position_ids: Optional[torch.LongTensor] = None,
350
+ past_key_value: Optional[Cache] = None,
351
+ output_attentions: bool = False,
352
+ use_cache: bool = False,
353
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
354
+ logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.")
355
+
356
+ bsz, q_len, _ = hidden_states.size()
357
+
358
+ qkv = self.qkv_proj(hidden_states)
359
+ query_pos = self.num_heads * self.head_dim
360
+ query_states = qkv[..., :query_pos]
361
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
362
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
363
+
364
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
365
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
366
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
367
+
368
+ kv_seq_len = key_states.shape[-2]
369
+ if past_key_value is not None:
370
+ if self.layer_idx is None:
371
+ raise ValueError(
372
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
373
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
374
+ "with a layer index."
375
+ )
376
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
377
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
378
+
379
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
380
+
381
+ if past_key_value is not None:
382
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
383
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
384
+
385
+ # repeat k/v heads if n_kv_heads < n_heads
386
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
387
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
388
+
389
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
390
+
391
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
392
+ raise ValueError(
393
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
394
+ f" {attn_weights.size()}"
395
+ )
396
+
397
+ if attention_mask is not None:
398
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
399
+ raise ValueError(
400
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
401
+ )
402
+ attn_weights = attn_weights + attention_mask
403
+
404
+ # upcast attention to fp32
405
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
406
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
407
+
408
+ attn_output = torch.matmul(attn_weights, value_states)
409
+
410
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
411
+ raise ValueError(
412
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
413
+ f" {attn_output.size()}"
414
+ )
415
+
416
+ attn_output = attn_output.transpose(1, 2).contiguous()
417
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
418
+
419
+ attn_output = self.o_proj(attn_output)
420
+
421
+ if not output_attentions:
422
+ attn_weights = None
423
+
424
+ return attn_output, attn_weights, past_key_value
425
+
426
+
427
+ class Phi3FlashAttention2(Phi3Attention):
428
+ """
429
+ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
430
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
431
+ flash attention and deal with padding tokens in case the input contains any of them.
432
+ """
433
+
434
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
435
+ def __init__(self, *args, **kwargs):
436
+ super().__init__(*args, **kwargs)
437
+
438
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
439
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
440
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
441
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
442
+
443
+ def forward(
444
+ self,
445
+ hidden_states: torch.Tensor,
446
+ attention_mask: Optional[torch.LongTensor] = None,
447
+ position_ids: Optional[torch.LongTensor] = None,
448
+ past_key_value: Optional[Cache] = None,
449
+ output_attentions: bool = False,
450
+ use_cache: bool = False,
451
+ **kwargs,
452
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
453
+ # Phi3FlashAttention2 attention does not support output_attentions
454
+
455
+ if not _flash_supports_window_size:
456
+ logger.warning_once(
457
+ "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
458
+ )
459
+ raise ValueError("The current flash attention version does not support sliding window attention.")
460
+
461
+ output_attentions = False
462
+
463
+ if "padding_mask" in kwargs:
464
+ warnings.warn(
465
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
466
+ )
467
+
468
+ # overwrite attention_mask with padding_mask
469
+ attention_mask = kwargs.pop("padding_mask")
470
+
471
+ bsz, q_len, _ = hidden_states.size()
472
+
473
+ qkv = self.qkv_proj(hidden_states)
474
+ query_pos = self.num_heads * self.head_dim
475
+ query_states = qkv[..., :query_pos]
476
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
477
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
478
+
479
+ # Flash attention requires the input to have the shape
480
+ # batch_size x seq_length x head_dim x hidden_dim
481
+ # therefore we just need to keep the original shape
482
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
483
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
484
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
485
+
486
+ kv_seq_len = key_states.shape[-2]
487
+ if past_key_value is not None:
488
+ if self.layer_idx is None:
489
+ raise ValueError(
490
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
491
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
492
+ "with a layer index."
493
+ )
494
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
495
+
496
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
497
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
498
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
499
+
500
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
501
+
502
+ use_sliding_windows = (
503
+ _flash_supports_window_size
504
+ and getattr(self.config, "sliding_window", None) is not None
505
+ and kv_seq_len > self.config.sliding_window
506
+ )
507
+
508
+ if past_key_value is not None:
509
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
510
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
511
+ if (
512
+ getattr(self.config, "sliding_window", None) is not None
513
+ and kv_seq_len > self.config.sliding_window
514
+ and cache_has_contents
515
+ ):
516
+ slicing_tokens = 1 - self.config.sliding_window
517
+
518
+ past_key = past_key_value[self.layer_idx][0]
519
+ past_value = past_key_value[self.layer_idx][1]
520
+
521
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
522
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
523
+
524
+ if past_key.shape[-2] != self.config.sliding_window - 1:
525
+ raise ValueError(
526
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
527
+ f" {past_key.shape}"
528
+ )
529
+
530
+ if attention_mask is not None:
531
+ attention_mask = attention_mask[:, slicing_tokens:]
532
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
533
+
534
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
535
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
536
+
537
+ # repeat k/v heads if n_kv_heads < n_heads
538
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
539
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
540
+
541
+ attn_dropout = self.attention_dropout if self.training else 0.0
542
+
543
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
544
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
545
+ # cast them back in the correct dtype just to be sure everything works as expected.
546
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
547
+ # in fp32.
548
+
549
+ if query_states.dtype == torch.float32:
550
+ if torch.is_autocast_enabled():
551
+ target_dtype = torch.get_autocast_gpu_dtype()
552
+ # Handle the case where the model is quantized
553
+ elif hasattr(self.config, "_pre_quantization_dtype"):
554
+ target_dtype = self.config._pre_quantization_dtype
555
+ else:
556
+ target_dtype = self.qkv_proj.weight.dtype
557
+
558
+ logger.warning_once(
559
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
560
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
561
+ f" {target_dtype}."
562
+ )
563
+
564
+ query_states = query_states.to(target_dtype)
565
+ key_states = key_states.to(target_dtype)
566
+ value_states = value_states.to(target_dtype)
567
+
568
+ # Reashape to the expected shape for Flash Attention
569
+ query_states = query_states.transpose(1, 2)
570
+ key_states = key_states.transpose(1, 2)
571
+ value_states = value_states.transpose(1, 2)
572
+
573
+ attn_output = self._flash_attention_forward(
574
+ query_states,
575
+ key_states,
576
+ value_states,
577
+ attention_mask,
578
+ q_len,
579
+ dropout=attn_dropout,
580
+ use_sliding_windows=use_sliding_windows,
581
+ )
582
+
583
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
584
+ attn_output = self.o_proj(attn_output)
585
+
586
+ if not output_attentions:
587
+ attn_weights = None
588
+
589
+ return attn_output, attn_weights, past_key_value
590
+
591
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
592
+ def _flash_attention_forward(
593
+ self,
594
+ query_states,
595
+ key_states,
596
+ value_states,
597
+ attention_mask,
598
+ query_length,
599
+ dropout=0.0,
600
+ softmax_scale=None,
601
+ use_sliding_windows=False,
602
+ ):
603
+ """
604
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
605
+ first unpad the input, then computes the attention scores and pad the final attention scores.
606
+
607
+ Args:
608
+ query_states (`torch.Tensor`):
609
+ Input query states to be passed to Flash Attention API
610
+ key_states (`torch.Tensor`):
611
+ Input key states to be passed to Flash Attention API
612
+ value_states (`torch.Tensor`):
613
+ Input value states to be passed to Flash Attention API
614
+ attention_mask (`torch.Tensor`):
615
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
616
+ position of padding tokens and 1 for the position of non-padding tokens.
617
+ dropout (`float`):
618
+ Attention dropout
619
+ softmax_scale (`float`, *optional*):
620
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
621
+ use_sliding_windows (`bool`, *optional*):
622
+ Whether to activate sliding window attention.
623
+ """
624
+ if not self._flash_attn_uses_top_left_mask:
625
+ causal = self.is_causal
626
+ else:
627
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
628
+ causal = self.is_causal and query_length != 1
629
+
630
+ # Contains at least one padding token in the sequence
631
+ if attention_mask is not None:
632
+ batch_size = query_states.shape[0]
633
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
634
+ query_states, key_states, value_states, attention_mask, query_length
635
+ )
636
+
637
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
638
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
639
+
640
+ if not use_sliding_windows:
641
+ attn_output_unpad = flash_attn_varlen_func(
642
+ query_states,
643
+ key_states,
644
+ value_states,
645
+ cu_seqlens_q=cu_seqlens_q,
646
+ cu_seqlens_k=cu_seqlens_k,
647
+ max_seqlen_q=max_seqlen_in_batch_q,
648
+ max_seqlen_k=max_seqlen_in_batch_k,
649
+ dropout_p=dropout,
650
+ softmax_scale=softmax_scale,
651
+ causal=causal,
652
+ )
653
+ else:
654
+ attn_output_unpad = flash_attn_varlen_func(
655
+ query_states,
656
+ key_states,
657
+ value_states,
658
+ cu_seqlens_q=cu_seqlens_q,
659
+ cu_seqlens_k=cu_seqlens_k,
660
+ max_seqlen_q=max_seqlen_in_batch_q,
661
+ max_seqlen_k=max_seqlen_in_batch_k,
662
+ dropout_p=dropout,
663
+ softmax_scale=softmax_scale,
664
+ causal=causal,
665
+ window_size=(self.config.sliding_window, self.config.sliding_window),
666
+ )
667
+
668
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
669
+ else:
670
+ if not use_sliding_windows:
671
+ attn_output = flash_attn_func(
672
+ query_states,
673
+ key_states,
674
+ value_states,
675
+ dropout,
676
+ softmax_scale=softmax_scale,
677
+ causal=causal,
678
+ )
679
+ else:
680
+ attn_output = flash_attn_func(
681
+ query_states,
682
+ key_states,
683
+ value_states,
684
+ dropout,
685
+ softmax_scale=softmax_scale,
686
+ causal=causal,
687
+ window_size=(self.config.sliding_window, self.config.sliding_window),
688
+ )
689
+
690
+ return attn_output
691
+
692
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
693
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
694
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
695
+
696
+ # On the first iteration we need to properly re-create the padding mask
697
+ # by slicing it on the proper place
698
+ if kv_seq_len != attention_mask.shape[-1]:
699
+ attention_mask_num_tokens = attention_mask.shape[-1]
700
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
701
+
702
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
703
+
704
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
705
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
706
+
707
+ if query_length == kv_seq_len:
708
+ query_layer = index_first_axis(
709
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
710
+ )
711
+ cu_seqlens_q = cu_seqlens_k
712
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
713
+ indices_q = indices_k
714
+ elif query_length == 1:
715
+ max_seqlen_in_batch_q = 1
716
+ cu_seqlens_q = torch.arange(
717
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
718
+ ) # There is a memcpy here, that is very bad.
719
+ indices_q = cu_seqlens_q[:-1]
720
+ query_layer = query_layer.squeeze(1)
721
+ else:
722
+ # The -q_len: slice assumes left padding.
723
+ attention_mask = attention_mask[:, -query_length:]
724
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
725
+
726
+ return (
727
+ query_layer,
728
+ key_layer,
729
+ value_layer,
730
+ indices_q,
731
+ (cu_seqlens_q, cu_seqlens_k),
732
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
733
+ )
734
+
735
+
736
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
737
+ # TODO @Arthur no longer copied from LLama after static cache
738
+ class Phi3SdpaAttention(Phi3Attention):
739
+ """
740
+ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
741
+ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
742
+ SDPA API.
743
+ """
744
+
745
+ # Adapted from Phi3Attention.forward
746
+ def forward(
747
+ self,
748
+ hidden_states: torch.Tensor,
749
+ attention_mask: Optional[torch.Tensor] = None,
750
+ position_ids: Optional[torch.LongTensor] = None,
751
+ past_key_value: Optional[Cache] = None,
752
+ output_attentions: bool = False,
753
+ use_cache: bool = False,
754
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
755
+ if output_attentions:
756
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
757
+ logger.warning_once(
758
+ "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
759
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
760
+ )
761
+ return super().forward(
762
+ hidden_states=hidden_states,
763
+ attention_mask=attention_mask,
764
+ position_ids=position_ids,
765
+ past_key_value=past_key_value,
766
+ output_attentions=output_attentions,
767
+ use_cache=use_cache,
768
+ )
769
+
770
+ bsz, q_len, _ = hidden_states.size()
771
+
772
+ qkv = self.qkv_proj(hidden_states)
773
+ query_pos = self.num_heads * self.head_dim
774
+ query_states = qkv[..., :query_pos]
775
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
776
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
777
+
778
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
779
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
780
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
781
+
782
+ kv_seq_len = key_states.shape[-2]
783
+ if past_key_value is not None:
784
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
785
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
786
+
787
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
788
+
789
+ if past_key_value is not None:
790
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
791
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
792
+
793
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
794
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
795
+
796
+ if attention_mask is not None:
797
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
798
+ raise ValueError(
799
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
800
+ )
801
+
802
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
803
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
804
+ if query_states.device.type == "cuda" and attention_mask is not None:
805
+ query_states = query_states.contiguous()
806
+ key_states = key_states.contiguous()
807
+ value_states = value_states.contiguous()
808
+
809
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
810
+ query_states,
811
+ key_states,
812
+ value_states,
813
+ attn_mask=attention_mask,
814
+ dropout_p=self.attention_dropout if self.training else 0.0,
815
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
816
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
817
+ )
818
+
819
+ attn_output = attn_output.transpose(1, 2).contiguous()
820
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
821
+
822
+ attn_output = self.o_proj(attn_output)
823
+
824
+ return attn_output, None, past_key_value
825
+
826
+
827
+ PHI3_ATTENTION_CLASSES = {
828
+ "eager": Phi3Attention,
829
+ "flash_attention_2": Phi3FlashAttention2,
830
+ "sdpa": Phi3SdpaAttention,
831
+ }
832
+
833
+
834
+ class Phi3DecoderLayer(nn.Module):
835
+ def __init__(self, config: Phi3Config, layer_idx: int):
836
+ super().__init__()
837
+
838
+ self.config = config
839
+ self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
840
+
841
+ self.mlp = Phi3MLP(config)
842
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
843
+
844
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
845
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
846
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
847
+
848
+ def forward(
849
+ self,
850
+ hidden_states: torch.Tensor,
851
+ attention_mask: Optional[torch.Tensor] = None,
852
+ position_ids: Optional[torch.LongTensor] = None,
853
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
854
+ output_attentions: Optional[bool] = False,
855
+ use_cache: Optional[bool] = False,
856
+ **kwargs,
857
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
858
+ if "padding_mask" in kwargs:
859
+ warnings.warn(
860
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
861
+ )
862
+ """
863
+ Args:
864
+ hidden_states (`torch.FloatTensor`):
865
+ input to the layer of shape `(batch, seq_len, embed_dim)`
866
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
867
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
868
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
869
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
870
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
871
+ output_attentions (`bool`, *optional*):
872
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
873
+ returned tensors for more detail.
874
+ use_cache (`bool`, *optional*):
875
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
876
+ (see `past_key_values`).
877
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
878
+ """
879
+
880
+ residual = hidden_states
881
+
882
+ hidden_states = self.input_layernorm(hidden_states)
883
+
884
+ # Self Attention
885
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
886
+ hidden_states=hidden_states,
887
+ attention_mask=attention_mask,
888
+ position_ids=position_ids,
889
+ past_key_value=past_key_value,
890
+ output_attentions=output_attentions,
891
+ use_cache=use_cache,
892
+ )
893
+
894
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
895
+
896
+ residual = hidden_states
897
+ hidden_states = self.post_attention_layernorm(hidden_states)
898
+ hidden_states = self.mlp(hidden_states)
899
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
900
+
901
+ outputs = (hidden_states,)
902
+
903
+ if output_attentions:
904
+ outputs += (self_attn_weights,)
905
+
906
+ if use_cache:
907
+ outputs += (present_key_value,)
908
+
909
+ return outputs
910
+
911
+
912
+ PHI3_START_DOCSTRING = r"""
913
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
914
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
915
+ etc.)
916
+
917
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
918
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
919
+ and behavior.
920
+
921
+ Parameters:
922
+ config ([`Phi3Config`]):
923
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
924
+ load the weights associated with the model, only the configuration. Check out the
925
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
926
+ """
927
+
928
+
929
+ @add_start_docstrings(
930
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
931
+ PHI3_START_DOCSTRING,
932
+ )
933
+ class Phi3PreTrainedModel(PreTrainedModel):
934
+ config_class = Phi3Config
935
+ base_model_prefix = "model"
936
+ supports_gradient_checkpointing = True
937
+ _no_split_modules = ["Phi3DecoderLayer"]
938
+ _skip_keys_device_placement = "past_key_values"
939
+ _supports_flash_attn_2 = True
940
+ _supports_sdpa = False
941
+ _supports_cache_class = True
942
+
943
+ _version = "0.0.5"
944
+
945
+ def _init_weights(self, module):
946
+ std = self.config.initializer_range
947
+ if isinstance(module, nn.Linear):
948
+ module.weight.data.normal_(mean=0.0, std=std)
949
+ if module.bias is not None:
950
+ module.bias.data.zero_()
951
+ elif isinstance(module, nn.Embedding):
952
+ module.weight.data.normal_(mean=0.0, std=std)
953
+ if module.padding_idx is not None:
954
+ module.weight.data[module.padding_idx].zero_()
955
+
956
+
957
+ PHI3_INPUTS_DOCSTRING = r"""
958
+ Args:
959
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
960
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
961
+ it.
962
+
963
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
964
+ [`PreTrainedTokenizer.__call__`] for details.
965
+
966
+ [What are input IDs?](../glossary#input-ids)
967
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
968
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
969
+
970
+ - 1 for tokens that are **not masked**,
971
+ - 0 for tokens that are **masked**.
972
+
973
+ [What are attention masks?](../glossary#attention-mask)
974
+
975
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
976
+ [`PreTrainedTokenizer.__call__`] for details.
977
+
978
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
979
+ `past_key_values`).
980
+
981
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
982
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
983
+ information on the default strategy.
984
+
985
+ - 1 indicates the head is **not masked**,
986
+ - 0 indicates the head is **masked**.
987
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
988
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
989
+ config.n_positions - 1]`.
990
+
991
+ [What are position IDs?](../glossary#position-ids)
992
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
993
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
994
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
995
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
996
+
997
+ Two formats are allowed:
998
+ - a [`~cache_utils.Cache`] instance;
999
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1000
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1001
+ cache format.
1002
+
1003
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1004
+ legacy cache format will be returned.
1005
+
1006
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1007
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1008
+ of shape `(batch_size, sequence_length)`.
1009
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1010
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1011
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1012
+ model's internal embedding lookup matrix.
1013
+ use_cache (`bool`, *optional*):
1014
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1015
+ `past_key_values`).
1016
+ output_attentions (`bool`, *optional*):
1017
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1018
+ tensors for more detail.
1019
+ output_hidden_states (`bool`, *optional*):
1020
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1021
+ more detail.
1022
+ return_dict (`bool`, *optional*):
1023
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1024
+ """
1025
+
1026
+
1027
+ @add_start_docstrings(
1028
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
1029
+ PHI3_START_DOCSTRING,
1030
+ )
1031
+ class Phi3Model(Phi3PreTrainedModel):
1032
+ """
1033
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
1034
+
1035
+ Args:
1036
+ config: Phi3Config
1037
+ """
1038
+
1039
+ def __init__(self, config: Phi3Config):
1040
+ super().__init__(config)
1041
+ self.padding_idx = config.pad_token_id
1042
+ self.vocab_size = config.vocab_size
1043
+
1044
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1045
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
1046
+ self.layers = nn.ModuleList(
1047
+ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1048
+ )
1049
+ self._attn_implementation = config._attn_implementation
1050
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1051
+
1052
+ self.gradient_checkpointing = False
1053
+ # Initialize weights and apply final processing
1054
+ self.post_init()
1055
+
1056
+ def get_input_embeddings(self):
1057
+ return self.embed_tokens
1058
+
1059
+ def set_input_embeddings(self, value):
1060
+ self.embed_tokens = value
1061
+
1062
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1063
+ def forward(
1064
+ self,
1065
+ input_ids: torch.LongTensor = None,
1066
+ attention_mask: Optional[torch.Tensor] = None,
1067
+ position_ids: Optional[torch.LongTensor] = None,
1068
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1069
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1070
+ use_cache: Optional[bool] = None,
1071
+ output_attentions: Optional[bool] = None,
1072
+ output_hidden_states: Optional[bool] = None,
1073
+ return_dict: Optional[bool] = None,
1074
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1075
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1076
+ output_hidden_states = (
1077
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1078
+ )
1079
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1080
+
1081
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1082
+
1083
+ # retrieve input_ids and inputs_embeds
1084
+ if input_ids is not None and inputs_embeds is not None:
1085
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1086
+ elif input_ids is not None:
1087
+ batch_size, seq_length = input_ids.shape[:2]
1088
+ elif inputs_embeds is not None:
1089
+ batch_size, seq_length = inputs_embeds.shape[:2]
1090
+ else:
1091
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1092
+
1093
+ past_key_values_length = 0
1094
+
1095
+ if self.gradient_checkpointing and self.training:
1096
+ if use_cache:
1097
+ logger.warning_once(
1098
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1099
+ )
1100
+ use_cache = False
1101
+
1102
+ if use_cache:
1103
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1104
+ if use_legacy_cache:
1105
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1106
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1107
+
1108
+ if position_ids is None:
1109
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1110
+ position_ids = torch.arange(
1111
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1112
+ )
1113
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1114
+ else:
1115
+ position_ids = position_ids.view(-1, seq_length).long()
1116
+
1117
+ if inputs_embeds is None:
1118
+ inputs_embeds = self.embed_tokens(input_ids)
1119
+
1120
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1121
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1122
+ if is_padding_right:
1123
+ raise ValueError(
1124
+ "You are attempting to perform batched generation with padding_side='right'"
1125
+ " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
1126
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1127
+ )
1128
+
1129
+ if self._attn_implementation == "flash_attention_2":
1130
+ # 2d mask is passed through the layers
1131
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1132
+ else:
1133
+ # 4d mask is passed through the layers
1134
+ attention_mask = _prepare_4d_causal_attention_mask(
1135
+ attention_mask,
1136
+ (batch_size, seq_length),
1137
+ inputs_embeds,
1138
+ past_key_values_length,
1139
+ sliding_window=self.config.sliding_window,
1140
+ )
1141
+
1142
+ hidden_states = inputs_embeds
1143
+
1144
+ # decoder layers
1145
+ all_hidden_states = () if output_hidden_states else None
1146
+ all_self_attns = () if output_attentions else None
1147
+ next_decoder_cache = None
1148
+
1149
+ for decoder_layer in self.layers:
1150
+ if output_hidden_states:
1151
+ all_hidden_states += (hidden_states,)
1152
+
1153
+ if self.gradient_checkpointing and self.training:
1154
+ layer_outputs = self._gradient_checkpointing_func(
1155
+ decoder_layer.__call__,
1156
+ hidden_states,
1157
+ attention_mask,
1158
+ position_ids,
1159
+ past_key_values,
1160
+ output_attentions,
1161
+ use_cache,
1162
+ )
1163
+ else:
1164
+ layer_outputs = decoder_layer(
1165
+ hidden_states,
1166
+ attention_mask=attention_mask,
1167
+ position_ids=position_ids,
1168
+ past_key_value=past_key_values,
1169
+ output_attentions=output_attentions,
1170
+ use_cache=use_cache,
1171
+ )
1172
+
1173
+ hidden_states = layer_outputs[0]
1174
+
1175
+ if use_cache:
1176
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1177
+
1178
+ if output_attentions:
1179
+ all_self_attns += (layer_outputs[1],)
1180
+
1181
+ hidden_states = self.norm(hidden_states)
1182
+
1183
+ # add hidden states from the last decoder layer
1184
+ if output_hidden_states:
1185
+ all_hidden_states += (hidden_states,)
1186
+
1187
+ next_cache = None
1188
+ if use_cache:
1189
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1190
+ if not return_dict:
1191
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1192
+ return BaseModelOutputWithPast(
1193
+ last_hidden_state=hidden_states,
1194
+ past_key_values=next_cache,
1195
+ hidden_states=all_hidden_states,
1196
+ attentions=all_self_attns,
1197
+ )
1198
+
1199
+
1200
+ class Phi3ForCausalLM(Phi3PreTrainedModel):
1201
+ _tied_weights_keys = ["lm_head.weight"]
1202
+
1203
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
1204
+ def __init__(self, config):
1205
+ super().__init__(config)
1206
+ self.model = Phi3Model(config)
1207
+ self.vocab_size = config.vocab_size
1208
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1209
+
1210
+ # Initialize weights and apply final processing
1211
+ self.post_init()
1212
+
1213
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1214
+ def get_input_embeddings(self):
1215
+ return self.model.embed_tokens
1216
+
1217
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1218
+ def set_input_embeddings(self, value):
1219
+ self.model.embed_tokens = value
1220
+
1221
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1222
+ def get_output_embeddings(self):
1223
+ return self.lm_head
1224
+
1225
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1226
+ def set_output_embeddings(self, new_embeddings):
1227
+ self.lm_head = new_embeddings
1228
+
1229
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1230
+ def set_decoder(self, decoder):
1231
+ self.model = decoder
1232
+
1233
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1234
+ def get_decoder(self):
1235
+ return self.model
1236
+
1237
+ # Ignore copy
1238
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1239
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1240
+ def forward(
1241
+ self,
1242
+ input_ids: torch.LongTensor = None,
1243
+ attention_mask: Optional[torch.Tensor] = None,
1244
+ position_ids: Optional[torch.LongTensor] = None,
1245
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1246
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1247
+ labels: Optional[torch.LongTensor] = None,
1248
+ use_cache: Optional[bool] = None,
1249
+ output_attentions: Optional[bool] = None,
1250
+ output_hidden_states: Optional[bool] = None,
1251
+ return_dict: Optional[bool] = None,
1252
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1253
+ r"""
1254
+ Args:
1255
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1256
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1257
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1258
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1259
+
1260
+ Returns:
1261
+
1262
+ Example:
1263
+
1264
+ ```python
1265
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1266
+
1267
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1268
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1269
+
1270
+ >>> prompt = "This is an example script ."
1271
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1272
+
1273
+ >>> # Generate
1274
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1275
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1276
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1277
+ ```"""
1278
+
1279
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1280
+ output_hidden_states = (
1281
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1282
+ )
1283
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1284
+
1285
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1286
+ outputs = self.model(
1287
+ input_ids=input_ids,
1288
+ attention_mask=attention_mask,
1289
+ position_ids=position_ids,
1290
+ past_key_values=past_key_values,
1291
+ inputs_embeds=inputs_embeds,
1292
+ use_cache=use_cache,
1293
+ output_attentions=output_attentions,
1294
+ output_hidden_states=output_hidden_states,
1295
+ return_dict=return_dict,
1296
+ )
1297
+
1298
+ hidden_states = outputs[0]
1299
+ logits = self.lm_head(hidden_states)
1300
+ logits = logits.float()
1301
+
1302
+ loss = None
1303
+ if labels is not None:
1304
+ # Shift so that tokens < n predict n
1305
+ shift_logits = logits[..., :-1, :].contiguous()
1306
+ shift_labels = labels[..., 1:].contiguous()
1307
+ # Flatten the tokens
1308
+ loss_fct = CrossEntropyLoss()
1309
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1310
+ shift_labels = shift_labels.view(-1)
1311
+ # Enable model parallelism
1312
+ shift_labels = shift_labels.to(shift_logits.device)
1313
+ loss = loss_fct(shift_logits, shift_labels)
1314
+
1315
+ if not return_dict:
1316
+ output = (logits,) + outputs[1:]
1317
+ return (loss,) + output if loss is not None else output
1318
+
1319
+ return CausalLMOutputWithPast(
1320
+ loss=loss,
1321
+ logits=logits,
1322
+ past_key_values=outputs.past_key_values,
1323
+ hidden_states=outputs.hidden_states,
1324
+ attentions=outputs.attentions,
1325
+ )
1326
+
1327
+ # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
1328
+ def prepare_inputs_for_generation(
1329
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1330
+ ):
1331
+ if past_key_values is not None:
1332
+ if isinstance(past_key_values, Cache):
1333
+ cache_length = past_key_values.get_seq_length()
1334
+ past_length = past_key_values.seen_tokens
1335
+ max_cache_length = past_key_values.get_max_length()
1336
+ else:
1337
+ cache_length = past_length = past_key_values[0][0].shape[2]
1338
+ max_cache_length = None
1339
+
1340
+ # Keep only the unprocessed tokens:
1341
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1342
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1343
+ # input)
1344
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1345
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1346
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1347
+ # input_ids based on the past_length.
1348
+ elif past_length < input_ids.shape[1]:
1349
+ input_ids = input_ids[:, past_length:]
1350
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1351
+
1352
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1353
+ if (
1354
+ max_cache_length is not None
1355
+ and attention_mask is not None
1356
+ and cache_length + input_ids.shape[1] > max_cache_length
1357
+ ):
1358
+ attention_mask = attention_mask[:, -max_cache_length:]
1359
+
1360
+ position_ids = kwargs.get("position_ids", None)
1361
+ if attention_mask is not None and position_ids is None:
1362
+ # create position_ids on the fly for batch generation
1363
+ position_ids = attention_mask.long().cumsum(-1) - 1
1364
+ position_ids.masked_fill_(attention_mask == 0, 1)
1365
+ if past_key_values:
1366
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1367
+
1368
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1369
+ if inputs_embeds is not None and past_key_values is None:
1370
+ model_inputs = {"inputs_embeds": inputs_embeds}
1371
+ else:
1372
+ model_inputs = {"input_ids": input_ids}
1373
+
1374
+ model_inputs.update(
1375
+ {
1376
+ "position_ids": position_ids,
1377
+ "past_key_values": past_key_values,
1378
+ "use_cache": kwargs.get("use_cache"),
1379
+ "attention_mask": attention_mask,
1380
+ }
1381
+ )
1382
+ return model_inputs
1383
+
1384
+ @staticmethod
1385
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1386
+ def _reorder_cache(past_key_values, beam_idx):
1387
+ reordered_past = ()
1388
+ for layer_past in past_key_values:
1389
+ reordered_past += (
1390
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1391
+ )
1392
+ return reordered_past
1393
+
1394
+
1395
+ @add_start_docstrings(
1396
+ """
1397
+ The [`Phi3Model`] with a sequence classification head on top (linear layer).
1398
+
1399
+ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1400
+ (e.g. GPT-2) do.
1401
+
1402
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1403
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1404
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1405
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1406
+ each row of the batch).
1407
+ """,
1408
+ PHI3_START_DOCSTRING,
1409
+ )
1410
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1411
+ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1412
+ def __init__(self, config):
1413
+ super().__init__(config)
1414
+ self.num_labels = config.num_labels
1415
+ self.model = Phi3Model(config)
1416
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1417
+
1418
+ # Initialize weights and apply final processing
1419
+ self.post_init()
1420
+
1421
+ def get_input_embeddings(self):
1422
+ return self.model.embed_tokens
1423
+
1424
+ def set_input_embeddings(self, value):
1425
+ self.model.embed_tokens = value
1426
+
1427
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1428
+ def forward(
1429
+ self,
1430
+ input_ids: torch.LongTensor = None,
1431
+ attention_mask: Optional[torch.Tensor] = None,
1432
+ position_ids: Optional[torch.LongTensor] = None,
1433
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1434
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1435
+ labels: Optional[torch.LongTensor] = None,
1436
+ use_cache: Optional[bool] = None,
1437
+ output_attentions: Optional[bool] = None,
1438
+ output_hidden_states: Optional[bool] = None,
1439
+ return_dict: Optional[bool] = None,
1440
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1441
+ r"""
1442
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1443
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1444
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1445
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1446
+ """
1447
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1448
+
1449
+ model_outputs = self.model(
1450
+ input_ids,
1451
+ attention_mask=attention_mask,
1452
+ position_ids=position_ids,
1453
+ past_key_values=past_key_values,
1454
+ inputs_embeds=inputs_embeds,
1455
+ use_cache=use_cache,
1456
+ output_attentions=output_attentions,
1457
+ output_hidden_states=output_hidden_states,
1458
+ return_dict=return_dict,
1459
+ )
1460
+ hidden_states = model_outputs[0]
1461
+ logits = self.score(hidden_states)
1462
+
1463
+ if input_ids is not None:
1464
+ batch_size = input_ids.shape[0]
1465
+ else:
1466
+ batch_size = inputs_embeds.shape[0]
1467
+
1468
+ if self.config.pad_token_id is None and batch_size != 1:
1469
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1470
+ if self.config.pad_token_id is None:
1471
+ sequence_lengths = -1
1472
+ else:
1473
+ if input_ids is not None:
1474
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1475
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1476
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1477
+ sequence_lengths = sequence_lengths.to(logits.device)
1478
+ else:
1479
+ sequence_lengths = -1
1480
+
1481
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1482
+
1483
+ loss = None
1484
+ if labels is not None:
1485
+ labels = labels.to(logits.device)
1486
+ if self.config.problem_type is None:
1487
+ if self.num_labels == 1:
1488
+ self.config.problem_type = "regression"
1489
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1490
+ self.config.problem_type = "single_label_classification"
1491
+ else:
1492
+ self.config.problem_type = "multi_label_classification"
1493
+
1494
+ if self.config.problem_type == "regression":
1495
+ loss_fct = MSELoss()
1496
+ if self.num_labels == 1:
1497
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1498
+ else:
1499
+ loss = loss_fct(pooled_logits, labels)
1500
+ elif self.config.problem_type == "single_label_classification":
1501
+ loss_fct = CrossEntropyLoss()
1502
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1503
+ elif self.config.problem_type == "multi_label_classification":
1504
+ loss_fct = BCEWithLogitsLoss()
1505
+ loss = loss_fct(pooled_logits, labels)
1506
+ if not return_dict:
1507
+ output = (pooled_logits,) + model_outputs[1:]
1508
+ return ((loss,) + output) if loss is not None else output
1509
+
1510
+ return SequenceClassifierOutputWithPast(
1511
+ loss=loss,
1512
+ logits=pooled_logits,
1513
+ past_key_values=model_outputs.past_key_values,
1514
+ hidden_states=model_outputs.hidden_states,
1515
+ attentions=model_outputs.attentions,
1516
+ )
1517
+
1518
+
1519
+ @add_start_docstrings(
1520
+ """
1521
+ [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1522
+ Named-Entity-Recognition (NER) tasks.
1523
+ """,
1524
+ PHI3_START_DOCSTRING,
1525
+ )
1526
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1527
+ class Phi3ForTokenClassification(Phi3PreTrainedModel):
1528
+ def __init__(self, config: Phi3Config):
1529
+ super().__init__(config)
1530
+ self.num_labels = config.num_labels
1531
+
1532
+ self.model = Phi3Model(config)
1533
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1534
+ classifier_dropout = config.classifier_dropout
1535
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1536
+ classifier_dropout = config.hidden_dropout
1537
+ else:
1538
+ classifier_dropout = 0.1
1539
+ self.dropout = nn.Dropout(classifier_dropout)
1540
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1541
+
1542
+ # Initialize weights and apply final processing
1543
+ self.post_init()
1544
+
1545
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1546
+ @add_code_sample_docstrings(
1547
+ checkpoint=_CHECKPOINT_FOR_DOC,
1548
+ output_type=TokenClassifierOutput,
1549
+ config_class=_CONFIG_FOR_DOC,
1550
+ )
1551
+ def forward(
1552
+ self,
1553
+ input_ids: Optional[torch.LongTensor] = None,
1554
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1555
+ attention_mask: Optional[torch.Tensor] = None,
1556
+ inputs_embeds: Optional[torch.Tensor] = None,
1557
+ labels: Optional[torch.Tensor] = None,
1558
+ use_cache: Optional[bool] = None,
1559
+ output_attentions: Optional[bool] = None,
1560
+ output_hidden_states: Optional[bool] = None,
1561
+ return_dict: Optional[bool] = None,
1562
+ **deprecated_arguments,
1563
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1564
+ r"""
1565
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1566
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1567
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1568
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1569
+ """
1570
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1571
+
1572
+ model_outputs = self.model(
1573
+ input_ids,
1574
+ past_key_values=past_key_values,
1575
+ attention_mask=attention_mask,
1576
+ inputs_embeds=inputs_embeds,
1577
+ use_cache=use_cache,
1578
+ output_attentions=output_attentions,
1579
+ output_hidden_states=output_hidden_states,
1580
+ return_dict=return_dict,
1581
+ )
1582
+
1583
+ hidden_states = model_outputs[0]
1584
+ hidden_states = self.dropout(hidden_states)
1585
+ logits = self.classifier(hidden_states)
1586
+
1587
+ loss = None
1588
+ if labels is not None:
1589
+ # move labels to correct device to enable model parallelism
1590
+ labels = labels.to(logits.device)
1591
+ batch_size, seq_length = labels.shape
1592
+ loss_fct = CrossEntropyLoss()
1593
+ loss = loss_fct(
1594
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1595
+ )
1596
+
1597
+ if not return_dict:
1598
+ output = (logits,) + model_outputs[2:]
1599
+ return ((loss,) + output) if loss is not None else output
1600
+
1601
+ return TokenClassifierOutput(
1602
+ loss=loss,
1603
+ logits=logits,
1604
+ hidden_states=model_outputs.hidden_states,
1605
+ attentions=model_outputs.attentions,
1606
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|/inst|>"
4
+ ],
5
+ "bos_token": {
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "eos_token": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "pad_token": {
20
+ "content": "<|endoftext|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": true,
26
+ "single_word": false,
27
+ "special": false
28
+ },
29
+ "32000": {
30
+ "content": "<|endoftext|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "32001": {
38
+ "content": "<|assistant|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": true,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "32002": {
46
+ "content": "<|step|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": true,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "32003": {
54
+ "content": "<|function_output|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": true,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "32004": {
62
+ "content": "<|tag|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": true,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "32005": {
70
+ "content": "<|function_call|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": true,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "32006": {
78
+ "content": "<|system|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": true,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "32007": {
86
+ "content": "<|end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": true,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "32008": {
94
+ "content": "<|raw|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": true,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "32009": {
102
+ "content": "<|continue|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": true,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "32010": {
110
+ "content": "<|user|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": true,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "32011": {
118
+ "content": "<|function_list|>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": true,
122
+ "single_word": false,
123
+ "special": true
124
+ },
125
+ "32012": {
126
+ "content": "<|calc|>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": true,
130
+ "single_word": false,
131
+ "special": true
132
+ },
133
+ "32013": {
134
+ "content": "<|code|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": true,
138
+ "single_word": false,
139
+ "special": true
140
+ },
141
+ "32014": {
142
+ "content": "<|/code|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": true,
146
+ "single_word": false,
147
+ "special": true
148
+ },
149
+ "32015": {
150
+ "content": "<|summary|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": true,
154
+ "single_word": false,
155
+ "special": true
156
+ },
157
+ "32016": {
158
+ "content": "<|resource|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": true,
162
+ "single_word": false,
163
+ "special": true
164
+ },
165
+ "32017": {
166
+ "content": "<|assistant_mask|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": true,
170
+ "single_word": false,
171
+ "special": true
172
+ },
173
+ "32018": {
174
+ "content": "<|start|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": true,
178
+ "single_word": false,
179
+ "special": true
180
+ },
181
+ "32019": {
182
+ "content": "<|message|>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": true,
186
+ "single_word": false,
187
+ "special": true
188
+ },
189
+ "32020": {
190
+ "content": "<|fim_prefix|>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": true,
194
+ "single_word": false,
195
+ "special": true
196
+ },
197
+ "32021": {
198
+ "content": "<|fim_middle|>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": true,
202
+ "single_word": false,
203
+ "special": true
204
+ },
205
+ "32022": {
206
+ "content": "<|fim_suffix|>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": true,
210
+ "single_word": false,
211
+ "special": true
212
+ },
213
+ "32023": {
214
+ "content": "<|meta_start|>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": true,
218
+ "single_word": false,
219
+ "special": true
220
+ },
221
+ "32024": {
222
+ "content": "<|ipynb_marker|>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": true,
226
+ "single_word": false,
227
+ "special": true
228
+ },
229
+ "32025": {
230
+ "content": "<|diff_marker|>",
231
+ "lstrip": false,
232
+ "normalized": false,
233
+ "rstrip": true,
234
+ "single_word": false,
235
+ "special": true
236
+ },
237
+ "32026": {
238
+ "content": "<|ghissue|>",
239
+ "lstrip": false,
240
+ "normalized": false,
241
+ "rstrip": true,
242
+ "single_word": false,
243
+ "special": true
244
+ },
245
+ "32027": {
246
+ "content": "<|ghreview|>",
247
+ "lstrip": false,
248
+ "normalized": false,
249
+ "rstrip": true,
250
+ "single_word": false,
251
+ "special": true
252
+ },
253
+ "32028": {
254
+ "content": "<|disc_start|>",
255
+ "lstrip": false,
256
+ "normalized": false,
257
+ "rstrip": true,
258
+ "single_word": false,
259
+ "special": true
260
+ },
261
+ "32029": {
262
+ "content": "<|disc_sep|>",
263
+ "lstrip": false,
264
+ "normalized": false,
265
+ "rstrip": true,
266
+ "single_word": false,
267
+ "special": true
268
+ },
269
+ "32030": {
270
+ "content": "<|disc_thread|><|query|>",
271
+ "lstrip": false,
272
+ "normalized": false,
273
+ "rstrip": true,
274
+ "single_word": false,
275
+ "special": true
276
+ },
277
+ "32031": {
278
+ "content": "<|/query|>",
279
+ "lstrip": false,
280
+ "normalized": false,
281
+ "rstrip": true,
282
+ "single_word": false,
283
+ "special": true
284
+ },
285
+ "32032": {
286
+ "content": "<|data|>",
287
+ "lstrip": false,
288
+ "normalized": false,
289
+ "rstrip": true,
290
+ "single_word": false,
291
+ "special": true
292
+ },
293
+ "32033": {
294
+ "content": "<|/data|>",
295
+ "lstrip": false,
296
+ "normalized": false,
297
+ "rstrip": true,
298
+ "single_word": false,
299
+ "special": true
300
+ },
301
+ "32034": {
302
+ "content": "<|sys|>",
303
+ "lstrip": false,
304
+ "normalized": false,
305
+ "rstrip": true,
306
+ "single_word": false,
307
+ "special": true
308
+ },
309
+ "32035": {
310
+ "content": "<|/sys|>",
311
+ "lstrip": false,
312
+ "normalized": false,
313
+ "rstrip": true,
314
+ "single_word": false,
315
+ "special": true
316
+ },
317
+ "32036": {
318
+ "content": "<|inst|>",
319
+ "lstrip": false,
320
+ "normalized": false,
321
+ "rstrip": true,
322
+ "single_word": false,
323
+ "special": true
324
+ },
325
+ "32037": {
326
+ "content": "<|/inst|>",
327
+ "lstrip": false,
328
+ "normalized": false,
329
+ "rstrip": true,
330
+ "single_word": false,
331
+ "special": true
332
+ }
333
+ },
334
+ "additional_special_tokens": [
335
+ "<|/inst|>"
336
+ ],
337
+ "bos_token": "<s>",
338
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message + '\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ 'Human: ' + content + '\\nAssistant: ' }}{% elif message['role'] == 'assistant' %}{{ content + '<|endoftext|>' + '\\n' }}{% endif %}{% endfor %}",
339
+ "clean_up_tokenization_spaces": false,
340
+ "eos_token": "<|endoftext|>",
341
+ "legacy": false,
342
+ "model_max_length": 131072,
343
+ "pad_token": "<|endoftext|>",
344
+ "padding_side": "right",
345
+ "sp_model_kwargs": {},
346
+ "split_special_tokens": false,
347
+ "tokenizer_class": "LlamaTokenizer",
348
+ "unk_token": "<unk>",
349
+ "use_default_system_prompt": false
350
+ }
train_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 2.994601079784043,
3
+ "total_flos": 132590267662336.0,
4
+ "train_loss": 0.7937506708579186,
5
+ "train_runtime": 49781.9259,
6
+ "train_samples_per_second": 1.205,
7
+ "train_steps_per_second": 0.025
8
+ }
trainer_state.json ADDED
@@ -0,0 +1,2138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 2.994601079784043,
5
+ "eval_steps": 500,
6
+ "global_step": 1248,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.02399520095980804,
13
+ "grad_norm": 24.58741331565172,
14
+ "learning_rate": 1.0000000000000002e-06,
15
+ "logits/chosen": -0.5075146555900574,
16
+ "logits/rejected": -0.31934085488319397,
17
+ "logps/chosen": -1.394007921218872,
18
+ "logps/rejected": -1.3630257844924927,
19
+ "loss": 1.3501,
20
+ "odds_ratio_loss": 0.8239962458610535,
21
+ "rewards/accuracies": 0.5874999761581421,
22
+ "rewards/chosen": -0.06970040500164032,
23
+ "rewards/margins": -0.0015491036465391517,
24
+ "rewards/rejected": -0.06815129518508911,
25
+ "sft_loss": 1.394007921218872,
26
+ "step": 10
27
+ },
28
+ {
29
+ "epoch": 0.04799040191961608,
30
+ "grad_norm": 4.281683015852783,
31
+ "learning_rate": 3.5e-06,
32
+ "logits/chosen": 0.08614908158779144,
33
+ "logits/rejected": 0.3013238310813904,
34
+ "logps/chosen": -1.3080074787139893,
35
+ "logps/rejected": -1.334457278251648,
36
+ "loss": 1.2858,
37
+ "odds_ratio_loss": 0.7804475426673889,
38
+ "rewards/accuracies": 0.5249999761581421,
39
+ "rewards/chosen": -0.0654003769159317,
40
+ "rewards/margins": 0.0013224859721958637,
41
+ "rewards/rejected": -0.06672286242246628,
42
+ "sft_loss": 1.3080074787139893,
43
+ "step": 20
44
+ },
45
+ {
46
+ "epoch": 0.07198560287942411,
47
+ "grad_norm": 3.830958349381369,
48
+ "learning_rate": 4.99986910314335e-06,
49
+ "logits/chosen": 0.3485943675041199,
50
+ "logits/rejected": 0.6042150855064392,
51
+ "logps/chosen": -0.9540683627128601,
52
+ "logps/rejected": -1.1750730276107788,
53
+ "loss": 0.9904,
54
+ "odds_ratio_loss": 0.6533687710762024,
55
+ "rewards/accuracies": 0.6000000238418579,
56
+ "rewards/chosen": -0.047703422605991364,
57
+ "rewards/margins": 0.011050237342715263,
58
+ "rewards/rejected": -0.05875365808606148,
59
+ "sft_loss": 0.9540683627128601,
60
+ "step": 30
61
+ },
62
+ {
63
+ "epoch": 0.09598080383923216,
64
+ "grad_norm": 3.6776666943951675,
65
+ "learning_rate": 4.998396670920005e-06,
66
+ "logits/chosen": 0.17601105570793152,
67
+ "logits/rejected": 0.5272272229194641,
68
+ "logps/chosen": -0.898045539855957,
69
+ "logps/rejected": -1.0136868953704834,
70
+ "loss": 0.9614,
71
+ "odds_ratio_loss": 0.6860688328742981,
72
+ "rewards/accuracies": 0.5375000238418579,
73
+ "rewards/chosen": -0.04490227997303009,
74
+ "rewards/margins": 0.005782057531177998,
75
+ "rewards/rejected": -0.05068434029817581,
76
+ "sft_loss": 0.898045539855957,
77
+ "step": 40
78
+ },
79
+ {
80
+ "epoch": 0.11997600479904019,
81
+ "grad_norm": 2.636908991979515,
82
+ "learning_rate": 4.995289152254744e-06,
83
+ "logits/chosen": 0.2309066355228424,
84
+ "logits/rejected": 0.22152824699878693,
85
+ "logps/chosen": -0.9074997901916504,
86
+ "logps/rejected": -1.0551084280014038,
87
+ "loss": 0.9374,
88
+ "odds_ratio_loss": 0.663613498210907,
89
+ "rewards/accuracies": 0.48750001192092896,
90
+ "rewards/chosen": -0.04537498578429222,
91
+ "rewards/margins": 0.007380434311926365,
92
+ "rewards/rejected": -0.05275542289018631,
93
+ "sft_loss": 0.9074997901916504,
94
+ "step": 50
95
+ },
96
+ {
97
+ "epoch": 0.14397120575884823,
98
+ "grad_norm": 1.8300107701302537,
99
+ "learning_rate": 4.990548580876516e-06,
100
+ "logits/chosen": 0.307407021522522,
101
+ "logits/rejected": 0.37507694959640503,
102
+ "logps/chosen": -0.9279610514640808,
103
+ "logps/rejected": -0.986476719379425,
104
+ "loss": 0.9464,
105
+ "odds_ratio_loss": 0.7063499093055725,
106
+ "rewards/accuracies": 0.6499999761581421,
107
+ "rewards/chosen": -0.04639805108308792,
108
+ "rewards/margins": 0.00292578199878335,
109
+ "rewards/rejected": -0.04932383447885513,
110
+ "sft_loss": 0.9279610514640808,
111
+ "step": 60
112
+ },
113
+ {
114
+ "epoch": 0.16796640671865626,
115
+ "grad_norm": 3.8157191209486507,
116
+ "learning_rate": 4.9841780592726385e-06,
117
+ "logits/chosen": 0.19509825110435486,
118
+ "logits/rejected": 0.2650177776813507,
119
+ "logps/chosen": -0.9848098754882812,
120
+ "logps/rejected": -1.0149097442626953,
121
+ "loss": 0.9578,
122
+ "odds_ratio_loss": 0.726799488067627,
123
+ "rewards/accuracies": 0.5625,
124
+ "rewards/chosen": -0.04924049228429794,
125
+ "rewards/margins": 0.0015049913199618459,
126
+ "rewards/rejected": -0.050745487213134766,
127
+ "sft_loss": 0.9848098754882812,
128
+ "step": 70
129
+ },
130
+ {
131
+ "epoch": 0.19196160767846432,
132
+ "grad_norm": 4.078587531391316,
133
+ "learning_rate": 4.976181756658363e-06,
134
+ "logits/chosen": 0.061622969806194305,
135
+ "logits/rejected": 0.2444450408220291,
136
+ "logps/chosen": -0.8894473910331726,
137
+ "logps/rejected": -1.0614734888076782,
138
+ "loss": 0.9675,
139
+ "odds_ratio_loss": 0.6382969617843628,
140
+ "rewards/accuracies": 0.550000011920929,
141
+ "rewards/chosen": -0.04447237029671669,
142
+ "rewards/margins": 0.008601305074989796,
143
+ "rewards/rejected": -0.05307367444038391,
144
+ "sft_loss": 0.8894473910331726,
145
+ "step": 80
146
+ },
147
+ {
148
+ "epoch": 0.21595680863827235,
149
+ "grad_norm": 2.9874023740770363,
150
+ "learning_rate": 4.9665649062483115e-06,
151
+ "logits/chosen": 0.6337467432022095,
152
+ "logits/rejected": 0.7902036905288696,
153
+ "logps/chosen": -0.9439412951469421,
154
+ "logps/rejected": -0.9588793516159058,
155
+ "loss": 0.9635,
156
+ "odds_ratio_loss": 0.7716476917266846,
157
+ "rewards/accuracies": 0.44999998807907104,
158
+ "rewards/chosen": -0.047197069972753525,
159
+ "rewards/margins": 0.0007468975381925702,
160
+ "rewards/rejected": -0.047943972051143646,
161
+ "sft_loss": 0.9439412951469421,
162
+ "step": 90
163
+ },
164
+ {
165
+ "epoch": 0.23995200959808038,
166
+ "grad_norm": 2.3029148332001745,
167
+ "learning_rate": 4.955333801831578e-06,
168
+ "logits/chosen": 0.49920982122421265,
169
+ "logits/rejected": 0.6337569355964661,
170
+ "logps/chosen": -0.8333128094673157,
171
+ "logps/rejected": -1.059599757194519,
172
+ "loss": 0.9453,
173
+ "odds_ratio_loss": 0.6517213582992554,
174
+ "rewards/accuracies": 0.5625,
175
+ "rewards/chosen": -0.041665639728307724,
176
+ "rewards/margins": 0.011314347386360168,
177
+ "rewards/rejected": -0.05297998711466789,
178
+ "sft_loss": 0.8333128094673157,
179
+ "step": 100
180
+ },
181
+ {
182
+ "epoch": 0.26394721055788845,
183
+ "grad_norm": 2.8766587489414395,
184
+ "learning_rate": 4.9424957936527295e-06,
185
+ "logits/chosen": -0.28645992279052734,
186
+ "logits/rejected": 0.04107431694865227,
187
+ "logps/chosen": -0.9429195523262024,
188
+ "logps/rejected": -0.9936224222183228,
189
+ "loss": 0.9526,
190
+ "odds_ratio_loss": 0.705885112285614,
191
+ "rewards/accuracies": 0.5249999761581421,
192
+ "rewards/chosen": -0.04714598134160042,
193
+ "rewards/margins": 0.002535139676183462,
194
+ "rewards/rejected": -0.04968111589550972,
195
+ "sft_loss": 0.9429195523262024,
196
+ "step": 110
197
+ },
198
+ {
199
+ "epoch": 0.28794241151769645,
200
+ "grad_norm": 2.1411106644617703,
201
+ "learning_rate": 4.92805928360141e-06,
202
+ "logits/chosen": -0.29608479142189026,
203
+ "logits/rejected": -0.21111997961997986,
204
+ "logps/chosen": -0.888851523399353,
205
+ "logps/rejected": -1.0842912197113037,
206
+ "loss": 0.8904,
207
+ "odds_ratio_loss": 0.5968859195709229,
208
+ "rewards/accuracies": 0.6875,
209
+ "rewards/chosen": -0.04444257169961929,
210
+ "rewards/margins": 0.009771987795829773,
211
+ "rewards/rejected": -0.054214559495449066,
212
+ "sft_loss": 0.888851523399353,
213
+ "step": 120
214
+ },
215
+ {
216
+ "epoch": 0.3119376124775045,
217
+ "grad_norm": 2.1891227152981347,
218
+ "learning_rate": 4.912033719713687e-06,
219
+ "logits/chosen": 0.49228960275650024,
220
+ "logits/rejected": 0.5680336952209473,
221
+ "logps/chosen": -0.9152839779853821,
222
+ "logps/rejected": -1.0058788061141968,
223
+ "loss": 0.9427,
224
+ "odds_ratio_loss": 0.6943625807762146,
225
+ "rewards/accuracies": 0.574999988079071,
226
+ "rewards/chosen": -0.04576420038938522,
227
+ "rewards/margins": 0.004529745317995548,
228
+ "rewards/rejected": -0.0502939410507679,
229
+ "sft_loss": 0.9152839779853821,
230
+ "step": 130
231
+ },
232
+ {
233
+ "epoch": 0.3359328134373125,
234
+ "grad_norm": 2.5131225459939,
235
+ "learning_rate": 4.894429589988739e-06,
236
+ "logits/chosen": -1.2468726634979248,
237
+ "logits/rejected": -1.0485397577285767,
238
+ "logps/chosen": -1.0104249715805054,
239
+ "logps/rejected": -1.0477244853973389,
240
+ "loss": 0.949,
241
+ "odds_ratio_loss": 0.7160865068435669,
242
+ "rewards/accuracies": 0.512499988079071,
243
+ "rewards/chosen": -0.05052124708890915,
244
+ "rewards/margins": 0.0018649749690666795,
245
+ "rewards/rejected": -0.05238622426986694,
246
+ "sft_loss": 1.0104249715805054,
247
+ "step": 140
248
+ },
249
+ {
250
+ "epoch": 0.3599280143971206,
251
+ "grad_norm": 2.696319834123575,
252
+ "learning_rate": 4.875258415524945e-06,
253
+ "logits/chosen": 0.039508234709501266,
254
+ "logits/rejected": 0.23594827950000763,
255
+ "logps/chosen": -0.904223620891571,
256
+ "logps/rejected": -1.032157063484192,
257
+ "loss": 0.9533,
258
+ "odds_ratio_loss": 0.6739581823348999,
259
+ "rewards/accuracies": 0.5625,
260
+ "rewards/chosen": -0.04521118476986885,
261
+ "rewards/margins": 0.0063966671004891396,
262
+ "rewards/rejected": -0.051607854664325714,
263
+ "sft_loss": 0.904223620891571,
264
+ "step": 150
265
+ },
266
+ {
267
+ "epoch": 0.38392321535692864,
268
+ "grad_norm": 2.241170193835809,
269
+ "learning_rate": 4.85453274297985e-06,
270
+ "logits/chosen": 0.4507044851779938,
271
+ "logits/rejected": 0.7088828682899475,
272
+ "logps/chosen": -0.9252007603645325,
273
+ "logps/rejected": -1.0105345249176025,
274
+ "loss": 0.9187,
275
+ "odds_ratio_loss": 0.6664329171180725,
276
+ "rewards/accuracies": 0.5625,
277
+ "rewards/chosen": -0.0462600402534008,
278
+ "rewards/margins": 0.004266692791134119,
279
+ "rewards/rejected": -0.050526730716228485,
280
+ "sft_loss": 0.9252007603645325,
281
+ "step": 160
282
+ },
283
+ {
284
+ "epoch": 0.40791841631673664,
285
+ "grad_norm": 1.759854296483571,
286
+ "learning_rate": 4.832266136358951e-06,
287
+ "logits/chosen": -0.12876208126544952,
288
+ "logits/rejected": 0.014335835352540016,
289
+ "logps/chosen": -0.8540490865707397,
290
+ "logps/rejected": -0.9863293766975403,
291
+ "loss": 0.926,
292
+ "odds_ratio_loss": 0.6714656352996826,
293
+ "rewards/accuracies": 0.6000000238418579,
294
+ "rewards/chosen": -0.04270245134830475,
295
+ "rewards/margins": 0.006614011712372303,
296
+ "rewards/rejected": -0.04931646212935448,
297
+ "sft_loss": 0.8540490865707397,
298
+ "step": 170
299
+ },
300
+ {
301
+ "epoch": 0.4319136172765447,
302
+ "grad_norm": 2.793191882203603,
303
+ "learning_rate": 4.808473168138675e-06,
304
+ "logits/chosen": 0.3617595136165619,
305
+ "logits/rejected": 0.3396950364112854,
306
+ "logps/chosen": -0.8613064885139465,
307
+ "logps/rejected": -1.0067331790924072,
308
+ "loss": 0.9162,
309
+ "odds_ratio_loss": 0.6582903861999512,
310
+ "rewards/accuracies": 0.5874999761581421,
311
+ "rewards/chosen": -0.04306532442569733,
312
+ "rewards/margins": 0.007271329872310162,
313
+ "rewards/rejected": -0.050336651504039764,
314
+ "sft_loss": 0.8613064885139465,
315
+ "step": 180
316
+ },
317
+ {
318
+ "epoch": 0.4559088182363527,
319
+ "grad_norm": 1.7774141067161418,
320
+ "learning_rate": 4.783169409729363e-06,
321
+ "logits/chosen": 0.9685203433036804,
322
+ "logits/rejected": 1.1009634733200073,
323
+ "logps/chosen": -0.8521540760993958,
324
+ "logps/rejected": -0.9150575399398804,
325
+ "loss": 0.9004,
326
+ "odds_ratio_loss": 0.7224193811416626,
327
+ "rewards/accuracies": 0.5375000238418579,
328
+ "rewards/chosen": -0.04260770231485367,
329
+ "rewards/margins": 0.0031451724935323,
330
+ "rewards/rejected": -0.0457528755068779,
331
+ "sft_loss": 0.8521540760993958,
332
+ "step": 190
333
+ },
334
+ {
335
+ "epoch": 0.47990401919616077,
336
+ "grad_norm": 2.052107783396207,
337
+ "learning_rate": 4.756371421284482e-06,
338
+ "logits/chosen": 0.33597105741500854,
339
+ "logits/rejected": 0.44187426567077637,
340
+ "logps/chosen": -0.8725342750549316,
341
+ "logps/rejected": -0.9003400802612305,
342
+ "loss": 0.919,
343
+ "odds_ratio_loss": 0.7135496735572815,
344
+ "rewards/accuracies": 0.574999988079071,
345
+ "rewards/chosen": -0.04362671449780464,
346
+ "rewards/margins": 0.0013902939390391111,
347
+ "rewards/rejected": -0.04501700773835182,
348
+ "sft_loss": 0.8725342750549316,
349
+ "step": 200
350
+ },
351
+ {
352
+ "epoch": 0.5038992201559688,
353
+ "grad_norm": 2.3000145040966973,
354
+ "learning_rate": 4.728096740862778e-06,
355
+ "logits/chosen": 0.16287042200565338,
356
+ "logits/rejected": 0.35098087787628174,
357
+ "logps/chosen": -0.8514264822006226,
358
+ "logps/rejected": -0.9913795590400696,
359
+ "loss": 0.9096,
360
+ "odds_ratio_loss": 0.6634506583213806,
361
+ "rewards/accuracies": 0.5874999761581421,
362
+ "rewards/chosen": -0.042571328580379486,
363
+ "rewards/margins": 0.006997650023549795,
364
+ "rewards/rejected": -0.04956897348165512,
365
+ "sft_loss": 0.8514264822006226,
366
+ "step": 210
367
+ },
368
+ {
369
+ "epoch": 0.5278944211157769,
370
+ "grad_norm": 1.581079267248328,
371
+ "learning_rate": 4.698363872950406e-06,
372
+ "logits/chosen": 0.298981636762619,
373
+ "logits/rejected": 0.49268895387649536,
374
+ "logps/chosen": -0.8895601034164429,
375
+ "logps/rejected": -1.026539921760559,
376
+ "loss": 0.8744,
377
+ "odds_ratio_loss": 0.6685082316398621,
378
+ "rewards/accuracies": 0.612500011920929,
379
+ "rewards/chosen": -0.04447800666093826,
380
+ "rewards/margins": 0.0068489923141896725,
381
+ "rewards/rejected": -0.051326997578144073,
382
+ "sft_loss": 0.8895601034164429,
383
+ "step": 220
384
+ },
385
+ {
386
+ "epoch": 0.5518896220755849,
387
+ "grad_norm": 1.7094822098553022,
388
+ "learning_rate": 4.6671922763505915e-06,
389
+ "logits/chosen": 0.34609514474868774,
390
+ "logits/rejected": 0.5052930116653442,
391
+ "logps/chosen": -0.863084614276886,
392
+ "logps/rejected": -0.9836879968643188,
393
+ "loss": 0.8905,
394
+ "odds_ratio_loss": 0.6813028454780579,
395
+ "rewards/accuracies": 0.550000011920929,
396
+ "rewards/chosen": -0.043154239654541016,
397
+ "rewards/margins": 0.006030158139765263,
398
+ "rewards/rejected": -0.049184400588274,
399
+ "sft_loss": 0.863084614276886,
400
+ "step": 230
401
+ },
402
+ {
403
+ "epoch": 0.5758848230353929,
404
+ "grad_norm": 1.9367159826113498,
405
+ "learning_rate": 4.634602351448738e-06,
406
+ "logits/chosen": 0.286350816488266,
407
+ "logits/rejected": 0.3788919448852539,
408
+ "logps/chosen": -0.8919585943222046,
409
+ "logps/rejected": -0.9452742338180542,
410
+ "loss": 0.9133,
411
+ "odds_ratio_loss": 0.6905114650726318,
412
+ "rewards/accuracies": 0.612500011920929,
413
+ "rewards/chosen": -0.04459793120622635,
414
+ "rewards/margins": 0.0026657807175070047,
415
+ "rewards/rejected": -0.04726371169090271,
416
+ "sft_loss": 0.8919585943222046,
417
+ "step": 240
418
+ },
419
+ {
420
+ "epoch": 0.5998800239952009,
421
+ "grad_norm": 2.0772847936555636,
422
+ "learning_rate": 4.6006154268613015e-06,
423
+ "logits/chosen": 0.4635019898414612,
424
+ "logits/rejected": 0.5444530248641968,
425
+ "logps/chosen": -0.8181222081184387,
426
+ "logps/rejected": -0.9908831715583801,
427
+ "loss": 0.8927,
428
+ "odds_ratio_loss": 0.6295598149299622,
429
+ "rewards/accuracies": 0.637499988079071,
430
+ "rewards/chosen": -0.04090610891580582,
431
+ "rewards/margins": 0.008638045750558376,
432
+ "rewards/rejected": -0.04954415559768677,
433
+ "sft_loss": 0.8181222081184387,
434
+ "step": 250
435
+ },
436
+ {
437
+ "epoch": 0.623875224955009,
438
+ "grad_norm": 2.084215689408855,
439
+ "learning_rate": 4.565253745477187e-06,
440
+ "logits/chosen": 0.40253886580467224,
441
+ "logits/rejected": 0.4625183045864105,
442
+ "logps/chosen": -0.9301355481147766,
443
+ "logps/rejected": -1.0306508541107178,
444
+ "loss": 0.9162,
445
+ "odds_ratio_loss": 0.6872043609619141,
446
+ "rewards/accuracies": 0.5249999761581421,
447
+ "rewards/chosen": -0.04650677740573883,
448
+ "rewards/margins": 0.005025765858590603,
449
+ "rewards/rejected": -0.05153254419565201,
450
+ "sft_loss": 0.9301355481147766,
451
+ "step": 260
452
+ },
453
+ {
454
+ "epoch": 0.647870425914817,
455
+ "grad_norm": 1.9031984888179019,
456
+ "learning_rate": 4.528540449900799e-06,
457
+ "logits/chosen": 0.4078219532966614,
458
+ "logits/rejected": 0.6789823174476624,
459
+ "logps/chosen": -0.8785255551338196,
460
+ "logps/rejected": -0.9139087796211243,
461
+ "loss": 0.9176,
462
+ "odds_ratio_loss": 0.7333613038063049,
463
+ "rewards/accuracies": 0.550000011920929,
464
+ "rewards/chosen": -0.04392627626657486,
465
+ "rewards/margins": 0.0017691642278805375,
466
+ "rewards/rejected": -0.04569543898105621,
467
+ "sft_loss": 0.8785255551338196,
468
+ "step": 270
469
+ },
470
+ {
471
+ "epoch": 0.671865626874625,
472
+ "grad_norm": 2.3067419173621113,
473
+ "learning_rate": 4.490499567306256e-06,
474
+ "logits/chosen": 0.304252564907074,
475
+ "logits/rejected": 0.5160123109817505,
476
+ "logps/chosen": -0.8951358795166016,
477
+ "logps/rejected": -0.9636558294296265,
478
+ "loss": 0.8917,
479
+ "odds_ratio_loss": 0.69621342420578,
480
+ "rewards/accuracies": 0.512499988079071,
481
+ "rewards/chosen": -0.04475679248571396,
482
+ "rewards/margins": 0.0034259993117302656,
483
+ "rewards/rejected": -0.04818279296159744,
484
+ "sft_loss": 0.8951358795166016,
485
+ "step": 280
486
+ },
487
+ {
488
+ "epoch": 0.6958608278344331,
489
+ "grad_norm": 3.1297290877323003,
490
+ "learning_rate": 4.451155993712711e-06,
491
+ "logits/chosen": 0.25184166431427,
492
+ "logits/rejected": 0.43299436569213867,
493
+ "logps/chosen": -0.808620810508728,
494
+ "logps/rejected": -0.9780584573745728,
495
+ "loss": 0.9379,
496
+ "odds_ratio_loss": 0.6151310205459595,
497
+ "rewards/accuracies": 0.675000011920929,
498
+ "rewards/chosen": -0.04043104499578476,
499
+ "rewards/margins": 0.008471880108118057,
500
+ "rewards/rejected": -0.048902928829193115,
501
+ "sft_loss": 0.808620810508728,
502
+ "step": 290
503
+ },
504
+ {
505
+ "epoch": 0.7198560287942412,
506
+ "grad_norm": 2.001570442654457,
507
+ "learning_rate": 4.410535477691041e-06,
508
+ "logits/chosen": 0.6736063957214355,
509
+ "logits/rejected": 0.8922637104988098,
510
+ "logps/chosen": -0.8743098974227905,
511
+ "logps/rejected": -1.0198915004730225,
512
+ "loss": 0.8962,
513
+ "odds_ratio_loss": 0.6545746326446533,
514
+ "rewards/accuracies": 0.625,
515
+ "rewards/chosen": -0.043715499341487885,
516
+ "rewards/margins": 0.0072790831327438354,
517
+ "rewards/rejected": -0.05099458247423172,
518
+ "sft_loss": 0.8743098974227905,
519
+ "step": 300
520
+ },
521
+ {
522
+ "epoch": 0.7438512297540492,
523
+ "grad_norm": 3.088640251108737,
524
+ "learning_rate": 4.368664603512586e-06,
525
+ "logits/chosen": -0.10074709355831146,
526
+ "logits/rejected": 0.08682968467473984,
527
+ "logps/chosen": -0.7929955720901489,
528
+ "logps/rejected": -0.9449365735054016,
529
+ "loss": 0.8789,
530
+ "odds_ratio_loss": 0.6474851369857788,
531
+ "rewards/accuracies": 0.5874999761581421,
532
+ "rewards/chosen": -0.03964977711439133,
533
+ "rewards/margins": 0.007597046438604593,
534
+ "rewards/rejected": -0.047246821224689484,
535
+ "sft_loss": 0.7929955720901489,
536
+ "step": 310
537
+ },
538
+ {
539
+ "epoch": 0.7678464307138573,
540
+ "grad_norm": 2.278875813822025,
541
+ "learning_rate": 4.325570773750952e-06,
542
+ "logits/chosen": -0.22130906581878662,
543
+ "logits/rejected": -0.028980206698179245,
544
+ "logps/chosen": -0.8826779127120972,
545
+ "logps/rejected": -1.0213041305541992,
546
+ "loss": 0.9204,
547
+ "odds_ratio_loss": 0.6443883180618286,
548
+ "rewards/accuracies": 0.612500011920929,
549
+ "rewards/chosen": -0.04413389414548874,
550
+ "rewards/margins": 0.006931307725608349,
551
+ "rewards/rejected": -0.05106520652770996,
552
+ "sft_loss": 0.8826779127120972,
553
+ "step": 320
554
+ },
555
+ {
556
+ "epoch": 0.7918416316736653,
557
+ "grad_norm": 1.6952516043840655,
558
+ "learning_rate": 4.281282191348289e-06,
559
+ "logits/chosen": 0.45927032828330994,
560
+ "logits/rejected": 0.6593443751335144,
561
+ "logps/chosen": -0.8378440141677856,
562
+ "logps/rejected": -0.9682254791259766,
563
+ "loss": 0.8995,
564
+ "odds_ratio_loss": 0.6620376110076904,
565
+ "rewards/accuracies": 0.625,
566
+ "rewards/chosen": -0.04189220070838928,
567
+ "rewards/margins": 0.006519075483083725,
568
+ "rewards/rejected": -0.04841126874089241,
569
+ "sft_loss": 0.8378440141677856,
570
+ "step": 330
571
+ },
572
+ {
573
+ "epoch": 0.8158368326334733,
574
+ "grad_norm": 2.4806806819218794,
575
+ "learning_rate": 4.235827841157748e-06,
576
+ "logits/chosen": 0.01970214769244194,
577
+ "logits/rejected": 0.11670324951410294,
578
+ "logps/chosen": -0.8856766819953918,
579
+ "logps/rejected": -1.0817759037017822,
580
+ "loss": 0.8834,
581
+ "odds_ratio_loss": 0.6194185018539429,
582
+ "rewards/accuracies": 0.637499988079071,
583
+ "rewards/chosen": -0.04428383335471153,
584
+ "rewards/margins": 0.009804959408938885,
585
+ "rewards/rejected": -0.054088789969682693,
586
+ "sft_loss": 0.8856766819953918,
587
+ "step": 340
588
+ },
589
+ {
590
+ "epoch": 0.8398320335932813,
591
+ "grad_norm": 1.5265892877639438,
592
+ "learning_rate": 4.1892374709742186e-06,
593
+ "logits/chosen": -0.7483745813369751,
594
+ "logits/rejected": -0.42045336961746216,
595
+ "logps/chosen": -0.7948485016822815,
596
+ "logps/rejected": -0.9918915033340454,
597
+ "loss": 0.9474,
598
+ "odds_ratio_loss": 0.5842909812927246,
599
+ "rewards/accuracies": 0.637499988079071,
600
+ "rewards/chosen": -0.03974242880940437,
601
+ "rewards/margins": 0.009852146729826927,
602
+ "rewards/rejected": -0.04959457367658615,
603
+ "sft_loss": 0.7948485016822815,
604
+ "step": 350
605
+ },
606
+ {
607
+ "epoch": 0.8638272345530894,
608
+ "grad_norm": 2.1051154185205543,
609
+ "learning_rate": 4.141541572065762e-06,
610
+ "logits/chosen": 0.41192498803138733,
611
+ "logits/rejected": 0.5341157913208008,
612
+ "logps/chosen": -0.7971394658088684,
613
+ "logps/rejected": -0.9216561317443848,
614
+ "loss": 0.8881,
615
+ "odds_ratio_loss": 0.69920814037323,
616
+ "rewards/accuracies": 0.5249999761581421,
617
+ "rewards/chosen": -0.03985697776079178,
618
+ "rewards/margins": 0.0062258280813694,
619
+ "rewards/rejected": -0.04608280584216118,
620
+ "sft_loss": 0.7971394658088684,
621
+ "step": 360
622
+ },
623
+ {
624
+ "epoch": 0.8878224355128974,
625
+ "grad_norm": 2.049071087536336,
626
+ "learning_rate": 4.092771359218462e-06,
627
+ "logits/chosen": 0.2649831771850586,
628
+ "logits/rejected": 0.45568495988845825,
629
+ "logps/chosen": -0.8466150164604187,
630
+ "logps/rejected": -1.0025365352630615,
631
+ "loss": 0.9065,
632
+ "odds_ratio_loss": 0.629971444606781,
633
+ "rewards/accuracies": 0.625,
634
+ "rewards/chosen": -0.042330749332904816,
635
+ "rewards/margins": 0.007796071469783783,
636
+ "rewards/rejected": -0.0501268208026886,
637
+ "sft_loss": 0.8466150164604187,
638
+ "step": 370
639
+ },
640
+ {
641
+ "epoch": 0.9118176364727054,
642
+ "grad_norm": 3.597524104140319,
643
+ "learning_rate": 4.04295875030778e-06,
644
+ "logits/chosen": -0.18752217292785645,
645
+ "logits/rejected": 0.15378537774085999,
646
+ "logps/chosen": -0.8704308271408081,
647
+ "logps/rejected": -0.9513336420059204,
648
+ "loss": 0.9014,
649
+ "odds_ratio_loss": 0.6948253512382507,
650
+ "rewards/accuracies": 0.574999988079071,
651
+ "rewards/chosen": -0.043521542102098465,
652
+ "rewards/margins": 0.004045139066874981,
653
+ "rewards/rejected": -0.04756668210029602,
654
+ "sft_loss": 0.8704308271408081,
655
+ "step": 380
656
+ },
657
+ {
658
+ "epoch": 0.9358128374325135,
659
+ "grad_norm": 3.1405630532603395,
660
+ "learning_rate": 3.992136345409765e-06,
661
+ "logits/chosen": -0.1735876053571701,
662
+ "logits/rejected": -0.20124337077140808,
663
+ "logps/chosen": -0.9253339767456055,
664
+ "logps/rejected": -1.0305973291397095,
665
+ "loss": 0.9111,
666
+ "odds_ratio_loss": 0.6636070013046265,
667
+ "rewards/accuracies": 0.6000000238418579,
668
+ "rewards/chosen": -0.04626670479774475,
669
+ "rewards/margins": 0.005263164173811674,
670
+ "rewards/rejected": -0.051529865711927414,
671
+ "sft_loss": 0.9253339767456055,
672
+ "step": 390
673
+ },
674
+ {
675
+ "epoch": 0.9598080383923215,
676
+ "grad_norm": 2.4716790122788983,
677
+ "learning_rate": 3.940337405465786e-06,
678
+ "logits/chosen": 0.26361703872680664,
679
+ "logits/rejected": 0.44345617294311523,
680
+ "logps/chosen": -0.8355854153633118,
681
+ "logps/rejected": -1.0225704908370972,
682
+ "loss": 0.9062,
683
+ "odds_ratio_loss": 0.6545855402946472,
684
+ "rewards/accuracies": 0.5874999761581421,
685
+ "rewards/chosen": -0.04177927225828171,
686
+ "rewards/margins": 0.009349259547889233,
687
+ "rewards/rejected": -0.05112852901220322,
688
+ "sft_loss": 0.8355854153633118,
689
+ "step": 400
690
+ },
691
+ {
692
+ "epoch": 0.9838032393521295,
693
+ "grad_norm": 2.3985102639359406,
694
+ "learning_rate": 3.887595830514775e-06,
695
+ "logits/chosen": 0.21671700477600098,
696
+ "logits/rejected": 0.29912179708480835,
697
+ "logps/chosen": -0.809670090675354,
698
+ "logps/rejected": -1.0107569694519043,
699
+ "loss": 0.9029,
700
+ "odds_ratio_loss": 0.6326887011528015,
701
+ "rewards/accuracies": 0.612500011920929,
702
+ "rewards/chosen": -0.0404835119843483,
703
+ "rewards/margins": 0.010054344311356544,
704
+ "rewards/rejected": -0.05053785443305969,
705
+ "sft_loss": 0.809670090675354,
706
+ "step": 410
707
+ },
708
+ {
709
+ "epoch": 1.0077984403119375,
710
+ "grad_norm": 1.6971594247197401,
711
+ "learning_rate": 3.833946137507195e-06,
712
+ "logits/chosen": 0.4990086555480957,
713
+ "logits/rejected": 0.616361141204834,
714
+ "logps/chosen": -0.8005359768867493,
715
+ "logps/rejected": -0.9603840708732605,
716
+ "loss": 0.8398,
717
+ "odds_ratio_loss": 0.6354148387908936,
718
+ "rewards/accuracies": 0.5249999761581421,
719
+ "rewards/chosen": -0.040026795119047165,
720
+ "rewards/margins": 0.007992411032319069,
721
+ "rewards/rejected": -0.04801920801401138,
722
+ "sft_loss": 0.8005359768867493,
723
+ "step": 420
724
+ },
725
+ {
726
+ "epoch": 1.0317936412717457,
727
+ "grad_norm": 2.2002987962167904,
728
+ "learning_rate": 3.779423437715274e-06,
729
+ "logits/chosen": 0.7601526975631714,
730
+ "logits/rejected": 0.8180352449417114,
731
+ "logps/chosen": -0.6671024560928345,
732
+ "logps/rejected": -0.9577730298042297,
733
+ "loss": 0.7742,
734
+ "odds_ratio_loss": 0.5807942152023315,
735
+ "rewards/accuracies": 0.675000011920929,
736
+ "rewards/chosen": -0.03335512429475784,
737
+ "rewards/margins": 0.014533529989421368,
738
+ "rewards/rejected": -0.047888655215501785,
739
+ "sft_loss": 0.6671024560928345,
740
+ "step": 430
741
+ },
742
+ {
743
+ "epoch": 1.0557888422315538,
744
+ "grad_norm": 1.5148819350515028,
745
+ "learning_rate": 3.7240634137542864e-06,
746
+ "logits/chosen": 0.19566980004310608,
747
+ "logits/rejected": 0.3528198003768921,
748
+ "logps/chosen": -0.6874720454216003,
749
+ "logps/rejected": -1.0558958053588867,
750
+ "loss": 0.7663,
751
+ "odds_ratio_loss": 0.48211669921875,
752
+ "rewards/accuracies": 0.800000011920929,
753
+ "rewards/chosen": -0.034373603761196136,
754
+ "rewards/margins": 0.01842118799686432,
755
+ "rewards/rejected": -0.052794791758060455,
756
+ "sft_loss": 0.6874720454216003,
757
+ "step": 440
758
+ },
759
+ {
760
+ "epoch": 1.0797840431913617,
761
+ "grad_norm": 1.6130353172110996,
762
+ "learning_rate": 3.6679022962299054e-06,
763
+ "logits/chosen": 0.8750432133674622,
764
+ "logits/rejected": 0.8553866147994995,
765
+ "logps/chosen": -0.7515122890472412,
766
+ "logps/rejected": -0.9563247561454773,
767
+ "loss": 0.7745,
768
+ "odds_ratio_loss": 0.5920617580413818,
769
+ "rewards/accuracies": 0.6499999761581421,
770
+ "rewards/chosen": -0.037575613707304,
771
+ "rewards/margins": 0.010240620002150536,
772
+ "rewards/rejected": -0.047816235572099686,
773
+ "sft_loss": 0.7515122890472412,
774
+ "step": 450
775
+ },
776
+ {
777
+ "epoch": 1.1037792441511698,
778
+ "grad_norm": 1.8444047185661667,
779
+ "learning_rate": 3.6109768400269336e-06,
780
+ "logits/chosen": 0.21664266288280487,
781
+ "logits/rejected": 0.3455556333065033,
782
+ "logps/chosen": -0.7820109128952026,
783
+ "logps/rejected": -1.1722263097763062,
784
+ "loss": 0.7949,
785
+ "odds_ratio_loss": 0.5249099731445312,
786
+ "rewards/accuracies": 0.762499988079071,
787
+ "rewards/chosen": -0.03910055011510849,
788
+ "rewards/margins": 0.019510772079229355,
789
+ "rewards/rejected": -0.05861131474375725,
790
+ "sft_loss": 0.7820109128952026,
791
+ "step": 460
792
+ },
793
+ {
794
+ "epoch": 1.127774445110978,
795
+ "grad_norm": 1.923809039800638,
796
+ "learning_rate": 3.5533243002549044e-06,
797
+ "logits/chosen": -0.051299355924129486,
798
+ "logits/rejected": 0.12599964439868927,
799
+ "logps/chosen": -0.6766480803489685,
800
+ "logps/rejected": -0.9556339979171753,
801
+ "loss": 0.769,
802
+ "odds_ratio_loss": 0.5771059989929199,
803
+ "rewards/accuracies": 0.6499999761581421,
804
+ "rewards/chosen": -0.03383240848779678,
805
+ "rewards/margins": 0.013949294574558735,
806
+ "rewards/rejected": -0.047781698405742645,
807
+ "sft_loss": 0.6766480803489685,
808
+ "step": 470
809
+ },
810
+ {
811
+ "epoch": 1.1517696460707858,
812
+ "grad_norm": 2.0416324249302593,
813
+ "learning_rate": 3.4949824078663214e-06,
814
+ "logits/chosen": 0.3260158598423004,
815
+ "logits/rejected": 0.4627075791358948,
816
+ "logps/chosen": -0.6955934762954712,
817
+ "logps/rejected": -1.0405316352844238,
818
+ "loss": 0.7744,
819
+ "odds_ratio_loss": 0.5207543969154358,
820
+ "rewards/accuracies": 0.7124999761581421,
821
+ "rewards/chosen": -0.03477967530488968,
822
+ "rewards/margins": 0.017246905714273453,
823
+ "rewards/rejected": -0.05202658101916313,
824
+ "sft_loss": 0.6955934762954712,
825
+ "step": 480
826
+ },
827
+ {
828
+ "epoch": 1.175764847030594,
829
+ "grad_norm": 2.159701142475688,
830
+ "learning_rate": 3.4359893449634713e-06,
831
+ "logits/chosen": 0.10285909473896027,
832
+ "logits/rejected": 0.18586108088493347,
833
+ "logps/chosen": -0.7835036516189575,
834
+ "logps/rejected": -0.9662873148918152,
835
+ "loss": 0.7699,
836
+ "odds_ratio_loss": 0.6257883310317993,
837
+ "rewards/accuracies": 0.612500011920929,
838
+ "rewards/chosen": -0.03917517885565758,
839
+ "rewards/margins": 0.009139184840023518,
840
+ "rewards/rejected": -0.04831436648964882,
841
+ "sft_loss": 0.7835036516189575,
842
+ "step": 490
843
+ },
844
+ {
845
+ "epoch": 1.1997600479904018,
846
+ "grad_norm": 1.905386181833648,
847
+ "learning_rate": 3.3763837198099807e-06,
848
+ "logits/chosen": 0.2618166208267212,
849
+ "logits/rejected": 0.403994083404541,
850
+ "logps/chosen": -0.7472913861274719,
851
+ "logps/rejected": -0.9723391532897949,
852
+ "loss": 0.8034,
853
+ "odds_ratio_loss": 0.5758217573165894,
854
+ "rewards/accuracies": 0.7250000238418579,
855
+ "rewards/chosen": -0.03736456483602524,
856
+ "rewards/margins": 0.011252395808696747,
857
+ "rewards/rejected": -0.048616960644721985,
858
+ "sft_loss": 0.7472913861274719,
859
+ "step": 500
860
+ },
861
+ {
862
+ "epoch": 1.22375524895021,
863
+ "grad_norm": 1.8483335773730425,
864
+ "learning_rate": 3.3162045415634793e-06,
865
+ "logits/chosen": -0.06936601549386978,
866
+ "logits/rejected": 0.15932008624076843,
867
+ "logps/chosen": -0.7298214435577393,
868
+ "logps/rejected": -0.989848792552948,
869
+ "loss": 0.764,
870
+ "odds_ratio_loss": 0.5586143136024475,
871
+ "rewards/accuracies": 0.6875,
872
+ "rewards/chosen": -0.036491066217422485,
873
+ "rewards/margins": 0.013001373037695885,
874
+ "rewards/rejected": -0.04949244111776352,
875
+ "sft_loss": 0.7298214435577393,
876
+ "step": 510
877
+ },
878
+ {
879
+ "epoch": 1.247750449910018,
880
+ "grad_norm": 1.4105189905656275,
881
+ "learning_rate": 3.255491194745878e-06,
882
+ "logits/chosen": -0.0699717178940773,
883
+ "logits/rejected": 0.11926586925983429,
884
+ "logps/chosen": -0.7712666988372803,
885
+ "logps/rejected": -1.0007984638214111,
886
+ "loss": 0.7514,
887
+ "odds_ratio_loss": 0.576269805431366,
888
+ "rewards/accuracies": 0.762499988079071,
889
+ "rewards/chosen": -0.03856333717703819,
890
+ "rewards/margins": 0.011476586572825909,
891
+ "rewards/rejected": -0.050039924681186676,
892
+ "sft_loss": 0.7712666988372803,
893
+ "step": 520
894
+ },
895
+ {
896
+ "epoch": 1.2717456508698262,
897
+ "grad_norm": 1.5086406745902339,
898
+ "learning_rate": 3.1942834134680123e-06,
899
+ "logits/chosen": -0.4110763669013977,
900
+ "logits/rejected": -0.197097510099411,
901
+ "logps/chosen": -0.7337836027145386,
902
+ "logps/rejected": -1.0581499338150024,
903
+ "loss": 0.747,
904
+ "odds_ratio_loss": 0.5731949806213379,
905
+ "rewards/accuracies": 0.612500011920929,
906
+ "rewards/chosen": -0.03668918460607529,
907
+ "rewards/margins": 0.016218315809965134,
908
+ "rewards/rejected": -0.05290750414133072,
909
+ "sft_loss": 0.7337836027145386,
910
+ "step": 530
911
+ },
912
+ {
913
+ "epoch": 1.295740851829634,
914
+ "grad_norm": 2.007767969966132,
915
+ "learning_rate": 3.13262125542547e-06,
916
+ "logits/chosen": 0.24464428424835205,
917
+ "logits/rejected": 0.42607539892196655,
918
+ "logps/chosen": -0.8008230328559875,
919
+ "logps/rejected": -1.019913911819458,
920
+ "loss": 0.7839,
921
+ "odds_ratio_loss": 0.5772299766540527,
922
+ "rewards/accuracies": 0.6875,
923
+ "rewards/chosen": -0.04004114866256714,
924
+ "rewards/margins": 0.010954543016850948,
925
+ "rewards/rejected": -0.05099569633603096,
926
+ "sft_loss": 0.8008230328559875,
927
+ "step": 540
928
+ },
929
+ {
930
+ "epoch": 1.3197360527894422,
931
+ "grad_norm": 2.031522996603775,
932
+ "learning_rate": 3.0705450756826707e-06,
933
+ "logits/chosen": -0.6761570572853088,
934
+ "logits/rejected": -0.5336428880691528,
935
+ "logps/chosen": -0.7791737914085388,
936
+ "logps/rejected": -0.9758432507514954,
937
+ "loss": 0.7734,
938
+ "odds_ratio_loss": 0.5955380201339722,
939
+ "rewards/accuracies": 0.675000011920929,
940
+ "rewards/chosen": -0.03895869478583336,
941
+ "rewards/margins": 0.009833470918238163,
942
+ "rewards/rejected": -0.04879216477274895,
943
+ "sft_loss": 0.7791737914085388,
944
+ "step": 550
945
+ },
946
+ {
947
+ "epoch": 1.34373125374925,
948
+ "grad_norm": 1.8127230145286217,
949
+ "learning_rate": 3.00809550026231e-06,
950
+ "logits/chosen": 0.7122937440872192,
951
+ "logits/rejected": 0.8374090194702148,
952
+ "logps/chosen": -0.7448546290397644,
953
+ "logps/rejected": -1.0183660984039307,
954
+ "loss": 0.7313,
955
+ "odds_ratio_loss": 0.5605376362800598,
956
+ "rewards/accuracies": 0.699999988079071,
957
+ "rewards/chosen": -0.03724273294210434,
958
+ "rewards/margins": 0.01367556769400835,
959
+ "rewards/rejected": -0.050918303430080414,
960
+ "sft_loss": 0.7448546290397644,
961
+ "step": 560
962
+ },
963
+ {
964
+ "epoch": 1.3677264547090582,
965
+ "grad_norm": 1.6102410365866324,
966
+ "learning_rate": 2.9453133995574955e-06,
967
+ "logits/chosen": 0.1695878505706787,
968
+ "logits/rejected": 0.34987810254096985,
969
+ "logps/chosen": -0.7041548490524292,
970
+ "logps/rejected": -1.1295292377471924,
971
+ "loss": 0.7529,
972
+ "odds_ratio_loss": 0.5541011095046997,
973
+ "rewards/accuracies": 0.675000011920929,
974
+ "rewards/chosen": -0.03520774096250534,
975
+ "rewards/margins": 0.02126871421933174,
976
+ "rewards/rejected": -0.05647646263241768,
977
+ "sft_loss": 0.7041548490524292,
978
+ "step": 570
979
+ },
980
+ {
981
+ "epoch": 1.3917216556688663,
982
+ "grad_norm": 2.0516481147792964,
983
+ "learning_rate": 2.8822398615839337e-06,
984
+ "logits/chosen": -0.15236589312553406,
985
+ "logits/rejected": 0.005555987358093262,
986
+ "logps/chosen": -0.7019264698028564,
987
+ "logps/rejected": -0.9463084936141968,
988
+ "loss": 0.7377,
989
+ "odds_ratio_loss": 0.5546727180480957,
990
+ "rewards/accuracies": 0.7250000238418579,
991
+ "rewards/chosen": -0.03509632498025894,
992
+ "rewards/margins": 0.012219103053212166,
993
+ "rewards/rejected": -0.04731542617082596,
994
+ "sft_loss": 0.7019264698028564,
995
+ "step": 580
996
+ },
997
+ {
998
+ "epoch": 1.4157168566286742,
999
+ "grad_norm": 2.5703275268486463,
1000
+ "learning_rate": 2.8189161650897045e-06,
1001
+ "logits/chosen": 0.09915417432785034,
1002
+ "logits/rejected": 0.2876579761505127,
1003
+ "logps/chosen": -0.7416352033615112,
1004
+ "logps/rejected": -0.9542354345321655,
1005
+ "loss": 0.7748,
1006
+ "odds_ratio_loss": 0.5765627026557922,
1007
+ "rewards/accuracies": 0.625,
1008
+ "rewards/chosen": -0.0370817631483078,
1009
+ "rewards/margins": 0.010630009695887566,
1010
+ "rewards/rejected": -0.04771176725625992,
1011
+ "sft_loss": 0.7416352033615112,
1012
+ "step": 590
1013
+ },
1014
+ {
1015
+ "epoch": 1.4397120575884823,
1016
+ "grad_norm": 1.6574957139548097,
1017
+ "learning_rate": 2.7553837525402095e-06,
1018
+ "logits/chosen": 0.14950448274612427,
1019
+ "logits/rejected": 0.14670611917972565,
1020
+ "logps/chosen": -0.7459922432899475,
1021
+ "logps/rejected": -0.9438718557357788,
1022
+ "loss": 0.764,
1023
+ "odds_ratio_loss": 0.6029990911483765,
1024
+ "rewards/accuracies": 0.6000000238418579,
1025
+ "rewards/chosen": -0.037299610674381256,
1026
+ "rewards/margins": 0.009893985465168953,
1027
+ "rewards/rejected": -0.04719359427690506,
1028
+ "sft_loss": 0.7459922432899475,
1029
+ "step": 600
1030
+ },
1031
+ {
1032
+ "epoch": 1.4637072585482904,
1033
+ "grad_norm": 1.5955732799355493,
1034
+ "learning_rate": 2.691684202995966e-06,
1035
+ "logits/chosen": 0.43530672788619995,
1036
+ "logits/rejected": 0.4994083344936371,
1037
+ "logps/chosen": -0.8142836689949036,
1038
+ "logps/rejected": -0.9706009030342102,
1039
+ "loss": 0.7559,
1040
+ "odds_ratio_loss": 0.7006958723068237,
1041
+ "rewards/accuracies": 0.574999988079071,
1042
+ "rewards/chosen": -0.04071418568491936,
1043
+ "rewards/margins": 0.007815859280526638,
1044
+ "rewards/rejected": -0.04853004962205887,
1045
+ "sft_loss": 0.8142836689949036,
1046
+ "step": 610
1047
+ },
1048
+ {
1049
+ "epoch": 1.4877024595080983,
1050
+ "grad_norm": 1.9589861397245603,
1051
+ "learning_rate": 2.6278592049010204e-06,
1052
+ "logits/chosen": -0.19675548374652863,
1053
+ "logits/rejected": -0.004504656884819269,
1054
+ "logps/chosen": -0.7537368535995483,
1055
+ "logps/rejected": -1.0135046243667603,
1056
+ "loss": 0.7741,
1057
+ "odds_ratio_loss": 0.5691729187965393,
1058
+ "rewards/accuracies": 0.6625000238418579,
1059
+ "rewards/chosen": -0.03768684342503548,
1060
+ "rewards/margins": 0.012988388538360596,
1061
+ "rewards/rejected": -0.050675224512815475,
1062
+ "sft_loss": 0.7537368535995483,
1063
+ "step": 620
1064
+ },
1065
+ {
1066
+ "epoch": 1.5116976604679064,
1067
+ "grad_norm": 1.7255875955000524,
1068
+ "learning_rate": 2.5639505287997584e-06,
1069
+ "logits/chosen": 0.3145737051963806,
1070
+ "logits/rejected": 0.47394928336143494,
1071
+ "logps/chosen": -0.7314926385879517,
1072
+ "logps/rejected": -1.001952886581421,
1073
+ "loss": 0.7829,
1074
+ "odds_ratio_loss": 0.5629433393478394,
1075
+ "rewards/accuracies": 0.675000011920929,
1076
+ "rewards/chosen": -0.03657463565468788,
1077
+ "rewards/margins": 0.013523015193641186,
1078
+ "rewards/rejected": -0.050097644329071045,
1079
+ "sft_loss": 0.7314926385879517,
1080
+ "step": 630
1081
+ },
1082
+ {
1083
+ "epoch": 1.5356928614277146,
1084
+ "grad_norm": 2.504847023988975,
1085
+ "learning_rate": 2.5e-06,
1086
+ "logits/chosen": 0.2320265769958496,
1087
+ "logits/rejected": 0.3284027874469757,
1088
+ "logps/chosen": -0.7656562924385071,
1089
+ "logps/rejected": -1.076923131942749,
1090
+ "loss": 0.7503,
1091
+ "odds_ratio_loss": 0.584337592124939,
1092
+ "rewards/accuracies": 0.7250000238418579,
1093
+ "rewards/chosen": -0.038282815366983414,
1094
+ "rewards/margins": 0.015563338994979858,
1095
+ "rewards/rejected": -0.053846150636672974,
1096
+ "sft_loss": 0.7656562924385071,
1097
+ "step": 640
1098
+ },
1099
+ {
1100
+ "epoch": 1.5596880623875224,
1101
+ "grad_norm": 1.4394266237384084,
1102
+ "learning_rate": 2.436049471200242e-06,
1103
+ "logits/chosen": -0.5206400156021118,
1104
+ "logits/rejected": -0.38631540536880493,
1105
+ "logps/chosen": -0.8094362020492554,
1106
+ "logps/rejected": -0.9923938512802124,
1107
+ "loss": 0.7752,
1108
+ "odds_ratio_loss": 0.5967071056365967,
1109
+ "rewards/accuracies": 0.6499999761581421,
1110
+ "rewards/chosen": -0.04047181457281113,
1111
+ "rewards/margins": 0.00914788618683815,
1112
+ "rewards/rejected": -0.04961969703435898,
1113
+ "sft_loss": 0.8094362020492554,
1114
+ "step": 650
1115
+ },
1116
+ {
1117
+ "epoch": 1.5836832633473306,
1118
+ "grad_norm": 1.7625452374002906,
1119
+ "learning_rate": 2.3721407950989804e-06,
1120
+ "logits/chosen": -0.24351301789283752,
1121
+ "logits/rejected": -0.07003232091665268,
1122
+ "logps/chosen": -0.6876959800720215,
1123
+ "logps/rejected": -0.9035342335700989,
1124
+ "loss": 0.7734,
1125
+ "odds_ratio_loss": 0.5917103290557861,
1126
+ "rewards/accuracies": 0.637499988079071,
1127
+ "rewards/chosen": -0.034384798258543015,
1128
+ "rewards/margins": 0.010791914537549019,
1129
+ "rewards/rejected": -0.045176707208156586,
1130
+ "sft_loss": 0.6876959800720215,
1131
+ "step": 660
1132
+ },
1133
+ {
1134
+ "epoch": 1.6076784643071385,
1135
+ "grad_norm": 1.6046093499190943,
1136
+ "learning_rate": 2.3083157970040344e-06,
1137
+ "logits/chosen": 0.5633162260055542,
1138
+ "logits/rejected": 0.6462755799293518,
1139
+ "logps/chosen": -0.7524802684783936,
1140
+ "logps/rejected": -1.0558850765228271,
1141
+ "loss": 0.7563,
1142
+ "odds_ratio_loss": 0.552274227142334,
1143
+ "rewards/accuracies": 0.699999988079071,
1144
+ "rewards/chosen": -0.03762401267886162,
1145
+ "rewards/margins": 0.015170246362686157,
1146
+ "rewards/rejected": -0.05279426649212837,
1147
+ "sft_loss": 0.7524802684783936,
1148
+ "step": 670
1149
+ },
1150
+ {
1151
+ "epoch": 1.6316736652669466,
1152
+ "grad_norm": 2.117352018263469,
1153
+ "learning_rate": 2.2446162474597913e-06,
1154
+ "logits/chosen": 0.43944865465164185,
1155
+ "logits/rejected": 0.5002392530441284,
1156
+ "logps/chosen": -0.7501770257949829,
1157
+ "logps/rejected": -0.9691005945205688,
1158
+ "loss": 0.7699,
1159
+ "odds_ratio_loss": 0.5791727304458618,
1160
+ "rewards/accuracies": 0.6625000238418579,
1161
+ "rewards/chosen": -0.037508852779865265,
1162
+ "rewards/margins": 0.010946177877485752,
1163
+ "rewards/rejected": -0.04845503345131874,
1164
+ "sft_loss": 0.7501770257949829,
1165
+ "step": 680
1166
+ },
1167
+ {
1168
+ "epoch": 1.6556688662267547,
1169
+ "grad_norm": 1.6685249776962552,
1170
+ "learning_rate": 2.1810838349102963e-06,
1171
+ "logits/chosen": 0.16153453290462494,
1172
+ "logits/rejected": 0.20878514647483826,
1173
+ "logps/chosen": -0.7516240477561951,
1174
+ "logps/rejected": -1.0250643491744995,
1175
+ "loss": 0.7666,
1176
+ "odds_ratio_loss": 0.5872852206230164,
1177
+ "rewards/accuracies": 0.6875,
1178
+ "rewards/chosen": -0.03758120536804199,
1179
+ "rewards/margins": 0.013672016561031342,
1180
+ "rewards/rejected": -0.051253218203783035,
1181
+ "sft_loss": 0.7516240477561951,
1182
+ "step": 690
1183
+ },
1184
+ {
1185
+ "epoch": 1.6796640671865628,
1186
+ "grad_norm": 2.782782057649718,
1187
+ "learning_rate": 2.117760138416067e-06,
1188
+ "logits/chosen": 0.24376201629638672,
1189
+ "logits/rejected": 0.44258540868759155,
1190
+ "logps/chosen": -0.6985687017440796,
1191
+ "logps/rejected": -1.0050299167633057,
1192
+ "loss": 0.7614,
1193
+ "odds_ratio_loss": 0.543103814125061,
1194
+ "rewards/accuracies": 0.737500011920929,
1195
+ "rewards/chosen": -0.03492843732237816,
1196
+ "rewards/margins": 0.015323063358664513,
1197
+ "rewards/rejected": -0.05025150254368782,
1198
+ "sft_loss": 0.6985687017440796,
1199
+ "step": 700
1200
+ },
1201
+ {
1202
+ "epoch": 1.7036592681463707,
1203
+ "grad_norm": 1.5369658154698735,
1204
+ "learning_rate": 2.0546866004425053e-06,
1205
+ "logits/chosen": 0.3964254558086395,
1206
+ "logits/rejected": 0.4900701642036438,
1207
+ "logps/chosen": -0.7590494155883789,
1208
+ "logps/rejected": -1.2440413236618042,
1209
+ "loss": 0.7652,
1210
+ "odds_ratio_loss": 0.5372438430786133,
1211
+ "rewards/accuracies": 0.699999988079071,
1212
+ "rewards/chosen": -0.037952471524477005,
1213
+ "rewards/margins": 0.024249596521258354,
1214
+ "rewards/rejected": -0.06220207363367081,
1215
+ "sft_loss": 0.7590494155883789,
1216
+ "step": 710
1217
+ },
1218
+ {
1219
+ "epoch": 1.7276544691061788,
1220
+ "grad_norm": 1.9970193945029362,
1221
+ "learning_rate": 1.9919044997376906e-06,
1222
+ "logits/chosen": 0.6031176447868347,
1223
+ "logits/rejected": 0.7783833742141724,
1224
+ "logps/chosen": -0.7290822267532349,
1225
+ "logps/rejected": -1.021554946899414,
1226
+ "loss": 0.7176,
1227
+ "odds_ratio_loss": 0.557815432548523,
1228
+ "rewards/accuracies": 0.6625000238418579,
1229
+ "rewards/chosen": -0.03645411133766174,
1230
+ "rewards/margins": 0.014623639173805714,
1231
+ "rewards/rejected": -0.051077745854854584,
1232
+ "sft_loss": 0.7290822267532349,
1233
+ "step": 720
1234
+ },
1235
+ {
1236
+ "epoch": 1.7516496700659867,
1237
+ "grad_norm": 2.558147455560064,
1238
+ "learning_rate": 1.9294549243173306e-06,
1239
+ "logits/chosen": -0.027294237166643143,
1240
+ "logits/rejected": 0.11035363376140594,
1241
+ "logps/chosen": -0.7765438556671143,
1242
+ "logps/rejected": -1.0300321578979492,
1243
+ "loss": 0.7771,
1244
+ "odds_ratio_loss": 0.5954040884971619,
1245
+ "rewards/accuracies": 0.637499988079071,
1246
+ "rewards/chosen": -0.03882719203829765,
1247
+ "rewards/margins": 0.012674416415393353,
1248
+ "rewards/rejected": -0.05150160938501358,
1249
+ "sft_loss": 0.7765438556671143,
1250
+ "step": 730
1251
+ },
1252
+ {
1253
+ "epoch": 1.7756448710257948,
1254
+ "grad_norm": 2.346615273317464,
1255
+ "learning_rate": 1.8673787445745298e-06,
1256
+ "logits/chosen": -0.449845552444458,
1257
+ "logits/rejected": -0.3746832311153412,
1258
+ "logps/chosen": -0.7114017605781555,
1259
+ "logps/rejected": -0.928491473197937,
1260
+ "loss": 0.7699,
1261
+ "odds_ratio_loss": 0.5795110464096069,
1262
+ "rewards/accuracies": 0.6625000238418579,
1263
+ "rewards/chosen": -0.035570088773965836,
1264
+ "rewards/margins": 0.010854486376047134,
1265
+ "rewards/rejected": -0.04642457515001297,
1266
+ "sft_loss": 0.7114017605781555,
1267
+ "step": 740
1268
+ },
1269
+ {
1270
+ "epoch": 1.799640071985603,
1271
+ "grad_norm": 1.995371230537378,
1272
+ "learning_rate": 1.805716586531988e-06,
1273
+ "logits/chosen": -0.13443303108215332,
1274
+ "logits/rejected": 0.014731263741850853,
1275
+ "logps/chosen": -0.8079891204833984,
1276
+ "logps/rejected": -1.0810317993164062,
1277
+ "loss": 0.7825,
1278
+ "odds_ratio_loss": 0.6112096309661865,
1279
+ "rewards/accuracies": 0.6875,
1280
+ "rewards/chosen": -0.0403994545340538,
1281
+ "rewards/margins": 0.013652140274643898,
1282
+ "rewards/rejected": -0.05405158922076225,
1283
+ "sft_loss": 0.8079891204833984,
1284
+ "step": 750
1285
+ },
1286
+ {
1287
+ "epoch": 1.823635272945411,
1288
+ "grad_norm": 1.8742057389590454,
1289
+ "learning_rate": 1.7445088052541218e-06,
1290
+ "logits/chosen": 0.046121031045913696,
1291
+ "logits/rejected": 0.1955467015504837,
1292
+ "logps/chosen": -0.7093559503555298,
1293
+ "logps/rejected": -1.0484099388122559,
1294
+ "loss": 0.7617,
1295
+ "odds_ratio_loss": 0.5657014846801758,
1296
+ "rewards/accuracies": 0.6499999761581421,
1297
+ "rewards/chosen": -0.03546779602766037,
1298
+ "rewards/margins": 0.016952697187662125,
1299
+ "rewards/rejected": -0.05242049694061279,
1300
+ "sft_loss": 0.7093559503555298,
1301
+ "step": 760
1302
+ },
1303
+ {
1304
+ "epoch": 1.847630473905219,
1305
+ "grad_norm": 1.2680203881504901,
1306
+ "learning_rate": 1.6837954584365217e-06,
1307
+ "logits/chosen": 0.4459083080291748,
1308
+ "logits/rejected": 0.5636454224586487,
1309
+ "logps/chosen": -0.7526987195014954,
1310
+ "logps/rejected": -1.009804606437683,
1311
+ "loss": 0.7871,
1312
+ "odds_ratio_loss": 0.5556772947311401,
1313
+ "rewards/accuracies": 0.737500011920929,
1314
+ "rewards/chosen": -0.03763493150472641,
1315
+ "rewards/margins": 0.012855296023190022,
1316
+ "rewards/rejected": -0.050490230321884155,
1317
+ "sft_loss": 0.7526987195014954,
1318
+ "step": 770
1319
+ },
1320
+ {
1321
+ "epoch": 1.8716256748650268,
1322
+ "grad_norm": 1.9254646582677224,
1323
+ "learning_rate": 1.6236162801900191e-06,
1324
+ "logits/chosen": -0.10451897233724594,
1325
+ "logits/rejected": 0.3060254156589508,
1326
+ "logps/chosen": -0.6585639715194702,
1327
+ "logps/rejected": -0.9869001507759094,
1328
+ "loss": 0.71,
1329
+ "odds_ratio_loss": 0.4942260682582855,
1330
+ "rewards/accuracies": 0.762499988079071,
1331
+ "rewards/chosen": -0.03292820230126381,
1332
+ "rewards/margins": 0.016416804865002632,
1333
+ "rewards/rejected": -0.04934500530362129,
1334
+ "sft_loss": 0.6585639715194702,
1335
+ "step": 780
1336
+ },
1337
+ {
1338
+ "epoch": 1.895620875824835,
1339
+ "grad_norm": 1.9904836511656812,
1340
+ "learning_rate": 1.5640106550365298e-06,
1341
+ "logits/chosen": 0.11656351387500763,
1342
+ "logits/rejected": 0.29824742674827576,
1343
+ "logps/chosen": -0.7831540703773499,
1344
+ "logps/rejected": -1.0284688472747803,
1345
+ "loss": 0.7758,
1346
+ "odds_ratio_loss": 0.5839165449142456,
1347
+ "rewards/accuracies": 0.6875,
1348
+ "rewards/chosen": -0.03915770351886749,
1349
+ "rewards/margins": 0.01226573996245861,
1350
+ "rewards/rejected": -0.051423441618680954,
1351
+ "sft_loss": 0.7831540703773499,
1352
+ "step": 790
1353
+ },
1354
+ {
1355
+ "epoch": 1.919616076784643,
1356
+ "grad_norm": 1.7061927534288226,
1357
+ "learning_rate": 1.5050175921336797e-06,
1358
+ "logits/chosen": 0.14354857802391052,
1359
+ "logits/rejected": 0.27334246039390564,
1360
+ "logps/chosen": -0.7474446892738342,
1361
+ "logps/rejected": -0.9480558633804321,
1362
+ "loss": 0.7575,
1363
+ "odds_ratio_loss": 0.6441240310668945,
1364
+ "rewards/accuracies": 0.637499988079071,
1365
+ "rewards/chosen": -0.03737223893404007,
1366
+ "rewards/margins": 0.010030550882220268,
1367
+ "rewards/rejected": -0.04740279167890549,
1368
+ "sft_loss": 0.7474446892738342,
1369
+ "step": 800
1370
+ },
1371
+ {
1372
+ "epoch": 1.9436112777444512,
1373
+ "grad_norm": 2.251879648695612,
1374
+ "learning_rate": 1.446675699745097e-06,
1375
+ "logits/chosen": 0.25183239579200745,
1376
+ "logits/rejected": 0.38326969742774963,
1377
+ "logps/chosen": -0.7823570966720581,
1378
+ "logps/rejected": -0.9946805238723755,
1379
+ "loss": 0.8037,
1380
+ "odds_ratio_loss": 0.6080455183982849,
1381
+ "rewards/accuracies": 0.6499999761581421,
1382
+ "rewards/chosen": -0.03911786153912544,
1383
+ "rewards/margins": 0.010616169311106205,
1384
+ "rewards/rejected": -0.049734026193618774,
1385
+ "sft_loss": 0.7823570966720581,
1386
+ "step": 810
1387
+ },
1388
+ {
1389
+ "epoch": 1.9676064787042593,
1390
+ "grad_norm": 1.9391362449031262,
1391
+ "learning_rate": 1.3890231599730674e-06,
1392
+ "logits/chosen": 0.31725913286209106,
1393
+ "logits/rejected": 0.5106421709060669,
1394
+ "logps/chosen": -0.7221857309341431,
1395
+ "logps/rejected": -0.9829575419425964,
1396
+ "loss": 0.7904,
1397
+ "odds_ratio_loss": 0.5538625121116638,
1398
+ "rewards/accuracies": 0.737500011920929,
1399
+ "rewards/chosen": -0.03610928729176521,
1400
+ "rewards/margins": 0.013038587756454945,
1401
+ "rewards/rejected": -0.049147870391607285,
1402
+ "sft_loss": 0.7221857309341431,
1403
+ "step": 820
1404
+ },
1405
+ {
1406
+ "epoch": 1.9916016796640672,
1407
+ "grad_norm": 1.5457295502049215,
1408
+ "learning_rate": 1.3320977037700952e-06,
1409
+ "logits/chosen": 0.8291665315628052,
1410
+ "logits/rejected": 1.1122350692749023,
1411
+ "logps/chosen": -0.6864774227142334,
1412
+ "logps/rejected": -1.0247427225112915,
1413
+ "loss": 0.7452,
1414
+ "odds_ratio_loss": 0.49447354674339294,
1415
+ "rewards/accuracies": 0.75,
1416
+ "rewards/chosen": -0.03432386741042137,
1417
+ "rewards/margins": 0.016913266852498055,
1418
+ "rewards/rejected": -0.051237136125564575,
1419
+ "sft_loss": 0.6864774227142334,
1420
+ "step": 830
1421
+ },
1422
+ {
1423
+ "epoch": 2.015596880623875,
1424
+ "grad_norm": 1.5016852289986733,
1425
+ "learning_rate": 1.2759365862457148e-06,
1426
+ "logits/chosen": -0.4956502318382263,
1427
+ "logits/rejected": -0.1621031016111374,
1428
+ "logps/chosen": -0.7308815717697144,
1429
+ "logps/rejected": -0.9828909039497375,
1430
+ "loss": 0.7173,
1431
+ "odds_ratio_loss": 0.5487710237503052,
1432
+ "rewards/accuracies": 0.675000011920929,
1433
+ "rewards/chosen": -0.0365440808236599,
1434
+ "rewards/margins": 0.012600463815033436,
1435
+ "rewards/rejected": -0.049144547432661057,
1436
+ "sft_loss": 0.7308815717697144,
1437
+ "step": 840
1438
+ },
1439
+ {
1440
+ "epoch": 2.039592081583683,
1441
+ "grad_norm": 1.622924065562837,
1442
+ "learning_rate": 1.2205765622847273e-06,
1443
+ "logits/chosen": -0.12397761642932892,
1444
+ "logits/rejected": 0.08023932576179504,
1445
+ "logps/chosen": -0.6277745962142944,
1446
+ "logps/rejected": -1.0955206155776978,
1447
+ "loss": 0.6995,
1448
+ "odds_ratio_loss": 0.4475070536136627,
1449
+ "rewards/accuracies": 0.824999988079071,
1450
+ "rewards/chosen": -0.03138873726129532,
1451
+ "rewards/margins": 0.023387301713228226,
1452
+ "rewards/rejected": -0.054776035249233246,
1453
+ "sft_loss": 0.6277745962142944,
1454
+ "step": 850
1455
+ },
1456
+ {
1457
+ "epoch": 2.0635872825434913,
1458
+ "grad_norm": 1.4741935497367946,
1459
+ "learning_rate": 1.1660538624928062e-06,
1460
+ "logits/chosen": -0.3639386296272278,
1461
+ "logits/rejected": -0.2011258602142334,
1462
+ "logps/chosen": -0.6642920970916748,
1463
+ "logps/rejected": -1.0270217657089233,
1464
+ "loss": 0.7019,
1465
+ "odds_ratio_loss": 0.4971997141838074,
1466
+ "rewards/accuracies": 0.7250000238418579,
1467
+ "rewards/chosen": -0.03321460261940956,
1468
+ "rewards/margins": 0.018136484548449516,
1469
+ "rewards/rejected": -0.05135108903050423,
1470
+ "sft_loss": 0.6642920970916748,
1471
+ "step": 860
1472
+ },
1473
+ {
1474
+ "epoch": 2.0875824835032994,
1475
+ "grad_norm": 1.7172174730539993,
1476
+ "learning_rate": 1.112404169485226e-06,
1477
+ "logits/chosen": -0.3923923075199127,
1478
+ "logits/rejected": -0.10327514261007309,
1479
+ "logps/chosen": -0.5645719766616821,
1480
+ "logps/rejected": -1.071115255355835,
1481
+ "loss": 0.6681,
1482
+ "odds_ratio_loss": 0.42052555084228516,
1483
+ "rewards/accuracies": 0.800000011920929,
1484
+ "rewards/chosen": -0.028228599578142166,
1485
+ "rewards/margins": 0.025327179580926895,
1486
+ "rewards/rejected": -0.053555767983198166,
1487
+ "sft_loss": 0.5645719766616821,
1488
+ "step": 870
1489
+ },
1490
+ {
1491
+ "epoch": 2.1115776844631076,
1492
+ "grad_norm": 1.1474314844125568,
1493
+ "learning_rate": 1.0596625945342148e-06,
1494
+ "logits/chosen": -0.008033117279410362,
1495
+ "logits/rejected": 0.16419892013072968,
1496
+ "logps/chosen": -0.7100299000740051,
1497
+ "logps/rejected": -0.9733055233955383,
1498
+ "loss": 0.6813,
1499
+ "odds_ratio_loss": 0.5328400731086731,
1500
+ "rewards/accuracies": 0.737500011920929,
1501
+ "rewards/chosen": -0.03550150245428085,
1502
+ "rewards/margins": 0.013163777068257332,
1503
+ "rewards/rejected": -0.048665277659893036,
1504
+ "sft_loss": 0.7100299000740051,
1505
+ "step": 880
1506
+ },
1507
+ {
1508
+ "epoch": 2.1355728854229152,
1509
+ "grad_norm": 2.1383619388719515,
1510
+ "learning_rate": 1.0078636545902363e-06,
1511
+ "logits/chosen": -0.4247666001319885,
1512
+ "logits/rejected": -0.17631380259990692,
1513
+ "logps/chosen": -0.6582883596420288,
1514
+ "logps/rejected": -1.0547147989273071,
1515
+ "loss": 0.6895,
1516
+ "odds_ratio_loss": 0.47398701310157776,
1517
+ "rewards/accuracies": 0.7749999761581421,
1518
+ "rewards/chosen": -0.0329144187271595,
1519
+ "rewards/margins": 0.019821325317025185,
1520
+ "rewards/rejected": -0.05273573845624924,
1521
+ "sft_loss": 0.6582883596420288,
1522
+ "step": 890
1523
+ },
1524
+ {
1525
+ "epoch": 2.1595680863827234,
1526
+ "grad_norm": 1.5320300236939732,
1527
+ "learning_rate": 9.570412496922198e-07,
1528
+ "logits/chosen": -0.27953624725341797,
1529
+ "logits/rejected": -0.08715387433767319,
1530
+ "logps/chosen": -0.5965186357498169,
1531
+ "logps/rejected": -1.154284119606018,
1532
+ "loss": 0.6738,
1533
+ "odds_ratio_loss": 0.4240815043449402,
1534
+ "rewards/accuracies": 0.8500000238418579,
1535
+ "rewards/chosen": -0.029825935140252113,
1536
+ "rewards/margins": 0.02788827195763588,
1537
+ "rewards/rejected": -0.05771421268582344,
1538
+ "sft_loss": 0.5965186357498169,
1539
+ "step": 900
1540
+ },
1541
+ {
1542
+ "epoch": 2.1835632873425315,
1543
+ "grad_norm": 1.6204787225170885,
1544
+ "learning_rate": 9.07228640781539e-07,
1545
+ "logits/chosen": 0.368365079164505,
1546
+ "logits/rejected": 0.6101259589195251,
1547
+ "logps/chosen": -0.6893322467803955,
1548
+ "logps/rejected": -1.0903311967849731,
1549
+ "loss": 0.6791,
1550
+ "odds_ratio_loss": 0.4818887710571289,
1551
+ "rewards/accuracies": 0.7875000238418579,
1552
+ "rewards/chosen": -0.03446660935878754,
1553
+ "rewards/margins": 0.02004995197057724,
1554
+ "rewards/rejected": -0.054516565054655075,
1555
+ "sft_loss": 0.6893322467803955,
1556
+ "step": 910
1557
+ },
1558
+ {
1559
+ "epoch": 2.2075584883023396,
1560
+ "grad_norm": 1.290844558254926,
1561
+ "learning_rate": 8.584584279342392e-07,
1562
+ "logits/chosen": -0.16083380579948425,
1563
+ "logits/rejected": -0.10739579051733017,
1564
+ "logps/chosen": -0.6938862800598145,
1565
+ "logps/rejected": -0.9513536691665649,
1566
+ "loss": 0.6888,
1567
+ "odds_ratio_loss": 0.5428452491760254,
1568
+ "rewards/accuracies": 0.7124999761581421,
1569
+ "rewards/chosen": -0.034694310277700424,
1570
+ "rewards/margins": 0.012873371131718159,
1571
+ "rewards/rejected": -0.047567687928676605,
1572
+ "sft_loss": 0.6938862800598145,
1573
+ "step": 920
1574
+ },
1575
+ {
1576
+ "epoch": 2.2315536892621477,
1577
+ "grad_norm": 1.5229766148545818,
1578
+ "learning_rate": 8.10762529025782e-07,
1579
+ "logits/chosen": -0.4659739136695862,
1580
+ "logits/rejected": -0.4786594808101654,
1581
+ "logps/chosen": -0.6584521532058716,
1582
+ "logps/rejected": -0.8917843699455261,
1583
+ "loss": 0.65,
1584
+ "odds_ratio_loss": 0.5486137866973877,
1585
+ "rewards/accuracies": 0.6875,
1586
+ "rewards/chosen": -0.03292260691523552,
1587
+ "rewards/margins": 0.011666612699627876,
1588
+ "rewards/rejected": -0.044589221477508545,
1589
+ "sft_loss": 0.6584521532058716,
1590
+ "step": 930
1591
+ },
1592
+ {
1593
+ "epoch": 2.255548890221956,
1594
+ "grad_norm": 1.7015940933867517,
1595
+ "learning_rate": 7.641721588422526e-07,
1596
+ "logits/chosen": -0.009342163801193237,
1597
+ "logits/rejected": 0.1280032843351364,
1598
+ "logps/chosen": -0.6387184262275696,
1599
+ "logps/rejected": -1.049140453338623,
1600
+ "loss": 0.687,
1601
+ "odds_ratio_loss": 0.4773840010166168,
1602
+ "rewards/accuracies": 0.75,
1603
+ "rewards/chosen": -0.0319359228014946,
1604
+ "rewards/margins": 0.020521100610494614,
1605
+ "rewards/rejected": -0.05245702341198921,
1606
+ "sft_loss": 0.6387184262275696,
1607
+ "step": 940
1608
+ },
1609
+ {
1610
+ "epoch": 2.2795440911817635,
1611
+ "grad_norm": 1.4203319350991257,
1612
+ "learning_rate": 7.187178086517116e-07,
1613
+ "logits/chosen": 0.14468683302402496,
1614
+ "logits/rejected": 0.2608656883239746,
1615
+ "logps/chosen": -0.6514204144477844,
1616
+ "logps/rejected": -1.2591578960418701,
1617
+ "loss": 0.6695,
1618
+ "odds_ratio_loss": 0.455849826335907,
1619
+ "rewards/accuracies": 0.7749999761581421,
1620
+ "rewards/chosen": -0.03257102146744728,
1621
+ "rewards/margins": 0.03038688376545906,
1622
+ "rewards/rejected": -0.06295789778232574,
1623
+ "sft_loss": 0.6514204144477844,
1624
+ "step": 950
1625
+ },
1626
+ {
1627
+ "epoch": 2.3035392921415716,
1628
+ "grad_norm": 1.7783791010197938,
1629
+ "learning_rate": 6.74429226249049e-07,
1630
+ "logits/chosen": 0.09898465871810913,
1631
+ "logits/rejected": 0.21373791992664337,
1632
+ "logps/chosen": -0.6381307244300842,
1633
+ "logps/rejected": -0.9742431640625,
1634
+ "loss": 0.6712,
1635
+ "odds_ratio_loss": 0.49530988931655884,
1636
+ "rewards/accuracies": 0.800000011920929,
1637
+ "rewards/chosen": -0.03190653771162033,
1638
+ "rewards/margins": 0.016805628314614296,
1639
+ "rewards/rejected": -0.04871216416358948,
1640
+ "sft_loss": 0.6381307244300842,
1641
+ "step": 960
1642
+ },
1643
+ {
1644
+ "epoch": 2.3275344931013797,
1645
+ "grad_norm": 1.6090454208525553,
1646
+ "learning_rate": 6.313353964874155e-07,
1647
+ "logits/chosen": 0.1333683431148529,
1648
+ "logits/rejected": 0.3417516350746155,
1649
+ "logps/chosen": -0.6887052655220032,
1650
+ "logps/rejected": -1.0016798973083496,
1651
+ "loss": 0.6673,
1652
+ "odds_ratio_loss": 0.5059822797775269,
1653
+ "rewards/accuracies": 0.7749999761581421,
1654
+ "rewards/chosen": -0.03443526476621628,
1655
+ "rewards/margins": 0.01564873196184635,
1656
+ "rewards/rejected": -0.05008399486541748,
1657
+ "sft_loss": 0.6887052655220032,
1658
+ "step": 970
1659
+ },
1660
+ {
1661
+ "epoch": 2.351529694061188,
1662
+ "grad_norm": 1.6382111002720514,
1663
+ "learning_rate": 5.894645223089584e-07,
1664
+ "logits/chosen": 0.7236309051513672,
1665
+ "logits/rejected": 0.8550646901130676,
1666
+ "logps/chosen": -0.6779772639274597,
1667
+ "logps/rejected": -1.2183148860931396,
1668
+ "loss": 0.6958,
1669
+ "odds_ratio_loss": 0.448292076587677,
1670
+ "rewards/accuracies": 0.7875000238418579,
1671
+ "rewards/chosen": -0.033898863941431046,
1672
+ "rewards/margins": 0.027016881853342056,
1673
+ "rewards/rejected": -0.0609157457947731,
1674
+ "sft_loss": 0.6779772639274597,
1675
+ "step": 980
1676
+ },
1677
+ {
1678
+ "epoch": 2.375524895020996,
1679
+ "grad_norm": 1.680992010239421,
1680
+ "learning_rate": 5.48844006287289e-07,
1681
+ "logits/chosen": 0.12925365567207336,
1682
+ "logits/rejected": 0.3167954981327057,
1683
+ "logps/chosen": -0.6692675352096558,
1684
+ "logps/rejected": -1.0140740871429443,
1685
+ "loss": 0.6691,
1686
+ "odds_ratio_loss": 0.4763975143432617,
1687
+ "rewards/accuracies": 0.75,
1688
+ "rewards/chosen": -0.033463381230831146,
1689
+ "rewards/margins": 0.01724032498896122,
1690
+ "rewards/rejected": -0.050703711807727814,
1691
+ "sft_loss": 0.6692675352096558,
1692
+ "step": 990
1693
+ },
1694
+ {
1695
+ "epoch": 2.3995200959808036,
1696
+ "grad_norm": 1.544720546176764,
1697
+ "learning_rate": 5.095004326937445e-07,
1698
+ "logits/chosen": -0.4231066107749939,
1699
+ "logits/rejected": -0.20230142772197723,
1700
+ "logps/chosen": -0.6737790107727051,
1701
+ "logps/rejected": -1.0810075998306274,
1702
+ "loss": 0.6744,
1703
+ "odds_ratio_loss": 0.4769432544708252,
1704
+ "rewards/accuracies": 0.75,
1705
+ "rewards/chosen": -0.033688947558403015,
1706
+ "rewards/margins": 0.02036142908036709,
1707
+ "rewards/rejected": -0.05405038595199585,
1708
+ "sft_loss": 0.6737790107727051,
1709
+ "step": 1000
1710
+ },
1711
+ {
1712
+ "epoch": 2.4235152969406117,
1713
+ "grad_norm": 1.7400382431256138,
1714
+ "learning_rate": 4.71459550099202e-07,
1715
+ "logits/chosen": 0.2943962812423706,
1716
+ "logits/rejected": 0.5343393087387085,
1717
+ "logps/chosen": -0.6686779856681824,
1718
+ "logps/rejected": -1.0820672512054443,
1719
+ "loss": 0.7078,
1720
+ "odds_ratio_loss": 0.5010559558868408,
1721
+ "rewards/accuracies": 0.7124999761581421,
1722
+ "rewards/chosen": -0.03343390300869942,
1723
+ "rewards/margins": 0.020669464021921158,
1724
+ "rewards/rejected": -0.054103363305330276,
1725
+ "sft_loss": 0.6686779856681824,
1726
+ "step": 1010
1727
+ },
1728
+ {
1729
+ "epoch": 2.44751049790042,
1730
+ "grad_norm": 1.548219424075948,
1731
+ "learning_rate": 4.347462545228134e-07,
1732
+ "logits/chosen": 0.13567771017551422,
1733
+ "logits/rejected": 0.31968480348587036,
1734
+ "logps/chosen": -0.6244124174118042,
1735
+ "logps/rejected": -1.05476975440979,
1736
+ "loss": 0.6563,
1737
+ "odds_ratio_loss": 0.4984089732170105,
1738
+ "rewards/accuracies": 0.7250000238418579,
1739
+ "rewards/chosen": -0.03122062422335148,
1740
+ "rewards/margins": 0.021517863497138023,
1741
+ "rewards/rejected": -0.052738480269908905,
1742
+ "sft_loss": 0.6244124174118042,
1743
+ "step": 1020
1744
+ },
1745
+ {
1746
+ "epoch": 2.471505698860228,
1747
+ "grad_norm": 1.4610216249122747,
1748
+ "learning_rate": 3.9938457313869914e-07,
1749
+ "logits/chosen": -0.08544759452342987,
1750
+ "logits/rejected": 0.07162941992282867,
1751
+ "logps/chosen": -0.7579829096794128,
1752
+ "logps/rejected": -1.1255767345428467,
1753
+ "loss": 0.6864,
1754
+ "odds_ratio_loss": 0.547897458076477,
1755
+ "rewards/accuracies": 0.762499988079071,
1756
+ "rewards/chosen": -0.03789914771914482,
1757
+ "rewards/margins": 0.01837969198822975,
1758
+ "rewards/rejected": -0.05627884343266487,
1759
+ "sft_loss": 0.7579829096794128,
1760
+ "step": 1030
1761
+ },
1762
+ {
1763
+ "epoch": 2.495500899820036,
1764
+ "grad_norm": 1.6006797776983446,
1765
+ "learning_rate": 3.6539764855126224e-07,
1766
+ "logits/chosen": -0.23340921103954315,
1767
+ "logits/rejected": -0.1814245879650116,
1768
+ "logps/chosen": -0.6439553499221802,
1769
+ "logps/rejected": -1.0276587009429932,
1770
+ "loss": 0.6617,
1771
+ "odds_ratio_loss": 0.5049816370010376,
1772
+ "rewards/accuracies": 0.800000011920929,
1773
+ "rewards/chosen": -0.03219776228070259,
1774
+ "rewards/margins": 0.019185172393918037,
1775
+ "rewards/rejected": -0.05138293653726578,
1776
+ "sft_loss": 0.6439553499221802,
1777
+ "step": 1040
1778
+ },
1779
+ {
1780
+ "epoch": 2.519496100779844,
1781
+ "grad_norm": 2.318524117790848,
1782
+ "learning_rate": 3.328077236494087e-07,
1783
+ "logits/chosen": -0.12850667536258698,
1784
+ "logits/rejected": 0.07032374292612076,
1785
+ "logps/chosen": -0.5922039747238159,
1786
+ "logps/rejected": -1.0730435848236084,
1787
+ "loss": 0.6694,
1788
+ "odds_ratio_loss": 0.43941235542297363,
1789
+ "rewards/accuracies": 0.8500000238418579,
1790
+ "rewards/chosen": -0.029610196128487587,
1791
+ "rewards/margins": 0.024041980504989624,
1792
+ "rewards/rejected": -0.05365217477083206,
1793
+ "sft_loss": 0.5922039747238159,
1794
+ "step": 1050
1795
+ },
1796
+ {
1797
+ "epoch": 2.5434913017396523,
1798
+ "grad_norm": 1.8087989245838814,
1799
+ "learning_rate": 3.0163612704959486e-07,
1800
+ "logits/chosen": -0.6611061692237854,
1801
+ "logits/rejected": -0.5293869376182556,
1802
+ "logps/chosen": -0.6281863451004028,
1803
+ "logps/rejected": -0.9944284558296204,
1804
+ "loss": 0.6705,
1805
+ "odds_ratio_loss": 0.47698038816452026,
1806
+ "rewards/accuracies": 0.75,
1807
+ "rewards/chosen": -0.03140931576490402,
1808
+ "rewards/margins": 0.018312102183699608,
1809
+ "rewards/rejected": -0.04972142353653908,
1810
+ "sft_loss": 0.6281863451004028,
1811
+ "step": 1060
1812
+ },
1813
+ {
1814
+ "epoch": 2.56748650269946,
1815
+ "grad_norm": 1.5444353690364836,
1816
+ "learning_rate": 2.71903259137222e-07,
1817
+ "logits/chosen": 0.411745548248291,
1818
+ "logits/rejected": 0.4236873686313629,
1819
+ "logps/chosen": -0.611006498336792,
1820
+ "logps/rejected": -1.0047032833099365,
1821
+ "loss": 0.672,
1822
+ "odds_ratio_loss": 0.48614612221717834,
1823
+ "rewards/accuracies": 0.7749999761581421,
1824
+ "rewards/chosen": -0.03055032715201378,
1825
+ "rewards/margins": 0.019684839993715286,
1826
+ "rewards/rejected": -0.050235163420438766,
1827
+ "sft_loss": 0.611006498336792,
1828
+ "step": 1070
1829
+ },
1830
+ {
1831
+ "epoch": 2.591481703659268,
1832
+ "grad_norm": 2.593043127599419,
1833
+ "learning_rate": 2.436285787155185e-07,
1834
+ "logits/chosen": 0.316955029964447,
1835
+ "logits/rejected": 0.47285112738609314,
1836
+ "logps/chosen": -0.6786519885063171,
1837
+ "logps/rejected": -1.2019875049591064,
1838
+ "loss": 0.6881,
1839
+ "odds_ratio_loss": 0.4908427298069,
1840
+ "rewards/accuracies": 0.7250000238418579,
1841
+ "rewards/chosen": -0.03393259644508362,
1842
+ "rewards/margins": 0.026166772469878197,
1843
+ "rewards/rejected": -0.060099370777606964,
1844
+ "sft_loss": 0.6786519885063171,
1845
+ "step": 1080
1846
+ },
1847
+ {
1848
+ "epoch": 2.6154769046190762,
1849
+ "grad_norm": 2.2050381193088207,
1850
+ "learning_rate": 2.168305902706383e-07,
1851
+ "logits/chosen": -0.4541945457458496,
1852
+ "logits/rejected": -0.18702273070812225,
1853
+ "logps/chosen": -0.7026795148849487,
1854
+ "logps/rejected": -0.962356448173523,
1855
+ "loss": 0.6583,
1856
+ "odds_ratio_loss": 0.5365189909934998,
1857
+ "rewards/accuracies": 0.75,
1858
+ "rewards/chosen": -0.035133976489305496,
1859
+ "rewards/margins": 0.012983846478164196,
1860
+ "rewards/rejected": -0.04811782017350197,
1861
+ "sft_loss": 0.7026795148849487,
1862
+ "step": 1090
1863
+ },
1864
+ {
1865
+ "epoch": 2.6394721055788843,
1866
+ "grad_norm": 1.6921175899136245,
1867
+ "learning_rate": 1.9152683186132476e-07,
1868
+ "logits/chosen": -0.4067768156528473,
1869
+ "logits/rejected": -0.3039708137512207,
1870
+ "logps/chosen": -0.6328436136245728,
1871
+ "logps/rejected": -1.12655770778656,
1872
+ "loss": 0.6919,
1873
+ "odds_ratio_loss": 0.4709090292453766,
1874
+ "rewards/accuracies": 0.762499988079071,
1875
+ "rewards/chosen": -0.031642183661460876,
1876
+ "rewards/margins": 0.024685706943273544,
1877
+ "rewards/rejected": -0.05632789060473442,
1878
+ "sft_loss": 0.6328436136245728,
1879
+ "step": 1100
1880
+ },
1881
+ {
1882
+ "epoch": 2.663467306538692,
1883
+ "grad_norm": 1.5594348597838832,
1884
+ "learning_rate": 1.6773386364104972e-07,
1885
+ "logits/chosen": -0.1575368195772171,
1886
+ "logits/rejected": -0.003553843591362238,
1887
+ "logps/chosen": -0.6768941879272461,
1888
+ "logps/rejected": -1.032041072845459,
1889
+ "loss": 0.6913,
1890
+ "odds_ratio_loss": 0.50171959400177,
1891
+ "rewards/accuracies": 0.75,
1892
+ "rewards/chosen": -0.033844709396362305,
1893
+ "rewards/margins": 0.017757344990968704,
1894
+ "rewards/rejected": -0.05160205811262131,
1895
+ "sft_loss": 0.6768941879272461,
1896
+ "step": 1110
1897
+ },
1898
+ {
1899
+ "epoch": 2.6874625074985,
1900
+ "grad_norm": 1.2735811398241894,
1901
+ "learning_rate": 1.4546725702015096e-07,
1902
+ "logits/chosen": 0.004650235176086426,
1903
+ "logits/rejected": 0.1661575585603714,
1904
+ "logps/chosen": -0.6541981101036072,
1905
+ "logps/rejected": -1.1094247102737427,
1906
+ "loss": 0.6669,
1907
+ "odds_ratio_loss": 0.4492813050746918,
1908
+ "rewards/accuracies": 0.7875000238418579,
1909
+ "rewards/chosen": -0.03270990774035454,
1910
+ "rewards/margins": 0.022761326283216476,
1911
+ "rewards/rejected": -0.055471230298280716,
1912
+ "sft_loss": 0.6541981101036072,
1913
+ "step": 1120
1914
+ },
1915
+ {
1916
+ "epoch": 2.7114577084583082,
1917
+ "grad_norm": 2.2135398834819715,
1918
+ "learning_rate": 1.24741584475056e-07,
1919
+ "logits/chosen": -0.07907108962535858,
1920
+ "logits/rejected": 0.08474680036306381,
1921
+ "logps/chosen": -0.6154497861862183,
1922
+ "logps/rejected": -1.0710924863815308,
1923
+ "loss": 0.6491,
1924
+ "odds_ratio_loss": 0.4509805142879486,
1925
+ "rewards/accuracies": 0.8125,
1926
+ "rewards/chosen": -0.030772492289543152,
1927
+ "rewards/margins": 0.022782133892178535,
1928
+ "rewards/rejected": -0.05355461686849594,
1929
+ "sft_loss": 0.6154497861862183,
1930
+ "step": 1130
1931
+ },
1932
+ {
1933
+ "epoch": 2.7354529094181164,
1934
+ "grad_norm": 1.5137426741255027,
1935
+ "learning_rate": 1.0557041001126145e-07,
1936
+ "logits/chosen": 0.3702402710914612,
1937
+ "logits/rejected": 0.6300150156021118,
1938
+ "logps/chosen": -0.5984182357788086,
1939
+ "logps/rejected": -1.115179419517517,
1940
+ "loss": 0.6191,
1941
+ "odds_ratio_loss": 0.41762223839759827,
1942
+ "rewards/accuracies": 0.8500000238418579,
1943
+ "rewards/chosen": -0.0299209114164114,
1944
+ "rewards/margins": 0.025838062167167664,
1945
+ "rewards/rejected": -0.05575897544622421,
1946
+ "sft_loss": 0.5984182357788086,
1947
+ "step": 1140
1948
+ },
1949
+ {
1950
+ "epoch": 2.7594481103779245,
1951
+ "grad_norm": 1.565522436867544,
1952
+ "learning_rate": 8.796628028631321e-08,
1953
+ "logits/chosen": 0.17880654335021973,
1954
+ "logits/rejected": 0.1116660013794899,
1955
+ "logps/chosen": -0.6091745495796204,
1956
+ "logps/rejected": -1.0210378170013428,
1957
+ "loss": 0.6583,
1958
+ "odds_ratio_loss": 0.4544963836669922,
1959
+ "rewards/accuracies": 0.800000011920929,
1960
+ "rewards/chosen": -0.030458729714155197,
1961
+ "rewards/margins": 0.02059316076338291,
1962
+ "rewards/rejected": -0.05105189234018326,
1963
+ "sft_loss": 0.6091745495796204,
1964
+ "step": 1150
1965
+ },
1966
+ {
1967
+ "epoch": 2.7834433113377326,
1968
+ "grad_norm": 1.604017358081912,
1969
+ "learning_rate": 7.19407163985894e-08,
1970
+ "logits/chosen": -0.04378344863653183,
1971
+ "logits/rejected": 0.18321049213409424,
1972
+ "logps/chosen": -0.6626521348953247,
1973
+ "logps/rejected": -1.1215763092041016,
1974
+ "loss": 0.666,
1975
+ "odds_ratio_loss": 0.4741577208042145,
1976
+ "rewards/accuracies": 0.75,
1977
+ "rewards/chosen": -0.033132605254650116,
1978
+ "rewards/margins": 0.022946210578083992,
1979
+ "rewards/rejected": -0.05607881397008896,
1980
+ "sft_loss": 0.6626521348953247,
1981
+ "step": 1160
1982
+ },
1983
+ {
1984
+ "epoch": 2.8074385122975407,
1985
+ "grad_norm": 1.4084206676302562,
1986
+ "learning_rate": 5.750420634727083e-08,
1987
+ "logits/chosen": -0.45710262656211853,
1988
+ "logits/rejected": -0.3050076961517334,
1989
+ "logps/chosen": -0.671418309211731,
1990
+ "logps/rejected": -1.1854102611541748,
1991
+ "loss": 0.6842,
1992
+ "odds_ratio_loss": 0.4368383288383484,
1993
+ "rewards/accuracies": 0.800000011920929,
1994
+ "rewards/chosen": -0.03357091173529625,
1995
+ "rewards/margins": 0.02569960430264473,
1996
+ "rewards/rejected": -0.05927051231265068,
1997
+ "sft_loss": 0.671418309211731,
1998
+ "step": 1170
1999
+ },
2000
+ {
2001
+ "epoch": 2.8314337132573484,
2002
+ "grad_norm": 1.3507137389822068,
2003
+ "learning_rate": 4.4666198168422656e-08,
2004
+ "logits/chosen": 0.33376216888427734,
2005
+ "logits/rejected": 0.41172194480895996,
2006
+ "logps/chosen": -0.6510582566261292,
2007
+ "logps/rejected": -1.0800405740737915,
2008
+ "loss": 0.6747,
2009
+ "odds_ratio_loss": 0.5277644395828247,
2010
+ "rewards/accuracies": 0.675000011920929,
2011
+ "rewards/chosen": -0.032552916556596756,
2012
+ "rewards/margins": 0.021449116989970207,
2013
+ "rewards/rejected": -0.054002027958631516,
2014
+ "sft_loss": 0.6510582566261292,
2015
+ "step": 1180
2016
+ },
2017
+ {
2018
+ "epoch": 2.8554289142171565,
2019
+ "grad_norm": 1.6874037821147798,
2020
+ "learning_rate": 3.343509375168863e-08,
2021
+ "logits/chosen": 0.20301933586597443,
2022
+ "logits/rejected": 0.32382094860076904,
2023
+ "logps/chosen": -0.6405006647109985,
2024
+ "logps/rejected": -1.0241023302078247,
2025
+ "loss": 0.6718,
2026
+ "odds_ratio_loss": 0.48166948556900024,
2027
+ "rewards/accuracies": 0.75,
2028
+ "rewards/chosen": -0.03202503174543381,
2029
+ "rewards/margins": 0.019180091097950935,
2030
+ "rewards/rejected": -0.051205117255449295,
2031
+ "sft_loss": 0.6405006647109985,
2032
+ "step": 1190
2033
+ },
2034
+ {
2035
+ "epoch": 2.8794241151769646,
2036
+ "grad_norm": 1.6417139708130921,
2037
+ "learning_rate": 2.3818243341637293e-08,
2038
+ "logits/chosen": -0.3619822859764099,
2039
+ "logits/rejected": -0.15361133217811584,
2040
+ "logps/chosen": -0.6599988341331482,
2041
+ "logps/rejected": -1.098881483078003,
2042
+ "loss": 0.6565,
2043
+ "odds_ratio_loss": 0.456063449382782,
2044
+ "rewards/accuracies": 0.8125,
2045
+ "rewards/chosen": -0.03299994021654129,
2046
+ "rewards/margins": 0.021944135427474976,
2047
+ "rewards/rejected": -0.054944075644016266,
2048
+ "sft_loss": 0.6599988341331482,
2049
+ "step": 1200
2050
+ },
2051
+ {
2052
+ "epoch": 2.9034193161367727,
2053
+ "grad_norm": 1.648932215503252,
2054
+ "learning_rate": 1.5821940727361874e-08,
2055
+ "logits/chosen": -0.7362561821937561,
2056
+ "logits/rejected": -0.4996170997619629,
2057
+ "logps/chosen": -0.6824958920478821,
2058
+ "logps/rejected": -0.9969790577888489,
2059
+ "loss": 0.7067,
2060
+ "odds_ratio_loss": 0.5307115316390991,
2061
+ "rewards/accuracies": 0.7124999761581421,
2062
+ "rewards/chosen": -0.034124795347452164,
2063
+ "rewards/margins": 0.01572415977716446,
2064
+ "rewards/rejected": -0.049848951399326324,
2065
+ "sft_loss": 0.6824958920478821,
2066
+ "step": 1210
2067
+ },
2068
+ {
2069
+ "epoch": 2.927414517096581,
2070
+ "grad_norm": 1.7678674281978446,
2071
+ "learning_rate": 9.451419123484573e-09,
2072
+ "logits/chosen": -0.15318191051483154,
2073
+ "logits/rejected": 0.047946538776159286,
2074
+ "logps/chosen": -0.6560810804367065,
2075
+ "logps/rejected": -1.0658347606658936,
2076
+ "loss": 0.6692,
2077
+ "odds_ratio_loss": 0.5046226382255554,
2078
+ "rewards/accuracies": 0.75,
2079
+ "rewards/chosen": -0.032804060727357864,
2080
+ "rewards/margins": 0.02048768661916256,
2081
+ "rewards/rejected": -0.053291745483875275,
2082
+ "sft_loss": 0.6560810804367065,
2083
+ "step": 1220
2084
+ },
2085
+ {
2086
+ "epoch": 2.9514097180563885,
2087
+ "grad_norm": 1.4413325593301094,
2088
+ "learning_rate": 4.710847745256209e-09,
2089
+ "logits/chosen": 0.12647075951099396,
2090
+ "logits/rejected": 0.2795228958129883,
2091
+ "logps/chosen": -0.6180914640426636,
2092
+ "logps/rejected": -1.0847346782684326,
2093
+ "loss": 0.6722,
2094
+ "odds_ratio_loss": 0.41623228788375854,
2095
+ "rewards/accuracies": 0.8374999761581421,
2096
+ "rewards/chosen": -0.030904576182365417,
2097
+ "rewards/margins": 0.02333216182887554,
2098
+ "rewards/rejected": -0.05423673242330551,
2099
+ "sft_loss": 0.6180914640426636,
2100
+ "step": 1230
2101
+ },
2102
+ {
2103
+ "epoch": 2.9754049190161966,
2104
+ "grad_norm": 1.5296676400661524,
2105
+ "learning_rate": 1.603329079994942e-09,
2106
+ "logits/chosen": -0.3425149619579315,
2107
+ "logits/rejected": -0.06856220215559006,
2108
+ "logps/chosen": -0.6569226980209351,
2109
+ "logps/rejected": -1.1020539999008179,
2110
+ "loss": 0.6649,
2111
+ "odds_ratio_loss": 0.4642546772956848,
2112
+ "rewards/accuracies": 0.762499988079071,
2113
+ "rewards/chosen": -0.03284613788127899,
2114
+ "rewards/margins": 0.02225656434893608,
2115
+ "rewards/rejected": -0.055102698504924774,
2116
+ "sft_loss": 0.6569226980209351,
2117
+ "step": 1240
2118
+ },
2119
+ {
2120
+ "epoch": 2.994601079784043,
2121
+ "step": 1248,
2122
+ "total_flos": 132590267662336.0,
2123
+ "train_loss": 0.7937506708579186,
2124
+ "train_runtime": 49781.9259,
2125
+ "train_samples_per_second": 1.205,
2126
+ "train_steps_per_second": 0.025
2127
+ }
2128
+ ],
2129
+ "logging_steps": 10,
2130
+ "max_steps": 1248,
2131
+ "num_input_tokens_seen": 0,
2132
+ "num_train_epochs": 3,
2133
+ "save_steps": 100.0,
2134
+ "total_flos": 132590267662336.0,
2135
+ "train_batch_size": 1,
2136
+ "trial_name": null,
2137
+ "trial_params": null
2138
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5252fd1d5c3ae6e8eedab656003266bd9a9302edb91e20004f9582cf004a79
3
+ size 7032