Spaces:
Running
Running
Commit
·
035761e
0
Parent(s):
first commit
Browse files- .gitignore +93 -0
- .gradio/certificate.pem +31 -0
- README.md +49 -0
- config_smollm2_135.yaml +128 -0
- inference.py +157 -0
- model.py +203 -0
- requirements.txt +0 -0
- train_smollm2.py +12 -0
.gitignore
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
|
23 |
+
# Virtual Environment
|
24 |
+
.env
|
25 |
+
.venv
|
26 |
+
env/
|
27 |
+
venv/
|
28 |
+
ENV/
|
29 |
+
env.bak/
|
30 |
+
venv.bak/
|
31 |
+
|
32 |
+
# Training artifacts
|
33 |
+
checkpoints/
|
34 |
+
runs/
|
35 |
+
logs/
|
36 |
+
*.ckpt
|
37 |
+
*.pt
|
38 |
+
*.pth
|
39 |
+
wandb/
|
40 |
+
lightning_logs/
|
41 |
+
final_model/
|
42 |
+
|
43 |
+
# IDE
|
44 |
+
.idea/
|
45 |
+
.vscode/
|
46 |
+
*.swp
|
47 |
+
*.swo
|
48 |
+
*~
|
49 |
+
|
50 |
+
# Jupyter Notebook
|
51 |
+
.ipynb_checkpoints
|
52 |
+
*.ipynb_checkpoints/
|
53 |
+
*.ipynb
|
54 |
+
|
55 |
+
# OS
|
56 |
+
.DS_Store
|
57 |
+
.DS_Store?
|
58 |
+
._*
|
59 |
+
.Spotlight-V100
|
60 |
+
.Trashes
|
61 |
+
ehthumbs.db
|
62 |
+
Thumbs.db
|
63 |
+
|
64 |
+
# Logs
|
65 |
+
*.log
|
66 |
+
*.logs
|
67 |
+
log.txt
|
68 |
+
logs.txt
|
69 |
+
|
70 |
+
# Data
|
71 |
+
data/
|
72 |
+
datasets/
|
73 |
+
*.csv
|
74 |
+
*.h5
|
75 |
+
*.pkl
|
76 |
+
*.npz
|
77 |
+
|
78 |
+
# Environment
|
79 |
+
.env
|
80 |
+
.env.local
|
81 |
+
.env.*.local
|
82 |
+
.env.development.local
|
83 |
+
.env.test.local
|
84 |
+
.env.production.local
|
85 |
+
|
86 |
+
# Misc
|
87 |
+
*.bak
|
88 |
+
*.tmp
|
89 |
+
*.temp
|
90 |
+
.coverage
|
91 |
+
htmlcov/
|
92 |
+
.pytest_cache/
|
93 |
+
.mypy_cache/
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
README.md
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- use venv to create a virtual environment -->
|
2 |
+
```
|
3 |
+
uv venv
|
4 |
+
source .venv/bin/activate
|
5 |
+
```
|
6 |
+
<!-- Train smollm2 model -->
|
7 |
+
use dataset from https://huggingface.co/datasets/HuggingFaceTB/smollm-corpus/tree/main/cosmopedia-v2
|
8 |
+
```
|
9 |
+
dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2")
|
10 |
+
```
|
11 |
+
|
12 |
+
use tokeniser from https://huggingface.co/HuggingFaceTB/cosmo2-tokenizer
|
13 |
+
```
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
15 |
+
```
|
16 |
+
use config from https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config_smollm2_135M.yaml
|
17 |
+
|
18 |
+
create model from above parameters
|
19 |
+
|
20 |
+
Use it for training using pytorch lightning
|
21 |
+
|
22 |
+
<!-- Model architecture -->
|
23 |
+
|
24 |
+
LlamaForCausalLM(
|
25 |
+
(model): LlamaModel(
|
26 |
+
(embed_tokens): Embedding(49152, 576)
|
27 |
+
(layers): ModuleList(
|
28 |
+
(0-29): 30 x LlamaDecoderLayer(
|
29 |
+
(self_attn): LlamaAttention(
|
30 |
+
(q_proj): Linear(in_features=576, out_features=576, bias=False)
|
31 |
+
(k_proj): Linear(in_features=576, out_features=192, bias=False)
|
32 |
+
(v_proj): Linear(in_features=576, out_features=192, bias=False)
|
33 |
+
(o_proj): Linear(in_features=576, out_features=576, bias=False)
|
34 |
+
)
|
35 |
+
(mlp): LlamaMLP(
|
36 |
+
(gate_proj): Linear(in_features=576, out_features=1536, bias=False)
|
37 |
+
(up_proj): Linear(in_features=576, out_features=1536, bias=False)
|
38 |
+
(down_proj): Linear(in_features=1536, out_features=576, bias=False)
|
39 |
+
(act_fn): SiLU()
|
40 |
+
)
|
41 |
+
(input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
|
42 |
+
(post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
|
43 |
+
)
|
44 |
+
)
|
45 |
+
(norm): LlamaRMSNorm((576,), eps=1e-05)
|
46 |
+
(rotary_emb): LlamaRotaryEmbedding()
|
47 |
+
)
|
48 |
+
(lm_head): Linear(in_features=576, out_features=49152, bias=False)
|
49 |
+
)
|
config_smollm2_135.yaml
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoints:
|
2 |
+
checkpoint_interval: 2000
|
3 |
+
checkpoints_path: checkpoints
|
4 |
+
checkpoints_path_is_shared_file_system: false
|
5 |
+
resume_checkpoint_path: null
|
6 |
+
save_final_state: false
|
7 |
+
save_initial_state: false
|
8 |
+
data_stages:
|
9 |
+
- data:
|
10 |
+
dataset:
|
11 |
+
dataset_folder:
|
12 |
+
- datasets/smollm2-corpus
|
13 |
+
dataset_weights:
|
14 |
+
- 1.0
|
15 |
+
num_loading_workers: 0
|
16 |
+
seed: 8
|
17 |
+
name: stable phase
|
18 |
+
start_training_step: 1
|
19 |
+
general:
|
20 |
+
benchmark_csv_path: null
|
21 |
+
consumed_train_samples: null
|
22 |
+
ignore_sanity_checks: true
|
23 |
+
project: smollm2
|
24 |
+
run: smollm2-135M
|
25 |
+
seed: 8
|
26 |
+
step: null
|
27 |
+
logging:
|
28 |
+
iteration_step_info_interval: 1
|
29 |
+
log_level: info
|
30 |
+
log_level_replica: info
|
31 |
+
model:
|
32 |
+
ddp_bucket_cap_mb: 25
|
33 |
+
dtype: bfloat16
|
34 |
+
init_method:
|
35 |
+
std: 0.041666666666666664
|
36 |
+
make_vocab_size_divisible_by: 1
|
37 |
+
model_config:
|
38 |
+
bos_token_id: 0
|
39 |
+
eos_token_id: 0
|
40 |
+
hidden_act: silu
|
41 |
+
hidden_size: 576
|
42 |
+
initializer_range: 0.041666666666666664
|
43 |
+
intermediate_size: 1536
|
44 |
+
is_llama_config: true
|
45 |
+
max_position_embeddings: 2048
|
46 |
+
num_attention_heads: 9
|
47 |
+
num_hidden_layers: 30
|
48 |
+
num_key_value_heads: 3
|
49 |
+
pad_token_id: null
|
50 |
+
pretraining_tp: 1
|
51 |
+
rms_norm_eps: 1.0e-05
|
52 |
+
rope_interleaved: false
|
53 |
+
rope_scaling: null
|
54 |
+
rope_theta: 10000.0
|
55 |
+
tie_word_embeddings: true
|
56 |
+
use_cache: true
|
57 |
+
vocab_size: 49152
|
58 |
+
optimizer:
|
59 |
+
accumulate_grad_in_fp32: true
|
60 |
+
clip_grad: 1.0
|
61 |
+
learning_rate_scheduler:
|
62 |
+
learning_rate: 0.003
|
63 |
+
lr_decay_starting_step: 1600000
|
64 |
+
lr_decay_steps: 400000
|
65 |
+
lr_decay_style: linear
|
66 |
+
lr_warmup_steps: 2000
|
67 |
+
lr_warmup_style: linear
|
68 |
+
min_decay_lr: 0
|
69 |
+
optimizer_factory:
|
70 |
+
adam_beta1: 0.9
|
71 |
+
adam_beta2: 0.95
|
72 |
+
adam_eps: 1.0e-08
|
73 |
+
name: adamW
|
74 |
+
torch_adam_is_fused: true
|
75 |
+
weight_decay: 0.01
|
76 |
+
zero_stage: 0
|
77 |
+
parallelism:
|
78 |
+
dp: 64
|
79 |
+
expert_parallel_size: 1
|
80 |
+
pp: 1
|
81 |
+
pp_engine: 1f1b
|
82 |
+
recompute_layer: false
|
83 |
+
tp: 1
|
84 |
+
tp_linear_async_communication: true
|
85 |
+
tp_mode: REDUCE_SCATTER
|
86 |
+
tp_recompute_allgather: true
|
87 |
+
profiler: null
|
88 |
+
tokenizer:
|
89 |
+
tokenizer_max_length: null
|
90 |
+
tokenizer_name_or_path: HuggingFaceTB/cosmo2-tokenizer
|
91 |
+
tokenizer_revision: null
|
92 |
+
tokens:
|
93 |
+
batch_accumulation_per_replica: 1
|
94 |
+
limit_test_batches: 0
|
95 |
+
limit_val_batches: 0
|
96 |
+
micro_batch_size: 8
|
97 |
+
sequence_length: 2048
|
98 |
+
train_steps: 2000000
|
99 |
+
val_check_interval: 1000
|
100 |
+
|
101 |
+
# model:
|
102 |
+
|
103 |
+
# LlamaForCausalLM(
|
104 |
+
# (model): LlamaModel(
|
105 |
+
# (embed_tokens): Embedding(49152, 576)
|
106 |
+
# (layers): ModuleList(
|
107 |
+
# (0-29): 30 x LlamaDecoderLayer(
|
108 |
+
# (self_attn): LlamaAttention(
|
109 |
+
# (q_proj): Linear(in_features=576, out_features=576, bias=False)
|
110 |
+
# (k_proj): Linear(in_features=576, out_features=192, bias=False)
|
111 |
+
# (v_proj): Linear(in_features=576, out_features=192, bias=False)
|
112 |
+
# (o_proj): Linear(in_features=576, out_features=576, bias=False)
|
113 |
+
# )
|
114 |
+
# (mlp): LlamaMLP(
|
115 |
+
# (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
|
116 |
+
# (up_proj): Linear(in_features=576, out_features=1536, bias=False)
|
117 |
+
# (down_proj): Linear(in_features=1536, out_features=576, bias=False)
|
118 |
+
# (act_fn): SiLU()
|
119 |
+
# )
|
120 |
+
# (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
|
121 |
+
# (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
|
122 |
+
# )
|
123 |
+
# )
|
124 |
+
# (norm): LlamaRMSNorm((576,), eps=1e-05)
|
125 |
+
# (rotary_emb): LlamaRotaryEmbedding()
|
126 |
+
# )
|
127 |
+
# (lm_head): Linear(in_features=576, out_features=49152, bias=False)
|
128 |
+
# )
|
inference.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from model import SmolLMModule, create_model_config
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
import yaml
|
7 |
+
import glob
|
8 |
+
|
9 |
+
# Load config
|
10 |
+
with open("config_smollm2_135.yaml", "r") as file:
|
11 |
+
config = yaml.safe_load(file)
|
12 |
+
|
13 |
+
# Load tokenizer
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
15 |
+
tokenizer.pad_token = tokenizer.eos_token
|
16 |
+
|
17 |
+
|
18 |
+
def load_model_from_checkpoint(checkpoint_path):
|
19 |
+
"""Load model from checkpoint"""
|
20 |
+
model = SmolLMModule.load_from_checkpoint(checkpoint_path, config=config)
|
21 |
+
model.eval() # Set to evaluation mode
|
22 |
+
return model
|
23 |
+
|
24 |
+
|
25 |
+
def get_available_checkpoints():
|
26 |
+
"""Get list of available checkpoints sorted by step number"""
|
27 |
+
checkpoints = glob.glob("checkpoints/*.ckpt")
|
28 |
+
if not checkpoints:
|
29 |
+
return [], []
|
30 |
+
|
31 |
+
# Sort by step number
|
32 |
+
def get_step_number(filepath):
|
33 |
+
try:
|
34 |
+
# Extract step number from the filename
|
35 |
+
filename = os.path.basename(filepath)
|
36 |
+
# Remove .ckpt extension
|
37 |
+
filename = filename.replace(".ckpt", "")
|
38 |
+
# Get the step number
|
39 |
+
if "step=" in filename:
|
40 |
+
return int(filename.split("step=")[1])
|
41 |
+
elif "-step-" in filename:
|
42 |
+
return int(filename.split("-step-")[1])
|
43 |
+
else:
|
44 |
+
return int("".join(filter(str.isdigit, filename)))
|
45 |
+
except (ValueError, IndexError):
|
46 |
+
return 0
|
47 |
+
|
48 |
+
# Sort checkpoints by step number
|
49 |
+
checkpoints.sort(key=get_step_number)
|
50 |
+
|
51 |
+
# Create display names
|
52 |
+
display_names = [f"Step {get_step_number(x)}" for x in checkpoints]
|
53 |
+
return display_names, checkpoints
|
54 |
+
|
55 |
+
|
56 |
+
def generate_text(
|
57 |
+
prompt, checkpoint_choice, max_length=100, temperature=0.7, top_p=0.9
|
58 |
+
):
|
59 |
+
"""Generate text based on prompt using selected checkpoint"""
|
60 |
+
# Check if checkpoint is selected
|
61 |
+
if not checkpoint_choice:
|
62 |
+
return "Please select a checkpoint first!"
|
63 |
+
|
64 |
+
if not prompt:
|
65 |
+
return "Please enter a prompt!"
|
66 |
+
|
67 |
+
try:
|
68 |
+
# Get actual checkpoint path
|
69 |
+
step_num = int("".join(filter(str.isdigit, checkpoint_choice)))
|
70 |
+
checkpoints = glob.glob("checkpoints/*.ckpt")
|
71 |
+
checkpoint_path = None
|
72 |
+
|
73 |
+
for ckpt in checkpoints:
|
74 |
+
if str(step_num) in ckpt:
|
75 |
+
checkpoint_path = ckpt
|
76 |
+
break
|
77 |
+
|
78 |
+
if not checkpoint_path or not os.path.exists(checkpoint_path):
|
79 |
+
return f"Checkpoint for step {step_num} not found!"
|
80 |
+
|
81 |
+
# Load model from checkpoint
|
82 |
+
model = load_model_from_checkpoint(checkpoint_path)
|
83 |
+
|
84 |
+
# Move model to GPU if available
|
85 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
86 |
+
model = model.to(device)
|
87 |
+
|
88 |
+
# Tokenize input
|
89 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
90 |
+
# Move inputs to same device as model
|
91 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
92 |
+
|
93 |
+
# Generate
|
94 |
+
with torch.no_grad():
|
95 |
+
outputs = model.model.generate(
|
96 |
+
inputs["input_ids"],
|
97 |
+
max_length=max_length,
|
98 |
+
temperature=temperature,
|
99 |
+
top_p=top_p,
|
100 |
+
pad_token_id=tokenizer.pad_token_id,
|
101 |
+
eos_token_id=tokenizer.eos_token_id,
|
102 |
+
)
|
103 |
+
|
104 |
+
# Decode and return generated text
|
105 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
106 |
+
return generated_text
|
107 |
+
except Exception as e:
|
108 |
+
return f"Error during generation: {str(e)}"
|
109 |
+
|
110 |
+
|
111 |
+
# Get available checkpoints
|
112 |
+
display_names, _ = get_available_checkpoints()
|
113 |
+
|
114 |
+
# Create Gradio interface
|
115 |
+
with gr.Blocks(title="SmolLM2 Inference") as demo:
|
116 |
+
gr.Markdown("# SmolLM2 Text Generation")
|
117 |
+
|
118 |
+
if not display_names:
|
119 |
+
gr.Markdown("⚠️ No checkpoints found! Please train the model first.")
|
120 |
+
else:
|
121 |
+
gr.Markdown(
|
122 |
+
f"Found {len(display_names)} checkpoints. Select one and enter a prompt to generate text."
|
123 |
+
)
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
with gr.Column():
|
127 |
+
checkpoint_dropdown = gr.Dropdown(
|
128 |
+
choices=display_names,
|
129 |
+
label="Select Checkpoint",
|
130 |
+
value=display_names[-1] if display_names else None,
|
131 |
+
interactive=True,
|
132 |
+
)
|
133 |
+
prompt = gr.Textbox(
|
134 |
+
lines=3, placeholder="Enter your prompt here...", label="Input Prompt"
|
135 |
+
)
|
136 |
+
max_length = gr.Slider(
|
137 |
+
minimum=10, maximum=500, value=100, step=10, label="Max Length"
|
138 |
+
)
|
139 |
+
temperature = gr.Slider(
|
140 |
+
minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"
|
141 |
+
)
|
142 |
+
top_p = gr.Slider(
|
143 |
+
minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p"
|
144 |
+
)
|
145 |
+
generate_btn = gr.Button("Generate")
|
146 |
+
|
147 |
+
with gr.Column():
|
148 |
+
output = gr.Textbox(lines=8, label="Generated Text")
|
149 |
+
|
150 |
+
generate_btn.click(
|
151 |
+
fn=generate_text,
|
152 |
+
inputs=[prompt, checkpoint_dropdown, max_length, temperature, top_p],
|
153 |
+
outputs=output,
|
154 |
+
)
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
demo.launch(share=True)
|
model.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import libraries
|
2 |
+
from datasets import load_dataset
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
|
4 |
+
from transformers import Trainer
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import yaml
|
7 |
+
from pytorch_lightning.callbacks import LearningRateMonitor
|
8 |
+
from pytorch_lightning.callbacks import RichProgressBar
|
9 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
# load dataset
|
14 |
+
dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True)
|
15 |
+
train_dataset = dataset["train"]
|
16 |
+
for sample in train_dataset:
|
17 |
+
print(sample)
|
18 |
+
break
|
19 |
+
# load tokenizer
|
20 |
+
# use tokeniser from https://huggingface.co/HuggingFaceTB/cosmo2-tokenizer
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
22 |
+
# Set padding token to be the same as EOS token
|
23 |
+
tokenizer.pad_token = tokenizer.eos_token
|
24 |
+
|
25 |
+
# load config
|
26 |
+
# use config from https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config_smollm2_135M.yaml
|
27 |
+
# config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
28 |
+
|
29 |
+
|
30 |
+
def collate_fn(examples):
|
31 |
+
# Tokenize the texts
|
32 |
+
encoding = tokenizer(
|
33 |
+
[example["text"] for example in examples],
|
34 |
+
padding=True,
|
35 |
+
truncation=True,
|
36 |
+
max_length=512,
|
37 |
+
return_tensors="pt",
|
38 |
+
)
|
39 |
+
|
40 |
+
# Create labels (same as input_ids for causal language modeling)
|
41 |
+
encoding["labels"] = encoding["input_ids"].clone()
|
42 |
+
|
43 |
+
return encoding
|
44 |
+
|
45 |
+
|
46 |
+
def create_model_config(config):
|
47 |
+
model_config = config["model"]["model_config"]
|
48 |
+
return LlamaConfig(
|
49 |
+
vocab_size=49152, # From the model architecture
|
50 |
+
hidden_size=model_config["hidden_size"],
|
51 |
+
intermediate_size=model_config["intermediate_size"],
|
52 |
+
num_hidden_layers=model_config["num_hidden_layers"],
|
53 |
+
num_attention_heads=model_config["num_attention_heads"],
|
54 |
+
num_key_value_heads=model_config["num_key_value_heads"],
|
55 |
+
hidden_act=model_config["hidden_act"],
|
56 |
+
max_position_embeddings=model_config["max_position_embeddings"],
|
57 |
+
initializer_range=model_config["initializer_range"],
|
58 |
+
rms_norm_eps=1e-5, # From the model architecture
|
59 |
+
use_cache=True,
|
60 |
+
pad_token_id=model_config["pad_token_id"],
|
61 |
+
bos_token_id=model_config["bos_token_id"],
|
62 |
+
eos_token_id=model_config["eos_token_id"],
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
# create model
|
67 |
+
class SmolLMModule(pl.LightningModule):
|
68 |
+
def __init__(self, config, learning_rate=1e-4):
|
69 |
+
super().__init__()
|
70 |
+
self.config = config
|
71 |
+
self.learning_rate = learning_rate
|
72 |
+
self.save_hyperparameters() # Save hyperparameters for resuming
|
73 |
+
|
74 |
+
# Create model from config
|
75 |
+
model_config = create_model_config(config)
|
76 |
+
self.model = AutoModelForCausalLM.from_config(model_config)
|
77 |
+
|
78 |
+
def forward(self, **inputs):
|
79 |
+
return self.model(**inputs)
|
80 |
+
|
81 |
+
def training_step(self, batch, batch_idx):
|
82 |
+
outputs = self.model(**batch)
|
83 |
+
loss = outputs.loss
|
84 |
+
self.log("train_loss", loss, prog_bar=True)
|
85 |
+
return loss
|
86 |
+
|
87 |
+
def configure_optimizers(self):
|
88 |
+
optimizer = torch.optim.AdamW(
|
89 |
+
self.model.parameters(),
|
90 |
+
lr=self.learning_rate,
|
91 |
+
betas=(0.9, 0.95),
|
92 |
+
eps=1e-8,
|
93 |
+
weight_decay=0.1,
|
94 |
+
)
|
95 |
+
return optimizer
|
96 |
+
|
97 |
+
def on_save_checkpoint(self, checkpoint):
|
98 |
+
# Save additional info if needed
|
99 |
+
checkpoint["step"] = self.global_step
|
100 |
+
checkpoint["model_config"] = self.config
|
101 |
+
|
102 |
+
def on_load_checkpoint(self, checkpoint):
|
103 |
+
# Restore additional info if needed
|
104 |
+
self.global_step = checkpoint["step"]
|
105 |
+
self.config = checkpoint["model_config"]
|
106 |
+
|
107 |
+
|
108 |
+
# train model
|
109 |
+
|
110 |
+
# save model
|
111 |
+
|
112 |
+
# training script
|
113 |
+
if __name__ == "__main__":
|
114 |
+
import os
|
115 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
116 |
+
|
117 |
+
# parameters load from config file
|
118 |
+
with open("config_smollm2_135.yaml", "r") as file:
|
119 |
+
config = yaml.safe_load(file)
|
120 |
+
max_steps = 5000 # Total training steps
|
121 |
+
|
122 |
+
# Create checkpoint directory if it doesn't exist
|
123 |
+
checkpoint_dir = "checkpoints"
|
124 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
125 |
+
|
126 |
+
# Checkpoint callback
|
127 |
+
checkpoint_callback = ModelCheckpoint(
|
128 |
+
dirpath=checkpoint_dir,
|
129 |
+
filename="model-step={step}",
|
130 |
+
save_top_k=-1, # Save all checkpoints
|
131 |
+
every_n_train_steps=500, # Save every 500 steps
|
132 |
+
save_weights_only=False, # Save the full model state
|
133 |
+
)
|
134 |
+
|
135 |
+
# load tokenizer
|
136 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
137 |
+
# Set padding token to be the same as EOS token
|
138 |
+
tokenizer.pad_token = tokenizer.eos_token
|
139 |
+
|
140 |
+
# load dataset
|
141 |
+
dataset = load_dataset(
|
142 |
+
"HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True
|
143 |
+
)
|
144 |
+
train_dataset = dataset["train"]
|
145 |
+
|
146 |
+
# Create DataLoader
|
147 |
+
train_loader = DataLoader(
|
148 |
+
train_dataset,
|
149 |
+
batch_size=4, # Small batch size for testing
|
150 |
+
collate_fn=collate_fn,
|
151 |
+
num_workers=2,
|
152 |
+
)
|
153 |
+
|
154 |
+
# create model
|
155 |
+
model = SmolLMModule(config, learning_rate=1e-4)
|
156 |
+
|
157 |
+
# progress bar
|
158 |
+
progress_bar = RichProgressBar(leave=False, refresh_rate=1, console_kwargs=None)
|
159 |
+
|
160 |
+
# Find latest checkpoint if exists
|
161 |
+
latest_checkpoint = None
|
162 |
+
if os.path.exists(checkpoint_dir):
|
163 |
+
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".ckpt")]
|
164 |
+
if checkpoints:
|
165 |
+
# Sort by step number and get the latest
|
166 |
+
latest_checkpoint = os.path.join(
|
167 |
+
checkpoint_dir,
|
168 |
+
sorted(checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]))[
|
169 |
+
-1
|
170 |
+
],
|
171 |
+
)
|
172 |
+
print(f"Resuming from checkpoint: {latest_checkpoint}")
|
173 |
+
|
174 |
+
# create trainer
|
175 |
+
trainer = pl.Trainer(
|
176 |
+
max_steps=max_steps,
|
177 |
+
accelerator="gpu",
|
178 |
+
devices=1,
|
179 |
+
precision="bf16-mixed",
|
180 |
+
callbacks=[
|
181 |
+
LearningRateMonitor(logging_interval="step"),
|
182 |
+
progress_bar,
|
183 |
+
checkpoint_callback,
|
184 |
+
],
|
185 |
+
log_every_n_steps=1,
|
186 |
+
enable_progress_bar=True,
|
187 |
+
enable_model_summary=True,
|
188 |
+
)
|
189 |
+
|
190 |
+
# train model
|
191 |
+
if latest_checkpoint:
|
192 |
+
# Resume training from checkpoint if it exists
|
193 |
+
trainer.fit(model, train_loader, ckpt_path=latest_checkpoint)
|
194 |
+
else:
|
195 |
+
# Start training from scratch
|
196 |
+
trainer.fit(model, train_loader)
|
197 |
+
|
198 |
+
# Save final model and tokenizer
|
199 |
+
if trainer.is_global_zero: # Only save on main process
|
200 |
+
output_dir = "final_model"
|
201 |
+
os.makedirs(output_dir, exist_ok=True)
|
202 |
+
model.model.save_pretrained(os.path.join(output_dir, "model"))
|
203 |
+
tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer"))
|
requirements.txt
ADDED
File without changes
|
train_smollm2.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer
|
2 |
+
# from datasets import load_dataset
|
3 |
+
|
4 |
+
|
5 |
+
# dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2")
|
6 |
+
|
7 |
+
# use tokeniser https://huggingface.co/HuggingFaceTB/cosmo2-tokenizer
|
8 |
+
# tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
9 |
+
|
10 |
+
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
11 |
+
|
12 |
+
print(model)
|