ThomasTheMaker commited on
Commit
4e09a9f
·
verified ·
1 Parent(s): 952b8db

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +480 -0
  2. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/config.json +22 -0
  3. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/fabric_state/checkpoint.pt +3 -0
  4. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/generation_config.json +4 -0
  5. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_activations.pt +0 -0
  6. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/data-00000-of-00001.arrow +3 -0
  7. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/dataset_info.json +19 -0
  8. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/state.json +13 -0
  9. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_gradients.pt +3 -0
  10. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_weights.pt +3 -0
  11. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/model.safetensors +3 -0
  12. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/pico_decoder.py +911 -0
  13. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/special_tokens_map.json +16 -0
  14. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/tokenizer.json +0 -0
  15. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/tokenizer_config.json +239 -0
  16. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/config.json +22 -0
  17. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/fabric_state/checkpoint.pt +3 -0
  18. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/generation_config.json +4 -0
  19. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_activations.pt +0 -0
  20. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/data-00000-of-00001.arrow +3 -0
  21. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/dataset_info.json +19 -0
  22. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/state.json +13 -0
  23. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_gradients.pt +3 -0
  24. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_weights.pt +3 -0
  25. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/model.safetensors +3 -0
  26. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/pico_decoder.py +911 -0
  27. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/special_tokens_map.json +16 -0
  28. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/tokenizer.json +0 -0
  29. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/tokenizer_config.json +239 -0
  30. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/config.json +22 -0
  31. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/fabric_state/checkpoint.pt +3 -0
  32. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/generation_config.json +4 -0
  33. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_activations.pt +0 -0
  34. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/data-00000-of-00001.arrow +3 -0
  35. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/dataset_info.json +19 -0
  36. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/state.json +13 -0
  37. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_gradients.pt +3 -0
  38. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_weights.pt +3 -0
  39. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/model.safetensors +3 -0
  40. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/pico_decoder.py +911 -0
  41. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/special_tokens_map.json +16 -0
  42. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/tokenizer.json +0 -0
  43. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/tokenizer_config.json +239 -0
  44. pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/config.json +22 -0
  45. pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/fabric_state/checkpoint.pt +3 -0
  46. pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/generation_config.json +4 -0
  47. pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_activations.pt +0 -0
  48. pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_data/data-00000-of-00001.arrow +3 -0
  49. pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_data/dataset_info.json +19 -0
  50. pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_data/state.json +13 -0
.gitattributes ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
2
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
3
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
4
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
5
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/model.safetensors filter=lfs diff=lfs merge=lfs -text
6
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
7
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
8
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
9
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
10
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/model.safetensors filter=lfs diff=lfs merge=lfs -text
11
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
12
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
13
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
14
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
15
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/model.safetensors filter=lfs diff=lfs merge=lfs -text
16
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
17
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
18
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
19
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
20
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/model.safetensors filter=lfs diff=lfs merge=lfs -text
21
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_12000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
22
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_12000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
23
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_12000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
24
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_12000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
25
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_12000/model.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_13000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
27
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_13000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
28
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_13000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
29
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_13000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
30
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_13000/model.safetensors filter=lfs diff=lfs merge=lfs -text
31
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_14000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
32
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_14000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
33
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_14000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
34
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_14000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
35
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_14000/model.safetensors filter=lfs diff=lfs merge=lfs -text
36
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_15000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
37
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_15000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
38
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_15000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
39
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_15000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
40
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_15000/model.safetensors filter=lfs diff=lfs merge=lfs -text
41
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_16000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
42
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_16000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
43
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_16000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
44
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_16000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
45
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_16000/model.safetensors filter=lfs diff=lfs merge=lfs -text
46
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_17000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
47
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_17000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
48
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_17000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
49
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_17000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
50
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_17000/model.safetensors filter=lfs diff=lfs merge=lfs -text
51
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_18000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
52
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_18000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
53
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_18000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
54
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_18000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
55
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_18000/model.safetensors filter=lfs diff=lfs merge=lfs -text
56
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_19000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
57
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_19000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
58
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_19000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
59
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_19000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
60
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_19000/model.safetensors filter=lfs diff=lfs merge=lfs -text
61
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_2000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
62
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_2000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
63
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_2000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
64
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_2000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
65
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_2000/model.safetensors filter=lfs diff=lfs merge=lfs -text
66
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_20000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
67
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_20000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
68
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_20000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
69
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_20000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
70
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_20000/model.safetensors filter=lfs diff=lfs merge=lfs -text
71
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_21000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
72
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_21000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
73
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_21000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
74
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_21000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
75
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_21000/model.safetensors filter=lfs diff=lfs merge=lfs -text
76
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_22000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
77
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_22000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
78
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_22000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
79
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_22000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
80
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_22000/model.safetensors filter=lfs diff=lfs merge=lfs -text
81
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_23000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
82
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_23000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
83
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_23000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
84
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_23000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
85
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_23000/model.safetensors filter=lfs diff=lfs merge=lfs -text
86
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_24000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
87
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_24000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
88
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_24000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
89
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_24000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
90
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_24000/model.safetensors filter=lfs diff=lfs merge=lfs -text
91
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_25000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
92
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_25000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
93
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_25000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
94
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_25000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
95
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_25000/model.safetensors filter=lfs diff=lfs merge=lfs -text
96
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_26000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
97
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_26000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
98
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_26000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
99
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_26000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
100
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_26000/model.safetensors filter=lfs diff=lfs merge=lfs -text
101
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_27000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
102
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_27000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
103
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_27000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
104
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_27000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
105
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_27000/model.safetensors filter=lfs diff=lfs merge=lfs -text
106
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_28000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
107
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_28000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
108
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_28000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
109
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_28000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
110
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_28000/model.safetensors filter=lfs diff=lfs merge=lfs -text
111
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_29000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
112
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_29000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
113
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_29000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
114
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_29000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
115
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_29000/model.safetensors filter=lfs diff=lfs merge=lfs -text
116
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_3000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
117
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_3000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
118
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_3000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
119
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_3000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
120
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_3000/model.safetensors filter=lfs diff=lfs merge=lfs -text
121
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_30000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
122
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_30000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
123
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_30000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
124
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_30000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
125
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_30000/model.safetensors filter=lfs diff=lfs merge=lfs -text
126
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_31000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
127
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_31000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
128
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_31000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
129
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_31000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
130
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_31000/model.safetensors filter=lfs diff=lfs merge=lfs -text
131
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_32000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
132
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_32000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
133
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_32000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
134
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_32000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
135
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_32000/model.safetensors filter=lfs diff=lfs merge=lfs -text
136
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_33000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
137
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_33000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
138
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_33000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
139
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_33000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
140
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_33000/model.safetensors filter=lfs diff=lfs merge=lfs -text
141
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_34000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
142
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_34000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
143
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_34000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
144
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_34000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
145
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_34000/model.safetensors filter=lfs diff=lfs merge=lfs -text
146
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_35000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
147
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_35000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
148
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_35000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
149
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_35000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
150
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_35000/model.safetensors filter=lfs diff=lfs merge=lfs -text
151
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_36000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
152
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_36000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
153
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_36000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
154
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_36000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
155
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_36000/model.safetensors filter=lfs diff=lfs merge=lfs -text
156
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_37000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
157
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_37000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
158
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_37000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
159
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_37000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
160
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_37000/model.safetensors filter=lfs diff=lfs merge=lfs -text
161
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_38000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
162
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_38000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
163
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_38000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
164
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_38000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
165
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_38000/model.safetensors filter=lfs diff=lfs merge=lfs -text
166
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_39000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
167
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_39000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
168
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_39000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
169
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_39000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
170
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_39000/model.safetensors filter=lfs diff=lfs merge=lfs -text
171
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_4000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
172
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_4000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
173
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_4000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
174
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_4000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
175
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_4000/model.safetensors filter=lfs diff=lfs merge=lfs -text
176
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_40000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
177
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_40000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
178
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_40000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
179
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_40000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
180
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_40000/model.safetensors filter=lfs diff=lfs merge=lfs -text
181
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_41000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
182
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_41000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
183
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_41000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
184
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_41000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
185
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_41000/model.safetensors filter=lfs diff=lfs merge=lfs -text
186
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_42000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
187
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_42000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
188
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_42000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
189
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_42000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
190
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_42000/model.safetensors filter=lfs diff=lfs merge=lfs -text
191
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_43000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
192
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_43000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
193
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_43000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
194
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_43000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
195
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_43000/model.safetensors filter=lfs diff=lfs merge=lfs -text
196
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_44000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
197
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_44000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
198
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_44000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
199
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_44000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
200
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_44000/model.safetensors filter=lfs diff=lfs merge=lfs -text
201
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_45000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
202
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_45000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
203
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_45000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
204
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_45000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
205
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_45000/model.safetensors filter=lfs diff=lfs merge=lfs -text
206
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_46000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
207
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_46000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
208
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_46000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
209
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_46000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
210
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_46000/model.safetensors filter=lfs diff=lfs merge=lfs -text
211
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_47000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
212
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_47000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
213
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_47000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
214
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_47000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
215
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_47000/model.safetensors filter=lfs diff=lfs merge=lfs -text
216
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_48000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
217
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_48000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
218
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_48000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
219
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_48000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
220
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_48000/model.safetensors filter=lfs diff=lfs merge=lfs -text
221
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_49000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
222
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_49000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
223
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_49000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
224
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_49000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
225
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_49000/model.safetensors filter=lfs diff=lfs merge=lfs -text
226
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_5000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
227
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_5000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
228
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_5000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
229
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_5000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
230
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_5000/model.safetensors filter=lfs diff=lfs merge=lfs -text
231
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_50000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
232
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_50000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
233
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_50000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
234
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_50000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
235
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_50000/model.safetensors filter=lfs diff=lfs merge=lfs -text
236
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_51000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
237
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_51000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
238
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_51000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
239
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_51000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
240
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_51000/model.safetensors filter=lfs diff=lfs merge=lfs -text
241
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_52000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
242
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_52000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
243
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_52000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
244
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_52000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
245
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_52000/model.safetensors filter=lfs diff=lfs merge=lfs -text
246
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_53000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
247
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_53000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
248
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_53000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
249
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_53000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
250
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_53000/model.safetensors filter=lfs diff=lfs merge=lfs -text
251
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_54000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
252
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_54000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
253
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_54000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
254
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_54000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
255
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_54000/model.safetensors filter=lfs diff=lfs merge=lfs -text
256
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_55000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
257
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_55000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
258
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_55000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
259
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_55000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
260
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_55000/model.safetensors filter=lfs diff=lfs merge=lfs -text
261
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_56000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
262
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_56000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
263
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_56000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
264
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_56000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
265
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_56000/model.safetensors filter=lfs diff=lfs merge=lfs -text
266
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_57000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
267
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_57000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
268
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_57000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
269
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_57000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
270
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_57000/model.safetensors filter=lfs diff=lfs merge=lfs -text
271
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_58000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
272
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_58000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
273
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_58000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
274
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_58000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
275
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_58000/model.safetensors filter=lfs diff=lfs merge=lfs -text
276
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_59000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
277
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_59000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
278
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_59000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
279
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_59000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
280
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_59000/model.safetensors filter=lfs diff=lfs merge=lfs -text
281
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_6000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
282
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_6000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
283
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_6000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
284
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_6000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
285
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_6000/model.safetensors filter=lfs diff=lfs merge=lfs -text
286
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_60000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
287
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_60000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
288
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_60000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
289
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_60000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
290
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_60000/model.safetensors filter=lfs diff=lfs merge=lfs -text
291
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_61000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
292
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_61000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
293
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_61000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
294
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_61000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
295
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_61000/model.safetensors filter=lfs diff=lfs merge=lfs -text
296
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_62000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
297
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_62000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
298
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_62000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
299
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_62000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
300
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_62000/model.safetensors filter=lfs diff=lfs merge=lfs -text
301
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_63000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
302
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_63000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
303
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_63000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
304
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_63000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
305
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_63000/model.safetensors filter=lfs diff=lfs merge=lfs -text
306
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_64000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
307
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_64000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
308
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_64000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
309
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_64000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
310
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_64000/model.safetensors filter=lfs diff=lfs merge=lfs -text
311
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_65000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
312
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_65000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
313
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_65000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
314
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_65000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
315
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_65000/model.safetensors filter=lfs diff=lfs merge=lfs -text
316
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_66000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
317
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_66000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
318
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_66000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
319
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_66000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
320
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_66000/model.safetensors filter=lfs diff=lfs merge=lfs -text
321
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_67000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
322
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_67000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
323
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_67000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
324
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_67000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
325
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_67000/model.safetensors filter=lfs diff=lfs merge=lfs -text
326
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_68000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
327
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_68000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
328
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_68000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
329
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_68000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
330
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_68000/model.safetensors filter=lfs diff=lfs merge=lfs -text
331
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_69000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
332
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_69000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
333
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_69000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
334
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_69000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
335
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_69000/model.safetensors filter=lfs diff=lfs merge=lfs -text
336
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_7000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
337
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_7000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
338
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_7000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
339
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_7000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
340
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_7000/model.safetensors filter=lfs diff=lfs merge=lfs -text
341
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_70000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
342
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_70000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
343
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_70000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
344
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_70000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
345
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_70000/model.safetensors filter=lfs diff=lfs merge=lfs -text
346
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_71000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
347
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_71000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
348
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_71000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
349
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_71000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
350
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_71000/model.safetensors filter=lfs diff=lfs merge=lfs -text
351
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_72000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
352
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_72000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
353
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_72000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
354
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_72000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
355
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_72000/model.safetensors filter=lfs diff=lfs merge=lfs -text
356
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_73000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
357
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_73000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
358
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_73000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
359
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_73000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
360
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_73000/model.safetensors filter=lfs diff=lfs merge=lfs -text
361
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_74000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
362
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_74000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
363
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_74000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
364
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_74000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
365
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_74000/model.safetensors filter=lfs diff=lfs merge=lfs -text
366
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_75000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
367
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_75000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
368
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_75000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
369
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_75000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
370
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_75000/model.safetensors filter=lfs diff=lfs merge=lfs -text
371
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_76000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
372
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_76000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
373
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_76000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
374
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_76000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
375
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_76000/model.safetensors filter=lfs diff=lfs merge=lfs -text
376
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_77000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
377
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_77000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
378
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_77000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
379
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_77000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
380
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_77000/model.safetensors filter=lfs diff=lfs merge=lfs -text
381
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_78000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
382
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_78000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
383
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_78000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
384
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_78000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
385
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_78000/model.safetensors filter=lfs diff=lfs merge=lfs -text
386
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_79000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
387
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_79000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
388
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_79000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
389
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_79000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
390
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_79000/model.safetensors filter=lfs diff=lfs merge=lfs -text
391
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_8000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
392
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_8000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
393
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_8000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
394
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_8000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
395
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_8000/model.safetensors filter=lfs diff=lfs merge=lfs -text
396
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_80000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
397
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_80000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
398
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_80000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
399
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_80000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
400
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_80000/model.safetensors filter=lfs diff=lfs merge=lfs -text
401
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_81000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
402
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_81000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
403
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_81000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
404
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_81000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
405
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_81000/model.safetensors filter=lfs diff=lfs merge=lfs -text
406
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_82000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
407
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_82000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
408
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_82000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
409
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_82000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
410
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_82000/model.safetensors filter=lfs diff=lfs merge=lfs -text
411
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_83000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
412
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_83000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
413
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_83000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
414
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_83000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
415
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_83000/model.safetensors filter=lfs diff=lfs merge=lfs -text
416
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_84000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
417
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_84000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
418
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_84000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
419
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_84000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
420
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_84000/model.safetensors filter=lfs diff=lfs merge=lfs -text
421
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_85000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
422
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_85000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
423
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_85000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
424
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_85000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
425
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_85000/model.safetensors filter=lfs diff=lfs merge=lfs -text
426
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_86000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
427
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_86000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
428
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_86000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
429
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_86000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
430
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_86000/model.safetensors filter=lfs diff=lfs merge=lfs -text
431
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_87000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
432
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_87000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
433
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_87000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
434
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_87000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
435
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_87000/model.safetensors filter=lfs diff=lfs merge=lfs -text
436
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_88000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
437
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_88000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
438
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_88000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
439
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_88000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
440
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_88000/model.safetensors filter=lfs diff=lfs merge=lfs -text
441
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_89000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
442
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_89000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
443
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_89000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
444
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_89000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
445
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_89000/model.safetensors filter=lfs diff=lfs merge=lfs -text
446
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_9000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
447
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_9000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
448
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_9000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
449
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_9000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
450
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_9000/model.safetensors filter=lfs diff=lfs merge=lfs -text
451
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_90000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
452
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_90000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
453
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_90000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
454
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_90000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
455
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_90000/model.safetensors filter=lfs diff=lfs merge=lfs -text
456
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_91000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
457
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_91000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
458
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_91000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
459
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_91000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
460
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_91000/model.safetensors filter=lfs diff=lfs merge=lfs -text
461
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_92000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
462
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_92000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
463
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_92000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
464
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_92000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
465
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_92000/model.safetensors filter=lfs diff=lfs merge=lfs -text
466
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_93000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
467
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_93000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
468
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_93000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
469
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_93000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
470
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_93000/model.safetensors filter=lfs diff=lfs merge=lfs -text
471
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_94000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
472
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_94000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
473
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_94000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
474
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_94000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
475
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_94000/model.safetensors filter=lfs diff=lfs merge=lfs -text
476
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_95000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
477
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_95000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
478
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_95000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
479
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_95000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
480
+ pico-decoder-tiny-dolma20M-v1/checkpoints/step_95000/model.safetensors filter=lfs diff=lfs merge=lfs -text
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_hidden_dim": 384,
3
+ "architectures": [
4
+ "PicoDecoderHF"
5
+ ],
6
+ "attention_n_heads": 12,
7
+ "attention_n_kv_heads": 4,
8
+ "auto_map": {
9
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
+ },
12
+ "batch_size": 1024,
13
+ "d_model": 96,
14
+ "max_seq_len": 2048,
15
+ "model_type": "pico_decoder",
16
+ "n_layers": 12,
17
+ "norm_eps": 1e-06,
18
+ "position_emb_theta": 10000.0,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.48.3",
21
+ "vocab_size": 50304
22
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/fabric_state/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d01a6a79f53f412afc600ef5825ba1ce606eacf5d8808aa0c83de62b2b42ef28
3
+ size 45187997
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_activations.pt ADDED
Binary file (98.3 kB). View file
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71dea5221a5b809d03b575f9a437c3772951dc4d8c202e5af09d005a23791b3a
3
+ size 271568
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/dataset_info.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "text": {
13
+ "dtype": "string",
14
+ "_type": "Value"
15
+ }
16
+ },
17
+ "homepage": "",
18
+ "license": ""
19
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "6cc12b19e292c1f8",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_gradients.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cd1f0ba01d7b6d8f8c470ba7065f7ba7251409f02235127fb5952480aec233a
3
+ size 2371527
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c029ef92a6494ae121c847e432e52e6a8ff3bf7d9fef3e61bef871c1e9a9aa02
3
+ size 2371443
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1852515eb5c8556533445f22edf523884b9f8cc44812379a6a951668a4ffa3a3
3
+ size 45143592
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/pico_decoder.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Decoder: A Lightweight Causal Transformer Language Model
3
+
4
+ Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
+
6
+ Everything is written with a modular design for easy modification and experimentation.
7
+
8
+ Key features:
9
+ - RMSNorm for layer normalization
10
+ - Rotary Positional Embeddings (RoPE)
11
+ - Multi-head attention with KV-cache support
12
+ - SwiGLU activation function
13
+ - Residual connections throughout
14
+
15
+ - KV-cache for faster autoregressive generation
16
+
17
+ References:
18
+ - RoPE: https://arxiv.org/abs/2104.09864
19
+ - SwiGLU: https://arxiv.org/abs/2002.05202
20
+ - LLAMA: https://arxiv.org/abs/2302.13971
21
+
22
+ Adapted from:
23
+ - OLMO: https://github.com/allenai/OLMo
24
+ - LLAMA: https://github.com/meta/llama
25
+ """
26
+
27
+ from dataclasses import asdict
28
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ # Handle PyTorch version compatibility for attention backend
35
+ try:
36
+ from torch.nn.attention import SDPBackend, sdpa_kernel
37
+
38
+ HAS_TORCH_ATTENTION = True
39
+ except ImportError:
40
+ # Fallback for older PyTorch versions
41
+ HAS_TORCH_ATTENTION = False
42
+ SDPBackend = None
43
+ sdpa_kernel = None
44
+
45
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
+ from transformers.generation import GenerationConfig
47
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
+
49
+ try:
50
+ if TYPE_CHECKING:
51
+ # We need to do this to avoid importing these when creating the HF-compatible models
52
+ from src.config import ModelConfig
53
+ except ImportError:
54
+ pass
55
+
56
+ ########################################################
57
+ #
58
+ # Layer Normalization
59
+ #
60
+ ########################################################
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ """Root Mean Square Layer Normalization.
65
+
66
+ A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
+ resulting in improved stability and performance.
68
+
69
+ Args:
70
+ config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
+ - config.norm_eps: Small constant for numerical stability
72
+ - config.d_model: Model dimension for the weight parameter
73
+
74
+ References:
75
+ https://arxiv.org/abs/1910.07467
76
+ """
77
+
78
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
+ super().__init__()
80
+ self.eps = config.norm_eps
81
+ self.weight = nn.Parameter(torch.ones(config.d_model))
82
+
83
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Normalizes the input tensor by its RMS value.
86
+ """
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
+ """
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+ ########################################################
98
+ #
99
+ # Positional Embedding
100
+ #
101
+ ########################################################
102
+
103
+
104
+ class RoPE(nn.Module):
105
+ """Rotary Positional Embeddings (RoPE).
106
+
107
+ Implements position-dependent rotation of keys and queries in attention mechanism,
108
+ allowing better modeling of relative positions in sequences. Uses complex number
109
+ operations for efficient rotation.
110
+
111
+ Args:
112
+ config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
+ - config.position_emb_theta: Base for frequency computation
114
+ - config.d_model: Model dimension
115
+ - config.attention_n_heads: Number of attention heads
116
+ - config.max_seq_len: Maximum sequence length
117
+
118
+ References:
119
+ https://arxiv.org/abs/2104.09864
120
+ """
121
+
122
+ _freqs_cis_tensor: torch.Tensor | None = None
123
+
124
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
+ super().__init__()
126
+
127
+ self.theta = config.position_emb_theta
128
+ self.dim = config.d_model // config.attention_n_heads
129
+
130
+ max_seq_len = config.max_seq_len
131
+
132
+ # only gets set once, and then reused for all RoPE instances
133
+ if RoPE._freqs_cis_tensor is None:
134
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
+ max_seq_len, self.theta, self.dim
136
+ )
137
+
138
+ # register _freqs_cis buffer
139
+ # can be easily recomputed so persistent=False
140
+ self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
+
142
+ @classmethod
143
+ def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
+ """Setup Frequency Tensor for RoPE Embeddings
145
+
146
+ Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
+
148
+ Note other implementations will use cos and sin directly, but using the complex
149
+ number representation is (probably) more efficient:
150
+
151
+ e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
+ """
153
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
+ positions = torch.arange(seq_len)
155
+ freqs = torch.outer(positions, _freqs)
156
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
+
158
+ def get_freqs_cis(
159
+ self, input_shape: torch.Size, start_pos: int, end_pos: int
160
+ ) -> torch.Tensor:
161
+ """Reshape Frequency Tensor for RoPE Embeddings
162
+
163
+ Makes the frequency tensor broadcastable with the input tensor.
164
+ """
165
+ _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
+ ndim = len(input_shape)
167
+ assert 0 <= 1 < ndim
168
+ assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
+
170
+ # TODO: Check whether this is correct (might be able to remove this)
171
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
+ return _freqs_cis.view(*shape)
173
+
174
+ def forward(
175
+ self,
176
+ queries: torch.Tensor,
177
+ keys: torch.Tensor,
178
+ start_pos: int = 0,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Apply RoPE Embeddings to Queries and Keys
181
+
182
+ Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
+
184
+ NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
+ """
186
+ queries_ = torch.view_as_complex(
187
+ queries.float().reshape(*queries.shape[:-1], -1, 2)
188
+ )
189
+ keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
+
191
+ input_shape = (
192
+ queries_.shape
193
+ ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
+ freqs_start_pos = start_pos
195
+ freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
+
197
+ freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
+
199
+ queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
+ keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
+ return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
+
203
+
204
+ ########################################################
205
+ #
206
+ # Attention
207
+ #
208
+ ########################################################
209
+
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head Attention with Group Query Attention support.
213
+
214
+ Implements scaled dot-product attention and supports:
215
+ - Grouped Query Attention (GQA)
216
+ - Key-Value caching for efficient inference
217
+ - RoPE integration
218
+
219
+ Args:
220
+ config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
+ - config.attention_n_heads: Number of attention heads
222
+ - config.attention_n_kv_heads: Number of key/value heads
223
+ - config.d_model: Model dimension
224
+ - config.batch_size: Maximum batch size
225
+ - config.max_seq_len: Maximum sequence length
226
+
227
+ Shape:
228
+ - Input: (batch_size, seq_len, d_model)
229
+ - Output: (batch_size, seq_len, d_model)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
+ ):
236
+ super().__init__()
237
+
238
+ self.n_heads = config.attention_n_heads
239
+ self.n_kv_heads = config.attention_n_kv_heads
240
+
241
+ self.batch_size = config.batch_size
242
+ self.max_seq_len = config.max_seq_len
243
+
244
+ d_model = config.d_model
245
+ self.head_dim = d_model // self.n_heads
246
+
247
+ self.n_rep = self.n_heads // self.n_kv_heads
248
+
249
+ self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
+
254
+ self.rope = RoPE(config)
255
+
256
+ def forward(
257
+ self,
258
+ input: torch.Tensor,
259
+ mask: Optional[torch.Tensor] = None,
260
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
+ use_cache: bool = False,
262
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
+ """Forward pass for the attention mechanism.
264
+
265
+ Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
+ embeddings to the queries and keys, and then computes attention scores and outputs.
267
+
268
+ For an introduction to the attention mechanism, see:
269
+ https://arxiv.org/abs/1706.03762
270
+
271
+ A few things to note:
272
+ - The past_key_values is used to implement the KV cache, which is used to speed up
273
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
274
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
+ its own KV cache - this KV cache is implemented as a tuple.
277
+ """
278
+ bsz, seq_len, _ = input.shape
279
+ _queries, _keys, _values = (
280
+ self.q_proj(input),
281
+ self.k_proj(input),
282
+ self.v_proj(input),
283
+ )
284
+
285
+ # Reshaping for multi-head attention
286
+ queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
+ keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
+ values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
+
290
+ # The start position is used to apply the RoPE embeddings to only the new tokens
291
+ # when using the kv_cache in the attention mechanism.
292
+ # We want to start from the last position in the cache.
293
+ start_pos = 0
294
+ if past_key_values is not None and past_key_values[0] is not None:
295
+ start_pos = past_key_values[0].shape[1]
296
+
297
+ # apply rotary positional embeddings
298
+ queries, keys = self.rope(queries, keys, start_pos)
299
+
300
+ if (
301
+ past_key_values is not None
302
+ and past_key_values[0] is not None
303
+ and past_key_values[1] is not None
304
+ ):
305
+ keys = torch.cat([past_key_values[0], keys], dim=1)
306
+ values = torch.cat([past_key_values[1], values], dim=1)
307
+
308
+ if use_cache:
309
+ cached_keys = keys
310
+ cached_values = values
311
+ else:
312
+ cached_keys = None
313
+ cached_values = None
314
+
315
+ queries = queries.transpose(1, 2)
316
+ keys = keys.transpose(1, 2)
317
+ values = values.transpose(1, 2)
318
+
319
+ apply_gqa = self.n_rep > 1
320
+ if apply_gqa and queries.device.type == "mps":
321
+ # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
+ # outside of the kernel to get the same effect.
323
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
+ values = values.repeat_interleave(self.n_rep, dim=-3)
326
+ apply_gqa = False
327
+
328
+ if HAS_TORCH_ATTENTION:
329
+ backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
+ with sdpa_kernel(backends=backends):
331
+ attn_output = F.scaled_dot_product_attention(
332
+ queries.contiguous(),
333
+ keys.contiguous(),
334
+ values.contiguous(),
335
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
+ enable_gqa=apply_gqa,
337
+ )
338
+ else:
339
+ # Fallback for older PyTorch versions - use default backend
340
+ attn_output = F.scaled_dot_product_attention(
341
+ queries.contiguous(),
342
+ keys.contiguous(),
343
+ values.contiguous(),
344
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
+ enable_gqa=apply_gqa,
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
+ output = self.o_proj(attn_output)
350
+
351
+ return output, (cached_keys, cached_values)
352
+
353
+
354
+ ########################################################
355
+ #
356
+ # SwiGLU (Combines MLP and Activation)
357
+ #
358
+ ########################################################
359
+
360
+
361
+ class SwiGLU(nn.Module):
362
+ """SwiGLU Activation Function with Linear Projections.
363
+
364
+ Implements the SwiGLU activation function combined with linear transformations,
365
+ serving as the feed-forward network in transformer blocks.
366
+
367
+ Args:
368
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
+ - config.d_model: Model dimension
370
+ - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
+
372
+ References:
373
+ https://arxiv.org/abs/2002.05202
374
+ """
375
+
376
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
+ super().__init__()
378
+
379
+ model_dim = config.d_model
380
+ act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
+
382
+ self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
+ self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
+ self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
+
389
+
390
+ ########################################################
391
+ #
392
+ # PicoDecoderBlock
393
+ #
394
+ ########################################################
395
+
396
+
397
+ class PicoDecoderBlock(nn.Module):
398
+ """Single Transformer Block with Attention and Feed-forward layers.
399
+
400
+ Implements a standard transformer block with:
401
+ - Multi-head attention with normalization and residual connection
402
+ - SwiGLU feed-forward network with normalization and residual connection
403
+
404
+ Args:
405
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
+ a HuggingFace PicoDecoderHFConfig
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
+ ):
413
+ super().__init__()
414
+
415
+ self.attention = Attention(config)
416
+ self.swiglu = SwiGLU(config)
417
+ self.attention_norm = RMSNorm(config)
418
+ self.swiglu_norm = RMSNorm(config)
419
+
420
+ def forward(
421
+ self,
422
+ input: torch.Tensor,
423
+ mask: Optional[torch.Tensor] = None,
424
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ attention_output, cached_key_values = self.attention(
428
+ self.attention_norm(input),
429
+ mask=mask,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ )
433
+ # NOTE: cached_key_values is None if use_cache is False
434
+
435
+ h = input + attention_output
436
+ out = h + self.swiglu(self.swiglu_norm(h))
437
+ return out, cached_key_values
438
+
439
+
440
+ ########################################################
441
+ #
442
+ # Pico Decoder (Causal Transformer Model)
443
+ #
444
+ ########################################################
445
+
446
+
447
+ class PicoDecoder(nn.Module):
448
+ """
449
+ Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
+ single autoregressive model.
451
+
452
+ For more information on the model, see the classes for the modules that make up the model.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
+ ):
459
+ super().__init__()
460
+ self.config = model_config
461
+
462
+ self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
+ self.layers = nn.ModuleList(
464
+ [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
+ )
466
+ self.output_norm = RMSNorm(self.config)
467
+ self.de_embedding_proj = nn.Linear(
468
+ self.config.d_model, self.config.vocab_size, bias=False
469
+ )
470
+
471
+ def convert_to_hf_model(self) -> "PicoDecoderHF":
472
+ """Convert the Lightning model to a HuggingFace model."""
473
+ # Create HF config without fabric-specific settings
474
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
+
476
+ # Create new HF model
477
+ hf_model = PicoDecoderHF(hf_config)
478
+
479
+ # Copy state dict, excluding fabric-specific keys
480
+ hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
+
482
+ return hf_model
483
+
484
+ def forward(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
+ use_cache: bool = False,
489
+ ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
+ """
491
+ This is the forward pass for the entire Pico model. It boils down to:
492
+ - Embedding the input ids
493
+ - Creating a causal mask
494
+ - Processing through the pico layers
495
+ - Projecting the output to logits
496
+
497
+ NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
499
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
+ its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
+ KV caches (so a tuple of tuples).
503
+ """
504
+
505
+ seq_len = input_ids.shape[-1]
506
+ h = self.embedding_proj(input_ids)
507
+
508
+ # Calculate start position from past cached KV pairs. Remember that each layer has its
509
+ # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
+ # correct layer and then for either the keys or values.
511
+ start_pos = 0
512
+ if (
513
+ past_key_values is not None
514
+ and past_key_values[0] is not None
515
+ and past_key_values[0][0] is not None
516
+ ):
517
+ start_pos = past_key_values[0][0].shape[1]
518
+
519
+ # Create causal mask for current sequence
520
+ mask = None
521
+ if seq_len > 1:
522
+ mask = torch.full((seq_len, seq_len), float("-inf"))
523
+ mask = torch.triu(mask, diagonal=1)
524
+
525
+ # If using KV cache, extend mask to cover cached sequence length
526
+ if past_key_values is not None:
527
+ # Add zeros for cached tokens (we can attend to all of them)
528
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
+
530
+ mask = mask.to(h.device)
531
+
532
+ # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
+ # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
+ cached_key_values = () if use_cache else None
535
+
536
+ # Process through transformer blocks
537
+ for idx, layer in enumerate(self.layers):
538
+ layer_past_key_values = None
539
+ if past_key_values is not None:
540
+ try:
541
+ # Handle both tuple-based cache and HuggingFace cache objects
542
+ if hasattr(past_key_values, "__getitem__") and idx < len(
543
+ past_key_values
544
+ ):
545
+ layer_past_key_values = past_key_values[idx]
546
+ except (KeyError, IndexError, TypeError):
547
+ # If we can't access the cache properly, just skip it
548
+ layer_past_key_values = None
549
+
550
+ h, layer_cached_key_values = layer(
551
+ h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
+ )
553
+
554
+ if use_cache:
555
+ cached_key_values += (layer_cached_key_values,)
556
+
557
+ # Final norm and projection
558
+ h = self.output_norm(h)
559
+ logits = self.de_embedding_proj(h).float()
560
+
561
+ return logits, cached_key_values
562
+
563
+
564
+ ########################################################
565
+ #
566
+ # HuggingFace Wrapper for the Pico Decoder model.
567
+ #
568
+ ########################################################
569
+
570
+
571
+ class PicoDecoderHFConfig(PretrainedConfig):
572
+ """Config class for the Pico Decoder HuggingFace wrapper."""
573
+
574
+ model_type = "pico_decoder"
575
+
576
+ @classmethod
577
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
+ """
579
+ Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
+ this is because with some kwargs special handling is required and can make this class
581
+ brittle.
582
+ """
583
+ pico_config = cls(**config_dict)
584
+
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+ unused_kwargs = {
587
+ key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
+ }
589
+
590
+ if return_unused_kwargs:
591
+ return pico_config, unused_kwargs
592
+ return pico_config
593
+
594
+ @classmethod
595
+ def from_dataclass(cls, model_config: "ModelConfig"):
596
+ """Initialise from our custom config dataclass."""
597
+ return cls.from_dict(asdict(model_config))
598
+
599
+
600
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
+ """
602
+ HuggingFace wrapper for the Pico model with generation support.
603
+
604
+ Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
+ wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
+ Pico model as well as the model wrapped in this HuggingFace class.
607
+
608
+ This also lets you do cool things like:
609
+
610
+ `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
+ """
612
+
613
+ config_class = PicoDecoderHFConfig
614
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
+ main_input_name = "input_ids"
616
+
617
+ def __init__(self, config: PicoDecoderHFConfig):
618
+ super().__init__(config)
619
+ self.pico_decoder = PicoDecoder(config)
620
+ # Initialize generation config with defaults
621
+ self.generation_config = GenerationConfig()
622
+ # Set some reasonable defaults for the model
623
+ if hasattr(config, "max_position_embeddings"):
624
+ self.generation_config.max_length = config.max_position_embeddings
625
+ if hasattr(config, "vocab_size"):
626
+ self.generation_config.vocab_size = config.vocab_size
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.Tensor,
631
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
+ use_cache: bool = False,
633
+ **kwargs,
634
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
+ """HuggingFace forward pass wrapper.
636
+
637
+ Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
+ Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
+ """
640
+ logits, past_key_values = self.pico_decoder(
641
+ input_ids, past_key_values, use_cache
642
+ )
643
+ if use_cache:
644
+ return CausalLMOutputWithPast(
645
+ logits=logits,
646
+ past_key_values=past_key_values,
647
+ )
648
+ else:
649
+ return CausalLMOutput(
650
+ logits=logits,
651
+ )
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ input_ids: torch.LongTensor,
656
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
+ attention_mask: Optional[torch.LongTensor] = None,
658
+ **kwargs,
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Prepare inputs for generation.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ past_key_values: Cached key-value pairs from previous forward passes
666
+ attention_mask: Attention mask for the input
667
+ **kwargs: Additional arguments
668
+
669
+ Returns:
670
+ Dictionary containing prepared inputs
671
+ """
672
+ # If we have past_key_values, we only need the last token
673
+ if past_key_values is not None:
674
+ input_ids = input_ids[:, -1:]
675
+
676
+ return {
677
+ "input_ids": input_ids,
678
+ "past_key_values": past_key_values,
679
+ "use_cache": True,
680
+ }
681
+
682
+ def get_input_embeddings(self):
683
+ """Get the input embeddings layer."""
684
+ return self.pico_decoder.embedding_proj
685
+
686
+ def set_input_embeddings(self, value):
687
+ """Set the input embeddings layer."""
688
+ self.pico_decoder.embedding_proj = value
689
+
690
+ def get_output_embeddings(self):
691
+ """Get the output embeddings layer."""
692
+ return self.pico_decoder.de_embedding_proj
693
+
694
+ def set_output_embeddings(self, value):
695
+ """Set the output embeddings layer."""
696
+ self.pico_decoder.de_embedding_proj = value
697
+
698
+ def get_lm_head(self):
699
+ """Get the language model head."""
700
+ return self.pico_decoder.de_embedding_proj
701
+
702
+ def can_generate(self) -> bool:
703
+ """Check if the model can generate text."""
704
+ return True
705
+
706
+ @property
707
+ def is_encoder_decoder(self) -> bool:
708
+ """Check if the model is an encoder-decoder model."""
709
+ return False
710
+
711
+ @property
712
+ def can_use_cache(self) -> bool:
713
+ """Check if the model can use KV cache."""
714
+ return True
715
+
716
+ def resize_token_embeddings(
717
+ self, new_num_tokens: Optional[int] = None
718
+ ) -> torch.nn.Embedding:
719
+ """Resize token embeddings."""
720
+ old_embeddings = self.get_input_embeddings()
721
+ if new_num_tokens is None:
722
+ new_num_tokens = old_embeddings.num_embeddings
723
+
724
+ new_embeddings = torch.nn.Embedding(
725
+ new_num_tokens, old_embeddings.embedding_dim
726
+ )
727
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
+ old_embeddings.weight.data
729
+ )
730
+
731
+ self.pico_decoder.embedding_proj = new_embeddings
732
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
734
+ )
735
+
736
+ return new_embeddings
737
+
738
+
739
+ # Register for auto classes
740
+ PicoDecoderHFConfig.register_for_auto_class()
741
+ PicoDecoderHF.register_for_auto_class("AutoModel")
742
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
+
744
+
745
+ ########################################################
746
+ #
747
+ # New PicoDecoderForCausalLM class for generation support
748
+ #
749
+ ########################################################
750
+
751
+
752
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
+ """
754
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
+
756
+ This class is designed to work with existing checkpoints and provides full generation support.
757
+ It inherits from the right base classes that HuggingFace expects for text generation.
758
+ """
759
+
760
+ config_class = PicoDecoderHFConfig
761
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
+ main_input_name = "input_ids"
763
+
764
+ def __init__(self, config: PicoDecoderHFConfig):
765
+ super().__init__(config)
766
+ self.pico_decoder = PicoDecoder(config)
767
+ # Initialize generation config with defaults
768
+ self.generation_config = GenerationConfig()
769
+ # Set some reasonable defaults for the model
770
+ if hasattr(config, "max_position_embeddings"):
771
+ self.generation_config.max_length = config.max_position_embeddings
772
+ if hasattr(config, "vocab_size"):
773
+ self.generation_config.vocab_size = config.vocab_size
774
+
775
+ def forward(
776
+ self,
777
+ input_ids: torch.Tensor,
778
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
+ use_cache: bool = False,
780
+ **kwargs,
781
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
+ """Forward pass for text generation."""
783
+ logits, past_key_values = self.pico_decoder(
784
+ input_ids, past_key_values, use_cache
785
+ )
786
+ if use_cache:
787
+ return CausalLMOutputWithPast(
788
+ logits=logits,
789
+ past_key_values=past_key_values,
790
+ )
791
+ else:
792
+ return CausalLMOutput(
793
+ logits=logits,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids: torch.LongTensor,
799
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
+ attention_mask: Optional[torch.LongTensor] = None,
801
+ **kwargs,
802
+ ) -> Dict[str, Any]:
803
+ """Prepare inputs for generation."""
804
+ # If we have past_key_values, we only need the last token
805
+ if past_key_values is not None:
806
+ input_ids = input_ids[:, -1:]
807
+
808
+ return {
809
+ "input_ids": input_ids,
810
+ "past_key_values": past_key_values,
811
+ "use_cache": True,
812
+ }
813
+
814
+ def get_input_embeddings(self):
815
+ """Get the input embeddings layer."""
816
+ return self.pico_decoder.embedding_proj
817
+
818
+ def set_input_embeddings(self, value):
819
+ """Set the input embeddings layer."""
820
+ self.pico_decoder.embedding_proj = value
821
+
822
+ def get_output_embeddings(self):
823
+ """Get the output embeddings layer."""
824
+ return self.pico_decoder.de_embedding_proj
825
+
826
+ def set_output_embeddings(self, value):
827
+ """Set the output embeddings layer."""
828
+ self.pico_decoder.de_embedding_proj = value
829
+
830
+ def get_lm_head(self):
831
+ """Get the language model head."""
832
+ return self.pico_decoder.de_embedding_proj
833
+
834
+ def can_generate(self) -> bool:
835
+ """Check if the model can generate text."""
836
+ return True
837
+
838
+ @property
839
+ def is_encoder_decoder(self) -> bool:
840
+ """Check if the model is an encoder-decoder model."""
841
+ return False
842
+
843
+ @property
844
+ def can_use_cache(self) -> bool:
845
+ """Check if the model can use KV cache."""
846
+ return True
847
+
848
+ def resize_token_embeddings(
849
+ self, new_num_tokens: Optional[int] = None
850
+ ) -> torch.nn.Embedding:
851
+ """Resize token embeddings."""
852
+ old_embeddings = self.get_input_embeddings()
853
+ if new_num_tokens is None:
854
+ new_num_tokens = old_embeddings.num_embeddings
855
+
856
+ new_embeddings = torch.nn.Embedding(
857
+ new_num_tokens, old_embeddings.embedding_dim
858
+ )
859
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
+ old_embeddings.weight.data
861
+ )
862
+
863
+ self.pico_decoder.embedding_proj = new_embeddings
864
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
866
+ )
867
+
868
+ return new_embeddings
869
+
870
+ @classmethod
871
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
+ """
873
+ Load a pretrained model from a checkpoint.
874
+
875
+ This method handles loading from both the old PicoDecoderHF format and the new format.
876
+ """
877
+ # First try to load with the new class
878
+ try:
879
+ return super().from_pretrained(
880
+ pretrained_model_name_or_path, *model_args, **kwargs
881
+ )
882
+ except Exception as e:
883
+ print(f"Failed to load with new class: {e}")
884
+ print("Attempting to load with legacy class and convert...")
885
+
886
+ # Try to load with the old class and convert
887
+ try:
888
+ from transformers import AutoModel
889
+
890
+ old_model = AutoModel.from_pretrained(
891
+ pretrained_model_name_or_path,
892
+ trust_remote_code=True,
893
+ *model_args,
894
+ **kwargs,
895
+ )
896
+
897
+ # Create new model instance
898
+ new_model = cls(old_model.config)
899
+
900
+ # Copy state dict
901
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
902
+
903
+ return new_model
904
+
905
+ except Exception as e2:
906
+ print(f"Failed to convert from legacy format: {e2}")
907
+ raise e
908
+
909
+
910
+ # Register the new class
911
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_hidden_dim": 384,
3
+ "architectures": [
4
+ "PicoDecoderHF"
5
+ ],
6
+ "attention_n_heads": 12,
7
+ "attention_n_kv_heads": 4,
8
+ "auto_map": {
9
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
+ },
12
+ "batch_size": 1024,
13
+ "d_model": 96,
14
+ "max_seq_len": 2048,
15
+ "model_type": "pico_decoder",
16
+ "n_layers": 12,
17
+ "norm_eps": 1e-06,
18
+ "position_emb_theta": 10000.0,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.48.3",
21
+ "vocab_size": 50304
22
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/fabric_state/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92fa51a0afd806b08b0d199e7d2ff4555923904d7ef132046182de3335c38e8e
3
+ size 135543171
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_activations.pt ADDED
Binary file (98.3 kB). View file
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7dc168d01315589299bb5c8857c28a085e3fef703290a2b99d551bc33a6fdf0
3
+ size 277160
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/dataset_info.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "text": {
13
+ "dtype": "string",
14
+ "_type": "Value"
15
+ }
16
+ },
17
+ "homepage": "",
18
+ "license": ""
19
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "1d11f8d9010f1e26",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_gradients.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:067570bee041e27414f53e1579b8269c0122124602a4b61263453baed7b22cb9
3
+ size 2371527
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca9e9ed20e2ee9b41c6999b5300990a19c499db05b6dcf0de03c17627480f2b5
3
+ size 2371443
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:594904d9e9d616c61f90057c2e32dfd5323b1994996434891eacb57abf9193f1
3
+ size 45143592
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/pico_decoder.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Decoder: A Lightweight Causal Transformer Language Model
3
+
4
+ Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
+
6
+ Everything is written with a modular design for easy modification and experimentation.
7
+
8
+ Key features:
9
+ - RMSNorm for layer normalization
10
+ - Rotary Positional Embeddings (RoPE)
11
+ - Multi-head attention with KV-cache support
12
+ - SwiGLU activation function
13
+ - Residual connections throughout
14
+
15
+ - KV-cache for faster autoregressive generation
16
+
17
+ References:
18
+ - RoPE: https://arxiv.org/abs/2104.09864
19
+ - SwiGLU: https://arxiv.org/abs/2002.05202
20
+ - LLAMA: https://arxiv.org/abs/2302.13971
21
+
22
+ Adapted from:
23
+ - OLMO: https://github.com/allenai/OLMo
24
+ - LLAMA: https://github.com/meta/llama
25
+ """
26
+
27
+ from dataclasses import asdict
28
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ # Handle PyTorch version compatibility for attention backend
35
+ try:
36
+ from torch.nn.attention import SDPBackend, sdpa_kernel
37
+
38
+ HAS_TORCH_ATTENTION = True
39
+ except ImportError:
40
+ # Fallback for older PyTorch versions
41
+ HAS_TORCH_ATTENTION = False
42
+ SDPBackend = None
43
+ sdpa_kernel = None
44
+
45
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
+ from transformers.generation import GenerationConfig
47
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
+
49
+ try:
50
+ if TYPE_CHECKING:
51
+ # We need to do this to avoid importing these when creating the HF-compatible models
52
+ from src.config import ModelConfig
53
+ except ImportError:
54
+ pass
55
+
56
+ ########################################################
57
+ #
58
+ # Layer Normalization
59
+ #
60
+ ########################################################
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ """Root Mean Square Layer Normalization.
65
+
66
+ A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
+ resulting in improved stability and performance.
68
+
69
+ Args:
70
+ config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
+ - config.norm_eps: Small constant for numerical stability
72
+ - config.d_model: Model dimension for the weight parameter
73
+
74
+ References:
75
+ https://arxiv.org/abs/1910.07467
76
+ """
77
+
78
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
+ super().__init__()
80
+ self.eps = config.norm_eps
81
+ self.weight = nn.Parameter(torch.ones(config.d_model))
82
+
83
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Normalizes the input tensor by its RMS value.
86
+ """
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
+ """
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+ ########################################################
98
+ #
99
+ # Positional Embedding
100
+ #
101
+ ########################################################
102
+
103
+
104
+ class RoPE(nn.Module):
105
+ """Rotary Positional Embeddings (RoPE).
106
+
107
+ Implements position-dependent rotation of keys and queries in attention mechanism,
108
+ allowing better modeling of relative positions in sequences. Uses complex number
109
+ operations for efficient rotation.
110
+
111
+ Args:
112
+ config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
+ - config.position_emb_theta: Base for frequency computation
114
+ - config.d_model: Model dimension
115
+ - config.attention_n_heads: Number of attention heads
116
+ - config.max_seq_len: Maximum sequence length
117
+
118
+ References:
119
+ https://arxiv.org/abs/2104.09864
120
+ """
121
+
122
+ _freqs_cis_tensor: torch.Tensor | None = None
123
+
124
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
+ super().__init__()
126
+
127
+ self.theta = config.position_emb_theta
128
+ self.dim = config.d_model // config.attention_n_heads
129
+
130
+ max_seq_len = config.max_seq_len
131
+
132
+ # only gets set once, and then reused for all RoPE instances
133
+ if RoPE._freqs_cis_tensor is None:
134
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
+ max_seq_len, self.theta, self.dim
136
+ )
137
+
138
+ # register _freqs_cis buffer
139
+ # can be easily recomputed so persistent=False
140
+ self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
+
142
+ @classmethod
143
+ def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
+ """Setup Frequency Tensor for RoPE Embeddings
145
+
146
+ Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
+
148
+ Note other implementations will use cos and sin directly, but using the complex
149
+ number representation is (probably) more efficient:
150
+
151
+ e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
+ """
153
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
+ positions = torch.arange(seq_len)
155
+ freqs = torch.outer(positions, _freqs)
156
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
+
158
+ def get_freqs_cis(
159
+ self, input_shape: torch.Size, start_pos: int, end_pos: int
160
+ ) -> torch.Tensor:
161
+ """Reshape Frequency Tensor for RoPE Embeddings
162
+
163
+ Makes the frequency tensor broadcastable with the input tensor.
164
+ """
165
+ _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
+ ndim = len(input_shape)
167
+ assert 0 <= 1 < ndim
168
+ assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
+
170
+ # TODO: Check whether this is correct (might be able to remove this)
171
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
+ return _freqs_cis.view(*shape)
173
+
174
+ def forward(
175
+ self,
176
+ queries: torch.Tensor,
177
+ keys: torch.Tensor,
178
+ start_pos: int = 0,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Apply RoPE Embeddings to Queries and Keys
181
+
182
+ Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
+
184
+ NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
+ """
186
+ queries_ = torch.view_as_complex(
187
+ queries.float().reshape(*queries.shape[:-1], -1, 2)
188
+ )
189
+ keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
+
191
+ input_shape = (
192
+ queries_.shape
193
+ ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
+ freqs_start_pos = start_pos
195
+ freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
+
197
+ freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
+
199
+ queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
+ keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
+ return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
+
203
+
204
+ ########################################################
205
+ #
206
+ # Attention
207
+ #
208
+ ########################################################
209
+
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head Attention with Group Query Attention support.
213
+
214
+ Implements scaled dot-product attention and supports:
215
+ - Grouped Query Attention (GQA)
216
+ - Key-Value caching for efficient inference
217
+ - RoPE integration
218
+
219
+ Args:
220
+ config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
+ - config.attention_n_heads: Number of attention heads
222
+ - config.attention_n_kv_heads: Number of key/value heads
223
+ - config.d_model: Model dimension
224
+ - config.batch_size: Maximum batch size
225
+ - config.max_seq_len: Maximum sequence length
226
+
227
+ Shape:
228
+ - Input: (batch_size, seq_len, d_model)
229
+ - Output: (batch_size, seq_len, d_model)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
+ ):
236
+ super().__init__()
237
+
238
+ self.n_heads = config.attention_n_heads
239
+ self.n_kv_heads = config.attention_n_kv_heads
240
+
241
+ self.batch_size = config.batch_size
242
+ self.max_seq_len = config.max_seq_len
243
+
244
+ d_model = config.d_model
245
+ self.head_dim = d_model // self.n_heads
246
+
247
+ self.n_rep = self.n_heads // self.n_kv_heads
248
+
249
+ self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
+
254
+ self.rope = RoPE(config)
255
+
256
+ def forward(
257
+ self,
258
+ input: torch.Tensor,
259
+ mask: Optional[torch.Tensor] = None,
260
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
+ use_cache: bool = False,
262
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
+ """Forward pass for the attention mechanism.
264
+
265
+ Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
+ embeddings to the queries and keys, and then computes attention scores and outputs.
267
+
268
+ For an introduction to the attention mechanism, see:
269
+ https://arxiv.org/abs/1706.03762
270
+
271
+ A few things to note:
272
+ - The past_key_values is used to implement the KV cache, which is used to speed up
273
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
274
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
+ its own KV cache - this KV cache is implemented as a tuple.
277
+ """
278
+ bsz, seq_len, _ = input.shape
279
+ _queries, _keys, _values = (
280
+ self.q_proj(input),
281
+ self.k_proj(input),
282
+ self.v_proj(input),
283
+ )
284
+
285
+ # Reshaping for multi-head attention
286
+ queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
+ keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
+ values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
+
290
+ # The start position is used to apply the RoPE embeddings to only the new tokens
291
+ # when using the kv_cache in the attention mechanism.
292
+ # We want to start from the last position in the cache.
293
+ start_pos = 0
294
+ if past_key_values is not None and past_key_values[0] is not None:
295
+ start_pos = past_key_values[0].shape[1]
296
+
297
+ # apply rotary positional embeddings
298
+ queries, keys = self.rope(queries, keys, start_pos)
299
+
300
+ if (
301
+ past_key_values is not None
302
+ and past_key_values[0] is not None
303
+ and past_key_values[1] is not None
304
+ ):
305
+ keys = torch.cat([past_key_values[0], keys], dim=1)
306
+ values = torch.cat([past_key_values[1], values], dim=1)
307
+
308
+ if use_cache:
309
+ cached_keys = keys
310
+ cached_values = values
311
+ else:
312
+ cached_keys = None
313
+ cached_values = None
314
+
315
+ queries = queries.transpose(1, 2)
316
+ keys = keys.transpose(1, 2)
317
+ values = values.transpose(1, 2)
318
+
319
+ apply_gqa = self.n_rep > 1
320
+ if apply_gqa and queries.device.type == "mps":
321
+ # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
+ # outside of the kernel to get the same effect.
323
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
+ values = values.repeat_interleave(self.n_rep, dim=-3)
326
+ apply_gqa = False
327
+
328
+ if HAS_TORCH_ATTENTION:
329
+ backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
+ with sdpa_kernel(backends=backends):
331
+ attn_output = F.scaled_dot_product_attention(
332
+ queries.contiguous(),
333
+ keys.contiguous(),
334
+ values.contiguous(),
335
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
+ enable_gqa=apply_gqa,
337
+ )
338
+ else:
339
+ # Fallback for older PyTorch versions - use default backend
340
+ attn_output = F.scaled_dot_product_attention(
341
+ queries.contiguous(),
342
+ keys.contiguous(),
343
+ values.contiguous(),
344
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
+ enable_gqa=apply_gqa,
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
+ output = self.o_proj(attn_output)
350
+
351
+ return output, (cached_keys, cached_values)
352
+
353
+
354
+ ########################################################
355
+ #
356
+ # SwiGLU (Combines MLP and Activation)
357
+ #
358
+ ########################################################
359
+
360
+
361
+ class SwiGLU(nn.Module):
362
+ """SwiGLU Activation Function with Linear Projections.
363
+
364
+ Implements the SwiGLU activation function combined with linear transformations,
365
+ serving as the feed-forward network in transformer blocks.
366
+
367
+ Args:
368
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
+ - config.d_model: Model dimension
370
+ - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
+
372
+ References:
373
+ https://arxiv.org/abs/2002.05202
374
+ """
375
+
376
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
+ super().__init__()
378
+
379
+ model_dim = config.d_model
380
+ act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
+
382
+ self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
+ self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
+ self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
+
389
+
390
+ ########################################################
391
+ #
392
+ # PicoDecoderBlock
393
+ #
394
+ ########################################################
395
+
396
+
397
+ class PicoDecoderBlock(nn.Module):
398
+ """Single Transformer Block with Attention and Feed-forward layers.
399
+
400
+ Implements a standard transformer block with:
401
+ - Multi-head attention with normalization and residual connection
402
+ - SwiGLU feed-forward network with normalization and residual connection
403
+
404
+ Args:
405
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
+ a HuggingFace PicoDecoderHFConfig
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
+ ):
413
+ super().__init__()
414
+
415
+ self.attention = Attention(config)
416
+ self.swiglu = SwiGLU(config)
417
+ self.attention_norm = RMSNorm(config)
418
+ self.swiglu_norm = RMSNorm(config)
419
+
420
+ def forward(
421
+ self,
422
+ input: torch.Tensor,
423
+ mask: Optional[torch.Tensor] = None,
424
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ attention_output, cached_key_values = self.attention(
428
+ self.attention_norm(input),
429
+ mask=mask,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ )
433
+ # NOTE: cached_key_values is None if use_cache is False
434
+
435
+ h = input + attention_output
436
+ out = h + self.swiglu(self.swiglu_norm(h))
437
+ return out, cached_key_values
438
+
439
+
440
+ ########################################################
441
+ #
442
+ # Pico Decoder (Causal Transformer Model)
443
+ #
444
+ ########################################################
445
+
446
+
447
+ class PicoDecoder(nn.Module):
448
+ """
449
+ Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
+ single autoregressive model.
451
+
452
+ For more information on the model, see the classes for the modules that make up the model.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
+ ):
459
+ super().__init__()
460
+ self.config = model_config
461
+
462
+ self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
+ self.layers = nn.ModuleList(
464
+ [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
+ )
466
+ self.output_norm = RMSNorm(self.config)
467
+ self.de_embedding_proj = nn.Linear(
468
+ self.config.d_model, self.config.vocab_size, bias=False
469
+ )
470
+
471
+ def convert_to_hf_model(self) -> "PicoDecoderHF":
472
+ """Convert the Lightning model to a HuggingFace model."""
473
+ # Create HF config without fabric-specific settings
474
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
+
476
+ # Create new HF model
477
+ hf_model = PicoDecoderHF(hf_config)
478
+
479
+ # Copy state dict, excluding fabric-specific keys
480
+ hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
+
482
+ return hf_model
483
+
484
+ def forward(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
+ use_cache: bool = False,
489
+ ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
+ """
491
+ This is the forward pass for the entire Pico model. It boils down to:
492
+ - Embedding the input ids
493
+ - Creating a causal mask
494
+ - Processing through the pico layers
495
+ - Projecting the output to logits
496
+
497
+ NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
499
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
+ its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
+ KV caches (so a tuple of tuples).
503
+ """
504
+
505
+ seq_len = input_ids.shape[-1]
506
+ h = self.embedding_proj(input_ids)
507
+
508
+ # Calculate start position from past cached KV pairs. Remember that each layer has its
509
+ # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
+ # correct layer and then for either the keys or values.
511
+ start_pos = 0
512
+ if (
513
+ past_key_values is not None
514
+ and past_key_values[0] is not None
515
+ and past_key_values[0][0] is not None
516
+ ):
517
+ start_pos = past_key_values[0][0].shape[1]
518
+
519
+ # Create causal mask for current sequence
520
+ mask = None
521
+ if seq_len > 1:
522
+ mask = torch.full((seq_len, seq_len), float("-inf"))
523
+ mask = torch.triu(mask, diagonal=1)
524
+
525
+ # If using KV cache, extend mask to cover cached sequence length
526
+ if past_key_values is not None:
527
+ # Add zeros for cached tokens (we can attend to all of them)
528
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
+
530
+ mask = mask.to(h.device)
531
+
532
+ # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
+ # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
+ cached_key_values = () if use_cache else None
535
+
536
+ # Process through transformer blocks
537
+ for idx, layer in enumerate(self.layers):
538
+ layer_past_key_values = None
539
+ if past_key_values is not None:
540
+ try:
541
+ # Handle both tuple-based cache and HuggingFace cache objects
542
+ if hasattr(past_key_values, "__getitem__") and idx < len(
543
+ past_key_values
544
+ ):
545
+ layer_past_key_values = past_key_values[idx]
546
+ except (KeyError, IndexError, TypeError):
547
+ # If we can't access the cache properly, just skip it
548
+ layer_past_key_values = None
549
+
550
+ h, layer_cached_key_values = layer(
551
+ h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
+ )
553
+
554
+ if use_cache:
555
+ cached_key_values += (layer_cached_key_values,)
556
+
557
+ # Final norm and projection
558
+ h = self.output_norm(h)
559
+ logits = self.de_embedding_proj(h).float()
560
+
561
+ return logits, cached_key_values
562
+
563
+
564
+ ########################################################
565
+ #
566
+ # HuggingFace Wrapper for the Pico Decoder model.
567
+ #
568
+ ########################################################
569
+
570
+
571
+ class PicoDecoderHFConfig(PretrainedConfig):
572
+ """Config class for the Pico Decoder HuggingFace wrapper."""
573
+
574
+ model_type = "pico_decoder"
575
+
576
+ @classmethod
577
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
+ """
579
+ Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
+ this is because with some kwargs special handling is required and can make this class
581
+ brittle.
582
+ """
583
+ pico_config = cls(**config_dict)
584
+
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+ unused_kwargs = {
587
+ key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
+ }
589
+
590
+ if return_unused_kwargs:
591
+ return pico_config, unused_kwargs
592
+ return pico_config
593
+
594
+ @classmethod
595
+ def from_dataclass(cls, model_config: "ModelConfig"):
596
+ """Initialise from our custom config dataclass."""
597
+ return cls.from_dict(asdict(model_config))
598
+
599
+
600
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
+ """
602
+ HuggingFace wrapper for the Pico model with generation support.
603
+
604
+ Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
+ wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
+ Pico model as well as the model wrapped in this HuggingFace class.
607
+
608
+ This also lets you do cool things like:
609
+
610
+ `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
+ """
612
+
613
+ config_class = PicoDecoderHFConfig
614
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
+ main_input_name = "input_ids"
616
+
617
+ def __init__(self, config: PicoDecoderHFConfig):
618
+ super().__init__(config)
619
+ self.pico_decoder = PicoDecoder(config)
620
+ # Initialize generation config with defaults
621
+ self.generation_config = GenerationConfig()
622
+ # Set some reasonable defaults for the model
623
+ if hasattr(config, "max_position_embeddings"):
624
+ self.generation_config.max_length = config.max_position_embeddings
625
+ if hasattr(config, "vocab_size"):
626
+ self.generation_config.vocab_size = config.vocab_size
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.Tensor,
631
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
+ use_cache: bool = False,
633
+ **kwargs,
634
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
+ """HuggingFace forward pass wrapper.
636
+
637
+ Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
+ Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
+ """
640
+ logits, past_key_values = self.pico_decoder(
641
+ input_ids, past_key_values, use_cache
642
+ )
643
+ if use_cache:
644
+ return CausalLMOutputWithPast(
645
+ logits=logits,
646
+ past_key_values=past_key_values,
647
+ )
648
+ else:
649
+ return CausalLMOutput(
650
+ logits=logits,
651
+ )
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ input_ids: torch.LongTensor,
656
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
+ attention_mask: Optional[torch.LongTensor] = None,
658
+ **kwargs,
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Prepare inputs for generation.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ past_key_values: Cached key-value pairs from previous forward passes
666
+ attention_mask: Attention mask for the input
667
+ **kwargs: Additional arguments
668
+
669
+ Returns:
670
+ Dictionary containing prepared inputs
671
+ """
672
+ # If we have past_key_values, we only need the last token
673
+ if past_key_values is not None:
674
+ input_ids = input_ids[:, -1:]
675
+
676
+ return {
677
+ "input_ids": input_ids,
678
+ "past_key_values": past_key_values,
679
+ "use_cache": True,
680
+ }
681
+
682
+ def get_input_embeddings(self):
683
+ """Get the input embeddings layer."""
684
+ return self.pico_decoder.embedding_proj
685
+
686
+ def set_input_embeddings(self, value):
687
+ """Set the input embeddings layer."""
688
+ self.pico_decoder.embedding_proj = value
689
+
690
+ def get_output_embeddings(self):
691
+ """Get the output embeddings layer."""
692
+ return self.pico_decoder.de_embedding_proj
693
+
694
+ def set_output_embeddings(self, value):
695
+ """Set the output embeddings layer."""
696
+ self.pico_decoder.de_embedding_proj = value
697
+
698
+ def get_lm_head(self):
699
+ """Get the language model head."""
700
+ return self.pico_decoder.de_embedding_proj
701
+
702
+ def can_generate(self) -> bool:
703
+ """Check if the model can generate text."""
704
+ return True
705
+
706
+ @property
707
+ def is_encoder_decoder(self) -> bool:
708
+ """Check if the model is an encoder-decoder model."""
709
+ return False
710
+
711
+ @property
712
+ def can_use_cache(self) -> bool:
713
+ """Check if the model can use KV cache."""
714
+ return True
715
+
716
+ def resize_token_embeddings(
717
+ self, new_num_tokens: Optional[int] = None
718
+ ) -> torch.nn.Embedding:
719
+ """Resize token embeddings."""
720
+ old_embeddings = self.get_input_embeddings()
721
+ if new_num_tokens is None:
722
+ new_num_tokens = old_embeddings.num_embeddings
723
+
724
+ new_embeddings = torch.nn.Embedding(
725
+ new_num_tokens, old_embeddings.embedding_dim
726
+ )
727
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
+ old_embeddings.weight.data
729
+ )
730
+
731
+ self.pico_decoder.embedding_proj = new_embeddings
732
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
734
+ )
735
+
736
+ return new_embeddings
737
+
738
+
739
+ # Register for auto classes
740
+ PicoDecoderHFConfig.register_for_auto_class()
741
+ PicoDecoderHF.register_for_auto_class("AutoModel")
742
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
+
744
+
745
+ ########################################################
746
+ #
747
+ # New PicoDecoderForCausalLM class for generation support
748
+ #
749
+ ########################################################
750
+
751
+
752
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
+ """
754
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
+
756
+ This class is designed to work with existing checkpoints and provides full generation support.
757
+ It inherits from the right base classes that HuggingFace expects for text generation.
758
+ """
759
+
760
+ config_class = PicoDecoderHFConfig
761
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
+ main_input_name = "input_ids"
763
+
764
+ def __init__(self, config: PicoDecoderHFConfig):
765
+ super().__init__(config)
766
+ self.pico_decoder = PicoDecoder(config)
767
+ # Initialize generation config with defaults
768
+ self.generation_config = GenerationConfig()
769
+ # Set some reasonable defaults for the model
770
+ if hasattr(config, "max_position_embeddings"):
771
+ self.generation_config.max_length = config.max_position_embeddings
772
+ if hasattr(config, "vocab_size"):
773
+ self.generation_config.vocab_size = config.vocab_size
774
+
775
+ def forward(
776
+ self,
777
+ input_ids: torch.Tensor,
778
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
+ use_cache: bool = False,
780
+ **kwargs,
781
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
+ """Forward pass for text generation."""
783
+ logits, past_key_values = self.pico_decoder(
784
+ input_ids, past_key_values, use_cache
785
+ )
786
+ if use_cache:
787
+ return CausalLMOutputWithPast(
788
+ logits=logits,
789
+ past_key_values=past_key_values,
790
+ )
791
+ else:
792
+ return CausalLMOutput(
793
+ logits=logits,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids: torch.LongTensor,
799
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
+ attention_mask: Optional[torch.LongTensor] = None,
801
+ **kwargs,
802
+ ) -> Dict[str, Any]:
803
+ """Prepare inputs for generation."""
804
+ # If we have past_key_values, we only need the last token
805
+ if past_key_values is not None:
806
+ input_ids = input_ids[:, -1:]
807
+
808
+ return {
809
+ "input_ids": input_ids,
810
+ "past_key_values": past_key_values,
811
+ "use_cache": True,
812
+ }
813
+
814
+ def get_input_embeddings(self):
815
+ """Get the input embeddings layer."""
816
+ return self.pico_decoder.embedding_proj
817
+
818
+ def set_input_embeddings(self, value):
819
+ """Set the input embeddings layer."""
820
+ self.pico_decoder.embedding_proj = value
821
+
822
+ def get_output_embeddings(self):
823
+ """Get the output embeddings layer."""
824
+ return self.pico_decoder.de_embedding_proj
825
+
826
+ def set_output_embeddings(self, value):
827
+ """Set the output embeddings layer."""
828
+ self.pico_decoder.de_embedding_proj = value
829
+
830
+ def get_lm_head(self):
831
+ """Get the language model head."""
832
+ return self.pico_decoder.de_embedding_proj
833
+
834
+ def can_generate(self) -> bool:
835
+ """Check if the model can generate text."""
836
+ return True
837
+
838
+ @property
839
+ def is_encoder_decoder(self) -> bool:
840
+ """Check if the model is an encoder-decoder model."""
841
+ return False
842
+
843
+ @property
844
+ def can_use_cache(self) -> bool:
845
+ """Check if the model can use KV cache."""
846
+ return True
847
+
848
+ def resize_token_embeddings(
849
+ self, new_num_tokens: Optional[int] = None
850
+ ) -> torch.nn.Embedding:
851
+ """Resize token embeddings."""
852
+ old_embeddings = self.get_input_embeddings()
853
+ if new_num_tokens is None:
854
+ new_num_tokens = old_embeddings.num_embeddings
855
+
856
+ new_embeddings = torch.nn.Embedding(
857
+ new_num_tokens, old_embeddings.embedding_dim
858
+ )
859
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
+ old_embeddings.weight.data
861
+ )
862
+
863
+ self.pico_decoder.embedding_proj = new_embeddings
864
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
866
+ )
867
+
868
+ return new_embeddings
869
+
870
+ @classmethod
871
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
+ """
873
+ Load a pretrained model from a checkpoint.
874
+
875
+ This method handles loading from both the old PicoDecoderHF format and the new format.
876
+ """
877
+ # First try to load with the new class
878
+ try:
879
+ return super().from_pretrained(
880
+ pretrained_model_name_or_path, *model_args, **kwargs
881
+ )
882
+ except Exception as e:
883
+ print(f"Failed to load with new class: {e}")
884
+ print("Attempting to load with legacy class and convert...")
885
+
886
+ # Try to load with the old class and convert
887
+ try:
888
+ from transformers import AutoModel
889
+
890
+ old_model = AutoModel.from_pretrained(
891
+ pretrained_model_name_or_path,
892
+ trust_remote_code=True,
893
+ *model_args,
894
+ **kwargs,
895
+ )
896
+
897
+ # Create new model instance
898
+ new_model = cls(old_model.config)
899
+
900
+ # Copy state dict
901
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
902
+
903
+ return new_model
904
+
905
+ except Exception as e2:
906
+ print(f"Failed to convert from legacy format: {e2}")
907
+ raise e
908
+
909
+
910
+ # Register the new class
911
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_hidden_dim": 384,
3
+ "architectures": [
4
+ "PicoDecoderHF"
5
+ ],
6
+ "attention_n_heads": 12,
7
+ "attention_n_kv_heads": 4,
8
+ "auto_map": {
9
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
+ },
12
+ "batch_size": 1024,
13
+ "d_model": 96,
14
+ "max_seq_len": 2048,
15
+ "model_type": "pico_decoder",
16
+ "n_layers": 12,
17
+ "norm_eps": 1e-06,
18
+ "position_emb_theta": 10000.0,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.48.3",
21
+ "vocab_size": 50304
22
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/fabric_state/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f63cfb2d211ae93a79dd41503954f66c62499535e964b2024491642521ef8c55
3
+ size 135543171
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_activations.pt ADDED
Binary file (98.3 kB). View file
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9565d0b7f27a7a6e121ff10214d8530cf1999016e90a567eac83d2a26bdeb3e
3
+ size 274672
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/dataset_info.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "text": {
13
+ "dtype": "string",
14
+ "_type": "Value"
15
+ }
16
+ },
17
+ "homepage": "",
18
+ "license": ""
19
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "80201725dca773a1",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_gradients.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e58cd384c1152b182faa36e56cee0ca5f28b1ddf786b8e1b68d90bba8539e9f
3
+ size 2371527
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4de550e5abc314db55e3074d9218d203a86b523c616ed13157d48954a7fac76
3
+ size 2371443
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d849318b797c81059d055fe1e1bf7dd20699b24fa9038c056724c49de447915a
3
+ size 45143592
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/pico_decoder.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Decoder: A Lightweight Causal Transformer Language Model
3
+
4
+ Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
+
6
+ Everything is written with a modular design for easy modification and experimentation.
7
+
8
+ Key features:
9
+ - RMSNorm for layer normalization
10
+ - Rotary Positional Embeddings (RoPE)
11
+ - Multi-head attention with KV-cache support
12
+ - SwiGLU activation function
13
+ - Residual connections throughout
14
+
15
+ - KV-cache for faster autoregressive generation
16
+
17
+ References:
18
+ - RoPE: https://arxiv.org/abs/2104.09864
19
+ - SwiGLU: https://arxiv.org/abs/2002.05202
20
+ - LLAMA: https://arxiv.org/abs/2302.13971
21
+
22
+ Adapted from:
23
+ - OLMO: https://github.com/allenai/OLMo
24
+ - LLAMA: https://github.com/meta/llama
25
+ """
26
+
27
+ from dataclasses import asdict
28
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ # Handle PyTorch version compatibility for attention backend
35
+ try:
36
+ from torch.nn.attention import SDPBackend, sdpa_kernel
37
+
38
+ HAS_TORCH_ATTENTION = True
39
+ except ImportError:
40
+ # Fallback for older PyTorch versions
41
+ HAS_TORCH_ATTENTION = False
42
+ SDPBackend = None
43
+ sdpa_kernel = None
44
+
45
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
+ from transformers.generation import GenerationConfig
47
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
+
49
+ try:
50
+ if TYPE_CHECKING:
51
+ # We need to do this to avoid importing these when creating the HF-compatible models
52
+ from src.config import ModelConfig
53
+ except ImportError:
54
+ pass
55
+
56
+ ########################################################
57
+ #
58
+ # Layer Normalization
59
+ #
60
+ ########################################################
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ """Root Mean Square Layer Normalization.
65
+
66
+ A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
+ resulting in improved stability and performance.
68
+
69
+ Args:
70
+ config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
+ - config.norm_eps: Small constant for numerical stability
72
+ - config.d_model: Model dimension for the weight parameter
73
+
74
+ References:
75
+ https://arxiv.org/abs/1910.07467
76
+ """
77
+
78
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
+ super().__init__()
80
+ self.eps = config.norm_eps
81
+ self.weight = nn.Parameter(torch.ones(config.d_model))
82
+
83
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Normalizes the input tensor by its RMS value.
86
+ """
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
+ """
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+ ########################################################
98
+ #
99
+ # Positional Embedding
100
+ #
101
+ ########################################################
102
+
103
+
104
+ class RoPE(nn.Module):
105
+ """Rotary Positional Embeddings (RoPE).
106
+
107
+ Implements position-dependent rotation of keys and queries in attention mechanism,
108
+ allowing better modeling of relative positions in sequences. Uses complex number
109
+ operations for efficient rotation.
110
+
111
+ Args:
112
+ config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
+ - config.position_emb_theta: Base for frequency computation
114
+ - config.d_model: Model dimension
115
+ - config.attention_n_heads: Number of attention heads
116
+ - config.max_seq_len: Maximum sequence length
117
+
118
+ References:
119
+ https://arxiv.org/abs/2104.09864
120
+ """
121
+
122
+ _freqs_cis_tensor: torch.Tensor | None = None
123
+
124
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
+ super().__init__()
126
+
127
+ self.theta = config.position_emb_theta
128
+ self.dim = config.d_model // config.attention_n_heads
129
+
130
+ max_seq_len = config.max_seq_len
131
+
132
+ # only gets set once, and then reused for all RoPE instances
133
+ if RoPE._freqs_cis_tensor is None:
134
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
+ max_seq_len, self.theta, self.dim
136
+ )
137
+
138
+ # register _freqs_cis buffer
139
+ # can be easily recomputed so persistent=False
140
+ self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
+
142
+ @classmethod
143
+ def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
+ """Setup Frequency Tensor for RoPE Embeddings
145
+
146
+ Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
+
148
+ Note other implementations will use cos and sin directly, but using the complex
149
+ number representation is (probably) more efficient:
150
+
151
+ e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
+ """
153
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
+ positions = torch.arange(seq_len)
155
+ freqs = torch.outer(positions, _freqs)
156
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
+
158
+ def get_freqs_cis(
159
+ self, input_shape: torch.Size, start_pos: int, end_pos: int
160
+ ) -> torch.Tensor:
161
+ """Reshape Frequency Tensor for RoPE Embeddings
162
+
163
+ Makes the frequency tensor broadcastable with the input tensor.
164
+ """
165
+ _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
+ ndim = len(input_shape)
167
+ assert 0 <= 1 < ndim
168
+ assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
+
170
+ # TODO: Check whether this is correct (might be able to remove this)
171
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
+ return _freqs_cis.view(*shape)
173
+
174
+ def forward(
175
+ self,
176
+ queries: torch.Tensor,
177
+ keys: torch.Tensor,
178
+ start_pos: int = 0,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Apply RoPE Embeddings to Queries and Keys
181
+
182
+ Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
+
184
+ NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
+ """
186
+ queries_ = torch.view_as_complex(
187
+ queries.float().reshape(*queries.shape[:-1], -1, 2)
188
+ )
189
+ keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
+
191
+ input_shape = (
192
+ queries_.shape
193
+ ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
+ freqs_start_pos = start_pos
195
+ freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
+
197
+ freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
+
199
+ queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
+ keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
+ return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
+
203
+
204
+ ########################################################
205
+ #
206
+ # Attention
207
+ #
208
+ ########################################################
209
+
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head Attention with Group Query Attention support.
213
+
214
+ Implements scaled dot-product attention and supports:
215
+ - Grouped Query Attention (GQA)
216
+ - Key-Value caching for efficient inference
217
+ - RoPE integration
218
+
219
+ Args:
220
+ config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
+ - config.attention_n_heads: Number of attention heads
222
+ - config.attention_n_kv_heads: Number of key/value heads
223
+ - config.d_model: Model dimension
224
+ - config.batch_size: Maximum batch size
225
+ - config.max_seq_len: Maximum sequence length
226
+
227
+ Shape:
228
+ - Input: (batch_size, seq_len, d_model)
229
+ - Output: (batch_size, seq_len, d_model)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
+ ):
236
+ super().__init__()
237
+
238
+ self.n_heads = config.attention_n_heads
239
+ self.n_kv_heads = config.attention_n_kv_heads
240
+
241
+ self.batch_size = config.batch_size
242
+ self.max_seq_len = config.max_seq_len
243
+
244
+ d_model = config.d_model
245
+ self.head_dim = d_model // self.n_heads
246
+
247
+ self.n_rep = self.n_heads // self.n_kv_heads
248
+
249
+ self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
+
254
+ self.rope = RoPE(config)
255
+
256
+ def forward(
257
+ self,
258
+ input: torch.Tensor,
259
+ mask: Optional[torch.Tensor] = None,
260
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
+ use_cache: bool = False,
262
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
+ """Forward pass for the attention mechanism.
264
+
265
+ Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
+ embeddings to the queries and keys, and then computes attention scores and outputs.
267
+
268
+ For an introduction to the attention mechanism, see:
269
+ https://arxiv.org/abs/1706.03762
270
+
271
+ A few things to note:
272
+ - The past_key_values is used to implement the KV cache, which is used to speed up
273
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
274
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
+ its own KV cache - this KV cache is implemented as a tuple.
277
+ """
278
+ bsz, seq_len, _ = input.shape
279
+ _queries, _keys, _values = (
280
+ self.q_proj(input),
281
+ self.k_proj(input),
282
+ self.v_proj(input),
283
+ )
284
+
285
+ # Reshaping for multi-head attention
286
+ queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
+ keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
+ values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
+
290
+ # The start position is used to apply the RoPE embeddings to only the new tokens
291
+ # when using the kv_cache in the attention mechanism.
292
+ # We want to start from the last position in the cache.
293
+ start_pos = 0
294
+ if past_key_values is not None and past_key_values[0] is not None:
295
+ start_pos = past_key_values[0].shape[1]
296
+
297
+ # apply rotary positional embeddings
298
+ queries, keys = self.rope(queries, keys, start_pos)
299
+
300
+ if (
301
+ past_key_values is not None
302
+ and past_key_values[0] is not None
303
+ and past_key_values[1] is not None
304
+ ):
305
+ keys = torch.cat([past_key_values[0], keys], dim=1)
306
+ values = torch.cat([past_key_values[1], values], dim=1)
307
+
308
+ if use_cache:
309
+ cached_keys = keys
310
+ cached_values = values
311
+ else:
312
+ cached_keys = None
313
+ cached_values = None
314
+
315
+ queries = queries.transpose(1, 2)
316
+ keys = keys.transpose(1, 2)
317
+ values = values.transpose(1, 2)
318
+
319
+ apply_gqa = self.n_rep > 1
320
+ if apply_gqa and queries.device.type == "mps":
321
+ # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
+ # outside of the kernel to get the same effect.
323
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
+ values = values.repeat_interleave(self.n_rep, dim=-3)
326
+ apply_gqa = False
327
+
328
+ if HAS_TORCH_ATTENTION:
329
+ backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
+ with sdpa_kernel(backends=backends):
331
+ attn_output = F.scaled_dot_product_attention(
332
+ queries.contiguous(),
333
+ keys.contiguous(),
334
+ values.contiguous(),
335
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
+ enable_gqa=apply_gqa,
337
+ )
338
+ else:
339
+ # Fallback for older PyTorch versions - use default backend
340
+ attn_output = F.scaled_dot_product_attention(
341
+ queries.contiguous(),
342
+ keys.contiguous(),
343
+ values.contiguous(),
344
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
+ enable_gqa=apply_gqa,
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
+ output = self.o_proj(attn_output)
350
+
351
+ return output, (cached_keys, cached_values)
352
+
353
+
354
+ ########################################################
355
+ #
356
+ # SwiGLU (Combines MLP and Activation)
357
+ #
358
+ ########################################################
359
+
360
+
361
+ class SwiGLU(nn.Module):
362
+ """SwiGLU Activation Function with Linear Projections.
363
+
364
+ Implements the SwiGLU activation function combined with linear transformations,
365
+ serving as the feed-forward network in transformer blocks.
366
+
367
+ Args:
368
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
+ - config.d_model: Model dimension
370
+ - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
+
372
+ References:
373
+ https://arxiv.org/abs/2002.05202
374
+ """
375
+
376
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
+ super().__init__()
378
+
379
+ model_dim = config.d_model
380
+ act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
+
382
+ self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
+ self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
+ self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
+
389
+
390
+ ########################################################
391
+ #
392
+ # PicoDecoderBlock
393
+ #
394
+ ########################################################
395
+
396
+
397
+ class PicoDecoderBlock(nn.Module):
398
+ """Single Transformer Block with Attention and Feed-forward layers.
399
+
400
+ Implements a standard transformer block with:
401
+ - Multi-head attention with normalization and residual connection
402
+ - SwiGLU feed-forward network with normalization and residual connection
403
+
404
+ Args:
405
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
+ a HuggingFace PicoDecoderHFConfig
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
+ ):
413
+ super().__init__()
414
+
415
+ self.attention = Attention(config)
416
+ self.swiglu = SwiGLU(config)
417
+ self.attention_norm = RMSNorm(config)
418
+ self.swiglu_norm = RMSNorm(config)
419
+
420
+ def forward(
421
+ self,
422
+ input: torch.Tensor,
423
+ mask: Optional[torch.Tensor] = None,
424
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ attention_output, cached_key_values = self.attention(
428
+ self.attention_norm(input),
429
+ mask=mask,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ )
433
+ # NOTE: cached_key_values is None if use_cache is False
434
+
435
+ h = input + attention_output
436
+ out = h + self.swiglu(self.swiglu_norm(h))
437
+ return out, cached_key_values
438
+
439
+
440
+ ########################################################
441
+ #
442
+ # Pico Decoder (Causal Transformer Model)
443
+ #
444
+ ########################################################
445
+
446
+
447
+ class PicoDecoder(nn.Module):
448
+ """
449
+ Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
+ single autoregressive model.
451
+
452
+ For more information on the model, see the classes for the modules that make up the model.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
+ ):
459
+ super().__init__()
460
+ self.config = model_config
461
+
462
+ self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
+ self.layers = nn.ModuleList(
464
+ [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
+ )
466
+ self.output_norm = RMSNorm(self.config)
467
+ self.de_embedding_proj = nn.Linear(
468
+ self.config.d_model, self.config.vocab_size, bias=False
469
+ )
470
+
471
+ def convert_to_hf_model(self) -> "PicoDecoderHF":
472
+ """Convert the Lightning model to a HuggingFace model."""
473
+ # Create HF config without fabric-specific settings
474
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
+
476
+ # Create new HF model
477
+ hf_model = PicoDecoderHF(hf_config)
478
+
479
+ # Copy state dict, excluding fabric-specific keys
480
+ hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
+
482
+ return hf_model
483
+
484
+ def forward(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
+ use_cache: bool = False,
489
+ ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
+ """
491
+ This is the forward pass for the entire Pico model. It boils down to:
492
+ - Embedding the input ids
493
+ - Creating a causal mask
494
+ - Processing through the pico layers
495
+ - Projecting the output to logits
496
+
497
+ NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
499
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
+ its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
+ KV caches (so a tuple of tuples).
503
+ """
504
+
505
+ seq_len = input_ids.shape[-1]
506
+ h = self.embedding_proj(input_ids)
507
+
508
+ # Calculate start position from past cached KV pairs. Remember that each layer has its
509
+ # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
+ # correct layer and then for either the keys or values.
511
+ start_pos = 0
512
+ if (
513
+ past_key_values is not None
514
+ and past_key_values[0] is not None
515
+ and past_key_values[0][0] is not None
516
+ ):
517
+ start_pos = past_key_values[0][0].shape[1]
518
+
519
+ # Create causal mask for current sequence
520
+ mask = None
521
+ if seq_len > 1:
522
+ mask = torch.full((seq_len, seq_len), float("-inf"))
523
+ mask = torch.triu(mask, diagonal=1)
524
+
525
+ # If using KV cache, extend mask to cover cached sequence length
526
+ if past_key_values is not None:
527
+ # Add zeros for cached tokens (we can attend to all of them)
528
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
+
530
+ mask = mask.to(h.device)
531
+
532
+ # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
+ # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
+ cached_key_values = () if use_cache else None
535
+
536
+ # Process through transformer blocks
537
+ for idx, layer in enumerate(self.layers):
538
+ layer_past_key_values = None
539
+ if past_key_values is not None:
540
+ try:
541
+ # Handle both tuple-based cache and HuggingFace cache objects
542
+ if hasattr(past_key_values, "__getitem__") and idx < len(
543
+ past_key_values
544
+ ):
545
+ layer_past_key_values = past_key_values[idx]
546
+ except (KeyError, IndexError, TypeError):
547
+ # If we can't access the cache properly, just skip it
548
+ layer_past_key_values = None
549
+
550
+ h, layer_cached_key_values = layer(
551
+ h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
+ )
553
+
554
+ if use_cache:
555
+ cached_key_values += (layer_cached_key_values,)
556
+
557
+ # Final norm and projection
558
+ h = self.output_norm(h)
559
+ logits = self.de_embedding_proj(h).float()
560
+
561
+ return logits, cached_key_values
562
+
563
+
564
+ ########################################################
565
+ #
566
+ # HuggingFace Wrapper for the Pico Decoder model.
567
+ #
568
+ ########################################################
569
+
570
+
571
+ class PicoDecoderHFConfig(PretrainedConfig):
572
+ """Config class for the Pico Decoder HuggingFace wrapper."""
573
+
574
+ model_type = "pico_decoder"
575
+
576
+ @classmethod
577
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
+ """
579
+ Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
+ this is because with some kwargs special handling is required and can make this class
581
+ brittle.
582
+ """
583
+ pico_config = cls(**config_dict)
584
+
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+ unused_kwargs = {
587
+ key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
+ }
589
+
590
+ if return_unused_kwargs:
591
+ return pico_config, unused_kwargs
592
+ return pico_config
593
+
594
+ @classmethod
595
+ def from_dataclass(cls, model_config: "ModelConfig"):
596
+ """Initialise from our custom config dataclass."""
597
+ return cls.from_dict(asdict(model_config))
598
+
599
+
600
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
+ """
602
+ HuggingFace wrapper for the Pico model with generation support.
603
+
604
+ Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
+ wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
+ Pico model as well as the model wrapped in this HuggingFace class.
607
+
608
+ This also lets you do cool things like:
609
+
610
+ `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
+ """
612
+
613
+ config_class = PicoDecoderHFConfig
614
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
+ main_input_name = "input_ids"
616
+
617
+ def __init__(self, config: PicoDecoderHFConfig):
618
+ super().__init__(config)
619
+ self.pico_decoder = PicoDecoder(config)
620
+ # Initialize generation config with defaults
621
+ self.generation_config = GenerationConfig()
622
+ # Set some reasonable defaults for the model
623
+ if hasattr(config, "max_position_embeddings"):
624
+ self.generation_config.max_length = config.max_position_embeddings
625
+ if hasattr(config, "vocab_size"):
626
+ self.generation_config.vocab_size = config.vocab_size
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.Tensor,
631
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
+ use_cache: bool = False,
633
+ **kwargs,
634
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
+ """HuggingFace forward pass wrapper.
636
+
637
+ Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
+ Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
+ """
640
+ logits, past_key_values = self.pico_decoder(
641
+ input_ids, past_key_values, use_cache
642
+ )
643
+ if use_cache:
644
+ return CausalLMOutputWithPast(
645
+ logits=logits,
646
+ past_key_values=past_key_values,
647
+ )
648
+ else:
649
+ return CausalLMOutput(
650
+ logits=logits,
651
+ )
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ input_ids: torch.LongTensor,
656
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
+ attention_mask: Optional[torch.LongTensor] = None,
658
+ **kwargs,
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Prepare inputs for generation.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ past_key_values: Cached key-value pairs from previous forward passes
666
+ attention_mask: Attention mask for the input
667
+ **kwargs: Additional arguments
668
+
669
+ Returns:
670
+ Dictionary containing prepared inputs
671
+ """
672
+ # If we have past_key_values, we only need the last token
673
+ if past_key_values is not None:
674
+ input_ids = input_ids[:, -1:]
675
+
676
+ return {
677
+ "input_ids": input_ids,
678
+ "past_key_values": past_key_values,
679
+ "use_cache": True,
680
+ }
681
+
682
+ def get_input_embeddings(self):
683
+ """Get the input embeddings layer."""
684
+ return self.pico_decoder.embedding_proj
685
+
686
+ def set_input_embeddings(self, value):
687
+ """Set the input embeddings layer."""
688
+ self.pico_decoder.embedding_proj = value
689
+
690
+ def get_output_embeddings(self):
691
+ """Get the output embeddings layer."""
692
+ return self.pico_decoder.de_embedding_proj
693
+
694
+ def set_output_embeddings(self, value):
695
+ """Set the output embeddings layer."""
696
+ self.pico_decoder.de_embedding_proj = value
697
+
698
+ def get_lm_head(self):
699
+ """Get the language model head."""
700
+ return self.pico_decoder.de_embedding_proj
701
+
702
+ def can_generate(self) -> bool:
703
+ """Check if the model can generate text."""
704
+ return True
705
+
706
+ @property
707
+ def is_encoder_decoder(self) -> bool:
708
+ """Check if the model is an encoder-decoder model."""
709
+ return False
710
+
711
+ @property
712
+ def can_use_cache(self) -> bool:
713
+ """Check if the model can use KV cache."""
714
+ return True
715
+
716
+ def resize_token_embeddings(
717
+ self, new_num_tokens: Optional[int] = None
718
+ ) -> torch.nn.Embedding:
719
+ """Resize token embeddings."""
720
+ old_embeddings = self.get_input_embeddings()
721
+ if new_num_tokens is None:
722
+ new_num_tokens = old_embeddings.num_embeddings
723
+
724
+ new_embeddings = torch.nn.Embedding(
725
+ new_num_tokens, old_embeddings.embedding_dim
726
+ )
727
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
+ old_embeddings.weight.data
729
+ )
730
+
731
+ self.pico_decoder.embedding_proj = new_embeddings
732
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
734
+ )
735
+
736
+ return new_embeddings
737
+
738
+
739
+ # Register for auto classes
740
+ PicoDecoderHFConfig.register_for_auto_class()
741
+ PicoDecoderHF.register_for_auto_class("AutoModel")
742
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
+
744
+
745
+ ########################################################
746
+ #
747
+ # New PicoDecoderForCausalLM class for generation support
748
+ #
749
+ ########################################################
750
+
751
+
752
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
+ """
754
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
+
756
+ This class is designed to work with existing checkpoints and provides full generation support.
757
+ It inherits from the right base classes that HuggingFace expects for text generation.
758
+ """
759
+
760
+ config_class = PicoDecoderHFConfig
761
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
+ main_input_name = "input_ids"
763
+
764
+ def __init__(self, config: PicoDecoderHFConfig):
765
+ super().__init__(config)
766
+ self.pico_decoder = PicoDecoder(config)
767
+ # Initialize generation config with defaults
768
+ self.generation_config = GenerationConfig()
769
+ # Set some reasonable defaults for the model
770
+ if hasattr(config, "max_position_embeddings"):
771
+ self.generation_config.max_length = config.max_position_embeddings
772
+ if hasattr(config, "vocab_size"):
773
+ self.generation_config.vocab_size = config.vocab_size
774
+
775
+ def forward(
776
+ self,
777
+ input_ids: torch.Tensor,
778
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
+ use_cache: bool = False,
780
+ **kwargs,
781
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
+ """Forward pass for text generation."""
783
+ logits, past_key_values = self.pico_decoder(
784
+ input_ids, past_key_values, use_cache
785
+ )
786
+ if use_cache:
787
+ return CausalLMOutputWithPast(
788
+ logits=logits,
789
+ past_key_values=past_key_values,
790
+ )
791
+ else:
792
+ return CausalLMOutput(
793
+ logits=logits,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids: torch.LongTensor,
799
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
+ attention_mask: Optional[torch.LongTensor] = None,
801
+ **kwargs,
802
+ ) -> Dict[str, Any]:
803
+ """Prepare inputs for generation."""
804
+ # If we have past_key_values, we only need the last token
805
+ if past_key_values is not None:
806
+ input_ids = input_ids[:, -1:]
807
+
808
+ return {
809
+ "input_ids": input_ids,
810
+ "past_key_values": past_key_values,
811
+ "use_cache": True,
812
+ }
813
+
814
+ def get_input_embeddings(self):
815
+ """Get the input embeddings layer."""
816
+ return self.pico_decoder.embedding_proj
817
+
818
+ def set_input_embeddings(self, value):
819
+ """Set the input embeddings layer."""
820
+ self.pico_decoder.embedding_proj = value
821
+
822
+ def get_output_embeddings(self):
823
+ """Get the output embeddings layer."""
824
+ return self.pico_decoder.de_embedding_proj
825
+
826
+ def set_output_embeddings(self, value):
827
+ """Set the output embeddings layer."""
828
+ self.pico_decoder.de_embedding_proj = value
829
+
830
+ def get_lm_head(self):
831
+ """Get the language model head."""
832
+ return self.pico_decoder.de_embedding_proj
833
+
834
+ def can_generate(self) -> bool:
835
+ """Check if the model can generate text."""
836
+ return True
837
+
838
+ @property
839
+ def is_encoder_decoder(self) -> bool:
840
+ """Check if the model is an encoder-decoder model."""
841
+ return False
842
+
843
+ @property
844
+ def can_use_cache(self) -> bool:
845
+ """Check if the model can use KV cache."""
846
+ return True
847
+
848
+ def resize_token_embeddings(
849
+ self, new_num_tokens: Optional[int] = None
850
+ ) -> torch.nn.Embedding:
851
+ """Resize token embeddings."""
852
+ old_embeddings = self.get_input_embeddings()
853
+ if new_num_tokens is None:
854
+ new_num_tokens = old_embeddings.num_embeddings
855
+
856
+ new_embeddings = torch.nn.Embedding(
857
+ new_num_tokens, old_embeddings.embedding_dim
858
+ )
859
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
+ old_embeddings.weight.data
861
+ )
862
+
863
+ self.pico_decoder.embedding_proj = new_embeddings
864
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
866
+ )
867
+
868
+ return new_embeddings
869
+
870
+ @classmethod
871
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
+ """
873
+ Load a pretrained model from a checkpoint.
874
+
875
+ This method handles loading from both the old PicoDecoderHF format and the new format.
876
+ """
877
+ # First try to load with the new class
878
+ try:
879
+ return super().from_pretrained(
880
+ pretrained_model_name_or_path, *model_args, **kwargs
881
+ )
882
+ except Exception as e:
883
+ print(f"Failed to load with new class: {e}")
884
+ print("Attempting to load with legacy class and convert...")
885
+
886
+ # Try to load with the old class and convert
887
+ try:
888
+ from transformers import AutoModel
889
+
890
+ old_model = AutoModel.from_pretrained(
891
+ pretrained_model_name_or_path,
892
+ trust_remote_code=True,
893
+ *model_args,
894
+ **kwargs,
895
+ )
896
+
897
+ # Create new model instance
898
+ new_model = cls(old_model.config)
899
+
900
+ # Copy state dict
901
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
902
+
903
+ return new_model
904
+
905
+ except Exception as e2:
906
+ print(f"Failed to convert from legacy format: {e2}")
907
+ raise e
908
+
909
+
910
+ # Register the new class
911
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_hidden_dim": 384,
3
+ "architectures": [
4
+ "PicoDecoderHF"
5
+ ],
6
+ "attention_n_heads": 12,
7
+ "attention_n_kv_heads": 4,
8
+ "auto_map": {
9
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
+ },
12
+ "batch_size": 1024,
13
+ "d_model": 96,
14
+ "max_seq_len": 2048,
15
+ "model_type": "pico_decoder",
16
+ "n_layers": 12,
17
+ "norm_eps": 1e-06,
18
+ "position_emb_theta": 10000.0,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.48.3",
21
+ "vocab_size": 50304
22
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/fabric_state/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa32db20f88e728c91c0d3ebdbf6c73eb9ae88ce72e585d545f4918a4032bdf6
3
+ size 135543171
pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_activations.pt ADDED
Binary file (98.3 kB). View file
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_data/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ec30c4780be6c0202b2e45eff9aafc355068121331fd4e15df007d6ffbcad98
3
+ size 271480
pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_data/dataset_info.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "text": {
13
+ "dtype": "string",
14
+ "_type": "Value"
15
+ }
16
+ },
17
+ "homepage": "",
18
+ "license": ""
19
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_11000/learning_dynamics/train_data/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "e60984724a7a4c9c",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }