nroggendorff commited on
Commit
d72e6ae
1 Parent(s): 928e52a

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +128 -0
train.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import trl
5
+
6
+ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast
7
+ from datasets import load_dataset
8
+ from tokenizers import ByteLevelBPETokenizer
9
+
10
+ MAX_SEQ_LENGTH = 128
11
+ BATCH_SIZE = 256
12
+ EPOCHS = 8
13
+ LEARNING_RATE = 1e-4
14
+ FP16 = True
15
+ FACTOR = 2
16
+ VOCAB_SIZE = 3200
17
+ INPUT_DATASET = "nroggendorff/elephant"
18
+ OUTPUT_REPO = "smallama"
19
+
20
+ def load_data():
21
+ dataset = load_dataset(INPUT_DATASET, split="train")
22
+ return dataset
23
+
24
+ def create_tokenizer(training_corpus):
25
+ tokenizer = ByteLevelBPETokenizer()
26
+ tokenizer.train_from_iterator(
27
+ training_corpus,
28
+ vocab_size=VOCAB_SIZE,
29
+ min_frequency=2,
30
+ special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>", "<|user|>", "<|bot|>", "<|end|>"]
31
+ )
32
+
33
+ fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
34
+ return fast_tokenizer
35
+
36
+ def get_training_corpus(dataset):
37
+ for i in range(0, len(dataset), 1000):
38
+ yield dataset[i : i + 1000]["text"]
39
+
40
+ def format_prompts(examples, tokenizer):
41
+ texts = []
42
+ for text in examples['text']:
43
+ conversation = []
44
+ parts = text.split('<|end|>')
45
+ for i in range(0, len(parts) - 1, 2):
46
+ prompt = parts[i].replace("<|user|>", "")
47
+ response = parts[i + 1].replace("<|bot|>", "")
48
+ conversation.append({"role": "user", "content": prompt})
49
+ conversation.append({"role": "assistant", "content": response})
50
+ formatted_conversation = tokenizer.apply_chat_template(conversation, tokenize=False)
51
+ texts.append(formatted_conversation)
52
+ return {"text": texts}
53
+
54
+ def create_model(tokenizer):
55
+ config = LlamaConfig(
56
+ vocab_size=tokenizer.vocab_size,
57
+ hidden_size=FACTOR,
58
+ intermediate_size=FACTOR * 2,
59
+ num_hidden_layers=max(1, FACTOR // 64),
60
+ num_attention_heads=max(1, FACTOR // 64),
61
+ max_position_embeddings=MAX_SEQ_LENGTH,
62
+ rms_norm_eps=1e-6,
63
+ initializer_range=0.02,
64
+ use_cache=True,
65
+ pad_token_id=tokenizer.pad_token_id,
66
+ bos_token_id=tokenizer.bos_token_id,
67
+ eos_token_id=tokenizer.eos_token_id,
68
+ tie_word_embeddings=False,
69
+ )
70
+
71
+ model = LlamaForCausalLM(config)
72
+ return model
73
+
74
+ def configure_tokenizer(tokenizer):
75
+ special_tokens = {
76
+ "bos_token": "<s>",
77
+ "eos_token": "</s>",
78
+ "unk_token": "<unk>",
79
+ "pad_token": "<pad>",
80
+ "mask_token": "<mask>",
81
+ "additional_special_tokens": ["<|user|>", "<|bot|>", "<|end|>"]
82
+ }
83
+ tokenizer.add_special_tokens(special_tokens)
84
+
85
+ tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
86
+ tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
87
+
88
+ chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
89
+ tokenizer.chat_template = chat_template
90
+
91
+ def train_model(model, tokenizer, dataset):
92
+ args = TrainingArguments(
93
+ output_dir="model",
94
+ num_train_epochs=EPOCHS,
95
+ per_device_train_batch_size=BATCH_SIZE,
96
+ learning_rate=LEARNING_RATE,
97
+ fp16=FP16,
98
+ optim="sgd"
99
+ )
100
+ dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
101
+ trainer = trl.SFTTrainer(
102
+ model=model,
103
+ tokenizer=tokenizer,
104
+ args=args,
105
+ train_dataset=dataset,
106
+ dataset_text_field='text',
107
+ max_seq_length=MAX_SEQ_LENGTH
108
+ )
109
+ trainer.train()
110
+
111
+ trained_model = trainer.model
112
+ trained_tokenizer = trainer.tokenizer
113
+
114
+ repo_id = OUTPUT_REPO
115
+ trained_model.push_to_hub(repo_id)
116
+ trained_tokenizer.push_to_hub(repo_id)
117
+
118
+ def main():
119
+ dataset = load_data()
120
+ training_corpus = get_training_corpus(dataset)
121
+ tokenizer = create_tokenizer(training_corpus)
122
+ configure_tokenizer(tokenizer)
123
+ model = create_model(tokenizer)
124
+ train_model(model, tokenizer, dataset)
125
+
126
+ if __name__ == "__main__":
127
+ main()
128
+ raise RuntimeError("The script is finished.")