CodeInsight / codeInsight /models /peft_trainer.py
GitHub Actions
Sync from GitHub Actions
c2af030
import sys
from peft import get_peft_model, LoraConfig
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
from transformers import EarlyStoppingCallback
from codeInsight.logger import logging
from codeInsight.exception import ExceptionHandle
class ModelTrainer:
def __init__(self, model, tokenizer, datasets: dict, config: dict):
self.model = model
self.tokenizer = tokenizer
self.datasets = datasets
self.lora_config = config['lora']
self.training_config = config['training']
self.paths_config = config['paths']
self.trainer = self._setup_trainer()
logging.info("ModelTrainer initialized.")
def _get_target_module(self, model) -> list:
try:
logging.info('Start Finding LoRA target module')
candidates = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
present = set()
for name, module in model.named_modules():
for cand in candidates:
if name.endswith(cand):
present.add(cand)
return list(present) if present else ["q_proj", "v_proj"]
except Exception as e:
logging.error(f"Something is wrong here")
raise ExceptionHandle(e, sys)
def _peft_model_setup(self):
try:
logging.info('Setting up PEFT LoRA model')
lora_config = LoraConfig(
r=self.lora_config['r'],
lora_alpha=self.lora_config['lora_alpha'],
target_modules=self._get_target_module(self.model),
lora_dropout=self.lora_config['lora_dropout'],
bias=self.lora_config['bias'],
task_type=self.lora_config['task_type'],
use_rslora=self.lora_config['use_rslora']
)
peft_model = get_peft_model(self.model, lora_config)
logging.info("PEFT model created successfully.")
peft_model.print_trainable_parameters()
return peft_model
except Exception as e:
logging.error("Failed to setup PEFT model")
raise ExceptionHandle(e, sys)
def _get_training_args(self) -> SFTConfig:
try:
return SFTConfig(
output_dir=self.paths_config['output_dir'],
per_device_train_batch_size=self.training_config['per_device_train_batch_size'],
per_device_eval_batch_siz=self.training_config['per_device_eval_batch_size'],
gradient_accumulation_steps=self.training_config['gradient_accumulation_steps'],
num_train_epochs=self.training_config['num_train_epochs'],
learning_rate=self.training_config['learning_rate'],
warmup_ratio=self.training_config['warmup_ratio'],
warmup_steps=self.training_config['warmup_steps'],
bf16=self.training_config['bf16'],
tf32=self.training_config['tf32'],
fp16=self.training_config['fp16'],
lr_scheduler_type=self.training_config['lr_scheduler_type'],
optim=self.training_config['optim'],
gradient_checkpointing=self.training_config['gradient_checkpointing'],
gradient_checkpointing_kwargs=self.training_config['gradient_checkpointing_kwargs'],
max_grad_norm=self.training_config['max_grad_norm'],
weight_decay=self.training_config['weight_decay'],
logging_steps=self.training_config['logging_steps'],
eval_steps=self.training_config['eval_steps'],
save_steps=self.training_config['save_steps'],
evaluation_strategy=self.training_config['eval_strategy'],
save_strategy=self.training_config['save_strategy'],
save_total_limit=self.training_config['save_total_limit'],
load_best_model_at_end=self.training_config['load_best_model_at_end'],
metric_for_best_model=self.training_config['metric_for_best_model'],
greater_is_better=self.training_config['greater_is_better'],
prediction_loss_only=self.training_config['prediction_loss_only'],
report_to=self.training_config['report_to'],
dataloader_num_workers=self.training_config['dataloader_num_workers'],
max_seq_length=self.training_config['max_seq_length'],
dataset_text_field=self.training_config['dataset_text_field'],
label_names=self.training_config['label_names'],
neftune_noise_alpha=self.training_config['neftune_noise_alpha']
)
except Exception as e:
logging.error("Failed to create TrainingArguments")
raise ExceptionHandle(e, sys)
def _data_collator(self):
try:
return DataCollatorForCompletionOnlyLM(
response_template="<|assistant|>",
tokenizer=self.tokenizer
)
except Exception as e:
logging.error("Failed to create Data Collator")
raise ExceptionHandle(e, sys)
def _setup_trainer(self) -> SFTTrainer:
logging.info("Initializing SFTTrainer")
peft_model = self._peft_model_setup()
training_args = self._get_training_args()
trainer = SFTTrainer(
model=peft_model,
train_dataset=self.datasets['train'],
eval_dataset=self.datasets['val'],
args=training_args,
data_collator=self._data_collator(),
callbacks=[EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.001)],
)
logging.info("SFTTrainer initialized successfully.")
return trainer
def save_apater(self):
try:
adapter_path = self.paths_config['adapter_save_dir']
self.trainer.model.save_pretrained(adapter_path)
logging.info(f"LoRA adapter saved successfully to {adapter_path}")
except Exception as e:
logging.error("Failed to save LoRA adapter")
raise ExceptionHandle(e, sys)