Vibi007 commited on
Commit
035761e
·
0 Parent(s):

first commit

Browse files
Files changed (8) hide show
  1. .gitignore +93 -0
  2. .gradio/certificate.pem +31 -0
  3. README.md +49 -0
  4. config_smollm2_135.yaml +128 -0
  5. inference.py +157 -0
  6. model.py +203 -0
  7. requirements.txt +0 -0
  8. train_smollm2.py +12 -0
.gitignore ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ .env
25
+ .venv
26
+ env/
27
+ venv/
28
+ ENV/
29
+ env.bak/
30
+ venv.bak/
31
+
32
+ # Training artifacts
33
+ checkpoints/
34
+ runs/
35
+ logs/
36
+ *.ckpt
37
+ *.pt
38
+ *.pth
39
+ wandb/
40
+ lightning_logs/
41
+ final_model/
42
+
43
+ # IDE
44
+ .idea/
45
+ .vscode/
46
+ *.swp
47
+ *.swo
48
+ *~
49
+
50
+ # Jupyter Notebook
51
+ .ipynb_checkpoints
52
+ *.ipynb_checkpoints/
53
+ *.ipynb
54
+
55
+ # OS
56
+ .DS_Store
57
+ .DS_Store?
58
+ ._*
59
+ .Spotlight-V100
60
+ .Trashes
61
+ ehthumbs.db
62
+ Thumbs.db
63
+
64
+ # Logs
65
+ *.log
66
+ *.logs
67
+ log.txt
68
+ logs.txt
69
+
70
+ # Data
71
+ data/
72
+ datasets/
73
+ *.csv
74
+ *.h5
75
+ *.pkl
76
+ *.npz
77
+
78
+ # Environment
79
+ .env
80
+ .env.local
81
+ .env.*.local
82
+ .env.development.local
83
+ .env.test.local
84
+ .env.production.local
85
+
86
+ # Misc
87
+ *.bak
88
+ *.tmp
89
+ *.temp
90
+ .coverage
91
+ htmlcov/
92
+ .pytest_cache/
93
+ .mypy_cache/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- use venv to create a virtual environment -->
2
+ ```
3
+ uv venv
4
+ source .venv/bin/activate
5
+ ```
6
+ <!-- Train smollm2 model -->
7
+ use dataset from https://huggingface.co/datasets/HuggingFaceTB/smollm-corpus/tree/main/cosmopedia-v2
8
+ ```
9
+ dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2")
10
+ ```
11
+
12
+ use tokeniser from https://huggingface.co/HuggingFaceTB/cosmo2-tokenizer
13
+ ```
14
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
15
+ ```
16
+ use config from https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config_smollm2_135M.yaml
17
+
18
+ create model from above parameters
19
+
20
+ Use it for training using pytorch lightning
21
+
22
+ <!-- Model architecture -->
23
+
24
+ LlamaForCausalLM(
25
+ (model): LlamaModel(
26
+ (embed_tokens): Embedding(49152, 576)
27
+ (layers): ModuleList(
28
+ (0-29): 30 x LlamaDecoderLayer(
29
+ (self_attn): LlamaAttention(
30
+ (q_proj): Linear(in_features=576, out_features=576, bias=False)
31
+ (k_proj): Linear(in_features=576, out_features=192, bias=False)
32
+ (v_proj): Linear(in_features=576, out_features=192, bias=False)
33
+ (o_proj): Linear(in_features=576, out_features=576, bias=False)
34
+ )
35
+ (mlp): LlamaMLP(
36
+ (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
37
+ (up_proj): Linear(in_features=576, out_features=1536, bias=False)
38
+ (down_proj): Linear(in_features=1536, out_features=576, bias=False)
39
+ (act_fn): SiLU()
40
+ )
41
+ (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
42
+ (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
43
+ )
44
+ )
45
+ (norm): LlamaRMSNorm((576,), eps=1e-05)
46
+ (rotary_emb): LlamaRotaryEmbedding()
47
+ )
48
+ (lm_head): Linear(in_features=576, out_features=49152, bias=False)
49
+ )
config_smollm2_135.yaml ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints:
2
+ checkpoint_interval: 2000
3
+ checkpoints_path: checkpoints
4
+ checkpoints_path_is_shared_file_system: false
5
+ resume_checkpoint_path: null
6
+ save_final_state: false
7
+ save_initial_state: false
8
+ data_stages:
9
+ - data:
10
+ dataset:
11
+ dataset_folder:
12
+ - datasets/smollm2-corpus
13
+ dataset_weights:
14
+ - 1.0
15
+ num_loading_workers: 0
16
+ seed: 8
17
+ name: stable phase
18
+ start_training_step: 1
19
+ general:
20
+ benchmark_csv_path: null
21
+ consumed_train_samples: null
22
+ ignore_sanity_checks: true
23
+ project: smollm2
24
+ run: smollm2-135M
25
+ seed: 8
26
+ step: null
27
+ logging:
28
+ iteration_step_info_interval: 1
29
+ log_level: info
30
+ log_level_replica: info
31
+ model:
32
+ ddp_bucket_cap_mb: 25
33
+ dtype: bfloat16
34
+ init_method:
35
+ std: 0.041666666666666664
36
+ make_vocab_size_divisible_by: 1
37
+ model_config:
38
+ bos_token_id: 0
39
+ eos_token_id: 0
40
+ hidden_act: silu
41
+ hidden_size: 576
42
+ initializer_range: 0.041666666666666664
43
+ intermediate_size: 1536
44
+ is_llama_config: true
45
+ max_position_embeddings: 2048
46
+ num_attention_heads: 9
47
+ num_hidden_layers: 30
48
+ num_key_value_heads: 3
49
+ pad_token_id: null
50
+ pretraining_tp: 1
51
+ rms_norm_eps: 1.0e-05
52
+ rope_interleaved: false
53
+ rope_scaling: null
54
+ rope_theta: 10000.0
55
+ tie_word_embeddings: true
56
+ use_cache: true
57
+ vocab_size: 49152
58
+ optimizer:
59
+ accumulate_grad_in_fp32: true
60
+ clip_grad: 1.0
61
+ learning_rate_scheduler:
62
+ learning_rate: 0.003
63
+ lr_decay_starting_step: 1600000
64
+ lr_decay_steps: 400000
65
+ lr_decay_style: linear
66
+ lr_warmup_steps: 2000
67
+ lr_warmup_style: linear
68
+ min_decay_lr: 0
69
+ optimizer_factory:
70
+ adam_beta1: 0.9
71
+ adam_beta2: 0.95
72
+ adam_eps: 1.0e-08
73
+ name: adamW
74
+ torch_adam_is_fused: true
75
+ weight_decay: 0.01
76
+ zero_stage: 0
77
+ parallelism:
78
+ dp: 64
79
+ expert_parallel_size: 1
80
+ pp: 1
81
+ pp_engine: 1f1b
82
+ recompute_layer: false
83
+ tp: 1
84
+ tp_linear_async_communication: true
85
+ tp_mode: REDUCE_SCATTER
86
+ tp_recompute_allgather: true
87
+ profiler: null
88
+ tokenizer:
89
+ tokenizer_max_length: null
90
+ tokenizer_name_or_path: HuggingFaceTB/cosmo2-tokenizer
91
+ tokenizer_revision: null
92
+ tokens:
93
+ batch_accumulation_per_replica: 1
94
+ limit_test_batches: 0
95
+ limit_val_batches: 0
96
+ micro_batch_size: 8
97
+ sequence_length: 2048
98
+ train_steps: 2000000
99
+ val_check_interval: 1000
100
+
101
+ # model:
102
+
103
+ # LlamaForCausalLM(
104
+ # (model): LlamaModel(
105
+ # (embed_tokens): Embedding(49152, 576)
106
+ # (layers): ModuleList(
107
+ # (0-29): 30 x LlamaDecoderLayer(
108
+ # (self_attn): LlamaAttention(
109
+ # (q_proj): Linear(in_features=576, out_features=576, bias=False)
110
+ # (k_proj): Linear(in_features=576, out_features=192, bias=False)
111
+ # (v_proj): Linear(in_features=576, out_features=192, bias=False)
112
+ # (o_proj): Linear(in_features=576, out_features=576, bias=False)
113
+ # )
114
+ # (mlp): LlamaMLP(
115
+ # (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
116
+ # (up_proj): Linear(in_features=576, out_features=1536, bias=False)
117
+ # (down_proj): Linear(in_features=1536, out_features=576, bias=False)
118
+ # (act_fn): SiLU()
119
+ # )
120
+ # (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
121
+ # (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
122
+ # )
123
+ # )
124
+ # (norm): LlamaRMSNorm((576,), eps=1e-05)
125
+ # (rotary_emb): LlamaRotaryEmbedding()
126
+ # )
127
+ # (lm_head): Linear(in_features=576, out_features=49152, bias=False)
128
+ # )
inference.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from model import SmolLMModule, create_model_config
5
+ from transformers import AutoTokenizer
6
+ import yaml
7
+ import glob
8
+
9
+ # Load config
10
+ with open("config_smollm2_135.yaml", "r") as file:
11
+ config = yaml.safe_load(file)
12
+
13
+ # Load tokenizer
14
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
15
+ tokenizer.pad_token = tokenizer.eos_token
16
+
17
+
18
+ def load_model_from_checkpoint(checkpoint_path):
19
+ """Load model from checkpoint"""
20
+ model = SmolLMModule.load_from_checkpoint(checkpoint_path, config=config)
21
+ model.eval() # Set to evaluation mode
22
+ return model
23
+
24
+
25
+ def get_available_checkpoints():
26
+ """Get list of available checkpoints sorted by step number"""
27
+ checkpoints = glob.glob("checkpoints/*.ckpt")
28
+ if not checkpoints:
29
+ return [], []
30
+
31
+ # Sort by step number
32
+ def get_step_number(filepath):
33
+ try:
34
+ # Extract step number from the filename
35
+ filename = os.path.basename(filepath)
36
+ # Remove .ckpt extension
37
+ filename = filename.replace(".ckpt", "")
38
+ # Get the step number
39
+ if "step=" in filename:
40
+ return int(filename.split("step=")[1])
41
+ elif "-step-" in filename:
42
+ return int(filename.split("-step-")[1])
43
+ else:
44
+ return int("".join(filter(str.isdigit, filename)))
45
+ except (ValueError, IndexError):
46
+ return 0
47
+
48
+ # Sort checkpoints by step number
49
+ checkpoints.sort(key=get_step_number)
50
+
51
+ # Create display names
52
+ display_names = [f"Step {get_step_number(x)}" for x in checkpoints]
53
+ return display_names, checkpoints
54
+
55
+
56
+ def generate_text(
57
+ prompt, checkpoint_choice, max_length=100, temperature=0.7, top_p=0.9
58
+ ):
59
+ """Generate text based on prompt using selected checkpoint"""
60
+ # Check if checkpoint is selected
61
+ if not checkpoint_choice:
62
+ return "Please select a checkpoint first!"
63
+
64
+ if not prompt:
65
+ return "Please enter a prompt!"
66
+
67
+ try:
68
+ # Get actual checkpoint path
69
+ step_num = int("".join(filter(str.isdigit, checkpoint_choice)))
70
+ checkpoints = glob.glob("checkpoints/*.ckpt")
71
+ checkpoint_path = None
72
+
73
+ for ckpt in checkpoints:
74
+ if str(step_num) in ckpt:
75
+ checkpoint_path = ckpt
76
+ break
77
+
78
+ if not checkpoint_path or not os.path.exists(checkpoint_path):
79
+ return f"Checkpoint for step {step_num} not found!"
80
+
81
+ # Load model from checkpoint
82
+ model = load_model_from_checkpoint(checkpoint_path)
83
+
84
+ # Move model to GPU if available
85
+ device = "cuda" if torch.cuda.is_available() else "cpu"
86
+ model = model.to(device)
87
+
88
+ # Tokenize input
89
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
90
+ # Move inputs to same device as model
91
+ inputs = {k: v.to(device) for k, v in inputs.items()}
92
+
93
+ # Generate
94
+ with torch.no_grad():
95
+ outputs = model.model.generate(
96
+ inputs["input_ids"],
97
+ max_length=max_length,
98
+ temperature=temperature,
99
+ top_p=top_p,
100
+ pad_token_id=tokenizer.pad_token_id,
101
+ eos_token_id=tokenizer.eos_token_id,
102
+ )
103
+
104
+ # Decode and return generated text
105
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
106
+ return generated_text
107
+ except Exception as e:
108
+ return f"Error during generation: {str(e)}"
109
+
110
+
111
+ # Get available checkpoints
112
+ display_names, _ = get_available_checkpoints()
113
+
114
+ # Create Gradio interface
115
+ with gr.Blocks(title="SmolLM2 Inference") as demo:
116
+ gr.Markdown("# SmolLM2 Text Generation")
117
+
118
+ if not display_names:
119
+ gr.Markdown("⚠️ No checkpoints found! Please train the model first.")
120
+ else:
121
+ gr.Markdown(
122
+ f"Found {len(display_names)} checkpoints. Select one and enter a prompt to generate text."
123
+ )
124
+
125
+ with gr.Row():
126
+ with gr.Column():
127
+ checkpoint_dropdown = gr.Dropdown(
128
+ choices=display_names,
129
+ label="Select Checkpoint",
130
+ value=display_names[-1] if display_names else None,
131
+ interactive=True,
132
+ )
133
+ prompt = gr.Textbox(
134
+ lines=3, placeholder="Enter your prompt here...", label="Input Prompt"
135
+ )
136
+ max_length = gr.Slider(
137
+ minimum=10, maximum=500, value=100, step=10, label="Max Length"
138
+ )
139
+ temperature = gr.Slider(
140
+ minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"
141
+ )
142
+ top_p = gr.Slider(
143
+ minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p"
144
+ )
145
+ generate_btn = gr.Button("Generate")
146
+
147
+ with gr.Column():
148
+ output = gr.Textbox(lines=8, label="Generated Text")
149
+
150
+ generate_btn.click(
151
+ fn=generate_text,
152
+ inputs=[prompt, checkpoint_dropdown, max_length, temperature, top_p],
153
+ outputs=output,
154
+ )
155
+
156
+ if __name__ == "__main__":
157
+ demo.launch(share=True)
model.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import libraries
2
+ from datasets import load_dataset
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
4
+ from transformers import Trainer
5
+ import pytorch_lightning as pl
6
+ import yaml
7
+ from pytorch_lightning.callbacks import LearningRateMonitor
8
+ from pytorch_lightning.callbacks import RichProgressBar
9
+ from pytorch_lightning.loggers import TensorBoardLogger
10
+ import torch
11
+ from torch.utils.data import DataLoader
12
+
13
+ # load dataset
14
+ dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True)
15
+ train_dataset = dataset["train"]
16
+ for sample in train_dataset:
17
+ print(sample)
18
+ break
19
+ # load tokenizer
20
+ # use tokeniser from https://huggingface.co/HuggingFaceTB/cosmo2-tokenizer
21
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
22
+ # Set padding token to be the same as EOS token
23
+ tokenizer.pad_token = tokenizer.eos_token
24
+
25
+ # load config
26
+ # use config from https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config_smollm2_135M.yaml
27
+ # config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
28
+
29
+
30
+ def collate_fn(examples):
31
+ # Tokenize the texts
32
+ encoding = tokenizer(
33
+ [example["text"] for example in examples],
34
+ padding=True,
35
+ truncation=True,
36
+ max_length=512,
37
+ return_tensors="pt",
38
+ )
39
+
40
+ # Create labels (same as input_ids for causal language modeling)
41
+ encoding["labels"] = encoding["input_ids"].clone()
42
+
43
+ return encoding
44
+
45
+
46
+ def create_model_config(config):
47
+ model_config = config["model"]["model_config"]
48
+ return LlamaConfig(
49
+ vocab_size=49152, # From the model architecture
50
+ hidden_size=model_config["hidden_size"],
51
+ intermediate_size=model_config["intermediate_size"],
52
+ num_hidden_layers=model_config["num_hidden_layers"],
53
+ num_attention_heads=model_config["num_attention_heads"],
54
+ num_key_value_heads=model_config["num_key_value_heads"],
55
+ hidden_act=model_config["hidden_act"],
56
+ max_position_embeddings=model_config["max_position_embeddings"],
57
+ initializer_range=model_config["initializer_range"],
58
+ rms_norm_eps=1e-5, # From the model architecture
59
+ use_cache=True,
60
+ pad_token_id=model_config["pad_token_id"],
61
+ bos_token_id=model_config["bos_token_id"],
62
+ eos_token_id=model_config["eos_token_id"],
63
+ )
64
+
65
+
66
+ # create model
67
+ class SmolLMModule(pl.LightningModule):
68
+ def __init__(self, config, learning_rate=1e-4):
69
+ super().__init__()
70
+ self.config = config
71
+ self.learning_rate = learning_rate
72
+ self.save_hyperparameters() # Save hyperparameters for resuming
73
+
74
+ # Create model from config
75
+ model_config = create_model_config(config)
76
+ self.model = AutoModelForCausalLM.from_config(model_config)
77
+
78
+ def forward(self, **inputs):
79
+ return self.model(**inputs)
80
+
81
+ def training_step(self, batch, batch_idx):
82
+ outputs = self.model(**batch)
83
+ loss = outputs.loss
84
+ self.log("train_loss", loss, prog_bar=True)
85
+ return loss
86
+
87
+ def configure_optimizers(self):
88
+ optimizer = torch.optim.AdamW(
89
+ self.model.parameters(),
90
+ lr=self.learning_rate,
91
+ betas=(0.9, 0.95),
92
+ eps=1e-8,
93
+ weight_decay=0.1,
94
+ )
95
+ return optimizer
96
+
97
+ def on_save_checkpoint(self, checkpoint):
98
+ # Save additional info if needed
99
+ checkpoint["step"] = self.global_step
100
+ checkpoint["model_config"] = self.config
101
+
102
+ def on_load_checkpoint(self, checkpoint):
103
+ # Restore additional info if needed
104
+ self.global_step = checkpoint["step"]
105
+ self.config = checkpoint["model_config"]
106
+
107
+
108
+ # train model
109
+
110
+ # save model
111
+
112
+ # training script
113
+ if __name__ == "__main__":
114
+ import os
115
+ from pytorch_lightning.callbacks import ModelCheckpoint
116
+
117
+ # parameters load from config file
118
+ with open("config_smollm2_135.yaml", "r") as file:
119
+ config = yaml.safe_load(file)
120
+ max_steps = 5000 # Total training steps
121
+
122
+ # Create checkpoint directory if it doesn't exist
123
+ checkpoint_dir = "checkpoints"
124
+ os.makedirs(checkpoint_dir, exist_ok=True)
125
+
126
+ # Checkpoint callback
127
+ checkpoint_callback = ModelCheckpoint(
128
+ dirpath=checkpoint_dir,
129
+ filename="model-step={step}",
130
+ save_top_k=-1, # Save all checkpoints
131
+ every_n_train_steps=500, # Save every 500 steps
132
+ save_weights_only=False, # Save the full model state
133
+ )
134
+
135
+ # load tokenizer
136
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
137
+ # Set padding token to be the same as EOS token
138
+ tokenizer.pad_token = tokenizer.eos_token
139
+
140
+ # load dataset
141
+ dataset = load_dataset(
142
+ "HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True
143
+ )
144
+ train_dataset = dataset["train"]
145
+
146
+ # Create DataLoader
147
+ train_loader = DataLoader(
148
+ train_dataset,
149
+ batch_size=4, # Small batch size for testing
150
+ collate_fn=collate_fn,
151
+ num_workers=2,
152
+ )
153
+
154
+ # create model
155
+ model = SmolLMModule(config, learning_rate=1e-4)
156
+
157
+ # progress bar
158
+ progress_bar = RichProgressBar(leave=False, refresh_rate=1, console_kwargs=None)
159
+
160
+ # Find latest checkpoint if exists
161
+ latest_checkpoint = None
162
+ if os.path.exists(checkpoint_dir):
163
+ checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".ckpt")]
164
+ if checkpoints:
165
+ # Sort by step number and get the latest
166
+ latest_checkpoint = os.path.join(
167
+ checkpoint_dir,
168
+ sorted(checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]))[
169
+ -1
170
+ ],
171
+ )
172
+ print(f"Resuming from checkpoint: {latest_checkpoint}")
173
+
174
+ # create trainer
175
+ trainer = pl.Trainer(
176
+ max_steps=max_steps,
177
+ accelerator="gpu",
178
+ devices=1,
179
+ precision="bf16-mixed",
180
+ callbacks=[
181
+ LearningRateMonitor(logging_interval="step"),
182
+ progress_bar,
183
+ checkpoint_callback,
184
+ ],
185
+ log_every_n_steps=1,
186
+ enable_progress_bar=True,
187
+ enable_model_summary=True,
188
+ )
189
+
190
+ # train model
191
+ if latest_checkpoint:
192
+ # Resume training from checkpoint if it exists
193
+ trainer.fit(model, train_loader, ckpt_path=latest_checkpoint)
194
+ else:
195
+ # Start training from scratch
196
+ trainer.fit(model, train_loader)
197
+
198
+ # Save final model and tokenizer
199
+ if trainer.is_global_zero: # Only save on main process
200
+ output_dir = "final_model"
201
+ os.makedirs(output_dir, exist_ok=True)
202
+ model.model.save_pretrained(os.path.join(output_dir, "model"))
203
+ tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer"))
requirements.txt ADDED
File without changes
train_smollm2.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer
2
+ # from datasets import load_dataset
3
+
4
+
5
+ # dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2")
6
+
7
+ # use tokeniser https://huggingface.co/HuggingFaceTB/cosmo2-tokenizer
8
+ # tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
9
+
10
+ model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
11
+
12
+ print(model)