File size: 1,944 Bytes
e112bff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig, EarlyStoppingCallback, PreTrainedTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import torch

def initialize_deepseek_model(model, device, tokenizer, train_dataset, val_dataset, MODEL_DIR):
    lora_config = LoraConfig(
        r=16,  # Rank of LoRA matrices (adjust for memory vs. accuracy)
        lora_alpha=32,  # Scaling factor
        lora_dropout=0.0,  # Dropout for regularization
        bias="none",
        task_type=TaskType.CAUSAL_LM,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj"
        ]
    )

    model = get_peft_model(model, lora_config)
    model = model.to(device)

    training_args = TrainingArguments(
        output_dir=MODEL_DIR,
        eval_strategy="epoch",  # Evaluate at the end of each epoch
        save_strategy="epoch",  # Save model every epoch
        per_device_train_batch_size=1,  # LoRA allows higher batch size
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=16,
        num_train_epochs=10,  # Increase if needed
        learning_rate=5e-5,  # Higher LR since we're only training LoRA layers
        weight_decay=0.001,
        logging_steps=50,  # Print loss every 50 steps
        save_total_limit=2,  # Keep last 4 checkpoints
        bf16=True if torch.cuda.is_available() else False,
        push_to_hub=False,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
    )

    return model, trainer