Update README.md
Browse files
@@ -7,4 +7,149 @@ datasets:
7 |
8 |
## Model description
9 |
10 |
This is a base Yi-34B-200K XLCTX model treated with DPO with adamo1139/rawrr_v2-2_stage1 dataset to make outputs be completions instead of answers for a question. DPO was done using chatml format but no previous SFT step was done. If it would do it now, I would have used ORPO instead of DPO for this step to make it stronger, but too late for that. It can be used to maybe slightly decensor a model, but I don't think this idea works too well with DPO before SFT step, as was widely known but I did it anyway.
7 |
8 |
## Model description
9 |
10 |
This is a base Yi-34B-200K XLCTX model treated with DPO with adamo1139/rawrr_v2-2_stage1 dataset to make outputs be completions instead of answers for a question. DPO was done using chatml format but no previous SFT step was done. If it would do it now, I would have used ORPO instead of DPO for this step to make it stronger, but too late for that. It can be used to maybe slightly decensor a model, but I don't think this idea works too well with DPO before SFT step, as was widely known but I did it anyway.
11 |
12 |
13 |
14 |
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" alt="made with Unsloth" width="400" height="64"/>](https://github.com/unslothai/unsloth)
15 |
16 |
## Training script for Unsloth
17 |
18 |
19 |
from unsloth import FastLanguageModel
20 |
from datasets import Dataset, load_dataset
21 |
from dataclasses import dataclass, field
22 |
from typing import Dict, Optional
23 |
import torch
24 |
max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
25 |
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
26 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
27 |
28 |
model, tokenizer = FastLanguageModel.from_pretrained(
29 |
model_name = "adamo1139/Yi-34B-200K-XLCTX", # Choose ANY! eg mistralai/Mistral-7B-Instruct-v0.2
30 |
max_seq_length = max_seq_length,
31 |
32 |
dtype = dtype,
33 |
load_in_4bit = load_in_4bit,
34 |
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
35 |
36 |
37 |
38 |
39 |
#@title Alignment Handbook utils
40 |
import os
41 |
import re
42 |
from typing import List, Literal, Optional
43 |
44 |
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
45 |
from datasets.builder import DatasetGenerationError
46 |
47 |
48 |
#DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
49 |
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
50 |
51 |
52 |
EOS_TOKEN = tokenizer.eos_token
53 |
54 |
def chatml_format(example):
55 |
# Format system
56 |
if len(example['system']) > 0:
57 |
message = {"role": "system", "content": example['system']}
58 |
system = tokenizer.apply_chat_template([message], tokenize=False)
59 |
60 |
system = ""
61 |
62 |
# Format instruction
63 |
message = {"role": "user", "content": example['prompt']}
64 |
prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
65 |
66 |
# Format chosen answer
67 |
chosen = example['chosen'] + "<|im_end|>\n" + EOS_TOKEN
68 |
69 |
# Format rejected answer
70 |
rejected = example['rejected'] + "<|im_end|>\n" + EOS_TOKEN
71 |
72 |
return {
73 |
"prompt": system + prompt,
74 |
"chosen": chosen,
75 |
"rejected": rejected,
76 |
77 |
78 |
# Load dataset
79 |
dataset = load_dataset("adamo1139/rawrr_v2-2_stage1", split="train")
80 |
81 |
import pprint
82 |
pprint.pprint("""NOT a formatted dataset
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
# Save columns
91 |
original_columns = dataset.column_names
92 |
93 |
# Format dataset
94 |
dataset = dataset.map(
95 |
96 |
97 |
98 |
99 |
# Print sample
100 |
pprint.pprint("""formatted dataset""")
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
model = FastLanguageModel.get_peft_model(
109 |
110 |
r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
111 |
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
112 |
"gate_proj", "up_proj", "down_proj",],
113 |
lora_alpha = 32,
114 |
lora_dropout = 0, # Currently only supports dropout = 0
115 |
bias = "none", # Currently only supports bias = "none"
116 |
use_gradient_checkpointing = "unsloth",
117 |
random_state = 3407,
118 |
use_rslora = False, # We support rank stabilized LoRA
119 |
loftq_config = None, # And LoftQ
120 |
121 |
122 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
123 |
from trl import DPOTrainer
124 |
125 |
dpo_trainer = DPOTrainer(
126 |
model = model,
127 |
ref_model = None,
128 |
args = TrainingArguments(
129 |
per_device_train_batch_size = 1,
130 |
gradient_accumulation_steps = 16,
131 |
warmup_ratio = 0.03,
132 |
num_train_epochs = 1,
133 |
learning_rate = 0.0001,
134 |
fp16 = not torch.cuda.is_bf16_supported(),
135 |
bf16 = torch.cuda.is_bf16_supported(),
136 |
logging_steps = 1,
137 |
optim = "adamw_8bit",
138 |
weight_decay = 0.0,
139 |
lr_scheduler_type = "cosine",
140 |
seed = 42,
141 |
save_strategy = "steps",
142 |
save_steps = 100,
143 |
save_total_limit = 20,
144 |
output_dir = "1904-yi-200k-xlctx-raw-intermediate",
145 |
146 |
beta = 0.1,
147 |
train_dataset = dataset,
148 |
# eval_dataset = raw_datasets["test"],
149 |
tokenizer = tokenizer,
150 |
max_length = 650,
151 |
max_prompt_length = 650,
152 |
153 |
154 |
model.save_pretrained("1904-yi-200k-xlctx-raw-final") # Local saving
155 |