File size: 3,495 Bytes
d184967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from unsloth import FastLlamaModel
import torch
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainingArguments
from datasets import load_from_disk
import math
import wandb
import os


max_seq_length = 2048 # Can change to whatever number <= 4096
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.


revisions = [("250k", "8ee454fe392a0267c3dee21323b5cac233d67441"), 
             ("500k", "12d3eec2d02533226c9cff719d4278967574ffcd"), ("750k", "845b8c6d8499c0e8fea0b8e5480d72e700385820"), ("1000k", "53669200ad7a6a6f1ac6a73e54c9e54c1d834a17")]



#for revision in revisions:
model, tokenizer = FastLlamaModel.from_pretrained(
    model_name = "Finnish-NLP/llama-7b-finnish",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    revision='53669200ad7a6a6f1ac6a73e54c9e54c1d834a17'    
)

tokenizer.clean_up_tokenization_spaces=True
tokenizer.add_tokens(["<|alku|>", "<PAD>", "<|ihminen|>", "<|avustaja|>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens({'eos_token': '<|loppu|>'})
tokenizer.add_tokens('\n', special_tokens=True)
tokenizer.add_eos_token=True
model.resize_token_embeddings(new_num_tokens=len(tokenizer))
model.config.eos_token_id = tokenizer.eos_token_id
print(model.config.eos_token_id)
assert tokenizer.pad_token_id != tokenizer.eos_token_id
print(tokenizer.padding_side)
print(tokenizer.add_bos_token)
print(model)



model = FastLlamaModel.get_peft_model(
    model,
    r = 32,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 32,
    lora_dropout = 0
    bias = "none"
    use_gradient_checkpointing = True,
    modules_to_save = ["lm_head", "embed_tokens"],
    random_state = 3407,
    max_seq_length = max_seq_length,
    use_rslora=True
)


dataset = load_from_disk("deepl_kaannetyt_combined")
dataset = dataset.train_test_split(test_size=0.02)


bs = 2
ga = 4
epochs = 3
train_steps =   math.ceil(len(dataset["train"]) / bs / ga * epochs)
print(train_steps)
eval_steps = math.ceil(train_steps/10)
print(eval_steps)



try:
    wandb.finish()
except Exception as e:
    wandb.init()

response_template = "\n<|avustaja|> Vastauksesi:"
response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)


collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer, mlm=False)

trainer = SFTTrainer(
    model = model,
    train_dataset = dataset["train"],
    eval_dataset = dataset["test"],
    dataset_text_field = "text",
    data_collator=collator,
    max_seq_length = max_seq_length,
    tokenizer=tokenizer,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        per_device_eval_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 50,
        max_steps = train_steps,
        report_to="wandb",
        eval_steps=eval_steps,
        evaluation_strategy="steps",
        save_strategy='steps',
        learning_rate = 2e-5,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 5,
        optim = "adamw_8bit",
        weight_decay = 0.001,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = f"llama7b-finniish-instruct-v0.1",
    ),
)

wandb.login()

trainer.train()