File size: 27,873 Bytes
dbd1a76 f5e640e 16b0f23 dbd1a76 9fd3ec1 dbd1a76 16b0f23 dbd1a76 f5e640e dbd1a76 9fd3ec1 dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 16b0f23 dbd1a76 f5e640e dbd1a76 9fd3ec1 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 9fd3ec1 dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 9fd3ec1 dbd1a76 16b0f23 dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 9fd3ec1 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 9fd3ec1 dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 9fd3ec1 dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e dbd1a76 f5e640e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
---
language: en
tags:
- jax
- flax
- text-generation
- transformers
- meta-llama/Llama-3.2-3B # Add the specific model name as a tag
---
# meta-llama/Llama-3.2-3B - JAX/Flax
This repository contains the JAX/Flax version of the meta-llama/Llama-3.2-3B model, originally a PyTorch model from meta-llama. This conversion enables efficient inference and training on TPUs and GPUs using the JAX/Flax framework.
## Model Description
meta-llama/Llama-3.2-3B is a transformer-based language model developed by meta-llama.
## Conversion Details
This model was converted from the original PyTorch implementation to JAX/Flax. The conversion process involved the following steps:
1. **Loading the PyTorch model and configuration:** The pretrained PyTorch model and its configuration were loaded using the Hugging Face Transformers library.
2. **Creating an equivalent Flax model architecture:** A Flax model with the same architecture as the original PyTorch model was created.
3. **Converting the PyTorch weights to Flax format:** The weights from the PyTorch model were converted to the Flax format using the `convert_pytorch_state_dict_to_flax` utility function provided by Hugging Face.
4. **Verifying the converted weights:** The converted Flax weights were compared against the original PyTorch weights to ensure that the conversion process was performed accurately.
### Important Note about `max_position_embeddings`
During the conversion process, it was necessary to modify the `max_position_embeddings` parameter in the model's configuration. The original value of 131072 led to out-of-memory (OOM) errors on the hardware used for conversion. To resolve this, `max_position_embeddings` was adjusted to 16384.
**Implications of this change:**
* The model may not be able to handle sequences longer than 16384 tokens without truncation or other modifications.
* If you fine-tune this model, keep in mind the revised `max_position_embeddings` when preparing your training data.
## Weight Comparison
The following table summarizes the comparison between the weights of the original PyTorch model and the converted JAX/Flax model. This detailed verification confirms that the conversion was accurate and that both models should produce (approximately) the same outputs given the same inputs.
| Layer | PyTorch Shape | Flax Shape | Allclose | Max Diff | Mean Diff | Std Diff |
| :---- | :------------ | :--------- | :------- | :------- | :-------- | :------- |
| model.embed_tokens.weight | (128256, 3072) | (128256, 3072) | True | 0 | 0 | 0 |
| model.layers.0.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.0.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.0.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.0.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.0.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.0.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.0.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.0.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.0.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.1.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.1.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.1.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.1.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.1.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.1.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.1.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.1.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.1.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.2.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.2.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.2.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.2.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.2.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.2.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.2.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.2.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.2.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.3.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.3.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.3.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.3.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.3.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.3.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.3.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.3.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.3.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.4.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.4.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.4.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.4.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.4.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.4.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.4.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.4.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.4.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.5.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.5.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.5.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.5.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.5.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.5.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.5.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.5.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.5.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.6.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.6.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.6.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.6.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.6.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.6.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.6.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.6.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.6.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.7.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.7.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.7.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.7.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.7.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.7.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.7.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.7.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.7.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.8.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.8.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.8.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.8.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.8.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.8.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.8.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.8.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.8.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.9.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.9.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.9.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.9.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.9.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.9.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.9.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.9.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.9.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.10.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.10.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.10.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.10.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.10.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.10.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.10.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.10.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.10.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.11.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.11.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.11.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.11.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.11.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.11.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.11.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.11.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.11.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.12.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.12.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.12.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.12.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.12.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.12.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.12.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.12.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.12.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.13.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.13.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.13.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.13.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.13.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.13.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.13.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.13.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.13.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.14.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.14.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.14.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.14.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.14.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.14.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.14.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.14.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.14.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.15.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.15.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.15.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.15.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.15.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.15.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.15.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.15.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.15.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.16.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.16.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.16.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.16.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.16.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.16.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.16.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.16.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.16.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.17.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.17.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.17.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.17.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.17.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.17.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.17.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.17.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.17.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.18.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.18.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.18.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.18.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.18.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.18.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.18.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.18.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.18.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.19.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.19.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.19.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.19.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.19.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.19.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.19.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.19.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.19.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.20.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.20.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.20.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.20.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.20.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.20.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.20.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.20.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.20.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.21.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.21.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.21.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.21.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.21.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.21.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.21.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.21.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.21.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.22.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.22.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.22.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.22.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.22.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.22.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.22.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.22.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.22.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.23.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.23.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.23.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.23.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.23.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.23.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.23.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.23.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.23.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.24.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.24.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.24.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.24.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.24.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.24.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.24.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.24.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.24.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.25.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.25.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.25.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.25.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.25.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.25.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.25.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.25.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.25.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.26.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.26.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.26.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.26.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.26.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.26.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.26.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.26.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.26.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.27.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.27.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.27.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 |
| model.layers.27.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 |
| model.layers.27.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.27.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 |
| model.layers.27.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 |
| model.layers.27.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.layers.27.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| model.norm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 |
| lm_head.weight | (3072, 128256) | (3072, 128256) | True | 0 | 0 | 0 |
**Note:**
* `Allclose` indicates whether the weights are approximately equal within the specified relative (`rtol=1e-5`) and absolute (`atol=1e-3`) tolerances using `jnp.allclose()`.
* `Max Diff`, `Mean Diff`, and `Std Diff` provide further details on the differences between the weights if `Allclose` is `False`, which might be expected for some layers due to numerical precision differences between frameworks.
## Hardware Used for Conversion
The conversion process was performed on the following hardware configuration:
* **CPU:**
* **RAM:** 251.67 GB
* **OS:** Linux-5.15.0-107-generic-x86_64-with-glibc2.36
* **JAX version:** 0.3.22
* **Flax version:** 0.6.2
* **Transformers version:** 4.47.0
* **GPU:** NVIDIA A100-SXM4-40GB
This conversion took approximately 81.05 seconds to complete.
## Usage
Here's how you can use the converted model in JAX/Flax for text generation:
```python
import jax
import jax.numpy as jnp
from transformers import FlaxAutoModelForCausalLM, AutoTokenizer
model_name = "Erland/Llama-3.2-3B-JAX" # Replace with your repository name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = FlaxAutoModelForCausalLM.from_pretrained(model_name, from_pt=False) # from_pt should be False since it's already flax
# Example prompt
prompt = "The quick brown fox"
# Tokenize the prompt
tokenized_prompt = tokenizer(prompt, return_tensors="np")
# Generate text
output_ids = model.generate(tokenized_prompt.input_ids, max_length=50)
# Decode the generated text
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
```
## Limitations
Sequence Length: As mentioned earlier, the max_position_embeddings has been modified to 16384. Be mindful of this limitation when working with long sequences.
Numerical Precision: Minor differences in outputs compared to the original PyTorch model might be observed due to numerical precision variations between PyTorch and JAX/Flax, particularly on different hardware.
## Acknowledgements
We thank the original authors of meta-llama/Llama-3.2-3B at `meta-llama` for their groundbreaking work in developing this powerful language model.
We acknowledge the Hugging Face Transformers library for providing the essential tools and infrastructure that made this conversion possible.
Thanks to the JAX and Flax teams for developing such performant and flexible frameworks for numerical computation and deep learning.
## License
This JAX/Flax model is released under the original model license. |