File size: 4,527 Bytes
6a0da58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging, TextStreamer
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os, torch, wandb, platform, warnings
from datasets import load_dataset
from trl import SFTTrainer

hf_token = ''
wnb_token = ''
wnb_name = 'vistral-chatml'
MODEL = 'Viet-Mistral/Vistral-7B-Chat'
resume_from_checkpoint = False
output_dir = 'vistral-chatml'
tokenizer_path = '.'

#######################################################
## DATASET


from datasets import load_dataset


def generate_system_prompt(i):
    system_prompt = "Bạn là một trợ lí Tiếng Việt nhiệt tình và trung thực. Hãy luôn trả lời một cách hữu ích nhất có thể, đồng thời giữ an toàn."
    if i % 2 == 0:
        system_prompt += "\nCâu trả lời của bạn không nên chứa bất kỳ nội dung gây hại, phân biệt chủng tộc, phân biệt giới tính, độc hại, nguy hiểm hoặc bất hợp pháp nào. Hãy đảm bảo rằng các câu trả lời của bạn không có thiên kiến xã hội và mang tính tích cực."
    if i % 5 == 0:
        system_prompt += "\nNếu một câu hỏi không có ý nghĩa hoặc không hợp lý về mặt thông tin, hãy giải thích tại sao thay vì trả lời một điều gì đó không chính xác. Nếu bạn không biết câu trả lời cho một câu hỏi, hãy trẳ lời là bạn không biết và vui lòng không chia sẻ thông tin sai lệch."
    return system_prompt

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

def tokenize_chat(input, i):
    print(generate_system_prompt(i))
    conversation = [{'role': 'system', 'content': generate_system_prompt(i)}]
    for msg in input['conversations']:
        output = {'role': 'user', 'content': msg['value']}
        if msg['from'] == 'gpt':
            output['role'] = 'assistant'
        conversation.append(output)
    formatted = tokenizer.apply_chat_template(conversation, tokenize=False)
    return tokenizer(formatted)

sharegpt_dataset = load_dataset('bkai-foundation-models/vi-self-chat-sharegpt-format')
train_data = sharegpt_dataset['train'].shuffle(seed=42)\
    .select(range(800))\
    .map(lambda x, i: tokenize_chat(x, i), remove_columns=["conversations"], with_indices=True)


#######################################################
## SETUP

wandb.login(key=wnb_token)
wandb.init(name=wnb_name)
# use custom tokenizer instead of one comes from the model
#tokenizer = AutoTokenizer.from_pretrained(
#  MODEL,
#  add_eos_token=False,
#  add_bos_token=False,
#  token=hf_token,
#)
bnb_config = BitsAndBytesConfig(
  load_in_4bit=True,
  bnb_4bit_quant_type="nf4",
  bnb_4bit_compute_dtype=torch.bfloat16,
  bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
  MODEL,
  device_map="auto",
  token=hf_token,
  quantization_config=bnb_config,
  trust_remote_code=True,
)


#######################################################
## LORA CONFIG

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

from accelerate import Accelerator
accelerator = Accelerator()
model = accelerator.prepare_model(model)


#######################################################
## TRAIN

from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
trainer = Trainer(
    model=model,
    train_dataset=train_data,
    args=TrainingArguments(
        report_to='wandb',
        warmup_steps=1,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        num_train_epochs=4,
        learning_rate=2.5e-5,
        logging_steps=1,
        optim="paged_adamw_8bit",
        save_strategy="steps",
        save_steps=10,
        save_total_limit=4,
        output_dir=output_dir
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False

trainer.train(resume_from_checkpoint=resume_from_checkpoint)