jic062 commited on
Commit
c9c6765
·
verified ·
1 Parent(s): ec1110c

Upload dpo-train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dpo-train.py +122 -0
dpo-train.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import PatchDPOTrainer # This line is from the DPO Zephyr example ******
2
+ PatchDPOTrainer()
3
+ from huggingface_hub import HfApi
4
+ from huggingface_hub import create_repo
5
+ from unsloth import FastLanguageModel
6
+ import torch
7
+ from datasets import load_dataset
8
+ import random
9
+
10
+ max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
11
+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
12
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
13
+ repo_name = "dpo-v1-Nemo"
14
+ # do wandb stuff
15
+ import wandb
16
+ import random
17
+ wandb.init(
18
+ project="huggingface",
19
+ name= repo_name,)
20
+
21
+
22
+ model, tokenizer = FastLanguageModel.from_pretrained(
23
+ model_name = "ijic062/Nemo-v1.1",
24
+ max_seq_length = max_seq_length,
25
+ dtype = dtype,
26
+ load_in_4bit = load_in_4bit,
27
+ token = "", # use one if using gated models like meta-llama/Llama-2-7b-hf
28
+ )
29
+
30
+ ########################################################################################################
31
+
32
+ model = FastLanguageModel.get_peft_model(
33
+
34
+ model,
35
+ r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
36
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
37
+ "gate_proj", "up_proj", "down_proj",],
38
+ lora_alpha = 16,
39
+ lora_dropout = 0, # Supports any, but = 0 is optimized
40
+ bias = "none", # Supports any, but = "none" is optimized
41
+ # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
42
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
43
+ random_state = 3407,
44
+ use_rslora = False, # We support rank stabilized LoRA
45
+ loftq_config = None, # And LoftQ
46
+
47
+ )
48
+
49
+ ######################################################################################################### ***
50
+
51
+ dataset = load_dataset(
52
+ "Chaser-cz/dpo-nice-prompt"
53
+ )
54
+
55
+ train_dataset = dataset['train'].shuffle(seed=random.randint(1, 9999))
56
+
57
+ # Shuffles data and take a small portion
58
+ # test_dataset = dataset['test_prefs']
59
+
60
+ column_names = list(dataset["train"].features)
61
+ print(f"This is column names: {column_names}")
62
+
63
+ import pprint
64
+ row = train_dataset[9]
65
+ pprint.pprint(row["prompt"])
66
+ pprint.pprint(row["chosen"])
67
+ pprint.pprint(row["rejected"])
68
+ ##########################################################################################################
69
+
70
+ from unsloth import PatchDPOTrainer
71
+ PatchDPOTrainer()
72
+ from trl import DPOTrainer
73
+ from transformers import TrainingArguments
74
+ from unsloth import is_bfloat16_supported
75
+
76
+ dpo_trainer = DPOTrainer(
77
+ model = model,
78
+ beta = 0.5,
79
+ tokenizer = tokenizer,
80
+ max_length = 1024,
81
+ max_prompt_length = 512,
82
+ train_dataset = train_dataset,
83
+ ref_model = None,
84
+ # dataset_text_field = "text",
85
+ # max_seq_length = max_seq_length,
86
+ # dataset_num_proc = 2,
87
+ # packing = False, # Can make training 5x faster for short sequences.
88
+ args = TrainingArguments(
89
+ # loss_type = "sigmoid",
90
+ per_device_train_batch_size = 2,
91
+ gradient_accumulation_steps = 32,
92
+ gradient_checkpointing= True,
93
+ warmup_steps = 5,
94
+ #num_train_epochs = 3,
95
+ max_steps = 1000,
96
+ learning_rate = 2.5e-4,
97
+ fp16 = not is_bfloat16_supported(),
98
+ bf16 = is_bfloat16_supported(),
99
+ logging_steps = 1,
100
+ optim = "adamw_8bit",
101
+ weight_decay = 0.07,
102
+ lr_scheduler_type = "cosine",
103
+ seed = 3407,
104
+ output_dir = "outputs/dpo-out-13b",
105
+ save_strategy = "steps",
106
+ save_steps = 500,
107
+ ),
108
+ )
109
+
110
+ dpo_trainer.train()
111
+
112
+ ########################################################################################################### ***
113
+ model.save_pretrained_merged("outputs/dpo-out-13b/merged", tokenizer, save_method = "merged_16bit")
114
+ api = HfApi()
115
+ create_repo(f"jic062/{repo_name}", repo_type="model",private=True,token="")
116
+ api.upload_folder(
117
+ folder_path="outputs/dpo-out-13b/merged",
118
+ repo_id=f"jic062/{repo_name}",
119
+ repo_type="model",
120
+ )
121
+ wandb.finish()
122
+