Crystalcareai commited on
Commit
6ff6eb4
1 Parent(s): 21a96f6

Create train-dora-alpaca.py

Browse files
Files changed (1) hide show
  1. train-dora-alpaca.py +162 -0
train-dora-alpaca.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ import random
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline, AutoConfig
5
+ from datasets import load_dataset
6
+ from transformers import TrainingArguments
7
+ from trl import SFTTrainer
8
+ from peft import LoraConfig
9
+ # from accelerate import infer_auto_device_map, init_empty_weights, dispatch_model
10
+ from torch.nn import CrossEntropyLoss
11
+
12
+ import time
13
+ random_seed = 42
14
+ torch.manual_seed(random_seed)
15
+ random.seed(random_seed)
16
+
17
+ dataset = load_dataset("Vezora/Tested-22k-Python-Alpaca", split="train")
18
+
19
+ def chatml_format(example):
20
+ """Format the dataset for training, accounting for empty columns."""
21
+ return {
22
+ "instruction": example['instruction'] if 'instruction' in example else " \n",
23
+ "input": example['input'] if 'input' in example else " \n",
24
+ "system": example['system'] if 'system' in example else " \n",
25
+ "output": example['output'] if 'output' in example else " \n",
26
+ }
27
+
28
+ # Format dataset
29
+ dataset = dataset.map(chatml_format, remove_columns=dataset.column_names)
30
+
31
+ n_ahead_talk_global = 4
32
+ n_passes_global = 2
33
+ n_ahead_global = 8
34
+ n_examples = 0
35
+
36
+ def model_init(params):
37
+ original = False
38
+ if params is None:
39
+ params = {}
40
+ else:
41
+ params = params.params
42
+ # save params to file
43
+ n_ahead = params.get("n_ahead", n_ahead_global if not original else 1)
44
+ n_ahead_talk = params.get("n_ahead_talk", n_ahead_talk_global if not original else 1)
45
+ n_passes = params.get("n_passes", n_passes_global if not original else 1)
46
+ gumbel_temperature = params.get("gumbel_temperature", 1)
47
+ use_start_thought_token = params.get("use_start_thought_token", True)
48
+ use_end_thought_token = params.get("use_end_thought_token", True)
49
+ include_policy_loss = params.get("include_policy_loss", True)
50
+ gumbel_detach = params.get("gumbel_detach", True)
51
+ merged_talk_heads = params.get("merged_talk_heads", True)
52
+ residual_think_head = params.get("residual_think_head", False)
53
+ optimize_lm_head_only_at_start = params.get("optimize_lm_head_only_at_start", False)
54
+
55
+ model_id = "Crystalcareai/Quiet-Star-Custom"
56
+ tokenizer_id = model_id
57
+ print("Loading model")
58
+
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ model_id,
61
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
62
+ max_thoughts=n_ahead + n_ahead_talk + 1,
63
+ merged_talk_heads=merged_talk_heads,
64
+ merged_lm_and_talk_heads=False,
65
+ merged_lm_and_think_heads=True,
66
+ use_concat_talk_head=True,
67
+ use_shallow_think=True,
68
+ use_shallow_talk=False,
69
+ use_complex_think_head=False,
70
+ use_complex_talk_head=True,
71
+ use_weighted_talk_head=True,
72
+ trust_remote_code=True,
73
+ device_map="auto",
74
+ )
75
+ print("Loaded model")
76
+
77
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, truncation=True, padding_side="right")
78
+ tokenizer.pad_token_id = tokenizer.eos_token_id
79
+
80
+ special_tokens_to_add = []
81
+ if model.use_start_thought_token:
82
+ special_tokens_to_add.append("<|startthought|>")
83
+ if model.use_end_thought_token:
84
+ special_tokens_to_add.append("<|endthought|>")
85
+ if special_tokens_to_add:
86
+ tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
87
+ model.resize_token_embeddings(len(tokenizer))
88
+ model.tokenizer = tokenizer
89
+ for name, module in model.named_modules():
90
+ if "embed" in name:
91
+ print(module, flush=True)
92
+
93
+ model.gumbel_detach = gumbel_detach
94
+ model.include_policy_loss = include_policy_loss
95
+ model.use_end_thought_token = use_end_thought_token
96
+ model.use_start_thought_token = use_start_thought_token
97
+ model.n_ahead = n_ahead
98
+ model.n_ahead_talk = n_ahead_talk
99
+ model.n_passes = n_passes
100
+ model.residual_think_head = residual_think_head
101
+ model.optimize_lm_head_only_at_start = optimize_lm_head_only_at_start
102
+ model.gumbel_temperature = gumbel_temperature
103
+ model.original_mode = original
104
+ model.config_params = params
105
+ model.run_start = int(time.time())
106
+ model.train()
107
+ return model
108
+
109
+ max_seq_length = 1024
110
+ run_id = int(time.time())
111
+ training_args = TrainingArguments(
112
+ output_dir="./out",
113
+ num_train_epochs=3,
114
+ per_device_train_batch_size=1,
115
+ gradient_checkpointing=False,
116
+ gradient_accumulation_steps=8,
117
+ optim="lion_32bit",
118
+ logging_steps=1,
119
+ save_strategy="steps",
120
+ save_steps=300,
121
+ max_steps=1000,
122
+ bf16=True,
123
+ tf32=False,
124
+ learning_rate=6e-05,
125
+ max_grad_norm=0.3,
126
+ warmup_ratio=0.06,
127
+ lr_scheduler_type="cosine",
128
+ push_to_hub=False,
129
+ report_to="wandb"
130
+ )
131
+
132
+ peft_config = LoraConfig(
133
+ r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
134
+ target_modules = ["q_proj", "k_proj"],
135
+ lora_alpha = 16,
136
+ lora_dropout = 0, # Supports any, but = 0 is optimized
137
+ bias = "none", # Enable Dora method
138
+ use_dora=True,
139
+ )
140
+
141
+
142
+ torch.autograd.set_detect_anomaly(True)
143
+
144
+ # Set the device for each process
145
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
+ # torch.cuda.set_device(device)
147
+
148
+ model = model_init(None) # Initialize the model
149
+
150
+ tokenizer = model.tokenizer
151
+
152
+ trainer = SFTTrainer(
153
+ args=training_args,
154
+ train_dataset=dataset,
155
+ model=model,
156
+ tokenizer=tokenizer,
157
+ max_seq_length=max_seq_length,
158
+ dataset_text_field="output",
159
+ peft_config=peft_config,
160
+ )
161
+
162
+ trainer.train()