AIR-hl commited on
Commit
acb4410
1 Parent(s): 36d364f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +131 -3
README.md CHANGED
@@ -1,3 +1,131 @@
1
- ---
2
- license: llama3.2
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: llama3.2
3
+ datasets:
4
+ - HuggingFaceH4/ultrafeedback_binarized
5
+ base_model:
6
+ - AIR-hl/Llama-3.2-1B-ultrachat200k
7
+ pipeline_tag: text-generation
8
+ tags:
9
+ - trl
10
+ - llama
11
+ - dpo
12
+ - alignment
13
+ - transformers
14
+ - custome
15
+ - chat
16
+ ---
17
+ # Llama-3.2-1B-DPO
18
+
19
+
20
+ ## Model Details
21
+
22
+ - **Model type:** aligned model
23
+ - **License:** llama3.2
24
+ - **Finetuned from model:** [AIR-hl/Llama-3.2-1B-ultrachat200k](https://huggingface.co/AIR-hl/Llama-3.2-1B-ultrachat200k)
25
+ - **Training data:** [HuggingFaceH4/ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
26
+ - **Training framework:** [trl](https://github.com/huggingface/trl)
27
+
28
+ ## Training Details
29
+
30
+ ### Training Hyperparameters
31
+ `attn_implementation`: flash_attention_2 \
32
+ `beta`: 0.05 \
33
+ `bf16`: True \
34
+ `learning_rate`: 1e-5 \
35
+ `lr_scheduler_type`: cosine \
36
+ `per_device_train_batch_size`: 4 \
37
+ `gradient_accumulation_steps`: 8 \
38
+ `torch_dtype`: bfloat16 \
39
+ `num_train_epochs`: 1 \
40
+ `max_prompt_length`: 512 \
41
+ `max_length`: 1024 \
42
+ `warmup_ratio`: 0.05
43
+
44
+ ### Results
45
+
46
+ `init_train_loss`: 0.6929 \
47
+ `final_train_loss`: 0.5713 \
48
+ `accuracy`: 0.7188 \
49
+ `reward_margin`: 0.5971
50
+
51
+ ### Training script
52
+
53
+ ```python
54
+ import torch
55
+ from datasets import load_dataset
56
+ from transformers import AutoModelForCausalLM, AutoTokenizer
57
+ import multiprocessing
58
+ from trl import (
59
+ DPOConfig,
60
+ DPOTrainer,
61
+ ModelConfig,
62
+ ScriptArguments,
63
+ TrlParser,
64
+ get_kbit_device_map,
65
+ get_peft_config,
66
+ get_quantization_config,
67
+ )
68
+ from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
69
+
70
+ if __name__ == "__main__":
71
+ parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
72
+ script_args, training_args, model_config = parser.parse_args_and_config()
73
+
74
+ torch_dtype = (
75
+ model_config.torch_dtype
76
+ if model_config.torch_dtype in ["auto", None]
77
+ else getattr(torch, model_config.torch_dtype)
78
+ )
79
+
80
+ quantization_config = get_quantization_config(model_config)
81
+
82
+ model_kwargs = dict(
83
+ revision=model_config.model_revision,
84
+ attn_implementation=model_config.attn_implementation,
85
+ torch_dtype=torch_dtype,
86
+ use_cache=False if training_args.gradient_checkpointing else True,
87
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
88
+ quantization_config=quantization_config,
89
+ )
90
+
91
+ model = AutoModelForCausalLM.from_pretrained(
92
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
93
+ )
94
+
95
+ peft_config = get_peft_config(model_config)
96
+ if peft_config is None:
97
+ ref_model = AutoModelForCausalLM.from_pretrained(
98
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
99
+ )
100
+ else:
101
+ ref_model = None
102
+
103
+ tokenizer = AutoTokenizer.from_pretrained(
104
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
105
+ )
106
+ if tokenizer.pad_token is None:
107
+ tokenizer.pad_token = tokenizer.eos_token
108
+ if tokenizer.chat_template is None:
109
+ tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
110
+ if script_args.ignore_bias_buffers:
111
+ model._ddp_params_and_buffers_to_ignore = [
112
+ name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
113
+ ]
114
+
115
+ dataset = load_dataset(script_args.dataset_name,
116
+ split=script_args.dataset_train_split)
117
+ dataset=dataset.select_columns(['chosen', 'prompt', 'rejected'])
118
+
119
+ trainer = DPOTrainer(
120
+ model,
121
+ ref_model,
122
+ args=training_args,
123
+ train_dataset=dataset,
124
+ processing_class=tokenizer,
125
+ peft_config=peft_config,
126
+ )
127
+
128
+ trainer.train()
129
+
130
+ trainer.save_model(training_args.output_dir)
131
+ ```