File size: 3,826 Bytes
aeddf48
 
cfcbee0
1d59dc1
aeddf48
cfcbee0
aeddf48
1d59dc1
aeddf48
 
 
 
 
cfcbee0
aeddf48
 
 
 
 
 
 
 
 
 
 
 
 
 
cfcbee0
 
aeddf48
 
 
cfcbee0
aeddf48
 
cfcbee0
aeddf48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d59dc1
 
 
 
 
 
 
cfcbee0
1d59dc1
 
 
 
 
aeddf48
 
 
 
 
 
 
 
cfcbee0
 
aeddf48
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig, prepare_model_for_kbit_training, PeftModel, PeftConfig

dataset = load_dataset("csv", data_files="nowhere_training_input.csv", delimiter=";", split="train")

if torch.cuda.is_available():
    print("Cuda is available")

base_model_id = "abacaj/phi-2-super"
base_peft_id = "./results"
output_dir = "./results_phi-2-super"

tokenizer = AutoTokenizer.from_pretrained(base_model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("pad_token was missing and has been set to eos_token")

# Configuration to load model in 4-bit quantized
bnb_config = BitsAndBytesConfig(load_in_4bit=True,
                                bnb_4bit_quant_type='nf4',
                                #bnb_4bit_compute_dtype='float16',
                                bnb_4bit_compute_dtype=torch.bfloat16,
                                bnb_4bit_use_double_quant=False)

base_model = AutoModelForCausalLM.from_pretrained(base_model_id, attn_implementation="flash_attention_2", quantization_config=bnb_config, torch_dtype="auto")
model = PeftModel.from_pretrained(base_model, base_peft_id, is_trainable=True)
print(model)

# Gradient checkpointing to save memory
# model.gradient_checkpointing_enable()

# Freeze base model layers and cast layernorm in fp32
# model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

training_args = TrainingArguments(
    output_dir=output_dir,  # Output directory for checkpoints and predictions
    overwrite_output_dir=True, # Overwrite the content of the output directory
    per_device_train_batch_size=2,  # Batch size for training
    per_device_eval_batch_size=2,  # Batch size for evaluation
    gradient_accumulation_steps=5, # number of steps before optimizing
    gradient_checkpointing=True,   # Enable gradient checkpointing
    gradient_checkpointing_kwargs={"use_reentrant": False},
    warmup_steps=10,  # Number of warmup steps
    #max_steps=1000,  # Total number of training steps
    num_train_epochs=100,  # Number of training epochs
    learning_rate=5e-5,  # Learning rate
    weight_decay=0.01,  # Weight decay
    optim="paged_adamw_8bit", #Keep the optimizer state and quantize it
    bf16=True, #Use mixed precision training
    #For logging and saving
    logging_dir='./logs',
    logging_strategy="epoch",
    logging_steps=10,
    save_strategy="epoch",
    save_steps=10,
    save_total_limit=2,  # Limit the total number of checkpoints
    evaluation_strategy="epoch",
    eval_steps=10,
    load_best_model_at_end=True, # Load the best model at the end of training
    lr_scheduler_type="linear",
)

def formatting_func(data):
    # text = f"[INST] {data['prompt']} [/INST]{data['completion']}{tokenizer.eos_token} "
    chat = [
        { "role": "user", "content": data['prompt'] },
        { "role": "assistant", "content": data['completion'] },
    ]

    text = tokenizer.apply_chat_template(chat, tokenize=False)
    # print(text)
    data['text'] = text
    
    return data

dataset = dataset.map(formatting_func)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    eval_dataset=dataset,
    args=training_args,
    max_seq_length=1024,
    packing=True,
    dataset_text_field="text",
    neftune_noise_alpha=5,
)

model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

start_time = time.time()  # Record the start time
trainer.train()
end_time = time.time()  # Record the end time

training_time = end_time - start_time  # Calculate total training time

trainer.save_model(output_dir)
print(f"Training completed in {training_time} seconds.")