Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -16,6 +16,7 @@ VOCAB_SIZE = 32000
|
|
16 |
INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
|
17 |
INSTRUCT_DATASET = "nroggendorff/elephant"
|
18 |
OUTPUT_REPO = "smallama"
|
|
|
19 |
FP16 = False
|
20 |
WARMUP_STEPS = 0
|
21 |
DECAY = 0
|
@@ -24,9 +25,9 @@ PUSH_TO_HUB = True
|
|
24 |
|
25 |
def load_data():
|
26 |
pretrain = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
|
27 |
-
pretrain = Dataset.from_generator(lambda: pretrain.take(int(3e+
|
28 |
instruct = load_dataset(INSTRUCT_DATASET, split="train", streaming=True)
|
29 |
-
instruct = Dataset.from_generator(lambda: instruct.take(int(5e+
|
30 |
dataset_dict = DatasetDict({
|
31 |
'pretrain': pretrain,
|
32 |
'instruct': instruct
|
@@ -91,6 +92,10 @@ def create_model(tokenizer):
|
|
91 |
model = LlamaForCausalLM(config)
|
92 |
return model
|
93 |
|
|
|
|
|
|
|
|
|
94 |
def configure_tokenizer(tokenizer):
|
95 |
special_tokens = {
|
96 |
"bos_token": "<s>",
|
@@ -145,7 +150,10 @@ def train_model(model, tokenizer, dataset, push, isinst):
|
|
145 |
trained_tokenizer = trainer.tokenizer
|
146 |
|
147 |
if push:
|
148 |
-
|
|
|
|
|
|
|
149 |
msg = str(train.training_loss)
|
150 |
trained_model.push_to_hub(repo_id, commit_message=msg, force=True)
|
151 |
trained_tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
|
@@ -153,17 +161,20 @@ def train_model(model, tokenizer, dataset, push, isinst):
|
|
153 |
trained_model.save_pretrained("model")
|
154 |
trained_tokenizer.save_pretrained("tokenizer")
|
155 |
|
156 |
-
def main(push_to_hub=True):
|
157 |
dataset = load_data()
|
158 |
pretrain = dataset['pretrain']
|
159 |
instruct = dataset['instruct']
|
160 |
training_corpus = get_training_corpus(dataset)
|
161 |
tokenizer = create_tokenizer(training_corpus)
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
166 |
|
167 |
if __name__ == "__main__":
|
168 |
-
main(PUSH_TO_HUB)
|
169 |
raise RuntimeError("The script is finished.")
|
|
|
16 |
INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
|
17 |
INSTRUCT_DATASET = "nroggendorff/elephant"
|
18 |
OUTPUT_REPO = "smallama"
|
19 |
+
INSTRUCT_FINETUNE_BOOL = False
|
20 |
FP16 = False
|
21 |
WARMUP_STEPS = 0
|
22 |
DECAY = 0
|
|
|
25 |
|
26 |
def load_data():
|
27 |
pretrain = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
|
28 |
+
pretrain = Dataset.from_generator(lambda: pretrain.take(int(3e+5)))
|
29 |
instruct = load_dataset(INSTRUCT_DATASET, split="train", streaming=True)
|
30 |
+
instruct = Dataset.from_generator(lambda: instruct.take(int(5e+5)))
|
31 |
dataset_dict = DatasetDict({
|
32 |
'pretrain': pretrain,
|
33 |
'instruct': instruct
|
|
|
92 |
model = LlamaForCausalLM(config)
|
93 |
return model
|
94 |
|
95 |
+
def load_model():
|
96 |
+
model = LlamaForCausalLM.from_pretrained(OUTPUT_REPO)
|
97 |
+
return model
|
98 |
+
|
99 |
def configure_tokenizer(tokenizer):
|
100 |
special_tokens = {
|
101 |
"bos_token": "<s>",
|
|
|
150 |
trained_tokenizer = trainer.tokenizer
|
151 |
|
152 |
if push:
|
153 |
+
if INSTRUCT_FINETUNE_BOOL:
|
154 |
+
repo_id = OUTPUT_REPO + "-it"
|
155 |
+
else:
|
156 |
+
repo_id = OUTPUT_REPO
|
157 |
msg = str(train.training_loss)
|
158 |
trained_model.push_to_hub(repo_id, commit_message=msg, force=True)
|
159 |
trained_tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
|
|
|
161 |
trained_model.save_pretrained("model")
|
162 |
trained_tokenizer.save_pretrained("tokenizer")
|
163 |
|
164 |
+
def main(push_to_hub=True, is_inst_finetune):
|
165 |
dataset = load_data()
|
166 |
pretrain = dataset['pretrain']
|
167 |
instruct = dataset['instruct']
|
168 |
training_corpus = get_training_corpus(dataset)
|
169 |
tokenizer = create_tokenizer(training_corpus)
|
170 |
+
if is_inst_finetune:
|
171 |
+
configure_tokenizer(tokenizer)
|
172 |
+
model = load_model()
|
173 |
+
train_model(model, tokenizer, instruct, push_to_hub, True)
|
174 |
+
else:
|
175 |
+
model = create_model(tokenizer)
|
176 |
+
train_model(model, tokenizer, pretrain, push_to_hub, False)
|
177 |
|
178 |
if __name__ == "__main__":
|
179 |
+
main(PUSH_TO_HUB, INSTRUCT_FINETUNE_BOOL)
|
180 |
raise RuntimeError("The script is finished.")
|