RASMUS commited on
Commit
d184967
1 Parent(s): 60cda28

Create train_unsloth_7b.py

Browse files
Files changed (1) hide show
  1. train_unsloth_7b.py +122 -0
train_unsloth_7b.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLlamaModel
2
+ import torch
3
+ from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
4
+ from transformers import TrainingArguments
5
+ from datasets import load_from_disk
6
+ import math
7
+ import wandb
8
+ import os
9
+
10
+
11
+ max_seq_length = 2048 # Can change to whatever number <= 4096
12
+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
13
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
14
+
15
+
16
+ revisions = [("250k", "8ee454fe392a0267c3dee21323b5cac233d67441"),
17
+ ("500k", "12d3eec2d02533226c9cff719d4278967574ffcd"), ("750k", "845b8c6d8499c0e8fea0b8e5480d72e700385820"), ("1000k", "53669200ad7a6a6f1ac6a73e54c9e54c1d834a17")]
18
+
19
+
20
+
21
+ #for revision in revisions:
22
+ model, tokenizer = FastLlamaModel.from_pretrained(
23
+ model_name = "Finnish-NLP/llama-7b-finnish",
24
+ max_seq_length = max_seq_length,
25
+ dtype = dtype,
26
+ load_in_4bit = load_in_4bit,
27
+ revision='53669200ad7a6a6f1ac6a73e54c9e54c1d834a17'
28
+ )
29
+
30
+ tokenizer.clean_up_tokenization_spaces=True
31
+ tokenizer.add_tokens(["<|alku|>", "<PAD>", "<|ihminen|>", "<|avustaja|>"])
32
+ tokenizer.pad_token = "<PAD>"
33
+ tokenizer.add_special_tokens({'eos_token': '<|loppu|>'})
34
+ tokenizer.add_tokens('\n', special_tokens=True)
35
+ tokenizer.add_eos_token=True
36
+ model.resize_token_embeddings(new_num_tokens=len(tokenizer))
37
+ model.config.eos_token_id = tokenizer.eos_token_id
38
+ print(model.config.eos_token_id)
39
+ assert tokenizer.pad_token_id != tokenizer.eos_token_id
40
+ print(tokenizer.padding_side)
41
+ print(tokenizer.add_bos_token)
42
+ print(model)
43
+
44
+
45
+
46
+ model = FastLlamaModel.get_peft_model(
47
+ model,
48
+ r = 32,
49
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
50
+ "gate_proj", "up_proj", "down_proj"],
51
+ lora_alpha = 32,
52
+ lora_dropout = 0
53
+ bias = "none"
54
+ use_gradient_checkpointing = True,
55
+ modules_to_save = ["lm_head", "embed_tokens"],
56
+ random_state = 3407,
57
+ max_seq_length = max_seq_length,
58
+ use_rslora=True
59
+ )
60
+
61
+
62
+ dataset = load_from_disk("deepl_kaannetyt_combined")
63
+ dataset = dataset.train_test_split(test_size=0.02)
64
+
65
+
66
+ bs = 2
67
+ ga = 4
68
+ epochs = 3
69
+ train_steps = math.ceil(len(dataset["train"]) / bs / ga * epochs)
70
+ print(train_steps)
71
+ eval_steps = math.ceil(train_steps/10)
72
+ print(eval_steps)
73
+
74
+
75
+
76
+ try:
77
+ wandb.finish()
78
+ except Exception as e:
79
+ wandb.init()
80
+
81
+ response_template = "\n<|avustaja|> Vastauksesi:"
82
+ response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)
83
+
84
+
85
+ collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer, mlm=False)
86
+
87
+ trainer = SFTTrainer(
88
+ model = model,
89
+ train_dataset = dataset["train"],
90
+ eval_dataset = dataset["test"],
91
+ dataset_text_field = "text",
92
+ data_collator=collator,
93
+ max_seq_length = max_seq_length,
94
+ tokenizer=tokenizer,
95
+ args = TrainingArguments(
96
+ per_device_train_batch_size = 2,
97
+ per_device_eval_batch_size = 2,
98
+ gradient_accumulation_steps = 4,
99
+ warmup_steps = 50,
100
+ max_steps = train_steps,
101
+ report_to="wandb",
102
+ eval_steps=eval_steps,
103
+ evaluation_strategy="steps",
104
+ save_strategy='steps',
105
+ learning_rate = 2e-5,
106
+ fp16 = not torch.cuda.is_bf16_supported(),
107
+ bf16 = torch.cuda.is_bf16_supported(),
108
+ logging_steps = 5,
109
+ optim = "adamw_8bit",
110
+ weight_decay = 0.001,
111
+ lr_scheduler_type = "cosine",
112
+ seed = 3407,
113
+ output_dir = f"llama7b-finniish-instruct-v0.1",
114
+ ),
115
+ )
116
+
117
+ wandb.login()
118
+
119
+ trainer.train()
120
+
121
+
122
+