File size: 1,683 Bytes
a36cb22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments

import config

def load_model(quantization_config):
    model = AutoModelForCausalLM.from_pretrained(
        config.MODEL_NAME,
        quantization_config = quantization_config,
        trust_remote_code = config.TRUST_REMOTE_CODE
    )
    model.config.use_cache = config.ENABLE_MODEL_CONFIG_CACHE
    return model

def load_tokenizers():
    tokenizer = AutoTokenizer.from_pretrained(
        config.MODEL_NAME, 
        trust_remote_code=config.TRUST_REMOTE_CODE)
    return tokenizer

def load_training_arguments():
    training_arguments = TrainingArguments(
        output_dir=config.MODEL_OUTPUT_DIR,
        per_device_train_batch_size=config.PER_DEVICE_TRAIN_BATCH_SIZE,
        gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
        optim=config.OPTIM,
        save_steps=config.SAVE_STEPS,
        logging_steps=config.LOGGING_STEPS,
        learning_rate=config.LEARNING_RATE,
        fp16=config.ENABLE_FP_16,
        max_grad_norm=config.MAX_GRAD_NORM,
        max_steps=config.MAX_STEPS,
        warmup_ratio=config.WARMUP_RATIO,
        gradient_checkpointing=config.ENABLE_GRADIENT_CHECKPOINTING
    )
    return training_arguments

def load_trainer(model, training_dataset, peft_config, tokenizer, training_arguments):
    trainer = SFTTrainer(
        model = model,
        train_dataset = training_dataset,
        peft_config = peft_config,
        dataset_text_field = config.DATASET_TEXT_FIELD,
        max_seq_length = config.MAX_SEQ_LENGTH,
        tokenizer = tokenizer,
        args = training_arguments
    )
    return trainer