# start with torchrun --nproc-per-node fine-tuning.py import os import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer, BitsAndBytesConfig, TrainerCallback, ) from datasets import load_from_disk from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft.tuners.lora import LoraLayer from accelerate import Accelerator batch_size = 2 checkpoint = "google/gemma-2b" data_dir = "dataset_ro_small_v1/" save_dir = "gemma-2b-romanian-1.6gb-finetuned-qlora" log_dir = "training_logs/" # load dataset tokenized_datasets = load_from_disk(f'tokenized_{data_dir}') tokenized_datasets = tokenized_datasets.shuffle(seed=42) print(tokenized_datasets) # load quantized model bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_dtype=torch.float16, bnb_4bit_compute_dtype=torch.float16, ) model = AutoModelForCausalLM.from_pretrained( checkpoint, load_in_8bit=False, quantization_config=bnb_config, device_map={ "": Accelerator().process_index }, # see https://github.com/huggingface/trl/issues/1348 torch_dtype=torch.float16, trust_remote_code=True, attn_implementation='sdpa',#'flash_attention_2', use_cache=False, ) model = prepare_model_for_kbit_training(model) # load qlora config lora_config = LoraConfig( lora_alpha=32, lora_dropout=0.1, r=8, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # load tokenizer from checkpoint tokenizer = AutoTokenizer.from_pretrained(checkpoint) tokenizer.pad_token = tokenizer.eos_token data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # training args args = TrainingArguments( output_dir='training_checkpoints/', logging_dir=log_dir, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, evaluation_strategy='no', logging_steps=100, save_strategy='steps', save_steps=100, save_total_limit=10, gradient_accumulation_steps=4, gradient_checkpointing=True, gradient_checkpointing_kwargs={ "use_reentrant": False }, num_train_epochs=1, warmup_steps=1_000, weight_decay=0.001, lr_scheduler_type='cosine', learning_rate=1e-4, max_grad_norm=0.3, fp16=True, ddp_find_unused_parameters=False, ) # stop the training loop after 1000 updates class StopCallback(TrainerCallback): def on_step_end(self, args, state, control, **kwargs): if state.global_step != 0 and state.global_step % 1000 == 0: # stop training control.should_training_stop = True # train as usual trainer = Trainer( model=model, args=args, data_collator=data_collator, train_dataset=tokenized_datasets['train'], eval_dataset=tokenized_datasets['test'], tokenizer=tokenizer, ) trainer.add_callback(StopCallback) print("Starting training...") train_checkpoint = os.getenv("TRAIN_CHECKPOINT") if train_checkpoint is not None: trainer.train(train_checkpoint) # resume training from checkpoint dir else: trainer.train() # save trainer state at end torch.save(trainer.state.log_history, "trainer_log_history.pth") model.save_pretrained(save_dir) tokenizer.save_pretrained(save_dir)