Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import os | |
from accelerate import Accelerator | |
from datasets import load_dataset | |
from peft import LoraConfig | |
from tqdm import tqdm | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, logging, set_seed | |
from trl import SFTTrainer | |
from trl.trainer import ConstantLengthDataset | |
""" | |
Fine-Tune Llama-7b on SE paired dataset | |
""" | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_path", type=str, default="") | |
parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired") | |
parser.add_argument("--subset", type=str, default="data/finetune") | |
parser.add_argument("--split", type=str, default="train") | |
parser.add_argument("--size_valid_set", type=int, default=4000) | |
parser.add_argument("--streaming", action="store_true") | |
parser.add_argument("--shuffle_buffer", type=int, default=5000) | |
parser.add_argument("--seq_length", type=int, default=1024) | |
parser.add_argument("--max_steps", type=int, default=10000) | |
parser.add_argument("--batch_size", type=int, default=4) | |
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) | |
parser.add_argument("--eos_token_id", type=int, default=49152) | |
parser.add_argument("--learning_rate", type=float, default=1e-4) | |
parser.add_argument("--lr_scheduler_type", type=str, default="cosine") | |
parser.add_argument("--num_warmup_steps", type=int, default=100) | |
parser.add_argument("--weight_decay", type=float, default=0.05) | |
parser.add_argument("--local_rank", type=int, default=0) | |
parser.add_argument("--fp16", action="store_true", default=False) | |
parser.add_argument("--bf16", action="store_true", default=False) | |
parser.add_argument("--gradient_checkpointing", action="store_true", default=False) | |
parser.add_argument("--seed", type=int, default=0) | |
parser.add_argument("--num_workers", type=int, default=None) | |
parser.add_argument("--output_dir", type=str, default="./checkpoints") | |
parser.add_argument("--log_freq", default=1, type=int) | |
parser.add_argument("--eval_freq", default=1000, type=int) | |
parser.add_argument("--save_freq", default=1000, type=int) | |
return parser.parse_args() | |
def chars_token_ratio(dataset, tokenizer, nb_examples=400): | |
""" | |
Estimate the average number of characters per token in the dataset. | |
""" | |
total_characters, total_tokens = 0, 0 | |
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): | |
text = prepare_sample_text(example) | |
total_characters += len(text) | |
if tokenizer.is_fast: | |
total_tokens += len(tokenizer(text).tokens()) | |
else: | |
total_tokens += len(tokenizer.tokenize(text)) | |
return total_characters / total_tokens | |
def print_trainable_parameters(model): | |
""" | |
Prints the number of trainable parameters in the model. | |
""" | |
trainable_params = 0 | |
all_param = 0 | |
for _, param in model.named_parameters(): | |
all_param += param.numel() | |
if param.requires_grad: | |
trainable_params += param.numel() | |
print( | |
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" | |
) | |
def prepare_sample_text(example): | |
"""Prepare the text from a sample of the dataset.""" | |
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" | |
return text | |
def create_datasets(tokenizer, args): | |
dataset = load_dataset( | |
args.dataset_name, | |
data_dir=args.subset, | |
split=args.split, | |
use_auth_token=True, | |
num_proc=args.num_workers if not args.streaming else None, | |
streaming=args.streaming, | |
) | |
if args.streaming: | |
print("Loading the dataset in streaming mode") | |
valid_data = dataset.take(args.size_valid_set) | |
train_data = dataset.skip(args.size_valid_set) | |
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) | |
else: | |
dataset = dataset.train_test_split(test_size=0.005, seed=args.seed) | |
train_data = dataset["train"] | |
valid_data = dataset["test"] | |
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") | |
chars_per_token = chars_token_ratio(train_data, tokenizer) | |
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") | |
train_dataset = ConstantLengthDataset( | |
tokenizer, | |
train_data, | |
formatting_func=prepare_sample_text, | |
infinite=True, | |
seq_length=args.seq_length, | |
chars_per_token=chars_per_token, | |
) | |
valid_dataset = ConstantLengthDataset( | |
tokenizer, | |
valid_data, | |
formatting_func=prepare_sample_text, | |
infinite=False, | |
seq_length=args.seq_length, | |
chars_per_token=chars_per_token, | |
) | |
return train_dataset, valid_dataset | |
def run_training(args, train_data, val_data): | |
print("Loading the model") | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
train_data.start_iteration = 0 | |
print("Starting main loop") | |
training_args = TrainingArguments( | |
output_dir=args.output_dir, | |
dataloader_drop_last=True, | |
eval_strategy="steps", | |
max_steps=args.max_steps, | |
eval_steps=args.eval_freq, | |
save_steps=args.save_freq, | |
logging_steps=args.log_freq, | |
per_device_train_batch_size=args.batch_size, | |
per_device_eval_batch_size=args.batch_size, | |
learning_rate=args.learning_rate, | |
lr_scheduler_type=args.lr_scheduler_type, | |
warmup_steps=args.num_warmup_steps, | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
gradient_checkpointing=args.gradient_checkpointing, | |
fp16=args.fp16, | |
bf16=args.bf16, | |
weight_decay=args.weight_decay, | |
run_name="llama-7b-finetuned", | |
report_to="wandb", | |
ddp_find_unused_parameters=False, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_path, load_in_8bit=True, device_map={"": Accelerator().process_index} | |
) | |
trainer = SFTTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_data, | |
eval_dataset=val_data, | |
peft_config=lora_config, | |
packing=True, | |
) | |
print_trainable_parameters(trainer.model) | |
print("Training...") | |
trainer.train() | |
print("Saving last checkpoint of the model") | |
trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/")) | |
def main(args): | |
tokenizer = AutoTokenizer.from_pretrained(args.model_path) | |
train_dataset, eval_dataset = create_datasets(tokenizer, args) | |
run_training(args, train_dataset, eval_dataset) | |
if __name__ == "__main__": | |
args = get_args() | |
assert args.model_path != "", "Please provide the llama model path" | |
set_seed(args.seed) | |
os.makedirs(args.output_dir, exist_ok=True) | |
logging.set_verbosity_error() | |
main(args) | |